refactor authentication and add session middlewares

This commit is contained in:
Markbeep
2025-03-12 16:32:51 +01:00
parent 524eb71b82
commit 08f7de3be5
13 changed files with 284 additions and 168 deletions
+99 -85
View File
@@ -8,12 +8,14 @@ import jwt
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBasic, OAuth2PasswordBearer
from fastapi.security import HTTPBasic, OAuth2PasswordBearer, OpenIdConnect
from sqlmodel import Session, select
from app.internal.auth.session_middleware import middleware_linker
from app.internal.models import GroupEnum, User
from app.util.cache import StringConfigCache
from app.util.db import get_session
from app.util.time import Minute, Second
JWT_ALGORITHM = "HS256"
@@ -21,6 +23,7 @@ JWT_ALGORITHM = "HS256"
class LoginTypeEnum(str, Enum):
basic = "basic"
forms = "forms"
oidc = "oidc"
none = "none"
def is_basic(self):
@@ -32,6 +35,9 @@ class LoginTypeEnum(str, Enum):
def is_none(self):
return self == LoginTypeEnum.none
def is_oidc(self):
return self == LoginTypeEnum.oidc
AuthConfigKey = Literal[
"login_type",
@@ -53,6 +59,7 @@ class AuthConfig(StringConfigCache[AuthConfigKey]):
def reset_auth_secret(self, session: Session):
auth_secret = base64.encodebytes(secrets.token_bytes(64)).decode("utf-8")
middleware_linker.update_secret(auth_secret)
self.set(session, "auth_secret", auth_secret)
def get_auth_secret(self, session: Session) -> str:
@@ -63,10 +70,11 @@ class AuthConfig(StringConfigCache[AuthConfigKey]):
self.set(session, "auth_secret", auth_secret)
return auth_secret
def get_access_token_expiry_minutes(self, session: Session):
return self.get_int(session, "access_token_expiry_minutes", 60 * 24 * 7)
def get_access_token_expiry_minutes(self, session: Session) -> Minute:
return Minute(self.get_int(session, "access_token_expiry_minutes", 60 * 24 * 7))
def set_access_token_expiry_minutes(self, session: Session, expiry: int):
def set_access_token_expiry_minutes(self, session: Session, expiry: Minute):
middleware_linker.update_max_age(Second(expiry * 60))
self.set_int(session, "access_token_expiry_minutes", expiry)
def get_min_password_length(self, session: Session) -> int:
@@ -83,12 +91,6 @@ class DetailedUser(User):
return self.login_type == LoginTypeEnum.forms
security = HTTPBasic()
ph = PasswordHasher()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False)
auth_config = AuthConfig()
def raise_for_invalid_password(
session: Session,
password: str,
@@ -154,91 +156,103 @@ def create_user(
return User(username=username, password=password_hash, group=group, root=root)
def get_authenticated_user(lowest_allowed_group: GroupEnum = GroupEnum.untrusted):
async def get_user(
request: Request,
session: Annotated[Session, Depends(get_session)],
) -> DetailedUser:
login_type = auth_config.get_login_type(session)
if login_type == LoginTypeEnum.forms:
user = await _get_forms_auth(request, session)
elif login_type == LoginTypeEnum.none:
user = await _get_none_auth(session)
else:
user = await _get_basic_auth(request, session)
if not user.is_above(lowest_allowed_group):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden"
)
user = DetailedUser.model_validate(user, update={"login_type": login_type})
return user
return get_user
async def _get_basic_auth(
request: Request,
session: Session,
) -> User:
invalid_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
headers={"WWW-Authenticate": "Basic"},
)
credentials = await security(request)
if not credentials:
raise invalid_exception
user = authenticate_user(session, credentials.username, credentials.password)
if not user:
raise invalid_exception
return user
class RequiresLoginException(Exception):
def __init__(self, detail: Optional[str] = None, **kwargs: object):
super().__init__(**kwargs)
self.detail = detail
async def _get_forms_auth(
request: Request,
session: Session,
) -> User:
# Authentication is either through Authorization header or cookie
token = await oauth2_scheme(request)
if not token:
token = request.cookies.get("audio_sess")
if not token:
class ABRAuth:
def __init__(self):
self.oidc_scheme: Optional[OpenIdConnect] = None
self.none_user: Optional[User] = None
def __call__(self, lowest_allowed_group: GroupEnum):
return self._get_authenticated_user(lowest_allowed_group)
def _get_authenticated_user(self, lowest_allowed_group: GroupEnum):
async def get_user(
request: Request,
session: Annotated[Session, Depends(get_session)],
) -> DetailedUser:
login_type = auth_config.get_login_type(session)
if login_type == LoginTypeEnum.forms:
user = await self._get_forms_auth(request, session)
elif login_type == LoginTypeEnum.none:
user = await self._get_none_auth(session)
elif login_type == LoginTypeEnum.oidc:
user = await self._get_oidc_auth(request, session)
else:
user = await self._get_basic_auth(request, session)
if not user.is_above(lowest_allowed_group):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden"
)
user = DetailedUser.model_validate(user, update={"login_type": login_type})
return user
return get_user
async def _get_basic_auth(
self,
request: Request,
session: Session,
) -> User:
invalid_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
headers={"WWW-Authenticate": "Basic"},
)
credentials = await security(request)
if not credentials:
raise invalid_exception
user = authenticate_user(session, credentials.username, credentials.password)
if not user:
raise invalid_exception
return user
async def _get_forms_auth(
self,
request: Request,
session: Session,
) -> User:
username = request.session.get("sub")
if not username:
raise RequiresLoginException()
try:
payload = jwt.decode( # pyright: ignore[reportUnknownMemberType]
token, auth_config.get_auth_secret(session), algorithms=[JWT_ALGORITHM]
)
except jwt.InvalidTokenError:
raise RequiresLoginException("Token is expired/invalid")
user = session.get(User, username)
if not user:
raise RequiresLoginException("User does not exist")
username = payload.get("sub")
if username is None:
raise RequiresLoginException("Token is invalid")
return user
user = session.get(User, username)
if not user:
raise RequiresLoginException("User does not exist")
# TODO
async def _get_oidc_auth(self, request: Request, session: Session) -> User: ...
return user
async def _get_none_auth(self, session: Session) -> User:
"""Treats every request as being root by returning the first admin user"""
if self.none_user:
return self.none_user
self.none_user = session.exec(
select(User).where(User.group == GroupEnum.admin).limit(1)
).one()
return self.none_user
async def _get_none_auth(session: Session) -> User:
"""Treats every request as being root by returning the first admin user"""
return session.exec(
select(User).where(User.group == GroupEnum.admin).limit(1)
).one()
security = HTTPBasic()
ph = PasswordHasher()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False)
auth_config = AuthConfig()
abr_authentication = ABRAuth()
def get_authenticated_user(lowest_allowed_group: GroupEnum = GroupEnum.untrusted):
return abr_authentication(lowest_allowed_group)
+67
View File
@@ -0,0 +1,67 @@
from starlette.middleware.sessions import SessionMiddleware
from starlette.types import ASGIApp, Receive, Scope, Send
from app.util.time import Second
class DynamicSessionMiddleware:
def __init__(
self,
app: ASGIApp,
secret_key: str,
linker: "DynamicMiddlewareLinker",
max_age: Second = Second(60 * 60 * 24 * 14),
):
self.app = app
self.secret_key = secret_key
self.expiry = max_age
self.session_middleware = SessionMiddleware(
app,
secret_key,
same_site="strict",
max_age=max_age,
)
linker.add_middleware(self)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return await self.session_middleware(scope, receive, send)
def update_secret(self, secret_key: str):
self.session_middleware = SessionMiddleware(
self.app,
secret_key,
same_site="strict",
max_age=self.expiry,
)
def update_max_age(self, max_age: Second):
self.session_middleware = SessionMiddleware(
self.app,
self.secret_key,
same_site="strict",
max_age=max_age,
)
class DynamicMiddlewareLinker:
"""
Linker is passed in as an argument to the DynamicSessionMiddleware so
wherever FastAPI initializes the middleware, we can update
the options to take effect immediately instead of having to restart the server
"""
middlewares: list[DynamicSessionMiddleware] = []
def add_middleware(self, middleware: DynamicSessionMiddleware):
self.middlewares.append(middleware)
def update_secret(self, secret_key: str):
for middleware in self.middlewares:
middleware.update_secret(secret_key)
def update_max_age(self, expiry: Second):
for middleware in self.middlewares:
middleware.update_max_age(expiry)
middleware_linker = DynamicMiddlewareLinker()
+18 -4
View File
@@ -2,27 +2,41 @@ from typing import Any
from urllib.parse import quote_plus
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.middleware import Middleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import RedirectResponse
from sqlalchemy import func
from sqlmodel import select
from app.internal.auth.auth import RequiresLoginException, auth_config
from app.internal.auth.session_middleware import (
DynamicSessionMiddleware,
middleware_linker,
)
from app.internal.env_settings import Settings
from app.internal.models import User
from app.routers import root, search, settings, wishlist
from app.util.auth import RequiresLoginException
from app.routers import auth, login, root, search, settings, wishlist
from app.util.db import open_session
with open_session() as session:
auth_secret = auth_config.get_auth_secret(session)
app = FastAPI(
title="AudioBookRequest",
debug=Settings().app.debug,
openapi_url="/openapi.json" if Settings().app.openapi_enabled else None,
middleware=[
Middleware(DynamicSessionMiddleware, auth_secret, middleware_linker),
Middleware(GZipMiddleware),
],
)
app.include_router(auth.router)
app.include_router(login.router)
app.include_router(root.router)
app.include_router(search.router)
app.include_router(wishlist.router)
app.include_router(settings.router)
app.include_router(wishlist.router)
user_exists = False
+43
View File
@@ -0,0 +1,43 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session
from app.internal.auth.auth import (
DetailedUser,
authenticate_user,
get_authenticated_user,
)
from app.util.db import get_session
from app.util.templates import templates
router = APIRouter(prefix="/auth")
@router.post("/logout")
def logout(
request: Request, user: Annotated[DetailedUser, Depends(get_authenticated_user())]
):
request.session["sub"] = ""
return Response(
status_code=status.HTTP_204_NO_CONTENT, headers={"HX-Redirect": "/login"}
)
@router.post("/token")
def login_access_token(
request: Request,
session: Annotated[Session, Depends(get_session)],
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
):
user = authenticate_user(session, form_data.username, form_data.password)
if not user:
return templates.TemplateResponse(
"login.html",
{"request": request, "hide_navbar": True, "error": "Invalid login"},
block_name="error_toast",
)
request.session["sub"] = form_data.username
return Response(status_code=status.HTTP_200_OK, headers={"HX-Redirect": "/"})
+39
View File
@@ -0,0 +1,39 @@
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import RedirectResponse
from sqlmodel import Session
from app.internal.auth.auth import (
LoginTypeEnum,
RequiresLoginException,
auth_config,
get_authenticated_user,
)
from app.util.db import get_session
from app.util.templates import templates
router = APIRouter(prefix="/login")
@router.get("")
async def login(
request: Request,
session: Annotated[Session, Depends(get_session)],
error: Optional[str] = None,
):
login_type = auth_config.get(session, "login_type")
if login_type != LoginTypeEnum.forms:
return RedirectResponse("/")
try:
await get_authenticated_user()(request, session)
# already logged in
return RedirectResponse("/")
except (HTTPException, RequiresLoginException):
pass
return templates.TemplateResponse(
"login.html",
{"request": request, "hide_navbar": True, "error": error},
)
+4 -74
View File
@@ -1,23 +1,18 @@
from datetime import timedelta
from typing import Annotated, Optional
from typing import Annotated
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, status
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response
from fastapi.responses import FileResponse, RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session
from app.internal.models import GroupEnum
from app.util.auth import (
from app.internal.auth.auth import (
DetailedUser,
LoginTypeEnum,
RequiresLoginException,
auth_config,
authenticate_user,
create_access_token,
create_user,
get_authenticated_user,
raise_for_invalid_password,
)
from app.internal.models import GroupEnum
from app.util.db import get_session
from app.util.templates import templates
@@ -117,68 +112,3 @@ def create_init(
session.commit()
return Response(status_code=201, headers={"HX-Redirect": "/"})
@router.get("/login")
async def login(
request: Request,
session: Annotated[Session, Depends(get_session)],
error: Optional[str] = None,
):
login_type = auth_config.get(session, "login_type")
if login_type != LoginTypeEnum.forms:
return RedirectResponse("/")
try:
await get_authenticated_user()(request, session)
# already logged in
return RedirectResponse("/")
except (HTTPException, RequiresLoginException):
pass
return templates.TemplateResponse(
"login.html",
{"request": request, "hide_navbar": True, "error": error},
)
@router.post("/auth/logout")
def logout(user: Annotated[DetailedUser, Depends(get_authenticated_user())]):
return Response(
status_code=status.HTTP_204_NO_CONTENT,
headers={
"Set-Cookie": "audio_sess=; Path=/; SameSite=Strict; HttpOnly; Expires=Thu, 01 Jan 1970 00:00:00 GMT",
"HX-Redirect": "/login",
},
)
@router.post("/auth/token")
def login_access_token(
request: Request,
session: Annotated[Session, Depends(get_session)],
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
):
user = authenticate_user(session, form_data.username, form_data.password)
if not user:
return templates.TemplateResponse(
"login.html",
{"request": request, "hide_navbar": True, "error": "Invalid login"},
block_name="error_toast",
)
access_token_expires_minues = auth_config.get_access_token_expiry_minutes(session)
access_token_exires = timedelta(minutes=access_token_expires_minues)
access_token = create_access_token(
auth_config.get_auth_secret(session),
{"sub": form_data.username},
access_token_exires,
)
return Response(
status_code=status.HTTP_200_OK,
headers={
"HX-Redirect": "/",
"Set-Cookie": f"audio_sess={access_token}; Path=/; SameSite=Strict; HttpOnly; ",
},
)
+1 -1
View File
@@ -36,7 +36,7 @@ from app.internal.prowlarr.prowlarr import prowlarr_config
from app.internal.query import query_sources
from app.internal.ranking.quality import quality_config
from app.routers.wishlist import get_wishlist_books
from app.util.auth import DetailedUser, get_authenticated_user
from app.internal.auth.auth import DetailedUser, get_authenticated_user
from app.util.connection import get_connection
from app.util.db import get_session, open_session
from app.util.templates import template_response
+3 -2
View File
@@ -11,7 +11,7 @@ from app.internal.prowlarr.indexer_categories import indexer_categories
from app.internal.notifications import send_notification
from app.internal.prowlarr.prowlarr import flush_prowlarr_cache, prowlarr_config
from app.internal.ranking.quality import IndexerFlag, QualityRange, quality_config
from app.util.auth import (
from app.internal.auth.auth import (
DetailedUser,
LoginTypeEnum,
auth_config,
@@ -23,6 +23,7 @@ from app.util.auth import (
from app.util.connection import get_connection
from app.util.db import get_session
from app.util.templates import template_response
from app.util.time import Minute
router = APIRouter(prefix="/settings")
@@ -690,7 +691,7 @@ def update_security(
old = auth_config.get_login_type(session)
auth_config.set_login_type(session, login_type)
auth_config.set_access_token_expiry_minutes(session, access_token_expiry)
auth_config.set_access_token_expiry_minutes(session, Minute(access_token_expiry))
auth_config.set_min_password_length(session, min_password_length)
return template_response(
"settings_page/security.html",
+1 -1
View File
@@ -28,7 +28,7 @@ from app.internal.prowlarr.prowlarr import (
)
from app.internal.query import query_sources
from app.internal.ranking.quality import quality_config
from app.util.auth import DetailedUser, get_authenticated_user
from app.internal.auth.auth import DetailedUser, get_authenticated_user
from app.util.connection import get_connection
from app.util.db import get_session, open_session
from app.util.templates import template_response
+1 -1
View File
@@ -5,7 +5,7 @@ from fastapi import Request, Response
from jinja2_fragments.fastapi import Jinja2Blocks
from starlette.background import BackgroundTask
from app.util.auth import DetailedUser
from app.internal.auth.auth import DetailedUser
templates = Jinja2Blocks(directory="templates")
templates.env.filters["quote_plus"] = lambda u: quote_plus(u) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType,reportUnknownArgumentType]
+4
View File
@@ -0,0 +1,4 @@
from typing import NewType
Second = NewType("Second", int)
Minute = NewType("Minute", int)
+1
View File
@@ -25,6 +25,7 @@ httpcore==1.0.7
httptools==0.6.4
httpx==0.28.1
idna==3.10
itsdangerous==2.2.0
Jinja2==3.1.6
jinja2_fragments==1.8.0
Mako==1.3.9
+3
View File
@@ -33,6 +33,9 @@
<option value="forms" {% if login_type.is_forms() %}selected{% endif %}>
Forms Login
</option>
<option value="oidc" {% if login_type.is_oidc() %}selected{% endif %}>
OpenID Connect
</option>
<option value="none" {% if login_type.is_none() %}selected{% endif %}>
None (Insecure)
</option>