diff --git a/app/util/auth.py b/app/internal/auth/auth.py similarity index 59% rename from app/util/auth.py rename to app/internal/auth/auth.py index dbed534..6b74859 100644 --- a/app/util/auth.py +++ b/app/internal/auth/auth.py @@ -8,12 +8,14 @@ import jwt from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError from fastapi import Depends, HTTPException, Request, status -from fastapi.security import HTTPBasic, OAuth2PasswordBearer +from fastapi.security import HTTPBasic, OAuth2PasswordBearer, OpenIdConnect from sqlmodel import Session, select +from app.internal.auth.session_middleware import middleware_linker from app.internal.models import GroupEnum, User from app.util.cache import StringConfigCache from app.util.db import get_session +from app.util.time import Minute, Second JWT_ALGORITHM = "HS256" @@ -21,6 +23,7 @@ JWT_ALGORITHM = "HS256" class LoginTypeEnum(str, Enum): basic = "basic" forms = "forms" + oidc = "oidc" none = "none" def is_basic(self): @@ -32,6 +35,9 @@ class LoginTypeEnum(str, Enum): def is_none(self): return self == LoginTypeEnum.none + def is_oidc(self): + return self == LoginTypeEnum.oidc + AuthConfigKey = Literal[ "login_type", @@ -53,6 +59,7 @@ class AuthConfig(StringConfigCache[AuthConfigKey]): def reset_auth_secret(self, session: Session): auth_secret = base64.encodebytes(secrets.token_bytes(64)).decode("utf-8") + middleware_linker.update_secret(auth_secret) self.set(session, "auth_secret", auth_secret) def get_auth_secret(self, session: Session) -> str: @@ -63,10 +70,11 @@ class AuthConfig(StringConfigCache[AuthConfigKey]): self.set(session, "auth_secret", auth_secret) return auth_secret - def get_access_token_expiry_minutes(self, session: Session): - return self.get_int(session, "access_token_expiry_minutes", 60 * 24 * 7) + def get_access_token_expiry_minutes(self, session: Session) -> Minute: + return Minute(self.get_int(session, "access_token_expiry_minutes", 60 * 24 * 7)) - def set_access_token_expiry_minutes(self, session: Session, expiry: int): + def set_access_token_expiry_minutes(self, session: Session, expiry: Minute): + middleware_linker.update_max_age(Second(expiry * 60)) self.set_int(session, "access_token_expiry_minutes", expiry) def get_min_password_length(self, session: Session) -> int: @@ -83,12 +91,6 @@ class DetailedUser(User): return self.login_type == LoginTypeEnum.forms -security = HTTPBasic() -ph = PasswordHasher() -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False) -auth_config = AuthConfig() - - def raise_for_invalid_password( session: Session, password: str, @@ -154,91 +156,103 @@ def create_user( return User(username=username, password=password_hash, group=group, root=root) -def get_authenticated_user(lowest_allowed_group: GroupEnum = GroupEnum.untrusted): - async def get_user( - request: Request, - session: Annotated[Session, Depends(get_session)], - ) -> DetailedUser: - login_type = auth_config.get_login_type(session) - - if login_type == LoginTypeEnum.forms: - user = await _get_forms_auth(request, session) - elif login_type == LoginTypeEnum.none: - user = await _get_none_auth(session) - else: - user = await _get_basic_auth(request, session) - - if not user.is_above(lowest_allowed_group): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden" - ) - - user = DetailedUser.model_validate(user, update={"login_type": login_type}) - - return user - - return get_user - - -async def _get_basic_auth( - request: Request, - session: Session, -) -> User: - invalid_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid credentials", - headers={"WWW-Authenticate": "Basic"}, - ) - - credentials = await security(request) - - if not credentials: - raise invalid_exception - - user = authenticate_user(session, credentials.username, credentials.password) - if not user: - raise invalid_exception - - return user - - class RequiresLoginException(Exception): def __init__(self, detail: Optional[str] = None, **kwargs: object): super().__init__(**kwargs) self.detail = detail -async def _get_forms_auth( - request: Request, - session: Session, -) -> User: - # Authentication is either through Authorization header or cookie - token = await oauth2_scheme(request) - if not token: - token = request.cookies.get("audio_sess") - if not token: +class ABRAuth: + def __init__(self): + self.oidc_scheme: Optional[OpenIdConnect] = None + self.none_user: Optional[User] = None + + def __call__(self, lowest_allowed_group: GroupEnum): + return self._get_authenticated_user(lowest_allowed_group) + + def _get_authenticated_user(self, lowest_allowed_group: GroupEnum): + async def get_user( + request: Request, + session: Annotated[Session, Depends(get_session)], + ) -> DetailedUser: + login_type = auth_config.get_login_type(session) + + if login_type == LoginTypeEnum.forms: + user = await self._get_forms_auth(request, session) + elif login_type == LoginTypeEnum.none: + user = await self._get_none_auth(session) + elif login_type == LoginTypeEnum.oidc: + user = await self._get_oidc_auth(request, session) + else: + user = await self._get_basic_auth(request, session) + + if not user.is_above(lowest_allowed_group): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden" + ) + + user = DetailedUser.model_validate(user, update={"login_type": login_type}) + + return user + + return get_user + + async def _get_basic_auth( + self, + request: Request, + session: Session, + ) -> User: + invalid_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials", + headers={"WWW-Authenticate": "Basic"}, + ) + + credentials = await security(request) + + if not credentials: + raise invalid_exception + + user = authenticate_user(session, credentials.username, credentials.password) + if not user: + raise invalid_exception + + return user + + async def _get_forms_auth( + self, + request: Request, + session: Session, + ) -> User: + username = request.session.get("sub") + if not username: raise RequiresLoginException() - try: - payload = jwt.decode( # pyright: ignore[reportUnknownMemberType] - token, auth_config.get_auth_secret(session), algorithms=[JWT_ALGORITHM] - ) - except jwt.InvalidTokenError: - raise RequiresLoginException("Token is expired/invalid") + user = session.get(User, username) + if not user: + raise RequiresLoginException("User does not exist") - username = payload.get("sub") - if username is None: - raise RequiresLoginException("Token is invalid") + return user - user = session.get(User, username) - if not user: - raise RequiresLoginException("User does not exist") + # TODO + async def _get_oidc_auth(self, request: Request, session: Session) -> User: ... - return user + async def _get_none_auth(self, session: Session) -> User: + """Treats every request as being root by returning the first admin user""" + if self.none_user: + return self.none_user + self.none_user = session.exec( + select(User).where(User.group == GroupEnum.admin).limit(1) + ).one() + return self.none_user -async def _get_none_auth(session: Session) -> User: - """Treats every request as being root by returning the first admin user""" - return session.exec( - select(User).where(User.group == GroupEnum.admin).limit(1) - ).one() +security = HTTPBasic() +ph = PasswordHasher() +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False) +auth_config = AuthConfig() +abr_authentication = ABRAuth() + + +def get_authenticated_user(lowest_allowed_group: GroupEnum = GroupEnum.untrusted): + return abr_authentication(lowest_allowed_group) diff --git a/app/internal/auth/session_middleware.py b/app/internal/auth/session_middleware.py new file mode 100644 index 0000000..e709686 --- /dev/null +++ b/app/internal/auth/session_middleware.py @@ -0,0 +1,67 @@ +from starlette.middleware.sessions import SessionMiddleware +from starlette.types import ASGIApp, Receive, Scope, Send + +from app.util.time import Second + + +class DynamicSessionMiddleware: + def __init__( + self, + app: ASGIApp, + secret_key: str, + linker: "DynamicMiddlewareLinker", + max_age: Second = Second(60 * 60 * 24 * 14), + ): + self.app = app + self.secret_key = secret_key + self.expiry = max_age + self.session_middleware = SessionMiddleware( + app, + secret_key, + same_site="strict", + max_age=max_age, + ) + linker.add_middleware(self) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + return await self.session_middleware(scope, receive, send) + + def update_secret(self, secret_key: str): + self.session_middleware = SessionMiddleware( + self.app, + secret_key, + same_site="strict", + max_age=self.expiry, + ) + + def update_max_age(self, max_age: Second): + self.session_middleware = SessionMiddleware( + self.app, + self.secret_key, + same_site="strict", + max_age=max_age, + ) + + +class DynamicMiddlewareLinker: + """ + Linker is passed in as an argument to the DynamicSessionMiddleware so + wherever FastAPI initializes the middleware, we can update + the options to take effect immediately instead of having to restart the server + """ + + middlewares: list[DynamicSessionMiddleware] = [] + + def add_middleware(self, middleware: DynamicSessionMiddleware): + self.middlewares.append(middleware) + + def update_secret(self, secret_key: str): + for middleware in self.middlewares: + middleware.update_secret(secret_key) + + def update_max_age(self, expiry: Second): + for middleware in self.middlewares: + middleware.update_max_age(expiry) + + +middleware_linker = DynamicMiddlewareLinker() diff --git a/app/main.py b/app/main.py index 236bdc6..6b73e04 100644 --- a/app/main.py +++ b/app/main.py @@ -2,27 +2,41 @@ from typing import Any from urllib.parse import quote_plus from fastapi import FastAPI, HTTPException, Request, status +from fastapi.middleware import Middleware +from fastapi.middleware.gzip import GZipMiddleware from fastapi.responses import RedirectResponse from sqlalchemy import func from sqlmodel import select +from app.internal.auth.auth import RequiresLoginException, auth_config +from app.internal.auth.session_middleware import ( + DynamicSessionMiddleware, + middleware_linker, +) from app.internal.env_settings import Settings from app.internal.models import User -from app.routers import root, search, settings, wishlist -from app.util.auth import RequiresLoginException +from app.routers import auth, login, root, search, settings, wishlist from app.util.db import open_session +with open_session() as session: + auth_secret = auth_config.get_auth_secret(session) + app = FastAPI( title="AudioBookRequest", debug=Settings().app.debug, openapi_url="/openapi.json" if Settings().app.openapi_enabled else None, + middleware=[ + Middleware(DynamicSessionMiddleware, auth_secret, middleware_linker), + Middleware(GZipMiddleware), + ], ) - +app.include_router(auth.router) +app.include_router(login.router) app.include_router(root.router) app.include_router(search.router) -app.include_router(wishlist.router) app.include_router(settings.router) +app.include_router(wishlist.router) user_exists = False diff --git a/app/routers/auth.py b/app/routers/auth.py new file mode 100644 index 0000000..c358ceb --- /dev/null +++ b/app/routers/auth.py @@ -0,0 +1,43 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Request, Response, status +from fastapi.security import OAuth2PasswordRequestForm +from sqlmodel import Session + +from app.internal.auth.auth import ( + DetailedUser, + authenticate_user, + get_authenticated_user, +) +from app.util.db import get_session +from app.util.templates import templates + +router = APIRouter(prefix="/auth") + + +@router.post("/logout") +def logout( + request: Request, user: Annotated[DetailedUser, Depends(get_authenticated_user())] +): + request.session["sub"] = "" + return Response( + status_code=status.HTTP_204_NO_CONTENT, headers={"HX-Redirect": "/login"} + ) + + +@router.post("/token") +def login_access_token( + request: Request, + session: Annotated[Session, Depends(get_session)], + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], +): + user = authenticate_user(session, form_data.username, form_data.password) + if not user: + return templates.TemplateResponse( + "login.html", + {"request": request, "hide_navbar": True, "error": "Invalid login"}, + block_name="error_toast", + ) + + request.session["sub"] = form_data.username + return Response(status_code=status.HTTP_200_OK, headers={"HX-Redirect": "/"}) diff --git a/app/routers/login.py b/app/routers/login.py new file mode 100644 index 0000000..2253efc --- /dev/null +++ b/app/routers/login.py @@ -0,0 +1,39 @@ +from typing import Annotated, Optional + +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import RedirectResponse +from sqlmodel import Session + +from app.internal.auth.auth import ( + LoginTypeEnum, + RequiresLoginException, + auth_config, + get_authenticated_user, +) +from app.util.db import get_session +from app.util.templates import templates + +router = APIRouter(prefix="/login") + + +@router.get("") +async def login( + request: Request, + session: Annotated[Session, Depends(get_session)], + error: Optional[str] = None, +): + login_type = auth_config.get(session, "login_type") + if login_type != LoginTypeEnum.forms: + return RedirectResponse("/") + + try: + await get_authenticated_user()(request, session) + # already logged in + return RedirectResponse("/") + except (HTTPException, RequiresLoginException): + pass + + return templates.TemplateResponse( + "login.html", + {"request": request, "hide_navbar": True, "error": error}, + ) diff --git a/app/routers/root.py b/app/routers/root.py index e5a5475..774930b 100644 --- a/app/routers/root.py +++ b/app/routers/root.py @@ -1,23 +1,18 @@ -from datetime import timedelta -from typing import Annotated, Optional +from typing import Annotated -from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, status +from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response from fastapi.responses import FileResponse, RedirectResponse -from fastapi.security import OAuth2PasswordRequestForm from sqlmodel import Session -from app.internal.models import GroupEnum -from app.util.auth import ( +from app.internal.auth.auth import ( DetailedUser, LoginTypeEnum, - RequiresLoginException, auth_config, - authenticate_user, - create_access_token, create_user, get_authenticated_user, raise_for_invalid_password, ) +from app.internal.models import GroupEnum from app.util.db import get_session from app.util.templates import templates @@ -117,68 +112,3 @@ def create_init( session.commit() return Response(status_code=201, headers={"HX-Redirect": "/"}) - - -@router.get("/login") -async def login( - request: Request, - session: Annotated[Session, Depends(get_session)], - error: Optional[str] = None, -): - login_type = auth_config.get(session, "login_type") - if login_type != LoginTypeEnum.forms: - return RedirectResponse("/") - - try: - await get_authenticated_user()(request, session) - # already logged in - return RedirectResponse("/") - except (HTTPException, RequiresLoginException): - pass - - return templates.TemplateResponse( - "login.html", - {"request": request, "hide_navbar": True, "error": error}, - ) - - -@router.post("/auth/logout") -def logout(user: Annotated[DetailedUser, Depends(get_authenticated_user())]): - return Response( - status_code=status.HTTP_204_NO_CONTENT, - headers={ - "Set-Cookie": "audio_sess=; Path=/; SameSite=Strict; HttpOnly; Expires=Thu, 01 Jan 1970 00:00:00 GMT", - "HX-Redirect": "/login", - }, - ) - - -@router.post("/auth/token") -def login_access_token( - request: Request, - session: Annotated[Session, Depends(get_session)], - form_data: Annotated[OAuth2PasswordRequestForm, Depends()], -): - user = authenticate_user(session, form_data.username, form_data.password) - if not user: - return templates.TemplateResponse( - "login.html", - {"request": request, "hide_navbar": True, "error": "Invalid login"}, - block_name="error_toast", - ) - - access_token_expires_minues = auth_config.get_access_token_expiry_minutes(session) - access_token_exires = timedelta(minutes=access_token_expires_minues) - access_token = create_access_token( - auth_config.get_auth_secret(session), - {"sub": form_data.username}, - access_token_exires, - ) - - return Response( - status_code=status.HTTP_200_OK, - headers={ - "HX-Redirect": "/", - "Set-Cookie": f"audio_sess={access_token}; Path=/; SameSite=Strict; HttpOnly; ", - }, - ) diff --git a/app/routers/search.py b/app/routers/search.py index da309db..639a253 100644 --- a/app/routers/search.py +++ b/app/routers/search.py @@ -36,7 +36,7 @@ from app.internal.prowlarr.prowlarr import prowlarr_config from app.internal.query import query_sources from app.internal.ranking.quality import quality_config from app.routers.wishlist import get_wishlist_books -from app.util.auth import DetailedUser, get_authenticated_user +from app.internal.auth.auth import DetailedUser, get_authenticated_user from app.util.connection import get_connection from app.util.db import get_session, open_session from app.util.templates import template_response diff --git a/app/routers/settings.py b/app/routers/settings.py index f1a9e82..8eba821 100644 --- a/app/routers/settings.py +++ b/app/routers/settings.py @@ -11,7 +11,7 @@ from app.internal.prowlarr.indexer_categories import indexer_categories from app.internal.notifications import send_notification from app.internal.prowlarr.prowlarr import flush_prowlarr_cache, prowlarr_config from app.internal.ranking.quality import IndexerFlag, QualityRange, quality_config -from app.util.auth import ( +from app.internal.auth.auth import ( DetailedUser, LoginTypeEnum, auth_config, @@ -23,6 +23,7 @@ from app.util.auth import ( from app.util.connection import get_connection from app.util.db import get_session from app.util.templates import template_response +from app.util.time import Minute router = APIRouter(prefix="/settings") @@ -690,7 +691,7 @@ def update_security( old = auth_config.get_login_type(session) auth_config.set_login_type(session, login_type) - auth_config.set_access_token_expiry_minutes(session, access_token_expiry) + auth_config.set_access_token_expiry_minutes(session, Minute(access_token_expiry)) auth_config.set_min_password_length(session, min_password_length) return template_response( "settings_page/security.html", diff --git a/app/routers/wishlist.py b/app/routers/wishlist.py index d060396..0c32968 100644 --- a/app/routers/wishlist.py +++ b/app/routers/wishlist.py @@ -28,7 +28,7 @@ from app.internal.prowlarr.prowlarr import ( ) from app.internal.query import query_sources from app.internal.ranking.quality import quality_config -from app.util.auth import DetailedUser, get_authenticated_user +from app.internal.auth.auth import DetailedUser, get_authenticated_user from app.util.connection import get_connection from app.util.db import get_session, open_session from app.util.templates import template_response diff --git a/app/util/templates.py b/app/util/templates.py index 6fbe467..17fb47b 100644 --- a/app/util/templates.py +++ b/app/util/templates.py @@ -5,7 +5,7 @@ from fastapi import Request, Response from jinja2_fragments.fastapi import Jinja2Blocks from starlette.background import BackgroundTask -from app.util.auth import DetailedUser +from app.internal.auth.auth import DetailedUser templates = Jinja2Blocks(directory="templates") templates.env.filters["quote_plus"] = lambda u: quote_plus(u) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType,reportUnknownArgumentType] diff --git a/app/util/time.py b/app/util/time.py new file mode 100644 index 0000000..4b63012 --- /dev/null +++ b/app/util/time.py @@ -0,0 +1,4 @@ +from typing import NewType + +Second = NewType("Second", int) +Minute = NewType("Minute", int) diff --git a/requirements.txt b/requirements.txt index ca4b8b8..09a306d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ httpcore==1.0.7 httptools==0.6.4 httpx==0.28.1 idna==3.10 +itsdangerous==2.2.0 Jinja2==3.1.6 jinja2_fragments==1.8.0 Mako==1.3.9 diff --git a/templates/settings_page/security.html b/templates/settings_page/security.html index 0c265f0..7d31494 100644 --- a/templates/settings_page/security.html +++ b/templates/settings_page/security.html @@ -33,6 +33,9 @@ +