diff --git a/.env.local b/.env.local index 86d17bd..e2fcf29 100644 --- a/.env.local +++ b/.env.local @@ -1,3 +1,4 @@ ABR_APP__CONFIG_DIR=config # Path to the config directory. Default: /config ABR_APP__DEBUG=true # Default: false ABR_APP__OPENAPI_ENABLED=true # Default: false +ABR_APP__LOG_LEVEL=DEBUG diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 2e7d43b..73827aa 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -70,13 +70,18 @@ jobs: MINOR=$(echo $VERSION | cut -d. -f2) PATCH=$(echo $VERSION | cut -d. -f3) echo "::set-output name=tags::${{ secrets.DOCKER_HUB_USERNAME }}/audiobookrequest:$VERSION,${{ secrets.DOCKER_HUB_USERNAME }}/audiobookrequest:$MAJOR.$MINOR,${{ secrets.DOCKER_HUB_USERNAME }}/audiobookrequest:$MAJOR,${{ secrets.DOCKER_HUB_USERNAME }}/audiobookrequest:latest" + echo "::set-output name=version::$VERSION" else echo "::set-output name=tags::${{ secrets.DOCKER_HUB_USERNAME }}/audiobookrequest:nightly" + github_sha_hash=${{ github.sha }} + echo "::set-output name=version::nightly:${github_sha_hash:0:7}" fi - name: Build and push uses: docker/build-push-action@v5 with: - platforms: linux/amd64 + platforms: linux/amd64,linux/arm64 push: true tags: ${{ steps.vars.outputs.tags }} + build-args: | + VERSION=${{ steps.vars.outputs.version }} diff --git a/Dockerfile b/Dockerfile index 86347d9..fbfc7b9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,19 +1,24 @@ -# Install daisyui FROM node:23-alpine3.20 WORKDIR /app +# Install daisyui COPY package.json package.json COPY package-lock.json package-lock.json RUN npm install # Setup python -FROM python:3.11-alpine - +FROM python:3.11-alpine AS linux-amd64 WORKDIR /app - RUN apk add --no-cache curl gcompat build-base RUN curl https://github.com/tailwindlabs/tailwindcss/releases/download/v4.0.6/tailwindcss-linux-x64-musl -L -o /bin/tailwindcss + +FROM python:3.11-alpine AS linux-arm64 +WORKDIR /app +RUN apk add --no-cache curl gcompat build-base +RUN curl https://github.com/tailwindlabs/tailwindcss/releases/download/v4.0.6/tailwindcss-linux-arm64-musl -L -o /bin/tailwindcss + +FROM ${TARGETOS}-${TARGETARCH}${TARGETVARIANT} RUN chmod +x /bin/tailwindcss COPY requirements.txt requirements.txt @@ -32,6 +37,8 @@ COPY app/ app/ RUN /bin/tailwindcss -i styles/globals.css -o static/globals.css -m ENV ABR_APP__PORT=8000 +ARG VERSION +ENV ABR_APP__VERSION=$VERSION CMD alembic upgrade heads && fastapi run --port $ABR_APP__PORT diff --git a/README.md b/README.md index b4989f0..dc0b57f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +![GitHub Release](https://img.shields.io/github/v/release/markbeep/AudioBookRequest) + +[![Discord](https://dcbadge.limes.pink/api/server/https://discord.gg/SsFRXWMg7s)](https://discord.gg/SsFRXWMg7s) + ![Header](/media/AudioBookRequestIcon.png) Your tool for handling audiobook requests on a Plex/Audiobookshelf/Jellyfin instance. @@ -13,12 +17,15 @@ If you've heard of Overseer, Ombi, or Jellyseer; this is in the similar vein, Notifications` and add the URL. 5. Configure the remaining settings. **The event variables are case sensitive**. +### OpenID Connect + +OIDC allows you to use an external authentication service (Authentik, Keycloak, etc.) for user and group authentication. It can be configured in `Settings>Security`. The following six settings are required to successfully set up oidc. Ensure you use the correct values. Incorrect values or changing values on your authentication server in the future can cause lead to locking you out of the service. In those cases head to [`Getting "locked" out`](#getting-locked-out). + +- `well-known` configuration endpoint: This is located at `/realms/{realm-name}/.well-known/openid-configuration` for keycloak or `/application/o/{issuer}/.well-known/openid-configuration` for authentik. +- username claim: The claim that should be used for usernames. The username has to be unique. **NOTE:** Any user logging in with the username of the root admin account will be root admin, no matter what group they're assigned. +- group claim: This is the claim that contains the group of each user. It should either be a string or a list of strings with one of the following case-insensitive values: `untrusted`, `trusted`, or `admin`. Any user without any groups is assigned the `untrusted` role. +- scope: The scopes required to get all the necessary information. The scope `openid` is almost **always** required. You need to add all required scopes to that the username and group claim is available. +- client id +- client secret + +In your auth server settings, make sure you allow for redirecting to `/auth/oidc`. The oidc-login flow will redirect you there after you log in. Additionally, the access token expiry time from the authentication server will be used if provided. This might be fairly low by default. + +Applying settings does not directly invalidate your current session. To test OIDC-settings, press the "log out" button to invalidate your current session. + +#### Getting locked out + +In the case of an OIDC misconfiguration, i.e. changing a setting like your client secret on your auth server, can cause you to be locked out. In these cases, you can head to `/login?backup=1`, where you can log in using your root admin credentials allowing you to correctly configure any settings. + ## Alternative Deployments Docker image is located on [dockerhub](https://hub.docker.com/r/markbeep/audiobookrequest). @@ -78,9 +104,7 @@ services: web: image: markbeep/audiobookrequest:1 ports: - - "8000:8765" - environment: - ABR_APP__PORT: 8765 + - "8000:8000" volumes: - ./config:/config ``` @@ -111,12 +135,9 @@ spec: volumeMounts: - mountPath: /config name: abr-config - env: - - name: ABR_APP__PORT - value: "8765" ports: - name: http-request - containerPort: 8765 + containerPort: 8000 volumes: - name: abr-config hostPath: @@ -131,6 +152,7 @@ spec: | `ABR_APP__DEBUG` | If to enable debug mode. Not recommended for production. | false | | `ABR_APP__OPENAPI_ENABLED` | If set to `true`, enables an OpenAPI specs page on `/docs`. | false | | `ABR_APP__CONFIG_DIR` | The directory path where persistant data and configuration is stored. If ran using Docker or Kubernetes, this is the location a volume should be mounted to. | /config | +| `ABR_APP__LOG_LEVEL` | One of `DEBUG`, `INFO`, `WARN`, `ERROR`. | INFO | | `ABR_DB__SQLITE_PATH` | If relative, path and name of the sqlite database in relation to `ABR_APP__CONFIG_DIR`. If absolute (path starts with `/`), the config dir is ignored and only the absolute path is used. | db.sqlite | --- @@ -194,3 +216,11 @@ browser-sync http://localhost:8000 --files templates/** --files app/** ``` **NOTE**: Website has to be visited at http://localhost:3000 instead. + +## Docker Compose + +The docker compose can also be used to run the app locally: + +```bash +docker compose up --build +``` diff --git a/alembic/versions/873737d287d3_manual_requests_downloaded_flag.py b/alembic/versions/873737d287d3_manual_requests_downloaded_flag.py new file mode 100644 index 0000000..0a5a589 --- /dev/null +++ b/alembic/versions/873737d287d3_manual_requests_downloaded_flag.py @@ -0,0 +1,42 @@ +"""manual requests downloaded flag + +Revision ID: 873737d287d3 +Revises: 76d7ccb8a116 +Create Date: 2025-03-16 09:39:19.684439 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "873737d287d3" +down_revision: Union[str, None] = "76d7ccb8a116" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("manualbookrequest", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "downloaded", + sa.Boolean(), + nullable=False, + server_default="false", + ) + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("manualbookrequest", schema=None) as batch_op: + batch_op.drop_column("downloaded") + + # ### end Alembic commands ### diff --git a/app/internal/auth/authentication.py b/app/internal/auth/authentication.py new file mode 100644 index 0000000..f339081 --- /dev/null +++ b/app/internal/auth/authentication.py @@ -0,0 +1,179 @@ +from math import inf +import time +from typing import Annotated, Optional + +from argon2 import PasswordHasher +from argon2.exceptions import VerifyMismatchError +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import HTTPBasic, OAuth2PasswordBearer, OpenIdConnect +from sqlmodel import Session, select + +from app.internal.auth.config import LoginTypeEnum, auth_config +from app.internal.models import GroupEnum, User +from app.util.db import get_session + + +class DetailedUser(User): + login_type: LoginTypeEnum + + def can_logout(self): + return self.login_type in [LoginTypeEnum.forms, LoginTypeEnum.oidc] + + +def raise_for_invalid_password( + session: Session, + password: str, + confirm_password: Optional[str] = None, + ignore_confirm: bool = False, +): + if not ignore_confirm and password != confirm_password: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Passwords must be equal", + ) + + min_password_length = auth_config.get_min_password_length(session) + if not len(password) >= min_password_length: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Password must be at least {min_password_length} characters long", + ) + + +def is_correct_password(user: User, password: str) -> bool: + try: + return ph.verify(user.password, password) + except VerifyMismatchError: + return False + + +def authenticate_user(session: Session, username: str, password: str) -> Optional[User]: + user = session.get(User, username) + if not user: + return None + + try: + ph.verify(user.password, password) + except VerifyMismatchError: + return None + + if ph.check_needs_rehash(user.password): + user.password = ph.hash(password) + session.add(user) + session.commit() + + return user + + +def create_user( + username: str, + password: str, + group: GroupEnum = GroupEnum.untrusted, + root: bool = False, +) -> User: + password_hash = ph.hash(password) + return User(username=username, password=password_hash, group=group, root=root) + + +class RequiresLoginException(Exception): + def __init__(self, detail: Optional[str] = None, **kwargs: object): + super().__init__(**kwargs) + self.detail = detail + + +class ABRAuth: + def __init__(self): + self.oidc_scheme: Optional[OpenIdConnect] = None + self.none_user: Optional[User] = None + + def get_authenticated_user(self, lowest_allowed_group: GroupEnum): + async def get_user( + request: Request, + session: Annotated[Session, Depends(get_session)], + ) -> DetailedUser: + login_type = auth_config.get_login_type(session) + + if login_type == LoginTypeEnum.forms: + user = await self._get_session_auth(request, session) + elif login_type == LoginTypeEnum.none: + user = await self._get_none_auth(session) + elif login_type == LoginTypeEnum.oidc: + user = await self._get_oidc_auth(request, session) + else: + user = await self._get_basic_auth(request, session) + + if not user.is_above(lowest_allowed_group): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden" + ) + + user = DetailedUser.model_validate(user, update={"login_type": login_type}) + + return user + + return get_user + + async def _get_basic_auth( + self, + request: Request, + session: Session, + ) -> User: + invalid_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials", + headers={"WWW-Authenticate": "Basic"}, + ) + + credentials = await security(request) + if not credentials: + raise invalid_exception + + user = authenticate_user(session, credentials.username, credentials.password) + if not user: + raise invalid_exception + + return user + + async def _get_session_auth( + self, + request: Request, + session: Session, + ) -> User: + # It's enough to get the username from the signed session cookie + username = request.session.get("sub") + if not username: + raise RequiresLoginException() + + user = session.get(User, username) + if not user: + raise RequiresLoginException("User does not exist") + + return user + + async def _get_oidc_auth( + self, + request: Request, + session: Session, + ) -> User: + if request.session.get("exp", inf) < time.time(): + raise RequiresLoginException() + return await self._get_session_auth(request, session) + + async def _get_none_auth(self, session: Session) -> User: + """Treats every request as being root by returning the first admin user""" + if self.none_user: + return self.none_user + self.none_user = session.exec( + select(User).where(User.group == GroupEnum.admin).limit(1) + ).one() + return self.none_user + + +security = HTTPBasic() +ph = PasswordHasher() +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False) +abr_authentication = ABRAuth() + + +def get_authenticated_user(lowest_allowed_group: GroupEnum = GroupEnum.untrusted): + return abr_authentication.get_authenticated_user(lowest_allowed_group) diff --git a/app/internal/auth/config.py b/app/internal/auth/config.py new file mode 100644 index 0000000..96e77c7 --- /dev/null +++ b/app/internal/auth/config.py @@ -0,0 +1,77 @@ +import base64 +import secrets +from enum import Enum +from typing import Literal + +from sqlmodel import Session + +from app.internal.auth.session_middleware import middleware_linker +from app.util.cache import StringConfigCache +from app.util.time import Minute, Second + + +class LoginTypeEnum(str, Enum): + basic = "basic" + forms = "forms" + oidc = "oidc" + none = "none" + + def is_basic(self): + return self == LoginTypeEnum.basic + + def is_forms(self): + return self == LoginTypeEnum.forms + + def is_none(self): + return self == LoginTypeEnum.none + + def is_oidc(self): + return self == LoginTypeEnum.oidc + + +AuthConfigKey = Literal[ + "login_type", + "access_token_expiry_minutes", + "auth_secret", + "min_password_length", +] + + +class AuthConfig(StringConfigCache[AuthConfigKey]): + def get_login_type(self, session: Session) -> LoginTypeEnum: + login_type = self.get(session, "login_type") + if login_type: + return LoginTypeEnum(login_type) + return LoginTypeEnum.basic + + def set_login_type(self, session: Session, login_Type: LoginTypeEnum): + self.set(session, "login_type", login_Type.value) + + def reset_auth_secret(self, session: Session): + auth_secret = base64.encodebytes(secrets.token_bytes(64)).decode("utf-8") + middleware_linker.update_secret(auth_secret) + self.set(session, "auth_secret", auth_secret) + + def get_auth_secret(self, session: Session) -> str: + auth_secret = self.get(session, "auth_secret") + if auth_secret: + return auth_secret + auth_secret = base64.encodebytes(secrets.token_bytes(64)).decode("utf-8") + self.set(session, "auth_secret", auth_secret) + return auth_secret + + def get_access_token_expiry_minutes(self, session: Session) -> Minute: + return Minute(self.get_int(session, "access_token_expiry_minutes", 60 * 24 * 7)) + + def set_access_token_expiry_minutes(self, session: Session, expiry: Minute): + middleware_linker.update_max_age(Second(expiry * 60)) + self.set_int(session, "access_token_expiry_minutes", expiry) + + def get_min_password_length(self, session: Session) -> int: + return self.get_int(session, "min_password_length", 1) + + def set_min_password_length(self, session: Session, min_password_length: int): + self.set_int(session, "min_password_length", min_password_length) + + +auth_config = AuthConfig() diff --git a/app/internal/auth/oidc_config.py b/app/internal/auth/oidc_config.py new file mode 100644 index 0000000..a5d847e --- /dev/null +++ b/app/internal/auth/oidc_config.py @@ -0,0 +1,90 @@ +from typing import Literal, Optional + +from aiohttp import ClientSession +from sqlmodel import Session + +from app.util.cache import StringConfigCache + + +oidcConfigKey = Literal[ + "oidc_client_id", + "oidc_client_secret", + "oidc_scope", + "oidc_username_claim", + "oidc_group_claim", + "oidc_endpoint", + "oidc_token_endpoint", + "oidc_userinfo_endpoint", + "oidc_authorize_endpoint", + "oidc_redirect_https", + "oidc_logout_url", +] + + +class InvalidOIDCConfiguration(Exception): + def __init__(self, detail: Optional[str] = None, **kwargs: object): + super().__init__(**kwargs) + self.detail = detail + + +class oidcConfig(StringConfigCache[oidcConfigKey]): + async def set_endpoint( + self, + session: Session, + client_session: ClientSession, + endpoint: str, + ): + self.set(session, "oidc_endpoint", endpoint) + async with client_session.get(endpoint) as response: + if response.status == 200: + data = await response.json() + self.set( + session, "oidc_authorize_endpoint", data["authorization_endpoint"] + ) + self.set(session, "oidc_token_endpoint", data["token_endpoint"]) + self.set(session, "oidc_userinfo_endpoint", data["userinfo_endpoint"]) + if "end_session_endpoint" in data and not self.get( + session, "oidc_logout_url" + ): + self.set(session, "oidc_logout_url", data["end_session_endpoint"]) + + def get_redirect_https(self, session: Session) -> bool: + if self.get(session, "oidc_redirect_https"): + return True + return False + + async def validate( + self, session: Session, client_session: ClientSession + ) -> Optional[str]: + """ + Returns None if valid, the error message otherwise + """ + endpoint = self.get(session, "oidc_endpoint") + if not endpoint: + return "Missing OIDC endpoint" + async with client_session.get(endpoint) as response: + if not response.ok: + return "Failed to fetch OIDC configuration" + data = await response.json() + + config_scope = self.get(session, "oidc_scope", "").split(" ") + provider_scope = data.get("scopes_supported") + if not provider_scope or not all( + scope in provider_scope for scope in config_scope + ): + return "Scopes are not all supported by the provider" + + provider_claims = data.get("claims_supported") + if not provider_claims: + return "Provider does not support or list claims" + + username_claim = self.get(session, "oidc_username_claim") + if not username_claim or username_claim not in provider_claims: + return "Username claim is not supported by the provider" + + group_claim = self.get(session, "oidc_group_claim") + if group_claim and group_claim not in provider_claims: + return "Group claim is not supported by the provider" + + +oidc_config = oidcConfig() diff --git a/app/internal/auth/session_middleware.py b/app/internal/auth/session_middleware.py new file mode 100644 index 0000000..dd54d09 --- /dev/null +++ b/app/internal/auth/session_middleware.py @@ -0,0 +1,73 @@ +from starlette.middleware.sessions import SessionMiddleware +from starlette.types import ASGIApp, Receive, Scope, Send + +from app.util.time import Second + + +class DynamicSessionMiddleware: + """ + A wrapper around the Starlette SessionMiddleware with the ability to + change options during run-time + https://www.starlette.io/middleware/#sessionmiddleware + """ + + def __init__( + self, + app: ASGIApp, + secret_key: str, + linker: "DynamicMiddlewareLinker", + max_age: Second = Second(60 * 60 * 24 * 14), + ): + self.app = app + self.secret_key = secret_key + self.expiry = max_age + self.session_middleware = SessionMiddleware( + app, + secret_key, + same_site="strict", + max_age=max_age, + ) + linker.add_middleware(self) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + return await self.session_middleware(scope, receive, send) + + def update_secret(self, secret_key: str): + self.session_middleware = SessionMiddleware( + self.app, + secret_key, + same_site="strict", + max_age=self.expiry, + ) + + def update_max_age(self, max_age: Second): + self.session_middleware = SessionMiddleware( + self.app, + self.secret_key, + same_site="strict", + max_age=max_age, + ) + + +class DynamicMiddlewareLinker: + """ + Linker is passed in as an argument to the DynamicSessionMiddleware so + wherever FastAPI initializes the middleware, we can update + the options to take effect immediately instead of having to restart the server + """ + + middlewares: list[DynamicSessionMiddleware] = [] + + def add_middleware(self, middleware: DynamicSessionMiddleware): + self.middlewares.append(middleware) + + def update_secret(self, secret_key: str): + for middleware in self.middlewares: + middleware.update_secret(secret_key) + + def update_max_age(self, expiry: Second): + for middleware in self.middlewares: + middleware.update_max_age(expiry) + + +middleware_linker = DynamicMiddlewareLinker() diff --git a/app/internal/env_settings.py b/app/internal/env_settings.py index 8ddfd6c..0c6b8a9 100644 --- a/app/internal/env_settings.py +++ b/app/internal/env_settings.py @@ -13,6 +13,8 @@ class ApplicationSettings(BaseModel): openapi_enabled: bool = False config_dir: str = "/config" port: int = 8000 + version: str = "local" + log_level: str = "INFO" class Settings(BaseSettings): @@ -20,7 +22,7 @@ class Settings(BaseSettings): env_prefix="ABR_", env_nested_delimiter="__", nested_model_default_partial_update=True, - env_file=".env.local", + env_file=(".env.local", ".env"), ) db: DBSettings = DBSettings() diff --git a/app/internal/models.py b/app/internal/models.py index d7bde62..4e2044e 100644 --- a/app/internal/models.py +++ b/app/internal/models.py @@ -25,6 +25,12 @@ class User(BaseModel, table=True): sa_column_kwargs={"server_default": "untrusted"}, ) root: bool = False + + # TODO: Add last_login + # last_login: datetime = Field( + # default_factory=datetime.now, sa_column_kwargs={"server_default": "now()"} + # ) + """ untrusted: Requests need to be manually reviewed trusted: Requests are automatically downloaded if possible @@ -72,9 +78,13 @@ class BookSearchResult(BaseBook): class BookWishlistResult(BaseBook): - amount_requested: int = 0 + requested_by: list[str] = [] download_error: Optional[str] = None + @property + def amount_requested(self): + return len(self.requested_by) + class BookRequest(BaseBook, table=True): id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) @@ -117,6 +127,7 @@ class ManualBookRequest(BaseModel, table=True): nullable=False, ), ) + downloaded: bool = False class Config: # pyright: ignore[reportIncompatibleVariableOverride] arbitrary_types_allowed = True @@ -131,7 +142,7 @@ class BaseSource(BaseModel): narrators: list[str] = Field(default_factory=list, sa_column=Column(JSON)) size: int # in bytes publish_date: datetime - info_url: str + info_url: Optional[str] indexer_flags: list[str] download_url: Optional[str] = None magnet_url: Optional[str] = None @@ -169,14 +180,6 @@ class Config(BaseModel, table=True): value: str -# TODO: add logs -class Log(BaseModel): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - user_username: str - message: str - timestamp: datetime = Field(default_factory=datetime.now) - - class EventEnum(str, Enum): on_new_request = "onNewRequest" on_successful_download = "onSuccessfulDownload" diff --git a/app/internal/notifications.py b/app/internal/notifications.py index b2bdc71..781b129 100644 --- a/app/internal/notifications.py +++ b/app/internal/notifications.py @@ -1,3 +1,4 @@ +import logging from typing import Optional from aiohttp import ClientSession @@ -6,6 +7,8 @@ from sqlmodel import Session, select from app.internal.models import BookRequest, EventEnum, ManualBookRequest, Notification from app.util.db import open_session +logger = logging.getLogger(__name__) + def replace_variables( title_template: str, @@ -137,5 +140,5 @@ async def send_manual_notification( response.raise_for_status() return await response.json() except Exception as e: - print("Failed to send notification", e) + logger.error("Failed to send notification", e) return None diff --git a/app/internal/prowlarr/prowlarr.py b/app/internal/prowlarr/prowlarr.py index a9fe05f..706dc11 100644 --- a/app/internal/prowlarr/prowlarr.py +++ b/app/internal/prowlarr/prowlarr.py @@ -1,8 +1,9 @@ import json import logging from datetime import datetime +import posixpath from typing import Any, Literal, Optional -from urllib.parse import urlencode, urljoin +from urllib.parse import urlencode from aiohttp import ClientResponse, ClientSession from sqlmodel import Session @@ -96,7 +97,7 @@ async def start_download( api_key = prowlarr_config.get_api_key(session) assert base_url is not None and api_key is not None - url = urljoin(base_url, "/api/v1/search") + url = posixpath.join(base_url, "api/v1/search") logger.debug("Starting download for %s", guid) async with client_session.post( url, @@ -104,7 +105,6 @@ async def start_download( headers={"X-Api-Key": api_key}, ) as response: if not response.ok: - print(response) logger.error("Failed to start download for %s: %s", guid, response) await send_all_notifications( EventEnum.on_failed_download, @@ -157,7 +157,7 @@ async def query_prowlarr( if indexer_ids is not None: params["indexerIds"] = indexer_ids - url = urljoin(base_url, f"/api/v1/search?{urlencode(params, doseq=True)}") + url = posixpath.join(base_url, f"api/v1/search?{urlencode(params, doseq=True)}") logger.info("Querying prowlarr: %s", url) @@ -169,44 +169,53 @@ async def query_prowlarr( sources: list[ProwlarrSource] = [] for result in search_results: - if result["protocol"] not in ["torrent", "usenet"]: - print("Skipping source with unknown protocol", result["protocol"]) - continue - if result["protocol"] == "torrent": - sources.append( - TorrentSource( - protocol="torrent", - guid=result["guid"], - indexer_id=result["indexerId"], - indexer=result["indexer"], - title=result["title"], - seeders=result.get("seeders", 0), - leechers=result.get("leechers", 0), - size=result.get("size", 0), - info_url=result["infoUrl"], - indexer_flags=[x.lower() for x in result.get("indexerFlags", [])], - download_url=result.get("downloadUrl"), - magnet_url=result.get("magnetUrl"), - publish_date=datetime.fromisoformat(result["publishDate"]), + try: + if result["protocol"] not in ["torrent", "usenet"]: + logger.info( + "Skipping source with unknown protocol %s", result["protocol"] ) - ) - else: - sources.append( - UsenetSource( - protocol="usenet", - guid=result["guid"], - indexer_id=result["indexerId"], - indexer=result["indexer"], - title=result["title"], - grabs=result.get("grabs"), - size=result.get("size", 0), - info_url=result["infoUrl"], - indexer_flags=[x.lower() for x in result.get("indexerFlags", [])], - download_url=result.get("downloadUrl"), - magnet_url=result.get("magnetUrl"), - publish_date=datetime.fromisoformat(result["publishDate"]), + continue + if result["protocol"] == "torrent": + sources.append( + TorrentSource( + protocol="torrent", + guid=result["guid"], + indexer_id=result["indexerId"], + indexer=result["indexer"], + title=result["title"], + seeders=result.get("seeders", 0), + leechers=result.get("leechers", 0), + size=result.get("size", 0), + info_url=result.get("infoUrl"), + indexer_flags=[ + x.lower() for x in result.get("indexerFlags", []) + ], + download_url=result.get("downloadUrl"), + magnet_url=result.get("magnetUrl"), + publish_date=datetime.fromisoformat(result["publishDate"]), + ) ) - ) + else: + sources.append( + UsenetSource( + protocol="usenet", + guid=result["guid"], + indexer_id=result["indexerId"], + indexer=result["indexer"], + title=result["title"], + grabs=result.get("grabs"), + size=result.get("size", 0), + info_url=result.get("infoUrl"), + indexer_flags=[ + x.lower() for x in result.get("indexerFlags", []) + ], + download_url=result.get("downloadUrl"), + magnet_url=result.get("magnetUrl"), + publish_date=datetime.fromisoformat(result["publishDate"]), + ) + ) + except KeyError as e: + logger.error("Failed to parse source: %s. KeyError: %s", result, e) prowlarr_source_cache.set(sources, query) diff --git a/app/main.py b/app/main.py index 236bdc6..8135655 100644 --- a/app/main.py +++ b/app/main.py @@ -1,28 +1,56 @@ +import logging +from pathlib import Path from typing import Any -from urllib.parse import quote_plus +from urllib.parse import quote_plus, urlencode from fastapi import FastAPI, HTTPException, Request, status +from fastapi.middleware import Middleware +from fastapi.middleware.gzip import GZipMiddleware from fastapi.responses import RedirectResponse from sqlalchemy import func from sqlmodel import select +from app.internal.auth.authentication import RequiresLoginException, auth_config +from app.internal.auth.oidc_config import InvalidOIDCConfiguration +from app.internal.auth.session_middleware import ( + DynamicSessionMiddleware, + middleware_linker, +) from app.internal.env_settings import Settings from app.internal.models import User -from app.routers import root, search, settings, wishlist -from app.util.auth import RequiresLoginException +from app.routers import auth, root, search, settings, wishlist from app.util.db import open_session +from app.util.templates import templates +from app.util.toast import ToastException + +logger = logging.getLogger(__name__) +logging.getLogger("uvicorn").handlers.clear() +file_handler = logging.FileHandler(Settings().app.config_dir / Path("abr.log")) +stream_handler = logging.StreamHandler() +logging.basicConfig( + level=Settings().app.log_level, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + handlers=[file_handler, stream_handler], +) + +with open_session() as session: + auth_secret = auth_config.get_auth_secret(session) app = FastAPI( title="AudioBookRequest", debug=Settings().app.debug, openapi_url="/openapi.json" if Settings().app.openapi_enabled else None, + middleware=[ + Middleware(DynamicSessionMiddleware, auth_secret, middleware_linker), + Middleware(GZipMiddleware), + ], ) - +app.include_router(auth.router) app.include_router(root.router) app.include_router(search.router) -app.include_router(wishlist.router) app.include_router(settings.router) +app.include_router(wishlist.router) user_exists = False @@ -30,9 +58,13 @@ user_exists = False @app.exception_handler(RequiresLoginException) async def redirect_to_login(request: Request, exc: RequiresLoginException): if request.method == "GET": + params: dict[str, str] = {} if exc.detail: - return RedirectResponse(f"/login?error={quote_plus(exc.detail)}") - return RedirectResponse("/login") + params["error"] = exc.detail + path = request.url.path + if path != "/" and not path.startswith("/login"): + params["redirect_uri"] = path + return RedirectResponse("/login?" + urlencode(params)) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -41,6 +73,32 @@ async def redirect_to_login(request: Request, exc: RequiresLoginException): ) +@app.exception_handler(InvalidOIDCConfiguration) +async def redirect_to_invalid_oidc(request: Request, exc: InvalidOIDCConfiguration): + path = "/auth/invalid-oidc" + if exc.detail: + path += f"?error={quote_plus(exc.detail)}" + return RedirectResponse(path) + + +@app.exception_handler(ToastException) +async def raise_toast(request: Request, exc: ToastException): + context: dict[str, Request | str] = {"request": request} + if exc.type == "error": + context["toast_error"] = exc.message + elif exc.type == "success": + context["toast_success"] = exc.message + elif exc.type == "info": + context["toast_info"] = exc.message + + return templates.TemplateResponse( + "base.html", + context, + block_name="toast_block", + headers={"HX-Retarget": "#toast-block"}, + ) + + @app.middleware("http") async def redirect_to_init(request: Request, call_next: Any): """ diff --git a/app/routers/auth.py b/app/routers/auth.py new file mode 100644 index 0000000..368f3ea --- /dev/null +++ b/app/routers/auth.py @@ -0,0 +1,279 @@ +import base64 +import logging +import secrets +import time +from typing import Annotated, Optional +from urllib.parse import urlencode, urljoin + +import jwt +from aiohttp import ClientSession +from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, status +from fastapi.responses import RedirectResponse +from fastapi.security import OAuth2PasswordRequestForm +from sqlmodel import Session, select + +from app.internal.auth.authentication import ( + DetailedUser, + RequiresLoginException, + authenticate_user, + create_user, + get_authenticated_user, +) +from app.internal.auth.config import LoginTypeEnum, auth_config +from app.internal.auth.oidc_config import InvalidOIDCConfiguration, oidc_config +from app.internal.models import GroupEnum, User +from app.util.connection import get_connection +from app.util.db import get_session +from app.util.templates import templates +from app.util.toast import ToastException + +router = APIRouter(prefix="/auth") + +logger = logging.getLogger(__name__) + + +@router.get("/login") +async def login( + request: Request, + session: Annotated[Session, Depends(get_session)], + redirect_uri: str = "/", + backup: bool = False, +): + login_type = auth_config.get(session, "login_type") + if login_type in [LoginTypeEnum.basic, LoginTypeEnum.none]: + return RedirectResponse(redirect_uri) + if login_type != LoginTypeEnum.oidc and backup: + backup = False + + try: + await get_authenticated_user()(request, session) + # already logged in + return RedirectResponse(redirect_uri) + except (HTTPException, RequiresLoginException): + pass + + if login_type != LoginTypeEnum.oidc or backup: + return templates.TemplateResponse( + "login.html", + { + "request": request, + "hide_navbar": True, + "redirect_uri": redirect_uri, + "backup": backup, + }, + ) + + authorize_endpoint = oidc_config.get(session, "oidc_authorize_endpoint") + client_id = oidc_config.get(session, "oidc_client_id") + scope = oidc_config.get(session, "oidc_scope") or "openid" + if not authorize_endpoint: + raise InvalidOIDCConfiguration("Missing OIDC endpoint") + if not client_id: + raise InvalidOIDCConfiguration("Missing OIDC client ID") + + auth_redirect_uri = urljoin(str(request.url), "/auth/oidc") + if oidc_config.get_redirect_https(session): + auth_redirect_uri = auth_redirect_uri.replace("http:", "https:") + + logger.info(f"Redirecting to OIDC login: {authorize_endpoint}") + logger.info(f"Redirect URI: {auth_redirect_uri}") + + state = jwt.encode( # pyright: ignore[reportUnknownMemberType] + {"redirect_uri": redirect_uri}, + auth_config.get_auth_secret(session), + algorithm="HS256", + ) + + params = { + "response_type": "code", + "client_id": client_id, + "redirect_uri": auth_redirect_uri, + "scope": scope, + "state": state, + } + return RedirectResponse(f"{authorize_endpoint}?" + urlencode(params)) + + +@router.post("/logout") +async def logout( + request: Request, + user: Annotated[DetailedUser, Depends(get_authenticated_user())], + session: Annotated[Session, Depends(get_session)], +): + request.session["sub"] = "" + + login_type = auth_config.get_login_type(session) + if login_type == LoginTypeEnum.oidc: + logout_url = oidc_config.get(session, "oidc_logout_url") + if logout_url: + return Response( + status_code=status.HTTP_204_NO_CONTENT, + headers={"HX-Redirect": logout_url}, + ) + return Response( + status_code=status.HTTP_204_NO_CONTENT, headers={"HX-Redirect": "/login"} + ) + + +@router.post("/token") +def login_access_token( + request: Request, + session: Annotated[Session, Depends(get_session)], + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + redirect_uri: str = Form("/"), +): + user = authenticate_user(session, form_data.username, form_data.password) + if not user: + raise ToastException("Invalid login", "error") + + # only admins can use the backup forms login + login_type = auth_config.get_login_type(session) + if login_type == LoginTypeEnum.oidc and not user.root: + raise ToastException("Not root admin", "error") + + request.session["sub"] = form_data.username + return Response( + status_code=status.HTTP_200_OK, headers={"HX-Redirect": redirect_uri} + ) + + +@router.get("/oidc") +async def login_oidc( + request: Request, + session: Annotated[Session, Depends(get_session)], + client_session: Annotated[ClientSession, Depends(get_connection)], + code: str, + state: Optional[str] = None, +): + token_endpoint = oidc_config.get(session, "oidc_token_endpoint") + userinfo_endpoint = oidc_config.get(session, "oidc_userinfo_endpoint") + client_id = oidc_config.get(session, "oidc_client_id") + client_secret = oidc_config.get(session, "oidc_client_secret") + username_claim = oidc_config.get(session, "oidc_username_claim") + group_claim = oidc_config.get(session, "oidc_group_claim") + + if not token_endpoint: + raise InvalidOIDCConfiguration("Missing OIDC endpoint") + if not userinfo_endpoint: + raise InvalidOIDCConfiguration("Missing OIDC userinfo endpoint") + if not client_id: + raise InvalidOIDCConfiguration("Missing OIDC client ID") + if not client_secret: + raise InvalidOIDCConfiguration("Missing OIDC client secret") + if not username_claim: + raise InvalidOIDCConfiguration("Missing OIDC username claim") + + auth_redirect_uri = urljoin(str(request.url), "/auth/oidc") + if oidc_config.get_redirect_https(session): + auth_redirect_uri = auth_redirect_uri.replace("http:", "https:") + + data = { + "grant_type": "authorization_code", + "code": code, + "client_id": client_id, + "client_secret": client_secret, + "redirect_uri": auth_redirect_uri, + } + async with client_session.post( + token_endpoint, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) as response: + body = await response.json() + + access_token: Optional[str] = body.get("access_token") + if not access_token: + return Response(status_code=status.HTTP_401_UNAUTHORIZED) + + async with client_session.get( + userinfo_endpoint, + headers={"Authorization": f"Bearer {access_token}"}, + ) as response: + userinfo = await response.json() + + username = userinfo.get(username_claim) + if not username: + raise InvalidOIDCConfiguration("Missing username claim") + + if group_claim: + groups: list[str] | str = userinfo.get(group_claim, []) + if isinstance(groups, str): + groups = groups.split(" ") + else: + groups = [] + + user = session.exec(select(User).where(User.username == username)).first() + if not user: + user = create_user( + username=username, + # assign a random password to users created via OIDC + password=base64.encodebytes(secrets.token_bytes(64)).decode("utf-8"), + ) + + # Don't overwrite the group if the user is root admin + if not user.root: + for group in groups: + if group.lower() == "admin": + user.group = GroupEnum.admin + break + elif group.lower() == "trusted": + user.group = GroupEnum.trusted + break + elif group.lower() == "untrusted": + user.group = GroupEnum.untrusted + break + + session.add(user) + session.commit() + + expires_in: int = body.get( + "expires_in", + auth_config.get_access_token_expiry_minutes(session) * 60, + ) + expires = int(time.time() + expires_in) + + request.session["sub"] = username + request.session["exp"] = expires + + if state: + decoded = jwt.decode( # pyright: ignore[reportUnknownMemberType] + state, + auth_config.get_auth_secret(session), + algorithms=["HS256"], + ) + redirect_uri = decoded.get("redirect_uri", "/") + else: + redirect_uri = "/" + + # We can't redirect server side, because that results in an infinite loop. + # The session token is never correctly set causing any other endpoint to + # redirect to the login page which in turn starts the OIDC flow again. + # The redirect page allows for the cookie to properly be set on the browser + # and then redirects client-side. + return templates.TemplateResponse( + "redirect.html", + { + "request": request, + "hide_navbar": True, + "redirect_uri": redirect_uri, + }, + ) + + +@router.get("/invalid-oidc") +def invalid_oidc( + request: Request, + session: Annotated[Session, Depends(get_session)], + error: Optional[str] = None, +): + if auth_config.get_login_type(session) != LoginTypeEnum.oidc: + return Response(status_code=status.HTTP_404_NOT_FOUND) + return templates.TemplateResponse( + "invalid_oidc.html", + { + "request": request, + "error": error, + "hide_navbar": True, + }, + status_code=status.HTTP_200_OK, + ) diff --git a/app/routers/root.py b/app/routers/root.py index e5a5475..9cfebb2 100644 --- a/app/routers/root.py +++ b/app/routers/root.py @@ -1,23 +1,18 @@ -from datetime import timedelta -from typing import Annotated, Optional +from typing import Annotated +from urllib.parse import urlencode -from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, status +from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response from fastapi.responses import FileResponse, RedirectResponse -from fastapi.security import OAuth2PasswordRequestForm from sqlmodel import Session -from app.internal.models import GroupEnum -from app.util.auth import ( +from app.internal.auth.config import LoginTypeEnum, auth_config +from app.internal.auth.authentication import ( DetailedUser, - LoginTypeEnum, - RequiresLoginException, - auth_config, - authenticate_user, - create_access_token, create_user, get_authenticated_user, raise_for_invalid_password, ) +from app.internal.models import GroupEnum from app.util.db import get_session from app.util.templates import templates @@ -120,65 +115,5 @@ def create_init( @router.get("/login") -async def login( - request: Request, - session: Annotated[Session, Depends(get_session)], - error: Optional[str] = None, -): - login_type = auth_config.get(session, "login_type") - if login_type != LoginTypeEnum.forms: - return RedirectResponse("/") - - try: - await get_authenticated_user()(request, session) - # already logged in - return RedirectResponse("/") - except (HTTPException, RequiresLoginException): - pass - - return templates.TemplateResponse( - "login.html", - {"request": request, "hide_navbar": True, "error": error}, - ) - - -@router.post("/auth/logout") -def logout(user: Annotated[DetailedUser, Depends(get_authenticated_user())]): - return Response( - status_code=status.HTTP_204_NO_CONTENT, - headers={ - "Set-Cookie": "audio_sess=; Path=/; SameSite=Strict; HttpOnly; Expires=Thu, 01 Jan 1970 00:00:00 GMT", - "HX-Redirect": "/login", - }, - ) - - -@router.post("/auth/token") -def login_access_token( - request: Request, - session: Annotated[Session, Depends(get_session)], - form_data: Annotated[OAuth2PasswordRequestForm, Depends()], -): - user = authenticate_user(session, form_data.username, form_data.password) - if not user: - return templates.TemplateResponse( - "login.html", - {"request": request, "hide_navbar": True, "error": "Invalid login"}, - block_name="error_toast", - ) - - access_token_expires_minues = auth_config.get_access_token_expiry_minutes(session) - access_token_exires = timedelta(minutes=access_token_expires_minues) - access_token = create_access_token( - auth_config.get_auth_secret(session), - {"sub": form_data.username}, - access_token_exires, - ) - - return Response( - status_code=status.HTTP_200_OK, - headers={ - "HX-Redirect": "/", - "Set-Cookie": f"audio_sess={access_token}; Path=/; SameSite=Strict; HttpOnly; ", - }, - ) +def redirect_login(request: Request): + return RedirectResponse("/auth/login?" + urlencode(request.query_params)) diff --git a/app/routers/search.py b/app/routers/search.py index da309db..ec797eb 100644 --- a/app/routers/search.py +++ b/app/routers/search.py @@ -36,7 +36,7 @@ from app.internal.prowlarr.prowlarr import prowlarr_config from app.internal.query import query_sources from app.internal.ranking.quality import quality_config from app.routers.wishlist import get_wishlist_books -from app.util.auth import DetailedUser, get_authenticated_user +from app.internal.auth.authentication import DetailedUser, get_authenticated_user from app.util.connection import get_connection from app.util.db import get_session, open_session from app.util.templates import template_response @@ -170,12 +170,6 @@ async def add_request( except sa.exc.IntegrityError: pass # ignore if already exists - if quality_config.get_auto_download(session) and user.is_above(GroupEnum.trusted): - # start querying and downloading if auto download is enabled - background_task.add_task( - background_start_query, asin=asin, requester_username=user.username - ) - background_task.add_task( send_all_notifications, event_type=EventEnum.on_new_request, @@ -183,6 +177,12 @@ async def add_request( book_asin=asin, ) + if quality_config.get_auto_download(session) and user.is_above(GroupEnum.trusted): + # start querying and downloading if auto download is enabled + background_task.add_task( + background_start_query, asin=asin, requester_username=user.username + ) + if audible_regions.get(region) is None: raise HTTPException(status_code=400, detail="Invalid region") if query: diff --git a/app/routers/settings.py b/app/routers/settings.py index d82323b..4e02a6b 100644 --- a/app/routers/settings.py +++ b/app/routers/settings.py @@ -6,25 +6,28 @@ from aiohttp import ClientResponseError, ClientSession from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response from sqlmodel import Session, select -from app.internal.models import EventEnum, GroupEnum, Notification, User -from app.internal.prowlarr.indexer_categories import indexer_categories -from app.internal.notifications import send_notification -from app.internal.prowlarr.prowlarr import flush_prowlarr_cache, prowlarr_config -from app.internal.indexers.mam import mam_config -from app.internal.ranking.quality import IndexerFlag, QualityRange, quality_config -from app.util.auth import ( +from app.internal.auth.authentication import ( DetailedUser, - LoginTypeEnum, - auth_config, create_user, get_authenticated_user, is_correct_password, raise_for_invalid_password, ) +from app.internal.auth.config import LoginTypeEnum, auth_config +from app.internal.auth.oidc_config import oidc_config +from app.internal.env_settings import Settings +from app.internal.models import EventEnum, GroupEnum, Notification, User +from app.internal.notifications import send_notification +from app.internal.prowlarr.indexer_categories import indexer_categories +from app.internal.prowlarr.prowlarr import flush_prowlarr_cache, prowlarr_config +from app.internal.indexers.mam import mam_config +from app.internal.ranking.quality import IndexerFlag, QualityRange, quality_config from app.util.connection import get_connection from app.util.db import get_session from app.util.templates import template_response +from app.util.time import Minute +from app.util.toast import ToastException router = APIRouter(prefix="/settings") @@ -35,7 +38,10 @@ def read_account( user: Annotated[DetailedUser, Depends(get_authenticated_user())], ): return template_response( - "settings_page/account.html", request, user, {"page": "account"} + "settings_page/account.html", + request, + user, + {"page": "account", "version": Settings().app.version}, ) @@ -49,25 +55,11 @@ def change_password( user: Annotated[DetailedUser, Depends(get_authenticated_user())], ): if not is_correct_password(user, old_password): - return template_response( - "settings_page/account.html", - request, - user, - {"page": "account", "error": "Old password is incorrect"}, - block_name="error", - headers={"HX-Retarget": "#error"}, - ) + raise ToastException("Old password is incorrect", "error") try: raise_for_invalid_password(session, password, confirm_password) except HTTPException as e: - return template_response( - "settings_page/account.html", - request, - user, - {"page": "account", "error": e.detail}, - block_name="error", - headers={"HX-Retarget": "#error"}, - ) + raise ToastException(e.detail, "error") new_user = create_user(user.username, password, user.group) old_user = session.exec(select(User).where(User.username == user.username)).one() @@ -93,11 +85,17 @@ def read_users( session: Annotated[Session, Depends(get_session)], ): users = session.exec(select(User)).all() + is_oidc = auth_config.get_login_type(session) == LoginTypeEnum.oidc return template_response( "settings_page/users.html", request, admin_user, - {"page": "users", "users": users}, + { + "page": "users", + "users": users, + "is_oidc": is_oidc, + "version": Settings().app.version, + }, ) @@ -113,49 +111,21 @@ def create_new_user( ], ): if username.strip() == "": - return template_response( - "settings_page/users.html", - request, - admin_user, - {"error": "Invalid username"}, - block_name="toast_block", - headers={"HX-Retarget": "#toast-block"}, - ) + raise ToastException("Invalid username", "error") try: raise_for_invalid_password(session, password, ignore_confirm=True) except HTTPException as e: - return template_response( - "settings_page/users.html", - request, - admin_user, - {"error": e.detail}, - block_name="toast_block", - headers={"HX-Retarget": "#toast-block"}, - ) + raise ToastException(e.detail, "error") if group not in GroupEnum.__members__: - return template_response( - "settings_page/users.html", - request, - admin_user, - {"error": "Invalid group selected"}, - block_name="toast_block", - headers={"HX-Retarget": "#toast-block"}, - ) + raise ToastException("Invalid group selected", "error") group = GroupEnum[group] user = session.exec(select(User).where(User.username == username)).first() if user: - return template_response( - "settings_page/users.html", - request, - admin_user, - {"error": "Username already exists"}, - block_name="toast_block", - headers={"HX-Retarget": "#toast-block"}, - ) + raise ToastException("Username already exists", "error") user = create_user(username, password, group) session.add(user) @@ -182,26 +152,11 @@ def delete_user( ], ): if username == admin_user.username: - users = session.exec(select(User)).all() - return template_response( - "settings_page/users.html", - request, - admin_user, - {"error": "Cannot delete own user"}, - block_name="toast_block", - headers={"HX-Retarget": "#toast-block"}, - ) + raise ToastException("Cannot delete own user", "error") user = session.exec(select(User).where(User.username == username)).one_or_none() if user and user.root: - return template_response( - "settings_page/users.html", - request, - admin_user, - {"error": "Cannot delete root user"}, - block_name="toast_block", - headers={"HX-Retarget": "#toast-block"}, - ) + raise ToastException("Cannot delete root user", "error") if user: session.delete(user) @@ -230,14 +185,7 @@ def update_user( ): user = session.exec(select(User).where(User.username == username)).one_or_none() if user and user.root: - return template_response( - "settings_page/users.html", - request, - admin_user, - {"error": "Cannot change root user"}, - block_name="toast_block", - headers={"HX-Retarget": "#toast-block"}, - ) + raise ToastException("Cannot change root user's group", "error") if user: user.group = group @@ -280,6 +228,7 @@ def read_prowlarr( "indexer_categories": indexer_categories, "selected_categories": selected, "prowlarr_misconfigured": True if prowlarr_misconfigured else False, + "version": Settings().app.version, "mam_active": mam_is_active, "mam_id": mam_id, }, @@ -405,6 +354,7 @@ def read_download( "name_ratio": name_ratio, "title_ratio": title_ratio, "indexer_flags": flags, + "version": Settings().app.version, }, ) @@ -514,7 +464,6 @@ def remove_indexer_flag( DetailedUser, Depends(get_authenticated_user(GroupEnum.admin)) ], ): - # TODO: very bad concurrency here flags = quality_config.get_indexer_flags(session) flags = [f for f in flags if f.flag != flag] quality_config.set_indexer_flags(session, flags) @@ -545,6 +494,7 @@ def read_notifications( "page": "notifications", "notifications": notifications, "event_types": event_types, + "version": Settings().app.version, }, ) @@ -649,7 +599,6 @@ async def execute_notification( DetailedUser, Depends(get_authenticated_user(GroupEnum.admin)) ], session: Annotated[Session, Depends(get_session)], - client_session: Annotated[ClientSession, Depends(get_connection)], ): notification = session.exec( select(Notification).where(Notification.id == notification_id) @@ -682,6 +631,15 @@ def read_security( "login_type": auth_config.get_login_type(session), "access_token_expiry": auth_config.get_access_token_expiry_minutes(session), "min_password_length": auth_config.get_min_password_length(session), + "oidc_endpoint": oidc_config.get(session, "oidc_endpoint", ""), + "oidc_client_secret": oidc_config.get(session, "oidc_client_secret", ""), + "oidc_client_id": oidc_config.get(session, "oidc_client_id", ""), + "oidc_scope": oidc_config.get(session, "oidc_scope", ""), + "oidc_username_claim": oidc_config.get(session, "oidc_username_claim", ""), + "oidc_group_claim": oidc_config.get(session, "oidc_group_claim", ""), + "oidc_redirect_https": oidc_config.get_redirect_https(session), + "oidc_logout_url": oidc_config.get(session, "oidc_logout_url", ""), + "version": Settings().app.version, }, ) @@ -698,40 +656,72 @@ def reset_auth_secret( @router.post("/security") -def update_security( +async def update_security( login_type: Annotated[LoginTypeEnum, Form()], - access_token_expiry: Annotated[int, Form()], - min_password_length: Annotated[int, Form()], request: Request, admin_user: Annotated[ DetailedUser, Depends(get_authenticated_user(GroupEnum.admin)) ], session: Annotated[Session, Depends(get_session)], + client_session: Annotated[ClientSession, Depends(get_connection)], + access_token_expiry: Optional[int] = Form(None), + min_password_length: Optional[int] = Form(None), + oidc_endpoint: Optional[str] = Form(None), + oidc_client_id: Optional[str] = Form(None), + oidc_client_secret: Optional[str] = Form(None), + oidc_scope: Optional[str] = Form(None), + oidc_username_claim: Optional[str] = Form(None), + oidc_group_claim: Optional[str] = Form(None), + oidc_redirect_https: Optional[bool] = Form(False), + oidc_logout_url: Optional[str] = Form(None), ): - if access_token_expiry < 1: - return template_response( - "settings_page/security.html", - request, - admin_user, - {"error": "Access token expiry can't be 0 or negative"}, - block_name="error_toast", - headers={"HX-Retarget": "#message"}, - ) + if ( + login_type in [LoginTypeEnum.basic, LoginTypeEnum.forms] + and min_password_length is not None + ): + if min_password_length < 1: + raise ToastException( + "Minimum password length can't be 0 or negative", "error" + ) + else: + auth_config.set_min_password_length(session, min_password_length) - if min_password_length < 1: - return template_response( - "settings_page/security.html", - request, - admin_user, - {"error": "Minimum password length can't be 0 or negative"}, - block_name="error_toast", - headers={"HX-Retarget": "#message"}, - ) + if access_token_expiry is not None: + if access_token_expiry < 1: + raise ToastException("Access token expiry can't be 0 or negative", "error") + else: + auth_config.set_access_token_expiry_minutes( + session, Minute(access_token_expiry) + ) + + if login_type == LoginTypeEnum.oidc: + if oidc_endpoint: + await oidc_config.set_endpoint(session, client_session, oidc_endpoint) + if oidc_client_id: + oidc_config.set(session, "oidc_client_id", oidc_client_id) + if oidc_client_secret: + oidc_config.set(session, "oidc_client_secret", oidc_client_secret) + if oidc_scope: + oidc_config.set(session, "oidc_scope", oidc_scope) + if oidc_username_claim: + oidc_config.set(session, "oidc_username_claim", oidc_username_claim) + if oidc_redirect_https is not None: + oidc_config.set( + session, + "oidc_redirect_https", + "true" if oidc_redirect_https else "", + ) + if oidc_logout_url: + oidc_config.set(session, "oidc_logout_url", oidc_logout_url) + if oidc_group_claim is not None: + oidc_config.set(session, "oidc_group_claim", oidc_group_claim) + + error_message = await oidc_config.validate(session, client_session) + if error_message: + raise ToastException(error_message, "error") old = auth_config.get_login_type(session) auth_config.set_login_type(session, login_type) - auth_config.set_access_token_expiry_minutes(session, access_token_expiry) - auth_config.set_min_password_length(session, min_password_length) return template_response( "settings_page/security.html", request, @@ -740,6 +730,14 @@ def update_security( "page": "security", "login_type": auth_config.get_login_type(session), "access_token_expiry": auth_config.get_access_token_expiry_minutes(session), + "oidc_client_id": oidc_config.get(session, "oidc_client_id", ""), + "oidc_scope": oidc_config.get(session, "oidc_scope", ""), + "oidc_username_claim": oidc_config.get(session, "oidc_username_claim", ""), + "oidc_group_claim": oidc_config.get(session, "oidc_group_claim", ""), + "oidc_client_secret": oidc_config.get(session, "oidc_client_secret", ""), + "oidc_endpoint": oidc_config.get(session, "oidc_endpoint", ""), + "oidc_redirect_https": oidc_config.get_redirect_https(session), + "oidc_logout_url": oidc_config.get(session, "oidc_logout_url", ""), "success": "Settings updated", }, block_name="form", diff --git a/app/routers/wishlist.py b/app/routers/wishlist.py index 3d4eb91..41a9c6d 100644 --- a/app/routers/wishlist.py +++ b/app/routers/wishlist.py @@ -1,3 +1,4 @@ +from collections import defaultdict import uuid from typing import Annotated, Literal, Optional @@ -12,8 +13,7 @@ from fastapi import ( Response, ) from fastapi.responses import RedirectResponse -from sqlalchemy import func -from sqlmodel import Session, col, select +from sqlmodel import Session, asc, col, select from app.internal.models import ( BookRequest, @@ -28,8 +28,7 @@ from app.internal.prowlarr.prowlarr import ( ) from app.internal.indexers.mam import mam_config from app.internal.query import query_sources -from app.internal.ranking.quality import quality_config -from app.util.auth import DetailedUser, get_authenticated_user +from app.internal.auth.authentication import DetailedUser, get_authenticated_user from app.util.connection import get_connection from app.util.db import get_session, open_session from app.util.templates import template_response @@ -42,23 +41,32 @@ def get_wishlist_books( username: Optional[str] = None, response_type: Literal["all", "downloaded", "not_downloaded"] = "all", ) -> list[BookWishlistResult]: - query = select( - BookRequest, func.count(col(BookRequest.user_username)).label("count") - ) + """ + Gets the books that have been requested. If a username is given only the books requested by that + user are returned. If no username is given, all book requests are returned. + """ if username: - query = query.where(BookRequest.user_username == username) + query = select(BookRequest).where(BookRequest.user_username == username) else: - query = query.where(col(BookRequest.user_username).is_not(None)) + query = select(BookRequest).where(col(BookRequest.user_username).is_not(None)) - book_requests = session.exec( - query.select_from(BookRequest).group_by(BookRequest.asin) - ).all() + book_requests = session.exec(query).all() + # group by asin and aggregate all usernames + usernames: dict[str, list[str]] = defaultdict(list) + distinct_books: dict[str, BookRequest] = {} + for book in book_requests: + if book.asin not in distinct_books: + distinct_books[book.asin] = book + if book.user_username: + usernames[book.asin].append(book.user_username) + + # add information of what users requested the book books: list[BookWishlistResult] = [] downloaded: list[BookWishlistResult] = [] - for book, count in book_requests: + for asin, book in distinct_books.items(): b = BookWishlistResult.model_validate(book) - b.amount_requested = count + b.requested_by = usernames[asin] if b.downloaded: downloaded.append(b) else: @@ -103,23 +111,73 @@ async def downloaded( ) +@router.patch("/downloaded/{asin}") +async def update_downloaded( + request: Request, + asin: str, + admin_user: Annotated[ + DetailedUser, Depends(get_authenticated_user(GroupEnum.admin)) + ], + session: Annotated[Session, Depends(get_session)], +): + books = session.exec(select(BookRequest).where(BookRequest.asin == asin)).all() + for book in books: + book.downloaded = True + session.add(book) + session.commit() + + username = None if admin_user.is_admin() else admin_user.username + books = get_wishlist_books(session, username, "not_downloaded") + return template_response( + "wishlist_page/wishlist.html", + request, + admin_user, + {"books": books, "page": "wishlist"}, + block_name="book_wishlist", + ) + + @router.get("/manual") async def manual( request: Request, user: Annotated[DetailedUser, Depends(get_authenticated_user())], session: Annotated[Session, Depends(get_session)], ): - books = session.exec(select(ManualBookRequest)).all() - auto_download = quality_config.get_auto_download(session) + books = session.exec( + select(ManualBookRequest).order_by(asc(ManualBookRequest.downloaded)) + ).all() return template_response( "wishlist_page/manual.html", request, user, - { - "books": books, - "page": "manual", - "auto_download": auto_download, - }, + {"books": books, "page": "manual"}, + ) + + +@router.patch("/manual/{id}") +async def downloaded_manual( + request: Request, + id: uuid.UUID, + admin_user: Annotated[ + DetailedUser, Depends(get_authenticated_user(GroupEnum.admin)) + ], + session: Annotated[Session, Depends(get_session)], +): + book = session.get(ManualBookRequest, id) + if book: + book.downloaded = True + session.add(book) + session.commit() + + books = session.exec( + select(ManualBookRequest).order_by(asc(ManualBookRequest.downloaded)) + ).all() + return template_response( + "wishlist_page/manual.html", + request, + admin_user, + {"books": books, "page": "manual"}, + block_name="book_wishlist", ) @@ -138,12 +196,11 @@ async def delete_manual( session.commit() books = session.exec(select(ManualBookRequest)).all() - auto_download = quality_config.get_auto_download(session) return template_response( "wishlist_page/manual.html", request, admin_user, - {"books": books, "page": "manual", "auto_download": auto_download}, + {"books": books, "page": "manual"}, block_name="book_wishlist", ) diff --git a/app/util/auth.py b/app/util/auth.py deleted file mode 100644 index 8aed984..0000000 --- a/app/util/auth.py +++ /dev/null @@ -1,242 +0,0 @@ -import base64 -import secrets -from datetime import datetime, timedelta, timezone -from enum import Enum -from typing import Annotated, Literal, Optional - -import jwt -from argon2 import PasswordHasher -from argon2.exceptions import VerifyMismatchError -from fastapi import Depends, HTTPException, Request, status -from fastapi.security import HTTPBasic, OAuth2PasswordBearer -from sqlmodel import Session - -from app.internal.models import GroupEnum, User -from app.util.cache import StringConfigCache -from app.util.db import get_session - -JWT_ALGORITHM = "HS256" - - -class LoginTypeEnum(str, Enum): - basic = "basic" - forms = "forms" - none = "none" - - def is_basic(self): - return self == LoginTypeEnum.basic - - def is_forms(self): - return self == LoginTypeEnum.forms - - def is_none(self): - return self == LoginTypeEnum.none - - -AuthConfigKey = Literal[ - "login_type", - "access_token_expiry_minutes", - "auth_secret", - "min_password_length", -] - - -class AuthConfig(StringConfigCache[AuthConfigKey]): - def get_login_type(self, session: Session) -> LoginTypeEnum: - login_type = self.get(session, "login_type") - if login_type: - return LoginTypeEnum(login_type) - return LoginTypeEnum.basic - - def set_login_type(self, session: Session, login_Type: LoginTypeEnum): - self.set(session, "login_type", login_Type.value) - - def reset_auth_secret(self, session: Session): - auth_secret = base64.encodebytes(secrets.token_bytes(64)).decode("utf-8") - self.set(session, "auth_secret", auth_secret) - - def get_auth_secret(self, session: Session) -> str: - auth_secret = self.get(session, "auth_secret") - if auth_secret: - return auth_secret - auth_secret = base64.encodebytes(secrets.token_bytes(64)).decode("utf-8") - self.set(session, "auth_secret", auth_secret) - return auth_secret - - def get_access_token_expiry_minutes(self, session: Session): - return self.get_int(session, "access_token_expiry_minutes", 60 * 24 * 7) - - def set_access_token_expiry_minutes(self, session: Session, expiry: int): - self.set_int(session, "access_token_expiry_minutes", expiry) - - def get_min_password_length(self, session: Session) -> int: - return self.get_int(session, "min_password_length", 1) - - def set_min_password_length(self, session: Session, min_password_length: int): - self.set_int(session, "min_password_length", min_password_length) - - -class DetailedUser(User): - login_type: LoginTypeEnum - - def can_logout(self): - return self.login_type == LoginTypeEnum.forms - - -security = HTTPBasic() -ph = PasswordHasher() -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False) -auth_config = AuthConfig() - - -def raise_for_invalid_password( - session: Session, - password: str, - confirm_password: Optional[str] = None, - ignore_confirm: bool = False, -): - if not ignore_confirm and password != confirm_password: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Passwords must be equal", - ) - - min_password_length = auth_config.get_min_password_length(session) - if not len(password) >= min_password_length: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Password must be at least {min_password_length} characters long", - ) - - -def is_correct_password(user: User, password: str) -> bool: - try: - return ph.verify(user.password, password) - except VerifyMismatchError: - return False - - -def authenticate_user(session: Session, username: str, password: str) -> Optional[User]: - user = session.get(User, username) - if not user: - return None - - try: - ph.verify(user.password, password) - except VerifyMismatchError: - return None - - if ph.check_needs_rehash(user.password): - user.password = ph.hash(password) - session.add(user) - session.commit() - - return user - - -def create_access_token( - auth_secret: str, data: dict[str, str | datetime], expires_delta: timedelta -): - to_encode = data.copy() - expires = datetime.now(timezone.utc) + expires_delta - to_encode.update({"exp": expires}) - encoded_jwt = jwt.encode(to_encode, auth_secret, algorithm=JWT_ALGORITHM) # pyright: ignore[reportUnknownMemberType] - return encoded_jwt - - -def create_user( - username: str, - password: str, - group: GroupEnum = GroupEnum.untrusted, - root: bool = False, -) -> User: - password_hash = ph.hash(password) - return User(username=username, password=password_hash, group=group, root=root) - - -def get_authenticated_user(lowest_allowed_group: GroupEnum = GroupEnum.untrusted): - async def get_user( - request: Request, - session: Annotated[Session, Depends(get_session)], - ) -> DetailedUser: - login_type = auth_config.get_login_type(session) - - if login_type == LoginTypeEnum.forms: - user = await _get_forms_auth(request, session) - elif login_type == LoginTypeEnum.none: - user = await _get_none_auth() - else: - user = await _get_basic_auth(request, session) - - if not user.is_above(lowest_allowed_group): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden" - ) - - user = DetailedUser.model_validate(user, update={"login_type": login_type}) - - return user - - return get_user - - -async def _get_basic_auth( - request: Request, - session: Session, -) -> User: - invalid_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid credentials", - headers={"WWW-Authenticate": "Basic"}, - ) - - credentials = await security(request) - - if not credentials: - raise invalid_exception - - user = authenticate_user(session, credentials.username, credentials.password) - if not user: - raise invalid_exception - - return user - - -class RequiresLoginException(Exception): - def __init__(self, detail: Optional[str] = None, **kwargs: object): - super().__init__(**kwargs) - self.detail = detail - - -async def _get_forms_auth( - request: Request, - session: Session, -) -> User: - # Authentication is either through Authorization header or cookie - token = await oauth2_scheme(request) - if not token: - token = request.cookies.get("audio_sess") - if not token: - raise RequiresLoginException() - - try: - payload = jwt.decode( # pyright: ignore[reportUnknownMemberType] - token, auth_config.get_auth_secret(session), algorithms=[JWT_ALGORITHM] - ) - except jwt.InvalidTokenError: - raise RequiresLoginException("Token is expired/invalid") - - username = payload.get("sub") - if username is None: - raise RequiresLoginException("Token is invalid") - - user = session.get(User, username) - if not user: - raise RequiresLoginException("User does not exist") - - return user - - -async def _get_none_auth() -> User: - """Treats every request as being root / turns off authentication""" - return User(username="no-login", password="", group=GroupEnum.admin, root=True) diff --git a/app/util/cache.py b/app/util/cache.py index 257fe55..3705012 100644 --- a/app/util/cache.py +++ b/app/util/cache.py @@ -35,10 +35,23 @@ L = TypeVar("L", bound=str) class StringConfigCache(Generic[L], ABC): _cache: dict[L, str] = {} + @overload def get(self, session: Session, key: L) -> Optional[str]: + pass + + @overload + def get(self, session: Session, key: L, default: str) -> str: + pass + + def get( + self, session: Session, key: L, default: Optional[str] = None + ) -> Optional[str]: if key in self._cache: return self._cache[key] - return session.exec(select(Config.value).where(Config.key == key)).one_or_none() + return ( + session.exec(select(Config.value).where(Config.key == key)).one_or_none() + or default + ) def set(self, session: Session, key: L, value: str): old = session.exec(select(Config).where(Config.key == key)).one_or_none() @@ -59,7 +72,7 @@ class StringConfigCache(Generic[L], ABC): del self._cache[key] @overload - def get_int(self, session: Session, key: L, default: None = None) -> Optional[int]: + def get_int(self, session: Session, key: L) -> Optional[int]: pass @overload diff --git a/app/util/templates.py b/app/util/templates.py index 6fbe467..45f6ba3 100644 --- a/app/util/templates.py +++ b/app/util/templates.py @@ -5,7 +5,7 @@ from fastapi import Request, Response from jinja2_fragments.fastapi import Jinja2Blocks from starlette.background import BackgroundTask -from app.util.auth import DetailedUser +from app.internal.auth.authentication import DetailedUser templates = Jinja2Blocks(directory="templates") templates.env.filters["quote_plus"] = lambda u: quote_plus(u) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType,reportUnknownArgumentType] diff --git a/app/util/time.py b/app/util/time.py new file mode 100644 index 0000000..4b63012 --- /dev/null +++ b/app/util/time.py @@ -0,0 +1,4 @@ +from typing import NewType + +Second = NewType("Second", int) +Minute = NewType("Minute", int) diff --git a/app/util/toast.py b/app/util/toast.py new file mode 100644 index 0000000..5a31aad --- /dev/null +++ b/app/util/toast.py @@ -0,0 +1,11 @@ +from typing import Literal + + +class ToastException(Exception): + """Shows a toast on the frontend if raised on an HTMX endpoint""" + + def __init__( + self, message: str, type: Literal["error", "success", "info"] = "error" + ): + self.message = message + self.type = type diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..7e096aa --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,10 @@ +services: + web: + build: + context: . + args: + - VERSION=local + volumes: + - ./config:/config + ports: + - "8000:8000" diff --git a/pyproject.toml b/pyproject.toml index ae01a25..99260ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [tool.pyright] include = ["**/*.py"] -exclude = ["**/__pycache__", "**/.venv"] +exclude = ["**/__pycache__", "**/.venv", "**/.direnv"] ignore = [] typeCheckingMode = "strict" diff --git a/requirements.txt b/requirements.txt index ca4b8b8..09a306d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ httpcore==1.0.7 httptools==0.6.4 httpx==0.28.1 idna==3.10 +itsdangerous==2.2.0 Jinja2==3.1.6 jinja2_fragments==1.8.0 Mako==1.3.9 diff --git a/styles/globals.css b/styles/globals.css index 4c9abe0..44c5e1e 100644 --- a/styles/globals.css +++ b/styles/globals.css @@ -1,6 +1,9 @@ @import "tailwindcss"; @plugin "daisyui"; +/* Enables dark mode specific styling with "dark:" */ +@custom-variant dark (&:where(.dark, .dark *)); + @plugin "daisyui/theme" { name: "nord"; default: true; diff --git a/templates/base.html b/templates/base.html index 8cc80d0..0b00a5c 100644 --- a/templates/base.html +++ b/templates/base.html @@ -27,12 +27,14 @@ "light-dark-toggle", )) { elem.classList.add("DARKCLASS"); + document.documentElement.classList.add("dark"); } } else { for (const elem of document.getElementsByClassName( "light-dark-toggle", )) { elem.classList.remove("DARKCLASS"); + document.documentElement.classList.remove("dark"); } } }; @@ -49,9 +51,27 @@ + + {% include 'scripts/toast.html' %} - {% if not hide_navbar %} + {% block toast_block %} + + {% endblock %} {% if not hide_navbar %}