import csv import io import datetime import secrets import logging import os from typing import Optional # Load .env before any module-level os.getenv() calls in auth/omada/database from dotenv import load_dotenv load_dotenv() from fastapi import FastAPI, Request, Depends, HTTPException from fastapi.responses import HTMLResponse, RedirectResponse, StreamingResponse from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from starlette.middleware.sessions import SessionMiddleware from sqlalchemy.orm import Session from app.database import init_db, get_db, RebootLog from app.auth import ( SESSION_SECRET_KEY, build_login_url, exchange_code, get_current_user, ) from app.omada import omada_client logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logger = logging.getLogger(__name__) app = FastAPI(title="Salus by Stranto", docs_url=None, redoc_url=None) app.add_middleware( SessionMiddleware, secret_key=SESSION_SECRET_KEY, max_age=86400, same_site="lax", https_only=False, ) BASE_DIR = os.path.dirname(__file__) templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates")) app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static") @app.on_event("startup") async def startup(): init_db() logger.info("Database initialized") @app.get("/health") async def health(): return {"status": "ok"} @app.get("/debug/sites") async def debug_sites(): """Temporary: lists all accessible Omada sites.""" try: sites = await omada_client.get_all_sites() return {"sites": [{"name": k, "key": v} for k, v in sites.items()]} except Exception as exc: return {"error": str(exc)} # --------------------------------------------------------------------------- # Auth helpers # --------------------------------------------------------------------------- def _redirect_login(): return RedirectResponse("/auth/login", status_code=302) def _get_user_or_redirect(request: Request): """Return user dict or a RedirectResponse. Callers must check the type.""" return get_current_user(request) def _csrf_token(request: Request) -> str: if "csrf_token" not in request.session: request.session["csrf_token"] = secrets.token_hex(32) return request.session["csrf_token"] def _verify_csrf(request: Request, token: str): expected = request.session.get("csrf_token") if not expected or not secrets.compare_digest(expected, token): raise HTTPException(status_code=403, detail="CSRF token invalid") AUTH_DISABLED = os.getenv("AUTH_DISABLED", "false").strip().lower() == "true" DEV_USER = {"username": "dev-user", "email": "dev@localhost", "name": "Dev User", "sub": "dev"} # --------------------------------------------------------------------------- # Auth routes # --------------------------------------------------------------------------- @app.get("/auth/login", response_class=HTMLResponse) async def login(request: Request): if AUTH_DISABLED: request.session["user"] = DEV_USER return RedirectResponse("/", status_code=302) if get_current_user(request): return RedirectResponse("/", status_code=302) return templates.TemplateResponse("login.html", {"request": request}) @app.get("/auth/login/start") async def login_start(request: Request): if AUTH_DISABLED: request.session["user"] = DEV_USER return RedirectResponse("/", status_code=302) url = await build_login_url(request) return RedirectResponse(url, status_code=302) @app.get("/auth/callback") async def callback(request: Request): await exchange_code(request) return RedirectResponse("/", status_code=302) @app.get("/auth/logout") async def logout(request: Request): request.session.clear() return RedirectResponse("/auth/login", status_code=302) # --------------------------------------------------------------------------- # Main pages # --------------------------------------------------------------------------- @app.get("/", response_class=HTMLResponse) async def index(request: Request): user = get_current_user(request) if not user: return _redirect_login() error: Optional[str] = None aps: list = [] try: aps = await omada_client.get_aps() except Exception as exc: logger.error("Failed to fetch APs: %s", exc) error = str(exc) csrf = _csrf_token(request) return templates.TemplateResponse("index.html", { "request": request, "user": user, "aps": aps, "error": error, "csrf_token": csrf, }) @app.get("/audit", response_class=HTMLResponse) async def audit_page( request: Request, db: Session = Depends(get_db), username: str = "", ap_name: str = "", ): user = get_current_user(request) if not user: return _redirect_login() query = db.query(RebootLog).order_by(RebootLog.timestamp.desc()) if username.strip(): query = query.filter(RebootLog.username.icontains(username.strip())) if ap_name.strip(): query = query.filter(RebootLog.ap_name.icontains(ap_name.strip())) logs = query.all() return templates.TemplateResponse("audit.html", { "request": request, "user": user, "logs": logs, "filter_username": username, "filter_ap": ap_name, }) @app.get("/audit/export") async def export_csv(request: Request, db: Session = Depends(get_db)): user = get_current_user(request) if not user: return _redirect_login() logs = db.query(RebootLog).order_by(RebootLog.timestamp.desc()).all() buf = io.StringIO() writer = csv.writer(buf) writer.writerow(["Timestamp (UTC)", "Username", "Email", "AP Name", "MAC", "IP", "Result", "Error"]) for log in logs: writer.writerow([ log.timestamp.isoformat(timespec="seconds"), log.username, log.user_email, log.ap_name, log.ap_mac, log.ap_ip or "", log.result, log.error_message or "", ]) filename = f"audit_log_{datetime.date.today().isoformat()}.csv" return StreamingResponse( io.BytesIO(buf.getvalue().encode("utf-8-sig")), media_type="text/csv", headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) # --------------------------------------------------------------------------- # API endpoints # --------------------------------------------------------------------------- @app.post("/api/reboot") async def api_reboot(request: Request, db: Session = Depends(get_db)): user = get_current_user(request) if not user: raise HTTPException(status_code=401, detail="Not authenticated") body = await request.json() _verify_csrf(request, body.get("csrf_token", "")) mac: str = body.get("mac", "").strip() ap_name: str = body.get("name", "") ap_ip: str = body.get("ip", "") site_key: str = body.get("site_key", "") if not mac: raise HTTPException(status_code=400, detail="MAC address required") result = "success" error_msg: Optional[str] = None try: await omada_client.reboot_ap(mac, site_key) logger.info("Reboot sent for AP %s (%s) by %s", ap_name, mac, user["username"]) except Exception as exc: result = "error" error_msg = str(exc) logger.error("Reboot failed for AP %s (%s): %s", ap_name, mac, exc) db.add(RebootLog( timestamp=datetime.datetime.utcnow(), username=user["username"], user_email=user["email"], ap_name=ap_name, ap_mac=mac, ap_ip=ap_ip, result=result, error_message=error_msg, )) db.commit() if result == "error": raise HTTPException(status_code=502, detail=error_msg) return {"status": "ok", "mac": mac} @app.post("/api/reboot-bulk") async def api_reboot_bulk(request: Request, db: Session = Depends(get_db)): user = get_current_user(request) if not user: raise HTTPException(status_code=401, detail="Not authenticated") body = await request.json() _verify_csrf(request, body.get("csrf_token", "")) aps: list = body.get("aps", []) if not aps: raise HTTPException(status_code=400, detail="No APs specified") results = [] for ap in aps: mac = ap.get("mac", "").strip() ap_name = ap.get("name", "") ap_ip = ap.get("ip", "") site_key = ap.get("site_key", "") result = "success" error_msg = None try: await omada_client.reboot_ap(mac, site_key) logger.info("Bulk reboot sent for AP %s (%s) by %s", ap_name, mac, user["username"]) except Exception as exc: result = "error" error_msg = str(exc) logger.error("Bulk reboot failed for AP %s (%s): %s", ap_name, mac, exc) db.add(RebootLog( timestamp=datetime.datetime.utcnow(), username=user["username"], user_email=user["email"], ap_name=ap_name, ap_mac=mac, ap_ip=ap_ip, result=result, error_message=error_msg, )) results.append({"mac": mac, "name": ap_name, "result": result, "error": error_msg}) db.commit() return {"results": results}