mirror of
https://github.com/markbeep/AudioBookRequest.git
synced 2026-04-28 04:21:30 -05:00
access userinfo endpoint to support oidc without id_tokens
This commit is contained in:
@@ -10,8 +10,6 @@ from app.internal.auth.config import LoginTypeEnum, auth_config
|
||||
from app.internal.models import GroupEnum, User
|
||||
from app.util.db import get_session
|
||||
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
|
||||
class DetailedUser(User):
|
||||
login_type: LoginTypeEnum
|
||||
|
||||
@@ -14,6 +14,7 @@ oidcConfigKey = Literal[
|
||||
"oidc_group_claim",
|
||||
"oidc_endpoint",
|
||||
"oidc_token_endpoint",
|
||||
"oidc_userinfo_endpoint",
|
||||
"oidc_authorize_endpoint",
|
||||
]
|
||||
|
||||
@@ -39,6 +40,7 @@ class oidcConfig(StringConfigCache[oidcConfigKey]):
|
||||
session, "oidc_authorize_endpoint", data["authorization_endpoint"]
|
||||
)
|
||||
self.set(session, "oidc_token_endpoint", data["token_endpoint"])
|
||||
self.set(session, "oidc_userinfo_endpoint", data["userinfo_endpoint"])
|
||||
|
||||
async def validate(
|
||||
self, session: Session, client_session: ClientSession
|
||||
|
||||
+13
-19
@@ -7,7 +7,6 @@ from aiohttp import ClientSession
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
import jwt
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.internal.auth.config import LoginTypeEnum, auth_config
|
||||
@@ -128,6 +127,7 @@ async def login_oidc(
|
||||
state: Optional[str] = None,
|
||||
):
|
||||
token_endpoint = oidc_config.get(session, "oidc_token_endpoint")
|
||||
userinfo_endpoint = oidc_config.get(session, "oidc_userinfo_endpoint")
|
||||
client_id = oidc_config.get(session, "oidc_client_id")
|
||||
client_secret = oidc_config.get(session, "oidc_client_secret")
|
||||
username_claim = oidc_config.get(session, "oidc_username_claim")
|
||||
@@ -135,6 +135,8 @@ async def login_oidc(
|
||||
|
||||
if not token_endpoint:
|
||||
raise InvalidOIDCConfiguration("Missing OIDC endpoint")
|
||||
if not userinfo_endpoint:
|
||||
raise InvalidOIDCConfiguration("Missing OIDC userinfo endpoint")
|
||||
if not client_id:
|
||||
raise InvalidOIDCConfiguration("Missing OIDC client ID")
|
||||
if not client_secret:
|
||||
@@ -160,29 +162,21 @@ async def login_oidc(
|
||||
) as response:
|
||||
body = await response.json()
|
||||
|
||||
id_token = body.get("id_token")
|
||||
if not id_token:
|
||||
access_token = body.get("access_token")
|
||||
if not access_token:
|
||||
return Response(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
try:
|
||||
# TODO: Verify signature
|
||||
decoded = jwt.decode( # pyright: ignore[reportUnknownMemberType]
|
||||
id_token,
|
||||
options={"verify_signature": False},
|
||||
require=[
|
||||
username_claim,
|
||||
group_claim,
|
||||
], # TODO: 'require' has no effect if verify_signature is False
|
||||
)
|
||||
except jwt.InvalidTokenError as e:
|
||||
print(f"Invalid id_token: {e}")
|
||||
return Response(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
async with client_session.get(
|
||||
userinfo_endpoint,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
) as response:
|
||||
userinfo = await response.json()
|
||||
|
||||
username = decoded.get(username_claim)
|
||||
username = userinfo.get(username_claim)
|
||||
if not username:
|
||||
raise InvalidOIDCConfiguration("Missing username claim")
|
||||
|
||||
groups: list[str] | str = decoded.get(group_claim, [])
|
||||
groups: list[str] | str = userinfo.get(group_claim, [])
|
||||
if isinstance(groups, str):
|
||||
groups = groups.split(" ")
|
||||
|
||||
@@ -211,7 +205,7 @@ async def login_oidc(
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
request.session["sub"] = decoded[username_claim]
|
||||
request.session["sub"] = username
|
||||
|
||||
# We can't redirect server side, because that results in an infinite loop.
|
||||
# The session token is never correctly set causing any other endpoint to
|
||||
|
||||
@@ -475,7 +475,6 @@ def remove_indexer_flag(
|
||||
DetailedUser, Depends(get_authenticated_user(GroupEnum.admin))
|
||||
],
|
||||
):
|
||||
# TODO: very bad concurrency here
|
||||
flags = quality_config.get_indexer_flags(session)
|
||||
flags = [f for f in flags if f.flag != flag]
|
||||
quality_config.set_indexer_flags(session, flags)
|
||||
|
||||
Reference in New Issue
Block a user