diff --git a/app/routers/auth.py b/app/routers/auth.py index b446010..1fbc633 100644 --- a/app/routers/auth.py +++ b/app/routers/auth.py @@ -5,14 +5,13 @@ import time from typing import Annotated, Optional from urllib.parse import urlencode, urljoin +import jwt from aiohttp import ClientSession from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, status from fastapi.responses import RedirectResponse from fastapi.security import OAuth2PasswordRequestForm from sqlmodel import Session, select -from app.internal.auth.config import LoginTypeEnum, auth_config -from app.internal.auth.oidc_config import InvalidOIDCConfiguration, oidc_config from app.internal.auth.authentication import ( DetailedUser, RequiresLoginException, @@ -20,6 +19,8 @@ from app.internal.auth.authentication import ( create_user, get_authenticated_user, ) +from app.internal.auth.config import LoginTypeEnum, auth_config +from app.internal.auth.oidc_config import InvalidOIDCConfiguration, oidc_config from app.internal.models import GroupEnum, User from app.util.connection import get_connection from app.util.db import get_session @@ -30,8 +31,6 @@ router = APIRouter(prefix="/auth") logger = logging.getLogger(__name__) -STATE_RANDOM_BYTES = 16 - @router.get("/login") async def login( @@ -79,19 +78,18 @@ async def login( logger.info(f"Redirecting to OIDC login: {authorize_endpoint}") logger.info(f"Redirect URI: {auth_redirect_uri}") - # Authelia requires the state to be at least 8 chars. We solely use it to keep track of the redirect URI. - # https://github.com/markbeep/AudioBookRequest/issues/62 - random_str = base64.encodebytes(secrets.token_bytes(2 * STATE_RANDOM_BYTES)).decode( - "utf-8" + state = jwt.encode( + {"redirect_uri": redirect_uri}, + auth_config.get_auth_secret(session), + algorithm="HS256", ) - longer_redirect_uri = redirect_uri + random_str[:STATE_RANDOM_BYTES] params = { "response_type": "code", "client_id": client_id, "redirect_uri": auth_redirect_uri, "scope": scope, - "state": longer_redirect_uri, + "state": state, } return RedirectResponse(f"{authorize_endpoint}?" + urlencode(params)) @@ -238,7 +236,14 @@ async def login_oidc( request.session["exp"] = expires if state: - state = state[:-STATE_RANDOM_BYTES] + decoded = jwt.decode( + state, + auth_config.get_auth_secret(session), + algorithms=["HS256"], + ) + redirect_uri = decoded.get("redirect_uri", "/") + else: + redirect_uri = "/" # We can't redirect server side, because that results in an infinite loop. # The session token is never correctly set causing any other endpoint to @@ -250,7 +255,7 @@ async def login_oidc( { "request": request, "hide_navbar": True, - "redirect_uri": state or "/", + "redirect_uri": redirect_uri, }, )