diff --git a/app/util/auth.py b/app/internal/auth/auth.py
similarity index 59%
rename from app/util/auth.py
rename to app/internal/auth/auth.py
index dbed534..6b74859 100644
--- a/app/util/auth.py
+++ b/app/internal/auth/auth.py
@@ -8,12 +8,14 @@ import jwt
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from fastapi import Depends, HTTPException, Request, status
-from fastapi.security import HTTPBasic, OAuth2PasswordBearer
+from fastapi.security import HTTPBasic, OAuth2PasswordBearer, OpenIdConnect
from sqlmodel import Session, select
+from app.internal.auth.session_middleware import middleware_linker
from app.internal.models import GroupEnum, User
from app.util.cache import StringConfigCache
from app.util.db import get_session
+from app.util.time import Minute, Second
JWT_ALGORITHM = "HS256"
@@ -21,6 +23,7 @@ JWT_ALGORITHM = "HS256"
class LoginTypeEnum(str, Enum):
basic = "basic"
forms = "forms"
+ oidc = "oidc"
none = "none"
def is_basic(self):
@@ -32,6 +35,9 @@ class LoginTypeEnum(str, Enum):
def is_none(self):
return self == LoginTypeEnum.none
+ def is_oidc(self):
+ return self == LoginTypeEnum.oidc
+
AuthConfigKey = Literal[
"login_type",
@@ -53,6 +59,7 @@ class AuthConfig(StringConfigCache[AuthConfigKey]):
def reset_auth_secret(self, session: Session):
auth_secret = base64.encodebytes(secrets.token_bytes(64)).decode("utf-8")
+ middleware_linker.update_secret(auth_secret)
self.set(session, "auth_secret", auth_secret)
def get_auth_secret(self, session: Session) -> str:
@@ -63,10 +70,11 @@ class AuthConfig(StringConfigCache[AuthConfigKey]):
self.set(session, "auth_secret", auth_secret)
return auth_secret
- def get_access_token_expiry_minutes(self, session: Session):
- return self.get_int(session, "access_token_expiry_minutes", 60 * 24 * 7)
+ def get_access_token_expiry_minutes(self, session: Session) -> Minute:
+ return Minute(self.get_int(session, "access_token_expiry_minutes", 60 * 24 * 7))
- def set_access_token_expiry_minutes(self, session: Session, expiry: int):
+ def set_access_token_expiry_minutes(self, session: Session, expiry: Minute):
+ middleware_linker.update_max_age(Second(expiry * 60))
self.set_int(session, "access_token_expiry_minutes", expiry)
def get_min_password_length(self, session: Session) -> int:
@@ -83,12 +91,6 @@ class DetailedUser(User):
return self.login_type == LoginTypeEnum.forms
-security = HTTPBasic()
-ph = PasswordHasher()
-oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False)
-auth_config = AuthConfig()
-
-
def raise_for_invalid_password(
session: Session,
password: str,
@@ -154,91 +156,103 @@ def create_user(
return User(username=username, password=password_hash, group=group, root=root)
-def get_authenticated_user(lowest_allowed_group: GroupEnum = GroupEnum.untrusted):
- async def get_user(
- request: Request,
- session: Annotated[Session, Depends(get_session)],
- ) -> DetailedUser:
- login_type = auth_config.get_login_type(session)
-
- if login_type == LoginTypeEnum.forms:
- user = await _get_forms_auth(request, session)
- elif login_type == LoginTypeEnum.none:
- user = await _get_none_auth(session)
- else:
- user = await _get_basic_auth(request, session)
-
- if not user.is_above(lowest_allowed_group):
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden"
- )
-
- user = DetailedUser.model_validate(user, update={"login_type": login_type})
-
- return user
-
- return get_user
-
-
-async def _get_basic_auth(
- request: Request,
- session: Session,
-) -> User:
- invalid_exception = HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid credentials",
- headers={"WWW-Authenticate": "Basic"},
- )
-
- credentials = await security(request)
-
- if not credentials:
- raise invalid_exception
-
- user = authenticate_user(session, credentials.username, credentials.password)
- if not user:
- raise invalid_exception
-
- return user
-
-
class RequiresLoginException(Exception):
def __init__(self, detail: Optional[str] = None, **kwargs: object):
super().__init__(**kwargs)
self.detail = detail
-async def _get_forms_auth(
- request: Request,
- session: Session,
-) -> User:
- # Authentication is either through Authorization header or cookie
- token = await oauth2_scheme(request)
- if not token:
- token = request.cookies.get("audio_sess")
- if not token:
+class ABRAuth:
+ def __init__(self):
+ self.oidc_scheme: Optional[OpenIdConnect] = None
+ self.none_user: Optional[User] = None
+
+ def __call__(self, lowest_allowed_group: GroupEnum):
+ return self._get_authenticated_user(lowest_allowed_group)
+
+ def _get_authenticated_user(self, lowest_allowed_group: GroupEnum):
+ async def get_user(
+ request: Request,
+ session: Annotated[Session, Depends(get_session)],
+ ) -> DetailedUser:
+ login_type = auth_config.get_login_type(session)
+
+ if login_type == LoginTypeEnum.forms:
+ user = await self._get_forms_auth(request, session)
+ elif login_type == LoginTypeEnum.none:
+ user = await self._get_none_auth(session)
+ elif login_type == LoginTypeEnum.oidc:
+ user = await self._get_oidc_auth(request, session)
+ else:
+ user = await self._get_basic_auth(request, session)
+
+ if not user.is_above(lowest_allowed_group):
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden"
+ )
+
+ user = DetailedUser.model_validate(user, update={"login_type": login_type})
+
+ return user
+
+ return get_user
+
+ async def _get_basic_auth(
+ self,
+ request: Request,
+ session: Session,
+ ) -> User:
+ invalid_exception = HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid credentials",
+ headers={"WWW-Authenticate": "Basic"},
+ )
+
+ credentials = await security(request)
+
+ if not credentials:
+ raise invalid_exception
+
+ user = authenticate_user(session, credentials.username, credentials.password)
+ if not user:
+ raise invalid_exception
+
+ return user
+
+ async def _get_forms_auth(
+ self,
+ request: Request,
+ session: Session,
+ ) -> User:
+ username = request.session.get("sub")
+ if not username:
raise RequiresLoginException()
- try:
- payload = jwt.decode( # pyright: ignore[reportUnknownMemberType]
- token, auth_config.get_auth_secret(session), algorithms=[JWT_ALGORITHM]
- )
- except jwt.InvalidTokenError:
- raise RequiresLoginException("Token is expired/invalid")
+ user = session.get(User, username)
+ if not user:
+ raise RequiresLoginException("User does not exist")
- username = payload.get("sub")
- if username is None:
- raise RequiresLoginException("Token is invalid")
+ return user
- user = session.get(User, username)
- if not user:
- raise RequiresLoginException("User does not exist")
+ # TODO
+ async def _get_oidc_auth(self, request: Request, session: Session) -> User: ...
- return user
+ async def _get_none_auth(self, session: Session) -> User:
+ """Treats every request as being root by returning the first admin user"""
+ if self.none_user:
+ return self.none_user
+ self.none_user = session.exec(
+ select(User).where(User.group == GroupEnum.admin).limit(1)
+ ).one()
+ return self.none_user
-async def _get_none_auth(session: Session) -> User:
- """Treats every request as being root by returning the first admin user"""
- return session.exec(
- select(User).where(User.group == GroupEnum.admin).limit(1)
- ).one()
+security = HTTPBasic()
+ph = PasswordHasher()
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False)
+auth_config = AuthConfig()
+abr_authentication = ABRAuth()
+
+
+def get_authenticated_user(lowest_allowed_group: GroupEnum = GroupEnum.untrusted):
+ return abr_authentication(lowest_allowed_group)
diff --git a/app/internal/auth/session_middleware.py b/app/internal/auth/session_middleware.py
new file mode 100644
index 0000000..e709686
--- /dev/null
+++ b/app/internal/auth/session_middleware.py
@@ -0,0 +1,67 @@
+from starlette.middleware.sessions import SessionMiddleware
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+from app.util.time import Second
+
+
+class DynamicSessionMiddleware:
+ def __init__(
+ self,
+ app: ASGIApp,
+ secret_key: str,
+ linker: "DynamicMiddlewareLinker",
+ max_age: Second = Second(60 * 60 * 24 * 14),
+ ):
+ self.app = app
+ self.secret_key = secret_key
+ self.expiry = max_age
+ self.session_middleware = SessionMiddleware(
+ app,
+ secret_key,
+ same_site="strict",
+ max_age=max_age,
+ )
+ linker.add_middleware(self)
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ return await self.session_middleware(scope, receive, send)
+
+ def update_secret(self, secret_key: str):
+ self.session_middleware = SessionMiddleware(
+ self.app,
+ secret_key,
+ same_site="strict",
+ max_age=self.expiry,
+ )
+
+ def update_max_age(self, max_age: Second):
+ self.session_middleware = SessionMiddleware(
+ self.app,
+ self.secret_key,
+ same_site="strict",
+ max_age=max_age,
+ )
+
+
+class DynamicMiddlewareLinker:
+ """
+ Linker is passed in as an argument to the DynamicSessionMiddleware so
+ wherever FastAPI initializes the middleware, we can update
+ the options to take effect immediately instead of having to restart the server
+ """
+
+ middlewares: list[DynamicSessionMiddleware] = []
+
+ def add_middleware(self, middleware: DynamicSessionMiddleware):
+ self.middlewares.append(middleware)
+
+ def update_secret(self, secret_key: str):
+ for middleware in self.middlewares:
+ middleware.update_secret(secret_key)
+
+ def update_max_age(self, expiry: Second):
+ for middleware in self.middlewares:
+ middleware.update_max_age(expiry)
+
+
+middleware_linker = DynamicMiddlewareLinker()
diff --git a/app/main.py b/app/main.py
index 236bdc6..6b73e04 100644
--- a/app/main.py
+++ b/app/main.py
@@ -2,27 +2,41 @@ from typing import Any
from urllib.parse import quote_plus
from fastapi import FastAPI, HTTPException, Request, status
+from fastapi.middleware import Middleware
+from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import RedirectResponse
from sqlalchemy import func
from sqlmodel import select
+from app.internal.auth.auth import RequiresLoginException, auth_config
+from app.internal.auth.session_middleware import (
+ DynamicSessionMiddleware,
+ middleware_linker,
+)
from app.internal.env_settings import Settings
from app.internal.models import User
-from app.routers import root, search, settings, wishlist
-from app.util.auth import RequiresLoginException
+from app.routers import auth, login, root, search, settings, wishlist
from app.util.db import open_session
+with open_session() as session:
+ auth_secret = auth_config.get_auth_secret(session)
+
app = FastAPI(
title="AudioBookRequest",
debug=Settings().app.debug,
openapi_url="/openapi.json" if Settings().app.openapi_enabled else None,
+ middleware=[
+ Middleware(DynamicSessionMiddleware, auth_secret, middleware_linker),
+ Middleware(GZipMiddleware),
+ ],
)
-
+app.include_router(auth.router)
+app.include_router(login.router)
app.include_router(root.router)
app.include_router(search.router)
-app.include_router(wishlist.router)
app.include_router(settings.router)
+app.include_router(wishlist.router)
user_exists = False
diff --git a/app/routers/auth.py b/app/routers/auth.py
new file mode 100644
index 0000000..c358ceb
--- /dev/null
+++ b/app/routers/auth.py
@@ -0,0 +1,43 @@
+from typing import Annotated
+
+from fastapi import APIRouter, Depends, Request, Response, status
+from fastapi.security import OAuth2PasswordRequestForm
+from sqlmodel import Session
+
+from app.internal.auth.auth import (
+ DetailedUser,
+ authenticate_user,
+ get_authenticated_user,
+)
+from app.util.db import get_session
+from app.util.templates import templates
+
+router = APIRouter(prefix="/auth")
+
+
+@router.post("/logout")
+def logout(
+ request: Request, user: Annotated[DetailedUser, Depends(get_authenticated_user())]
+):
+ request.session["sub"] = ""
+ return Response(
+ status_code=status.HTTP_204_NO_CONTENT, headers={"HX-Redirect": "/login"}
+ )
+
+
+@router.post("/token")
+def login_access_token(
+ request: Request,
+ session: Annotated[Session, Depends(get_session)],
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
+):
+ user = authenticate_user(session, form_data.username, form_data.password)
+ if not user:
+ return templates.TemplateResponse(
+ "login.html",
+ {"request": request, "hide_navbar": True, "error": "Invalid login"},
+ block_name="error_toast",
+ )
+
+ request.session["sub"] = form_data.username
+ return Response(status_code=status.HTTP_200_OK, headers={"HX-Redirect": "/"})
diff --git a/app/routers/login.py b/app/routers/login.py
new file mode 100644
index 0000000..2253efc
--- /dev/null
+++ b/app/routers/login.py
@@ -0,0 +1,39 @@
+from typing import Annotated, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi.responses import RedirectResponse
+from sqlmodel import Session
+
+from app.internal.auth.auth import (
+ LoginTypeEnum,
+ RequiresLoginException,
+ auth_config,
+ get_authenticated_user,
+)
+from app.util.db import get_session
+from app.util.templates import templates
+
+router = APIRouter(prefix="/login")
+
+
+@router.get("")
+async def login(
+ request: Request,
+ session: Annotated[Session, Depends(get_session)],
+ error: Optional[str] = None,
+):
+ login_type = auth_config.get(session, "login_type")
+ if login_type != LoginTypeEnum.forms:
+ return RedirectResponse("/")
+
+ try:
+ await get_authenticated_user()(request, session)
+ # already logged in
+ return RedirectResponse("/")
+ except (HTTPException, RequiresLoginException):
+ pass
+
+ return templates.TemplateResponse(
+ "login.html",
+ {"request": request, "hide_navbar": True, "error": error},
+ )
diff --git a/app/routers/root.py b/app/routers/root.py
index e5a5475..774930b 100644
--- a/app/routers/root.py
+++ b/app/routers/root.py
@@ -1,23 +1,18 @@
-from datetime import timedelta
-from typing import Annotated, Optional
+from typing import Annotated
-from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, status
+from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response
from fastapi.responses import FileResponse, RedirectResponse
-from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session
-from app.internal.models import GroupEnum
-from app.util.auth import (
+from app.internal.auth.auth import (
DetailedUser,
LoginTypeEnum,
- RequiresLoginException,
auth_config,
- authenticate_user,
- create_access_token,
create_user,
get_authenticated_user,
raise_for_invalid_password,
)
+from app.internal.models import GroupEnum
from app.util.db import get_session
from app.util.templates import templates
@@ -117,68 +112,3 @@ def create_init(
session.commit()
return Response(status_code=201, headers={"HX-Redirect": "/"})
-
-
-@router.get("/login")
-async def login(
- request: Request,
- session: Annotated[Session, Depends(get_session)],
- error: Optional[str] = None,
-):
- login_type = auth_config.get(session, "login_type")
- if login_type != LoginTypeEnum.forms:
- return RedirectResponse("/")
-
- try:
- await get_authenticated_user()(request, session)
- # already logged in
- return RedirectResponse("/")
- except (HTTPException, RequiresLoginException):
- pass
-
- return templates.TemplateResponse(
- "login.html",
- {"request": request, "hide_navbar": True, "error": error},
- )
-
-
-@router.post("/auth/logout")
-def logout(user: Annotated[DetailedUser, Depends(get_authenticated_user())]):
- return Response(
- status_code=status.HTTP_204_NO_CONTENT,
- headers={
- "Set-Cookie": "audio_sess=; Path=/; SameSite=Strict; HttpOnly; Expires=Thu, 01 Jan 1970 00:00:00 GMT",
- "HX-Redirect": "/login",
- },
- )
-
-
-@router.post("/auth/token")
-def login_access_token(
- request: Request,
- session: Annotated[Session, Depends(get_session)],
- form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
-):
- user = authenticate_user(session, form_data.username, form_data.password)
- if not user:
- return templates.TemplateResponse(
- "login.html",
- {"request": request, "hide_navbar": True, "error": "Invalid login"},
- block_name="error_toast",
- )
-
- access_token_expires_minues = auth_config.get_access_token_expiry_minutes(session)
- access_token_exires = timedelta(minutes=access_token_expires_minues)
- access_token = create_access_token(
- auth_config.get_auth_secret(session),
- {"sub": form_data.username},
- access_token_exires,
- )
-
- return Response(
- status_code=status.HTTP_200_OK,
- headers={
- "HX-Redirect": "/",
- "Set-Cookie": f"audio_sess={access_token}; Path=/; SameSite=Strict; HttpOnly; ",
- },
- )
diff --git a/app/routers/search.py b/app/routers/search.py
index da309db..639a253 100644
--- a/app/routers/search.py
+++ b/app/routers/search.py
@@ -36,7 +36,7 @@ from app.internal.prowlarr.prowlarr import prowlarr_config
from app.internal.query import query_sources
from app.internal.ranking.quality import quality_config
from app.routers.wishlist import get_wishlist_books
-from app.util.auth import DetailedUser, get_authenticated_user
+from app.internal.auth.auth import DetailedUser, get_authenticated_user
from app.util.connection import get_connection
from app.util.db import get_session, open_session
from app.util.templates import template_response
diff --git a/app/routers/settings.py b/app/routers/settings.py
index f1a9e82..8eba821 100644
--- a/app/routers/settings.py
+++ b/app/routers/settings.py
@@ -11,7 +11,7 @@ from app.internal.prowlarr.indexer_categories import indexer_categories
from app.internal.notifications import send_notification
from app.internal.prowlarr.prowlarr import flush_prowlarr_cache, prowlarr_config
from app.internal.ranking.quality import IndexerFlag, QualityRange, quality_config
-from app.util.auth import (
+from app.internal.auth.auth import (
DetailedUser,
LoginTypeEnum,
auth_config,
@@ -23,6 +23,7 @@ from app.util.auth import (
from app.util.connection import get_connection
from app.util.db import get_session
from app.util.templates import template_response
+from app.util.time import Minute
router = APIRouter(prefix="/settings")
@@ -690,7 +691,7 @@ def update_security(
old = auth_config.get_login_type(session)
auth_config.set_login_type(session, login_type)
- auth_config.set_access_token_expiry_minutes(session, access_token_expiry)
+ auth_config.set_access_token_expiry_minutes(session, Minute(access_token_expiry))
auth_config.set_min_password_length(session, min_password_length)
return template_response(
"settings_page/security.html",
diff --git a/app/routers/wishlist.py b/app/routers/wishlist.py
index d060396..0c32968 100644
--- a/app/routers/wishlist.py
+++ b/app/routers/wishlist.py
@@ -28,7 +28,7 @@ from app.internal.prowlarr.prowlarr import (
)
from app.internal.query import query_sources
from app.internal.ranking.quality import quality_config
-from app.util.auth import DetailedUser, get_authenticated_user
+from app.internal.auth.auth import DetailedUser, get_authenticated_user
from app.util.connection import get_connection
from app.util.db import get_session, open_session
from app.util.templates import template_response
diff --git a/app/util/templates.py b/app/util/templates.py
index 6fbe467..17fb47b 100644
--- a/app/util/templates.py
+++ b/app/util/templates.py
@@ -5,7 +5,7 @@ from fastapi import Request, Response
from jinja2_fragments.fastapi import Jinja2Blocks
from starlette.background import BackgroundTask
-from app.util.auth import DetailedUser
+from app.internal.auth.auth import DetailedUser
templates = Jinja2Blocks(directory="templates")
templates.env.filters["quote_plus"] = lambda u: quote_plus(u) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType,reportUnknownArgumentType]
diff --git a/app/util/time.py b/app/util/time.py
new file mode 100644
index 0000000..4b63012
--- /dev/null
+++ b/app/util/time.py
@@ -0,0 +1,4 @@
+from typing import NewType
+
+Second = NewType("Second", int)
+Minute = NewType("Minute", int)
diff --git a/requirements.txt b/requirements.txt
index ca4b8b8..09a306d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -25,6 +25,7 @@ httpcore==1.0.7
httptools==0.6.4
httpx==0.28.1
idna==3.10
+itsdangerous==2.2.0
Jinja2==3.1.6
jinja2_fragments==1.8.0
Mako==1.3.9
diff --git a/templates/settings_page/security.html b/templates/settings_page/security.html
index 0c265f0..7d31494 100644
--- a/templates/settings_page/security.html
+++ b/templates/settings_page/security.html
@@ -33,6 +33,9 @@
+