correctly use jwt for state

This commit is contained in:
Markbeep
2025-03-17 08:50:39 +01:00
parent 00238c0ec4
commit 4781960da1

View File

@@ -5,14 +5,13 @@ import time
from typing import Annotated, Optional
from urllib.parse import urlencode, urljoin
import jwt
from aiohttp import ClientSession
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, status
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session, select
from app.internal.auth.config import LoginTypeEnum, auth_config
from app.internal.auth.oidc_config import InvalidOIDCConfiguration, oidc_config
from app.internal.auth.authentication import (
DetailedUser,
RequiresLoginException,
@@ -20,6 +19,8 @@ from app.internal.auth.authentication import (
create_user,
get_authenticated_user,
)
from app.internal.auth.config import LoginTypeEnum, auth_config
from app.internal.auth.oidc_config import InvalidOIDCConfiguration, oidc_config
from app.internal.models import GroupEnum, User
from app.util.connection import get_connection
from app.util.db import get_session
@@ -30,8 +31,6 @@ router = APIRouter(prefix="/auth")
logger = logging.getLogger(__name__)
STATE_RANDOM_BYTES = 16
@router.get("/login")
async def login(
@@ -79,19 +78,18 @@ async def login(
logger.info(f"Redirecting to OIDC login: {authorize_endpoint}")
logger.info(f"Redirect URI: {auth_redirect_uri}")
# Authelia requires the state to be at least 8 chars. We solely use it to keep track of the redirect URI.
# https://github.com/markbeep/AudioBookRequest/issues/62
random_str = base64.encodebytes(secrets.token_bytes(2 * STATE_RANDOM_BYTES)).decode(
"utf-8"
state = jwt.encode(
{"redirect_uri": redirect_uri},
auth_config.get_auth_secret(session),
algorithm="HS256",
)
longer_redirect_uri = redirect_uri + random_str[:STATE_RANDOM_BYTES]
params = {
"response_type": "code",
"client_id": client_id,
"redirect_uri": auth_redirect_uri,
"scope": scope,
"state": longer_redirect_uri,
"state": state,
}
return RedirectResponse(f"{authorize_endpoint}?" + urlencode(params))
@@ -238,7 +236,14 @@ async def login_oidc(
request.session["exp"] = expires
if state:
state = state[:-STATE_RANDOM_BYTES]
decoded = jwt.decode(
state,
auth_config.get_auth_secret(session),
algorithms=["HS256"],
)
redirect_uri = decoded.get("redirect_uri", "/")
else:
redirect_uri = "/"
# We can't redirect server side, because that results in an infinite loop.
# The session token is never correctly set causing any other endpoint to
@@ -250,7 +255,7 @@ async def login_oidc(
{
"request": request,
"hide_navbar": True,
"redirect_uri": state or "/",
"redirect_uri": redirect_uri,
},
)