feat: all website operations can now be handled using the REST API (Closes #135) (#176)

fix: make inputs on download settings page editable
This commit is contained in:
Mark
2026-01-11 17:10:17 +01:00
committed by GitHub
parent 6727dc67ba
commit 0dff1f38c8
33 changed files with 1791 additions and 871 deletions
+117
View File
@@ -0,0 +1,117 @@
from typing import Literal, Sequence, cast
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import InstrumentedAttribute, selectinload
from sqlmodel import Session, asc, col, not_, select
from app.internal.models import (
Audiobook,
AudiobookRequest,
AudiobookWishlistResult,
ManualBookRequest,
User,
)
class WishlistCounts(BaseModel):
requests: int
downloaded: int
manual: int
def get_wishlist_counts(session: Session, user: User | None = None) -> WishlistCounts:
"""
If a non-admin user is given, only count requests for that user.
Admins can see and get counts for all requests.
"""
username = None if user is None or user.is_admin() else user.username
rows = session.exec(
select(Audiobook.downloaded, func.count("*"))
.where(not username or AudiobookRequest.user_username == username)
.select_from(Audiobook)
.join(AudiobookRequest)
.group_by(col(Audiobook.downloaded))
).all()
requests = 0
downloaded = 0
for downloaded_status, count in rows:
if downloaded_status:
downloaded = count
else:
requests = count
manual = session.exec(
select(func.count())
.select_from(ManualBookRequest)
.where(
not username or ManualBookRequest.user_username == username,
col(ManualBookRequest.user_username).is_not(None),
)
).one()
return WishlistCounts(
requests=requests,
downloaded=downloaded,
manual=manual,
)
def get_wishlist_results(
session: Session,
username: str | None = None,
response_type: Literal["all", "downloaded", "not_downloaded"] = "all",
) -> list[AudiobookWishlistResult]:
"""
Gets the books that have been requested. If a username is given only the books requested by that
user are returned. If no username is given, all book requests are returned.
"""
match response_type:
case "downloaded":
clause = Audiobook.downloaded
case "not_downloaded":
clause = not_(Audiobook.downloaded)
case _:
clause = True
results = session.exec(
select(Audiobook)
.where(
clause,
col(Audiobook.asin).in_(
select(AudiobookRequest.asin).where(
not username or AudiobookRequest.user_username == username
)
),
)
.options(
selectinload(
cast(
InstrumentedAttribute[list[AudiobookRequest]],
cast(object, Audiobook.requests),
)
)
)
).all()
return [
AudiobookWishlistResult(
book=book,
requests=book.requests,
)
for book in results
]
def get_all_manual_requests(
session: Session, user: User
) -> Sequence[ManualBookRequest]:
return session.exec(
select(ManualBookRequest)
.where(
user.is_admin() or ManualBookRequest.user_username == user.username,
col(ManualBookRequest.user_username).is_not(None),
)
.order_by(asc(ManualBookRequest.downloaded))
).all()
+62 -1
View File
@@ -1,8 +1,10 @@
# pyright: reportExplicitAny=false
from typing import Any
from typing import Any, Mapping
from aiohttp import ClientSession
from pydantic import BaseModel
from sqlmodel import Session
from app.internal.indexers.abstract import AbstractIndexer, SessionContainer
from app.internal.indexers.configuration import (
@@ -11,8 +13,11 @@ from app.internal.indexers.configuration import (
IndexerConfiguration,
ValuedConfigurations,
create_valued_configuration,
indexer_configuration_cache,
)
from app.internal.indexers.indexers import indexers
from app.internal.prowlarr.util import flush_prowlarr_cache
from app.util.json_type import get_bool
from app.util.log import logger
@@ -70,3 +75,59 @@ async def get_indexer_contexts(
)
return contexts
async def update_single_indexer(
indexer_select: str,
values: Mapping[str, object],
session: Session,
client_session: ClientSession,
ignore_missing_booleans: bool = False,
):
"""
Update a single indexer with the given values.
`ignore_missing_booleans` can be set to true to ignore missing boolean values. By default, missing booleans are treated as false.
"""
session_container = SessionContainer(session=session, client_session=client_session)
contexts = await get_indexer_contexts(
session_container, check_required=False, return_disabled=True
)
updated_context: IndexerContext | None = None
for context in contexts:
if context.indexer.name == indexer_select:
updated_context = context
break
if not updated_context:
raise ValueError("Indexer not found")
for key, context in updated_context.configuration.items():
value = values.get(key)
if value is None:
# forms do not include false checkboxes, so we handle missing booleans as false
if context.type_ is bool and not ignore_missing_booleans:
value = False
else:
logger.warning("Value is missing for key", key=key)
continue
if context.type_ is bool:
indexer_configuration_cache.set_bool(session, key, value == "on")
else:
indexer_configuration_cache.set(session, key, str(value))
if "enabled" in values and (
isinstance(e := values["enabled"], str)
or isinstance(e, bool)
or isinstance(e, int)
):
logger.debug("Setting enabled state", enabled=values["enabled"])
enabled = get_bool(e) or False
await updated_context.indexer.set_enabled(
session_container,
enabled,
)
flush_prowlarr_cache()
+8 -1
View File
@@ -1,3 +1,4 @@
import html
from sqlmodel._compat import SQLModelConfig
import json
import uuid
@@ -253,7 +254,13 @@ class Notification(BaseSQLModel, table=True):
@property
def serialized_headers(self):
return json.dumps(self.headers)
return html.escape(json.dumps(self.headers))
class APIKeyResponse(BaseModel):
id: uuid.UUID
name: str
enabled: bool
class APIKey(BaseSQLModel, table=True):
+3 -3
View File
@@ -12,7 +12,7 @@ from app.internal.models import (
User,
)
from app.util import json_type
from app.util.db import open_session
from app.util.db import get_session
from app.util.log import logger
@@ -143,7 +143,7 @@ async def send_all_notifications(
):
if other_replacements is None:
other_replacements = {}
with open_session() as session:
with next(get_session()) as session:
notifications = session.exec(
select(Notification).where(
Notification.event == event_type, Notification.enabled
@@ -221,7 +221,7 @@ async def send_all_manual_notifications(
):
if other_replacements is None:
other_replacements = {}
with open_session() as session:
with next(get_session()) as session:
user = session.exec(
select(User).where(User.username == book_request.user_username)
).first()
+14 -80
View File
@@ -1,3 +1,4 @@
import html
import json
import posixpath
from datetime import datetime
@@ -22,87 +23,14 @@ from app.internal.models import (
)
from app.internal.notifications import send_all_notifications
from app.internal.prowlarr.source_metadata import edit_source_metadata
from app.util.cache import SimpleCache, StringConfigCache
from app.internal.prowlarr.util import (
prowlarr_config,
prowlarr_indexer_cache,
prowlarr_source_cache,
)
from app.util.log import logger
class ProwlarrMisconfigured(ValueError):
pass
ProwlarrConfigKey = Literal[
"prowlarr_api_key",
"prowlarr_base_url",
"prowlarr_source_ttl",
"prowlarr_categories",
"prowlarr_indexers",
]
class ProwlarrConfig(StringConfigCache[ProwlarrConfigKey]):
def raise_if_invalid(self, session: Session):
if not self.get_base_url(session):
raise ProwlarrMisconfigured("Prowlarr base url not set")
if not self.get_api_key(session):
raise ProwlarrMisconfigured("Prowlarr base url not set")
def is_valid(self, session: Session) -> bool:
return (
self.get_base_url(session) is not None
and self.get_api_key(session) is not None
)
def get_api_key(self, session: Session) -> str | None:
return self.get(session, "prowlarr_api_key")
def set_api_key(self, session: Session, api_key: str):
self.set(session, "prowlarr_api_key", api_key)
def get_base_url(self, session: Session) -> str | None:
path = self.get(session, "prowlarr_base_url")
if path:
return path.rstrip("/")
return None
def set_base_url(self, session: Session, base_url: str):
self.set(session, "prowlarr_base_url", base_url)
def get_source_ttl(self, session: Session) -> int:
return self.get_int(session, "prowlarr_source_ttl", 24 * 60 * 60)
def set_source_ttl(self, session: Session, source_ttl: int):
self.set_int(session, "prowlarr_source_ttl", source_ttl)
def get_categories(self, session: Session) -> list[int]:
categories = self.get(session, "prowlarr_categories")
if categories is None:
return [3030]
return json.loads(categories) # pyright: ignore[reportAny]
def set_categories(self, session: Session, categories: list[int]):
self.set(session, "prowlarr_categories", json.dumps(categories))
def get_indexers(self, session: Session) -> list[int]:
indexers = self.get(session, "prowlarr_indexers")
if indexers is None:
return []
return json.loads(indexers) # pyright: ignore[reportAny]
def set_indexers(self, session: Session, indexers: list[int]):
self.set(session, "prowlarr_indexers", json.dumps(indexers))
prowlarr_config = ProwlarrConfig()
prowlarr_source_cache = SimpleCache[list[ProwlarrSource], str]()
prowlarr_indexer_cache = SimpleCache[Indexer, str]()
def flush_prowlarr_cache():
logger.info("Flushing prowlarr caches")
prowlarr_source_cache.flush()
prowlarr_indexer_cache.flush()
async def _get_torrent_info_hash(
client_session: ClientSession, download_url: str
) -> str | None:
@@ -365,8 +293,10 @@ class IndexerResponse(BaseModel):
@property
def json_string(self) -> str:
return json.dumps(
{id: indexer.model_dump() for id, indexer in self.indexers.items()}
return html.escape(
json.dumps(
{id: indexer.model_dump() for id, indexer in self.indexers.items()}
)
)
@property
@@ -417,6 +347,10 @@ async def get_indexers(
indexers = _IndexerList.validate_python(await response.json())
for indexer in indexers:
prowlarr_indexer_cache.set(indexer, str(indexer.id))
logger.info(
"Successfully fetched indexers from Prowlarr",
count=len(indexers),
)
return IndexerResponse(
indexers={
+85
View File
@@ -0,0 +1,85 @@
import json
from typing import Literal
from sqlmodel import Session
from app.internal.models import Indexer, ProwlarrSource
from app.util.cache import SimpleCache, StringConfigCache
from app.util.log import logger
class ProwlarrMisconfigured(ValueError):
pass
ProwlarrConfigKey = Literal[
"prowlarr_api_key",
"prowlarr_base_url",
"prowlarr_source_ttl",
"prowlarr_categories",
"prowlarr_indexers",
]
class ProwlarrConfig(StringConfigCache[ProwlarrConfigKey]):
def raise_if_invalid(self, session: Session):
if not self.get_base_url(session):
raise ProwlarrMisconfigured("Prowlarr base url not set")
if not self.get_api_key(session):
raise ProwlarrMisconfigured("Prowlarr base url not set")
def is_valid(self, session: Session) -> bool:
return (
self.get_base_url(session) is not None
and self.get_api_key(session) is not None
)
def get_api_key(self, session: Session) -> str | None:
return self.get(session, "prowlarr_api_key")
def set_api_key(self, session: Session, api_key: str):
self.set(session, "prowlarr_api_key", api_key)
def get_base_url(self, session: Session) -> str | None:
path = self.get(session, "prowlarr_base_url")
if path:
return path.rstrip("/")
return None
def set_base_url(self, session: Session, base_url: str):
self.set(session, "prowlarr_base_url", base_url)
def get_source_ttl(self, session: Session) -> int:
return self.get_int(session, "prowlarr_source_ttl", 24 * 60 * 60)
def set_source_ttl(self, session: Session, source_ttl: int):
self.set_int(session, "prowlarr_source_ttl", source_ttl)
def get_categories(self, session: Session) -> list[int]:
categories = self.get(session, "prowlarr_categories")
if categories is None:
return [3030]
return json.loads(categories) # pyright: ignore[reportAny]
def set_categories(self, session: Session, categories: list[int]):
self.set(session, "prowlarr_categories", json.dumps(categories))
def get_indexers(self, session: Session) -> list[int]:
indexers = self.get(session, "prowlarr_indexers")
if indexers is None:
return []
return json.loads(indexers) # pyright: ignore[reportAny]
def set_indexers(self, session: Session, indexers: list[int]):
self.set(session, "prowlarr_indexers", json.dumps(indexers))
prowlarr_config = ProwlarrConfig()
prowlarr_source_cache = SimpleCache[list[ProwlarrSource], str]()
prowlarr_indexer_cache = SimpleCache[Indexer, str]()
def flush_prowlarr_cache():
logger.info("Flushing prowlarr caches")
prowlarr_source_cache.flush()
prowlarr_indexer_cache.flush()
+17 -6
View File
@@ -1,18 +1,17 @@
# what is currently being queried
# To dermine what is currently being queried:
from contextlib import contextmanager
from typing import Literal
import pydantic
import aiohttp
from aiohttp import ClientSession
from fastapi import HTTPException
from sqlmodel import Session, select
from app.internal.prowlarr.util import prowlarr_config
from app.util.db import get_session
from app.internal.models import Audiobook, ProwlarrSource, User
from app.internal.prowlarr.prowlarr import (
prowlarr_config,
query_prowlarr,
start_download,
)
from app.internal.prowlarr.prowlarr import query_prowlarr, start_download
from app.internal.ranking.download_ranking import rank_sources
querying: set[str] = set()
@@ -108,3 +107,15 @@ async def query_sources(
book=book,
state="ok",
)
async def background_start_query(asin: str, requester: User, auto_download: bool):
with next(get_session()) as session:
async with ClientSession(timeout=aiohttp.ClientTimeout(60)) as client_session:
await query_sources(
asin=asin,
session=session,
client_session=client_session,
start_auto_download=auto_download,
requester=requester,
)
+1 -1
View File
@@ -8,7 +8,7 @@ from aiohttp import ClientSession
from sqlmodel import Session
from app.internal.models import Audiobook, ProwlarrSource
from app.internal.prowlarr.prowlarr import prowlarr_config
from app.internal.prowlarr.util import prowlarr_config
from app.internal.ranking.quality import FileFormat
# HACK: Disabled because it doesn't work well with ratelimiting
+3 -3
View File
@@ -19,7 +19,7 @@ from app.internal.book_search import clear_old_book_caches
from app.internal.env_settings import Settings
from app.internal.models import User
from app.routers import api, auth, root, search, settings, wishlist
from app.util.db import open_session
from app.util.db import get_session
from app.util.fetch_js import fetch_scripts
from app.util.redirect import BaseUrlRedirectResponse
from app.util.templates import templates
@@ -28,7 +28,7 @@ from app.util.toast import ToastException
# intialize js dependencies or throw an error if not in debug mode
fetch_scripts(Settings().app.debug)
with open_session() as session:
with next(get_session()) as session:
auth_secret = auth_config.get_auth_secret(session)
initialize_force_login_type(session)
clear_old_book_caches(session)
@@ -119,7 +119,7 @@ async def redirect_to_init(
and not request.url.path.startswith("/static")
and request.method == "GET"
):
with open_session() as session:
with next(get_session()) as session:
user_count = session.exec(select(func.count()).select_from(User)).one()
if user_count == 0:
return BaseUrlRedirectResponse("/init")
+7 -1
View File
@@ -5,12 +5,18 @@ from sqlmodel import Session, select, text
from app.routers.api.indexers import router as indexers_router
from app.routers.api.users import router as users_router
from app.routers.api.search import router as search_router
from app.routers.api.requests import router as requests_router
from app.routers.api.settings import router as settings_router
from app.util.db import get_session
from app.util.log import logger
router = APIRouter(prefix="/api", tags=["API"])
router = APIRouter(prefix="/api")
router.include_router(indexers_router)
router.include_router(users_router)
router.include_router(search_router)
router.include_router(requests_router)
router.include_router(settings_router)
@router.get("/health", tags=["System"])
+6 -3
View File
@@ -1,18 +1,21 @@
import json
from typing import Annotated, cast
from aiohttp import ClientSession
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response, Security
from sqlmodel import Session
from app.internal.auth.authentication import APIKeyAuth, DetailedUser
from app.internal.indexers.abstract import SessionContainer
from app.internal.indexers.indexer_util import get_indexer_contexts
from app.internal.indexers.indexer_util import (
get_indexer_contexts,
update_single_indexer,
)
from app.internal.models import BaseSQLModel, GroupEnum
from app.routers.settings.indexers import update_single_indexer
from app.util.connection import get_connection
from app.util.db import get_session
from app.util.toast import ToastException
from app.util.log import logger
from app.util.toast import ToastException
router = APIRouter(prefix="/indexers", tags=["Indexers"])
+370
View File
@@ -0,0 +1,370 @@
import uuid
from typing import Annotated, Literal
from aiohttp import ClientSession
from fastapi import (
APIRouter,
BackgroundTasks,
Depends,
HTTPException,
Security,
Response,
)
from pydantic import BaseModel
from sqlmodel import Session, asc, col, select, delete
from app.internal.auth.authentication import APIKeyAuth, DetailedUser
from app.internal.book_search import (
get_book_by_asin,
audible_region_type,
get_region_from_settings,
audible_regions,
)
from app.internal.models import (
Audiobook,
AudiobookRequest,
AudiobookWishlistResult,
EventEnum,
GroupEnum,
ManualBookRequest,
User,
)
from app.internal.notifications import (
send_all_manual_notifications,
send_all_notifications,
)
from app.internal.prowlarr.prowlarr import start_download
from app.internal.prowlarr.util import ProwlarrMisconfigured, prowlarr_config
from app.internal.query import query_sources, QueryResult, background_start_query
from app.internal.ranking.quality import quality_config
from app.internal.db_queries import get_wishlist_results
from app.util.connection import get_connection
from app.util.db import get_session
from app.util.log import logger
from app.util.toast import ToastException
router = APIRouter(prefix="/requests", tags=["Requests"])
class DownloadSourceBody(BaseModel):
guid: str
indexer_id: int
@router.post("/{asin}", status_code=201)
async def create_request(
session: Annotated[Session, Depends(get_session)],
client_session: Annotated[ClientSession, Depends(get_connection)],
user: Annotated[DetailedUser, Security(APIKeyAuth())],
background_task: BackgroundTasks,
asin: str,
region: audible_region_type | None = None,
):
if region is None:
region = get_region_from_settings()
if audible_regions.get(region) is None:
raise HTTPException(status_code=400, detail="Invalid region")
book = await get_book_by_asin(client_session, asin, region)
if not book:
raise HTTPException(status_code=404, detail="Book not found")
if not session.exec(
select(AudiobookRequest).where(
AudiobookRequest.asin == asin,
AudiobookRequest.user_username == user.username,
)
).first():
book_request = AudiobookRequest(asin=asin, user_username=user.username)
session.add(book_request)
session.commit()
logger.info(
"Added new audiobook request",
username=user.username,
asin=asin,
)
else:
raise HTTPException(status_code=409, detail="Book already requested")
background_task.add_task(
send_all_notifications,
event_type=EventEnum.on_new_request,
requester=User.model_validate(user),
book_asin=asin,
)
if quality_config.get_auto_download(session) and user.is_above(GroupEnum.trusted):
# start querying and downloading if auto download is enabled
background_task.add_task(
background_start_query,
asin=asin,
requester=User.model_validate(user),
auto_download=True,
)
return Response(status_code=201)
@router.get("", response_model=list[AudiobookWishlistResult])
async def list_requests(
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(APIKeyAuth())],
filter: Literal["all", "downloaded", "not_downloaded"] = "all",
):
username = None if user.is_admin() else user.username
results = get_wishlist_results(session, username, filter)
return results
@router.delete("/{asin}")
async def delete_request(
asin: str,
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(APIKeyAuth())],
):
if user.is_admin():
session.execute(
delete(AudiobookRequest).where(col(AudiobookRequest.asin) == asin)
)
else:
session.execute(
delete(AudiobookRequest).where(
(col(AudiobookRequest.asin) == asin)
& (col(AudiobookRequest.user_username) == user.username)
)
)
session.commit()
return Response(status_code=204)
@router.patch("/{asin}/downloaded")
async def mark_downloaded(
asin: str,
session: Annotated[Session, Depends(get_session)],
background_task: BackgroundTasks,
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
book = session.exec(select(Audiobook).where(Audiobook.asin == asin)).first()
if book:
book.downloaded = True
session.add(book)
session.commit()
background_task.add_task(
send_all_notifications,
event_type=EventEnum.on_successful_download,
requester=None,
book_asin=asin,
)
return Response(status_code=204)
raise HTTPException(status_code=404, detail="Book not found")
@router.get("/manual", response_model=list[ManualBookRequest])
async def list_manual_requests(
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(APIKeyAuth())],
):
return session.exec(
select(ManualBookRequest)
.where(
user.is_admin() or ManualBookRequest.user_username == user.username,
col(ManualBookRequest.user_username).is_not(None),
)
.order_by(asc(ManualBookRequest.downloaded))
).all()
class ManualRequest(BaseModel):
title: str
author: str
narrator: str | None = None
subtitle: str | None = None
publish_date: str | None = None
info: str | None = None
@router.post("/manual", status_code=201)
async def create_manual_request(
body: ManualRequest,
session: Annotated[Session, Depends(get_session)],
background_task: BackgroundTasks,
user: Annotated[DetailedUser, Security(APIKeyAuth())],
):
book_request = ManualBookRequest(
user_username=user.username,
title=body.title,
authors=body.author.split(","),
narrators=body.narrator.split(",") if body.narrator else [],
subtitle=body.subtitle,
publish_date=body.publish_date,
additional_info=body.info,
)
session.add(book_request)
session.commit()
background_task.add_task(
send_all_manual_notifications,
event_type=EventEnum.on_new_request,
book_request=ManualBookRequest.model_validate(book_request),
)
return Response(status_code=201)
@router.put("/manual/{id}", status_code=204)
async def update_manual_request(
id: uuid.UUID,
body: ManualRequest,
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(APIKeyAuth())],
):
book_request = session.get(ManualBookRequest, id)
if not book_request:
raise HTTPException(status_code=404, detail="Book request not found")
if not user.is_admin() and book_request.user_username != user.username:
raise HTTPException(status_code=403, detail="Not authorized")
book_request.title = body.title
book_request.subtitle = body.subtitle
book_request.authors = body.author.split(",")
book_request.narrators = body.narrator.split(",") if body.narrator else []
book_request.publish_date = body.publish_date
book_request.additional_info = body.info
session.add(book_request)
session.commit()
return Response(status_code=204)
@router.patch("/manual/{id}/downloaded")
async def mark_manual_downloaded(
id: uuid.UUID,
session: Annotated[Session, Depends(get_session)],
background_task: BackgroundTasks,
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
book_request = session.get(ManualBookRequest, id)
if book_request:
book_request.downloaded = True
session.add(book_request)
session.commit()
background_task.add_task(
send_all_manual_notifications,
event_type=EventEnum.on_successful_download,
book_request=ManualBookRequest.model_validate(book_request),
)
return Response(status_code=204)
raise HTTPException(status_code=404, detail="Request not found")
@router.delete("/manual/{id}")
async def delete_manual_request(
id: uuid.UUID,
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
book = session.get(ManualBookRequest, id)
if book:
session.delete(book)
session.commit()
return Response(status_code=204)
raise HTTPException(status_code=404, detail="Request not found")
@router.post(
"/{asin}/refresh",
description="Refresh the sources from prowlarr for a book",
)
async def refresh_source(
asin: str,
session: Annotated[Session, Depends(get_session)],
client_session: Annotated[ClientSession, Depends(get_connection)],
user: Annotated[DetailedUser, Security(APIKeyAuth())],
force_refresh: bool = False,
):
# causes the sources to be placed into cache once they're done
await query_sources(
asin=asin,
session=session,
client_session=client_session,
force_refresh=force_refresh,
requester=User.model_validate(user),
)
return Response(status_code=202)
@router.get("/{asin}/sources", response_model=QueryResult)
async def list_sources(
asin: str,
session: Annotated[Session, Depends(get_session)],
client_session: Annotated[ClientSession, Depends(get_connection)],
admin_user: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
only_cached: bool = False,
):
try:
prowlarr_config.raise_if_invalid(session)
except ProwlarrMisconfigured:
raise HTTPException(status_code=400, detail="Prowlarr misconfigured")
result = await query_sources(
asin,
session=session,
client_session=client_session,
requester=admin_user,
only_return_if_cached=only_cached,
)
return result
@router.post("/{asin}/download")
async def download_book(
asin: str,
body: DownloadSourceBody,
session: Annotated[Session, Depends(get_session)],
client_session: Annotated[ClientSession, Depends(get_connection)],
admin_user: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
try:
resp = await start_download(
session=session,
client_session=client_session,
guid=body.guid,
indexer_id=body.indexer_id,
requester=admin_user,
book_asin=asin,
)
except ProwlarrMisconfigured as e:
raise HTTPException(status_code=500, detail=str(e))
if not resp.ok:
raise HTTPException(status_code=500, detail="Failed to start download")
book = session.exec(select(Audiobook).where(Audiobook.asin == asin)).first()
if book:
book.downloaded = True
session.add(book)
session.commit()
return Response(status_code=204)
@router.post("/{asin}/auto-download")
async def start_auto_download_endpoint(
asin: str,
session: Annotated[Session, Depends(get_session)],
client_session: Annotated[ClientSession, Depends(get_connection)],
user: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.trusted))],
):
try:
await query_sources(
asin=asin,
start_auto_download=True,
session=session,
client_session=client_session,
requester=user,
)
except HTTPException as e:
raise ToastException(e.detail) from None
return Response(status_code=204)
+69
View File
@@ -0,0 +1,69 @@
from typing import Annotated
from aiohttp import ClientSession
from fastapi import APIRouter, Depends, HTTPException, Query, Security
from sqlmodel import Session
from app.internal import book_search
from app.internal.auth.authentication import APIKeyAuth, DetailedUser
from app.internal.book_search import (
audible_region_type,
audible_regions,
clear_old_book_caches,
get_region_from_settings,
list_audible_books,
)
from app.internal.models import AudiobookSearchResult
from app.util.connection import get_connection
from app.util.db import get_session
router = APIRouter(prefix="/search", tags=["Search"])
@router.get("", response_model=list[AudiobookSearchResult])
async def search_books(
client_session: Annotated[ClientSession, Depends(get_connection)],
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(APIKeyAuth())],
query: Annotated[str | None, Query(alias="q")] = None,
num_results: int = 20,
page: int = 0,
region: audible_region_type | None = None,
):
if region is None:
region = get_region_from_settings()
if audible_regions.get(region) is None:
raise HTTPException(status_code=400, detail="Invalid region")
if query:
clear_old_book_caches(session)
results = await list_audible_books(
session=session,
client_session=client_session,
query=query,
num_results=num_results,
page=page,
audible_region=region,
)
else:
results = []
return [
AudiobookSearchResult(
book=book,
requests=book.requests,
username=user.username,
)
for book in results
]
@router.get("/suggestions", response_model=list[str])
async def search_suggestions(
query: Annotated[str, Query(alias="q")],
_: Annotated[DetailedUser, Security(APIKeyAuth())],
region: audible_region_type | None = None,
):
if region is None:
region = get_region_from_settings()
async with ClientSession() as client_session:
return await book_search.get_search_suggestions(client_session, query, region)
+14
View File
@@ -0,0 +1,14 @@
from fastapi import APIRouter
from app.routers.api.settings.account import router as account_router
from app.routers.api.settings.download import router as download_router
from app.routers.api.settings.notifications import router as notifications_router
from app.routers.api.settings.prowlarr import router as prowlarr_router
from app.routers.api.settings.security import router as security_router
router = APIRouter(prefix="/settings", tags=["Settings"])
router.include_router(account_router)
router.include_router(download_router)
router.include_router(notifications_router)
router.include_router(prowlarr_router)
router.include_router(security_router)
+139
View File
@@ -0,0 +1,139 @@
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Response, Security
from pydantic import BaseModel, Field
from sqlmodel import Session, select
from app.internal.auth.authentication import (
APIKeyAuth,
DetailedUser,
create_api_key,
create_user,
is_correct_password,
raise_for_invalid_password,
)
from app.internal.models import APIKey, APIKeyResponse, User
from app.util.db import get_session
router = APIRouter(prefix="/account", tags=["Account"])
class CreateAPIKeyRequest(BaseModel):
name: str = Field(min_length=1)
class CreateAPIKeyResponse(BaseModel):
name: str
key: str
class ChangePasswordRequest(BaseModel):
old_password: str
new_password: str
confirm_password: str
@router.get("/api-keys", response_model=list[APIKeyResponse])
def list_api_keys(
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(APIKeyAuth())],
):
api_keys = session.exec(
select(APIKey).where(APIKey.user_username == user.username)
).all()
return api_keys
@router.post("/api-keys", response_model=CreateAPIKeyResponse)
def create_new_api_key(
body: CreateAPIKeyRequest,
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(APIKeyAuth())],
):
name = body.name.strip()
same_name_key = session.exec(
select(APIKey).where(
APIKey.user_username == user.username,
APIKey.name == name,
)
).first()
if same_name_key:
raise HTTPException(status_code=400, detail="API key name must be unique")
api_key, private_key = create_api_key(user, name)
session.add(api_key)
session.commit()
return CreateAPIKeyResponse(name=name, key=private_key)
@router.delete("/api-keys/{id}")
def delete_api_key(
id: str,
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(APIKeyAuth())],
):
try:
uuid_id = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid UUID")
api_key = session.exec(
select(APIKey).where(
APIKey.user_username == user.username,
APIKey.id == uuid_id,
)
).first()
if not api_key:
raise HTTPException(status_code=404, detail="API key not found")
session.delete(api_key)
session.commit()
return Response(status_code=204)
@router.patch("/api-keys/{id}/toggle")
def toggle_api_key(
id: str,
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(APIKeyAuth())],
):
try:
uuid_id = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid UUID")
api_key = session.exec(
select(APIKey).where(
APIKey.user_username == user.username,
APIKey.id == uuid_id,
)
).first()
if not api_key:
raise HTTPException(status_code=404, detail="API key not found")
api_key.enabled = not api_key.enabled
session.add(api_key)
session.commit()
return Response(status_code=204)
@router.put("/password")
def change_password(
body: ChangePasswordRequest,
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(APIKeyAuth())],
):
if not is_correct_password(user, body.old_password):
raise HTTPException(status_code=400, detail="Old password is incorrect")
raise_for_invalid_password(session, body.new_password, body.confirm_password)
new_user = create_user(user.username, body.new_password, user.group)
old_user = session.exec(select(User).where(User.username == user.username)).one()
old_user.password = new_user.password
session.add(old_user)
session.commit()
return Response(status_code=204)
+75
View File
@@ -0,0 +1,75 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Security, Response
from pydantic import BaseModel
from sqlmodel import Session
from app.internal.auth.authentication import APIKeyAuth, DetailedUser
from app.internal.models import GroupEnum
from app.internal.ranking.quality import IndexerFlag, QualityRange, quality_config
from app.util.db import get_session
router = APIRouter(prefix="/download")
class DownloadSettings(BaseModel):
auto_download: bool
flac_range: QualityRange
m4b_range: QualityRange
mp3_range: QualityRange
unknown_audio_range: QualityRange
unknown_range: QualityRange
min_seeders: int
name_ratio: int
title_ratio: int
indexer_flags: list[IndexerFlag]
@router.get("", response_model=DownloadSettings)
def get_download_settings(
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
return DownloadSettings(
auto_download=quality_config.get_auto_download(session),
flac_range=quality_config.get_range(session, "quality_flac"),
m4b_range=quality_config.get_range(session, "quality_m4b"),
mp3_range=quality_config.get_range(session, "quality_mp3"),
unknown_audio_range=quality_config.get_range(session, "quality_unknown_audio"),
unknown_range=quality_config.get_range(session, "quality_unknown"),
min_seeders=quality_config.get_min_seeders(session),
name_ratio=quality_config.get_name_exists_ratio(session),
title_ratio=quality_config.get_title_exists_ratio(session),
indexer_flags=quality_config.get_indexer_flags(session),
)
class UpdateDownloadSettings(BaseModel):
auto_download: bool
flac_range: QualityRange
m4b_range: QualityRange
mp3_range: QualityRange
unknown_audio_range: QualityRange
unknown_range: QualityRange
min_seeders: int
name_ratio: int
title_ratio: int
@router.patch("", status_code=204)
def update_download_settings(
body: UpdateDownloadSettings,
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
quality_config.set_auto_download(session, body.auto_download)
quality_config.set_range(session, "quality_flac", body.flac_range)
quality_config.set_range(session, "quality_m4b", body.m4b_range)
quality_config.set_range(session, "quality_mp3", body.mp3_range)
quality_config.set_range(session, "quality_unknown_audio", body.unknown_audio_range)
quality_config.set_range(session, "quality_unknown", body.unknown_range)
quality_config.set_min_seeders(session, body.min_seeders)
quality_config.set_name_exists_ratio(session, body.name_ratio)
quality_config.set_title_exists_ratio(session, body.title_ratio)
return Response(status_code=204)
+207
View File
@@ -0,0 +1,207 @@
import json
import uuid
from typing import Annotated, cast
from fastapi import APIRouter, Depends, HTTPException, Security, Response
from pydantic import BaseModel, Field
from sqlmodel import Session, select
from app.internal.auth.authentication import APIKeyAuth, DetailedUser
from app.internal.models import (
EventEnum,
GroupEnum,
Notification,
NotificationBodyTypeEnum,
)
from app.internal.notifications import send_notification
from app.util.db import get_session
router = APIRouter(prefix="/notifications", tags=["Notifications"])
class NotificationRequest(BaseModel):
id: uuid.UUID | None = None
name: str = Field(min_length=1)
url: str = Field(min_length=1)
event_type: str
body: str
body_type: NotificationBodyTypeEnum
headers: str = "{}"
@router.get("", response_model=list[Notification])
def list_notifications(
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
return session.exec(select(Notification)).all()
def _validate_headers(headers: str) -> dict[str, str]:
try:
headers_json = json.loads(headers or "{}") # pyright: ignore[reportAny]
if not isinstance(headers_json, dict) or any(
not isinstance(v, str)
for v in cast(dict[str, object], headers_json).values()
):
raise HTTPException(400, "Invalid headers JSON. Not of type object/dict")
headers_json = cast(dict[str, str], headers_json)
return headers_json
except (json.JSONDecodeError, ValueError):
raise HTTPException(400, "Invalid headers JSON")
def _upsert_notification(
name: str,
url: str,
event_type: str,
body: str,
body_type: NotificationBodyTypeEnum,
headers: str,
session: Session,
notification_id: uuid.UUID | None = None,
):
headers_json = _validate_headers(headers)
try:
if body_type == NotificationBodyTypeEnum.json:
json_body = json.loads(body, strict=False) # pyright: ignore[reportAny]
if not isinstance(json_body, dict):
raise HTTPException(422, "Invalid body. Not a JSON object")
body = json.dumps(json_body, indent=2)
except (json.JSONDecodeError, ValueError):
raise HTTPException(422, "Body is invalid JSON")
try:
event_enum = EventEnum(event_type)
except ValueError:
raise HTTPException(400, "Invalid event type")
try:
body_enum = NotificationBodyTypeEnum(body_type)
except ValueError:
raise HTTPException(400, "Invalid notification service type")
if notification_id:
notification = session.get(Notification, notification_id)
if not notification:
raise HTTPException(404, "Notification not found")
notification.name = name
notification.url = url
notification.event = event_enum
notification.body_type = body_enum
notification.body = body
notification.headers = headers_json
notification.enabled = True
else:
notification = Notification(
name=name,
url=url,
event=event_enum,
body_type=body_enum,
body=body,
headers=headers_json,
enabled=True,
)
session.add(notification)
session.commit()
session.refresh(notification)
return notification
@router.post("", response_model=Notification)
def create_notification(
body: NotificationRequest,
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
return _upsert_notification(
notification_id=body.id,
name=body.name,
url=body.url,
event_type=body.event_type,
body=body.body,
body_type=body.body_type,
headers=body.headers,
session=session,
)
@router.delete("/{id}")
def delete_notification(
id: uuid.UUID,
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
notif = session.get(Notification, id)
if not notif:
raise HTTPException(status_code=404, detail="Notification not found")
session.delete(notif)
session.commit()
return Response(status_code=204)
@router.post("/{id}/test")
async def test_notification_id(
id: uuid.UUID,
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
notif = session.get(Notification, id)
if not notif:
raise HTTPException(status_code=404, detail="Notification not found")
try:
await send_notification(session, notif)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return Response(status_code=200)
@router.patch("/{id}/enable")
def toggle_notification(
id: uuid.UUID,
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
notif = session.get(Notification, id)
if not notif:
raise HTTPException(status_code=404, detail="Notification not found")
notif.enabled = not notif.enabled
session.add(notif)
session.commit()
session.refresh(notif)
return notif
@router.post("/test")
async def test_notification(
body: NotificationRequest,
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
headers_json = _validate_headers(body.headers)
try:
event_enum = EventEnum(body.event_type)
except ValueError:
raise HTTPException(400, "Invalid event type")
try:
await send_notification(
session,
Notification(
name=body.name,
url=body.url,
event=event_enum,
body=body.body,
body_type=body.body_type,
headers=headers_json,
enabled=True,
),
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return Response(status_code=200)
+87
View File
@@ -0,0 +1,87 @@
from typing import Annotated
from aiohttp import ClientSession
from fastapi import APIRouter, Depends, Response, Security
from pydantic import BaseModel
from sqlmodel import Session
from app.internal.auth.authentication import APIKeyAuth, DetailedUser
from app.internal.models import GroupEnum
from app.internal.prowlarr.indexer_categories import indexer_categories
from app.internal.prowlarr.prowlarr import IndexerResponse, get_indexers
from app.internal.prowlarr.util import flush_prowlarr_cache, prowlarr_config
from app.util.connection import get_connection
from app.util.db import get_session
router = APIRouter(prefix="/prowlarr")
class ProwlarrSettings(BaseModel):
base_url: str
api_key: str
selected_categories: list[int]
selected_indexers: list[int]
all_categories: dict[int, str]
indexers: IndexerResponse
@router.get("", response_model=ProwlarrSettings)
async def get_prowlarr_settings(
session: Annotated[Session, Depends(get_session)],
client_session: Annotated[ClientSession, Depends(get_connection)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
indexers = await get_indexers(session, client_session)
return ProwlarrSettings(
base_url=prowlarr_config.get_base_url(session) or "",
api_key=prowlarr_config.get_api_key(session) or "",
selected_categories=prowlarr_config.get_categories(session),
selected_indexers=prowlarr_config.get_indexers(session),
all_categories=indexer_categories,
indexers=indexers,
)
class UpdateApiKey(BaseModel):
api_key: str
@router.put("/api-key", status_code=204)
def update_prowlarr_api_key(
body: UpdateApiKey,
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
prowlarr_config.set_api_key(session, body.api_key)
flush_prowlarr_cache()
return Response(status_code=204)
class UpdateBaseUrl(BaseModel):
base_url: str
@router.put("/base-url", status_code=204)
def update_prowlarr_base_url(
body: UpdateBaseUrl,
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
prowlarr_config.set_base_url(session, body.base_url)
flush_prowlarr_cache()
return Response(status_code=204)
class UpdateCategories(BaseModel):
categories: list[int]
@router.put("/categories", status_code=204)
def update_indexer_categories(
body: UpdateCategories,
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
prowlarr_config.set_categories(session, body.categories)
flush_prowlarr_cache()
return Response(status_code=204)
+161
View File
@@ -0,0 +1,161 @@
from typing import Annotated, Optional
from aiohttp import ClientSession
from fastapi import APIRouter, Depends, Security, Response, HTTPException
from pydantic import BaseModel
from sqlmodel import Session
from app.internal.auth.authentication import APIKeyAuth, DetailedUser
from app.internal.auth.config import auth_config
from app.internal.auth.login_types import LoginTypeEnum
from app.internal.auth.oidc_config import InvalidOIDCConfiguration, oidc_config
from app.internal.env_settings import Settings
from app.internal.models import GroupEnum
from app.util.connection import get_connection
from app.util.db import get_session
from app.util.log import logger
from app.util.time import Minute
router = APIRouter(prefix="/security")
class SecuritySettings(BaseModel):
login_type: LoginTypeEnum
access_token_expiry: int
min_password_length: int
oidc_endpoint: str
oidc_client_secret: str
oidc_client_id: str
oidc_scope: str
oidc_username_claim: str
oidc_group_claim: str
oidc_redirect_https: bool
oidc_logout_url: str
force_login_type: LoginTypeEnum | None
@router.get("", response_model=SecuritySettings)
def get_security_settings(
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
try:
force_login_type = Settings().app.get_force_login_type()
except ValueError as e:
logger.error("Invalid force login type", exc_info=e)
force_login_type = None
return SecuritySettings(
login_type=auth_config.get_login_type(session),
access_token_expiry=auth_config.get_access_token_expiry_minutes(session),
min_password_length=auth_config.get_min_password_length(session),
oidc_endpoint=oidc_config.get(session, "oidc_endpoint", ""),
oidc_client_secret=oidc_config.get(session, "oidc_client_secret", ""),
oidc_client_id=oidc_config.get(session, "oidc_client_id", ""),
oidc_scope=oidc_config.get(session, "oidc_scope", ""),
oidc_username_claim=oidc_config.get(session, "oidc_username_claim", ""),
oidc_group_claim=oidc_config.get(session, "oidc_group_claim", ""),
oidc_redirect_https=oidc_config.get_redirect_https(session),
oidc_logout_url=oidc_config.get(session, "oidc_logout_url", ""),
force_login_type=force_login_type,
)
@router.post("/reset-auth", status_code=204)
def reset_auth_secret(
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
auth_config.reset_auth_secret(session)
return Response(status_code=204)
class UpdateSecuritySettings(BaseModel):
login_type: LoginTypeEnum
access_token_expiry: Optional[int] = None
min_password_length: Optional[int] = None
oidc_endpoint: Optional[str] = None
oidc_client_id: Optional[str] = None
oidc_client_secret: Optional[str] = None
oidc_scope: Optional[str] = None
oidc_username_claim: Optional[str] = None
oidc_group_claim: Optional[str] = None
oidc_redirect_https: Optional[bool] = None
oidc_logout_url: Optional[str] = None
@router.patch("", status_code=204)
async def update_security_settings(
body: UpdateSecuritySettings,
session: Annotated[Session, Depends(get_session)],
client_session: Annotated[ClientSession, Depends(get_connection)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
if (
body.login_type in [LoginTypeEnum.basic, LoginTypeEnum.forms]
and body.min_password_length is not None
):
if body.min_password_length < 1:
raise HTTPException(
status_code=400, detail="Minimum password length can't be 0 or negative"
)
else:
auth_config.set_min_password_length(session, body.min_password_length)
if body.access_token_expiry is not None:
if body.access_token_expiry < 1:
raise HTTPException(
status_code=400, detail="Access token expiry can't be 0 or negative"
)
else:
auth_config.set_access_token_expiry_minutes(
session, Minute(body.access_token_expiry)
)
if body.login_type == LoginTypeEnum.oidc:
if body.oidc_endpoint:
try:
await oidc_config.set_endpoint(
session, client_session, body.oidc_endpoint
)
except InvalidOIDCConfiguration as e:
raise HTTPException(
status_code=400, detail=f"Invalid OIDC endpoint: {e.detail}"
)
if body.oidc_client_id:
oidc_config.set(session, "oidc_client_id", body.oidc_client_id)
if body.oidc_client_secret:
oidc_config.set(session, "oidc_client_secret", body.oidc_client_secret)
if body.oidc_scope:
oidc_config.set(session, "oidc_scope", body.oidc_scope)
if body.oidc_username_claim:
oidc_config.set(session, "oidc_username_claim", body.oidc_username_claim)
if body.oidc_redirect_https:
oidc_config.set(
session,
"oidc_redirect_https",
"true" if body.oidc_redirect_https else "",
)
if body.oidc_logout_url:
oidc_config.set(session, "oidc_logout_url", body.oidc_logout_url)
if body.oidc_group_claim is not None:
oidc_config.set(session, "oidc_group_claim", body.oidc_group_claim)
error_message = await oidc_config.validate(session, client_session)
if error_message:
raise HTTPException(status_code=400, detail=error_message)
try:
force_login_type = Settings().app.get_force_login_type()
except ValueError as e:
logger.error("Invalid force login type", exc_info=e)
force_login_type = None
if force_login_type and body.login_type != force_login_type:
raise HTTPException(
status_code=400,
detail=f"Cannot change login type to '{body.login_type.value}' when force login type is set to '{force_login_type.value}'",
)
auth_config.set_login_type(session, body.login_type)
return Response(status_code=204)
+12 -3
View File
@@ -153,10 +153,10 @@ def create_new_user(
@router.put("/{username}", response_model=UserResponse)
def update_user(
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
session: Annotated[Session, Depends(get_session)],
username: str,
user_data: UserUpdate,
session: Annotated[Session, Depends(get_session)],
_: Annotated[DetailedUser, Security(APIKeyAuth(GroupEnum.admin))],
):
"""
Updates the specified user's password and/or group.
@@ -186,11 +186,20 @@ def update_user(
detail=e.detail,
)
# create user to get pw hash
updated_user = create_user(
username, user_data.password, user.group, extra_data=user.extra_data
username,
user_data.password,
user.group,
extra_data=user.extra_data,
)
user.password = updated_user.password
if user_data.extra_data is not None:
user.extra_data = (
user_data.extra_data.strip() if user_data.extra_data.strip() != "" else None
)
if user_data.group is not None:
user.group = user_data.group
+66 -164
View File
@@ -1,5 +1,3 @@
import aiohttp
from app.internal.models import AudiobookSearchResult
import uuid
from typing import Annotated
@@ -14,37 +12,36 @@ from fastapi import (
Request,
Security,
)
from sqlmodel import Session, col, delete, select
from sqlmodel import Session
from app.internal import book_search
from app.internal.auth.authentication import ABRAuth, DetailedUser
from app.internal.book_search import (
audible_region_type,
audible_regions,
clear_old_book_caches,
get_book_by_asin,
get_region_from_settings,
list_audible_books,
)
from app.internal.models import (
AudiobookRequest,
EventEnum,
GroupEnum,
ManualBookRequest,
User,
)
from app.internal.notifications import (
send_all_manual_notifications,
send_all_notifications,
)
from app.internal.prowlarr.prowlarr import prowlarr_config
from app.internal.query import query_sources
from app.internal.prowlarr.util import prowlarr_config
from app.internal.ranking.quality import quality_config
from app.routers.wishlist import get_wishlist_results, get_wishlist_counts
from app.internal.db_queries import get_wishlist_results, get_wishlist_counts
from app.util.connection import get_connection
from app.util.db import get_session, open_session
from app.util.db import get_session
from app.util.log import logger
from app.util.templates import template_response
from app.routers.api.search import (
search_books,
search_suggestions as api_search_suggestions,
)
from app.routers.api.requests import (
create_request,
delete_request as api_delete_request,
create_manual_request,
ManualRequest,
update_manual_request,
)
router = APIRouter(prefix="/search")
@@ -60,32 +57,18 @@ async def read_search(
page: int = 0,
region: audible_region_type | None = None,
):
if region is None:
region = get_region_from_settings()
try:
if region is None:
region = get_region_from_settings()
if audible_regions.get(region) is None:
raise HTTPException(status_code=400, detail="Invalid region")
if query:
clear_old_book_caches(session)
results = await list_audible_books(
session=session,
client_session=client_session,
query=query,
num_results=num_results,
page=page,
audible_region=region,
)
else:
results = []
results = [
AudiobookSearchResult(
book=book,
requests=book.requests,
username=user.username,
)
for book in results
]
results = await search_books(
client_session=client_session,
session=session,
user=user,
query=query,
num_results=num_results,
page=page,
region=region,
)
prowlarr_configured = prowlarr_config.is_valid(session)
@@ -117,31 +100,14 @@ async def search_suggestions(
user: Annotated[DetailedUser, Security(ABRAuth())],
region: audible_region_type | None = None,
):
if region is None:
region = get_region_from_settings()
async with ClientSession() as client_session:
suggestions = await book_search.get_search_suggestions(
client_session, query, region
)
return template_response(
"search.html",
request,
user,
{"suggestions": suggestions},
block_name="search_suggestions",
)
async def background_start_query(asin: str, requester: User, auto_download: bool):
with open_session() as session:
async with ClientSession(timeout=aiohttp.ClientTimeout(60)) as client_session:
_ = await query_sources(
asin=asin,
session=session,
client_session=client_session,
start_auto_download=auto_download,
requester=requester,
)
suggestions = await api_search_suggestions(query, user, region)
return template_response(
"search.html",
request,
user,
{"suggestions": suggestions},
block_name="search_suggestions",
)
@router.post("/request/{asin}")
@@ -157,69 +123,33 @@ async def add_request(
user: Annotated[DetailedUser, Security(ABRAuth())],
num_results: Annotated[int, Form()] = 20,
):
book = await get_book_by_asin(client_session, asin, region)
if not book:
raise HTTPException(status_code=404, detail="Book not found")
if not session.exec(
select(AudiobookRequest).where(
AudiobookRequest.asin == asin,
AudiobookRequest.user_username == user.username,
)
).first():
book_request = AudiobookRequest(asin=asin, user_username=user.username)
session.add(book_request)
session.commit()
logger.info(
"Added new audiobook request",
username=user.username,
try:
await create_request(
asin=asin,
)
else:
logger.warning(
"User has already requested this book",
username=user.username,
asin=asin,
)
background_task.add_task(
send_all_notifications,
event_type=EventEnum.on_new_request,
requester=User.model_validate(user),
book_asin=asin,
)
if quality_config.get_auto_download(session) and user.is_above(GroupEnum.trusted):
# start querying and downloading if auto download is enabled
background_task.add_task(
background_start_query,
asin=asin,
requester=User.model_validate(user),
auto_download=True,
)
if audible_regions.get(region) is None:
raise HTTPException(status_code=400, detail="Invalid region")
if query:
results = await list_audible_books(
session=session,
client_session=client_session,
background_task=background_task,
user=user,
region=region,
)
except HTTPException as e:
logger.warning(
e.detail,
username=user.username,
asin=asin,
)
results = []
if query:
results = await search_books(
client_session=client_session,
session=session,
user=user,
query=query,
num_results=num_results,
page=page,
audible_region=region,
region=region,
)
else:
results = []
results = [
AudiobookSearchResult(
book=book,
requests=book.requests,
username=user.username,
)
for book in results
]
prowlarr_configured = prowlarr_config.is_valid(session)
@@ -249,19 +179,7 @@ async def delete_request(
user: Annotated[DetailedUser, Security(ABRAuth())],
downloaded: bool | None = None,
):
if user.is_admin():
session.execute(
delete(AudiobookRequest).where(col(AudiobookRequest.asin) == asin)
)
session.commit()
else:
session.execute(
delete(AudiobookRequest).where(
(col(AudiobookRequest.asin) == asin)
& (col(AudiobookRequest.user_username) == user.username)
)
)
session.commit()
await api_delete_request(asin, session, user)
results = get_wishlist_results(
session,
@@ -315,34 +233,18 @@ async def add_manual(
info: Annotated[str | None, Form()] = None,
id: uuid.UUID | None = None,
):
if id:
book_request = session.get(ManualBookRequest, id)
if not book_request:
raise HTTPException(status_code=404, detail="Book request not found")
book_request.title = title
book_request.subtitle = subtitle
book_request.authors = author.split(",")
book_request.narrators = narrator.split(",") if narrator else []
book_request.publish_date = publish_date
book_request.additional_info = info
else:
book_request = ManualBookRequest(
user_username=user.username,
title=title,
authors=author.split(","),
narrators=narrator.split(",") if narrator else [],
subtitle=subtitle,
publish_date=publish_date,
additional_info=info,
)
session.add(book_request)
session.commit()
background_task.add_task(
send_all_manual_notifications,
event_type=EventEnum.on_new_request,
book_request=ManualBookRequest.model_validate(book_request),
req_body = ManualRequest(
title=title,
author=author,
narrator=narrator,
subtitle=subtitle,
publish_date=publish_date,
info=info,
)
if id:
await update_manual_request(id, req_body, session, user)
else:
await create_manual_request(req_body, session, background_task, user)
auto_download = quality_config.get_auto_download(session)
+36 -55
View File
@@ -4,18 +4,20 @@ from typing import Annotated
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Security
from sqlmodel import Session, select
from app.internal.auth.authentication import (
ABRAuth,
DetailedUser,
create_api_key,
create_user,
is_correct_password,
raise_for_invalid_password,
)
from app.internal.auth.authentication import ABRAuth, DetailedUser
from app.internal.models import (
APIKey,
User,
)
from app.routers.api.settings.account import (
ChangePasswordRequest,
CreateAPIKeyRequest,
)
from app.routers.api.settings.account import change_password as api_change_password
from app.routers.api.settings.account import (
create_new_api_key as api_create_new_api_key,
)
from app.routers.api.settings.account import delete_api_key as api_delete_api_key
from app.routers.api.settings.account import toggle_api_key as api_toggle_api_key
from app.util.db import get_session
from app.util.templates import template_response
from app.util.toast import ToastException
@@ -49,19 +51,19 @@ def change_password(
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(ABRAuth())],
):
if not is_correct_password(user, old_password):
raise ToastException("Old password is incorrect", "error")
try:
raise_for_invalid_password(session, password, confirm_password)
api_change_password(
ChangePasswordRequest(
old_password=old_password,
new_password=password,
confirm_password=confirm_password,
),
session,
user,
)
except HTTPException as e:
raise ToastException(e.detail, "error")
new_user = create_user(user.username, password, user.group)
old_user = session.exec(select(User).where(User.username == user.username)).one()
old_user.password = new_user.password
session.add(old_user)
session.commit()
return template_response(
"settings_page/account.html",
request,
@@ -81,18 +83,11 @@ def create_new_api_key(
if not name.strip():
raise ToastException("API key name cannot be empty", "error")
api_key, private_key = create_api_key(user, name.strip())
same_name_key = session.exec(
select(APIKey).where(
APIKey.user_username == user.username, APIKey.name == name.strip()
)
).first()
if same_name_key:
raise ToastException("API key name must be unique", "error")
session.add(api_key)
session.commit()
try:
resp = api_create_new_api_key(CreateAPIKeyRequest(name=name), session, user)
private_key = resp.key
except HTTPException as e:
raise ToastException(e.detail, "error")
api_keys = session.exec(
select(APIKey).where(APIKey.user_username == user.username)
@@ -120,17 +115,10 @@ def delete_api_key(
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(ABRAuth())],
):
api_key = session.exec(
select(APIKey).where(
APIKey.id == api_key_id, APIKey.user_username == user.username
)
).first()
if not api_key:
raise ToastException("API key not found", "error", cause_refresh=True)
session.delete(api_key)
session.commit()
try:
api_delete_api_key(str(api_key_id), session, user)
except HTTPException as e:
raise ToastException(e.detail, "error")
api_keys = session.exec(
select(APIKey).where(APIKey.user_username == user.username)
@@ -155,23 +143,16 @@ def toggle_api_key(
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(ABRAuth())],
):
api_key = session.exec(
select(APIKey).where(
APIKey.id == api_key_id,
APIKey.user_username == user.username,
)
).first()
if not api_key:
raise ToastException("API key not found", "error")
api_key.enabled = not api_key.enabled
session.add(api_key)
session.commit()
try:
api_toggle_api_key(str(api_key_id), session, user)
except HTTPException as e:
raise ToastException(e.detail, "error")
api_keys = session.exec(
select(APIKey).where(APIKey.user_username == user.username)
).all()
enabled = next((k.enabled for k in api_keys if k.id == api_key_id), False)
return template_response(
"settings_page/account.html",
request,
@@ -179,7 +160,7 @@ def toggle_api_key(
{
"page": "account",
"api_keys": api_keys,
"success": f"API key {'enabled' if api_key.enabled else 'disabled'}",
"success": f"API key {'enabled' if enabled else 'disabled'}",
},
block_name="api_keys",
)
+21 -9
View File
@@ -6,6 +6,12 @@ from sqlmodel import Session
from app.internal.auth.authentication import ABRAuth, DetailedUser
from app.internal.models import GroupEnum
from app.internal.ranking.quality import IndexerFlag, QualityRange, quality_config
from app.routers.api.settings.download import (
UpdateDownloadSettings,
)
from app.routers.api.settings.download import (
update_download_settings as api_update_download_settings,
)
from app.util.db import get_session
from app.util.templates import template_response
@@ -77,15 +83,21 @@ def update_download(
)
unknown = QualityRange(from_kbits=unknown_from, to_kbits=unknown_to)
quality_config.set_auto_download(session, auto_download)
quality_config.set_range(session, "quality_flac", flac)
quality_config.set_range(session, "quality_m4b", m4b)
quality_config.set_range(session, "quality_mp3", mp3)
quality_config.set_range(session, "quality_unknown_audio", unknown_audio)
quality_config.set_range(session, "quality_unknown", unknown)
quality_config.set_min_seeders(session, min_seeders)
quality_config.set_name_exists_ratio(session, name_ratio)
quality_config.set_title_exists_ratio(session, title_ratio)
api_update_download_settings(
UpdateDownloadSettings(
auto_download=auto_download,
flac_range=flac,
m4b_range=m4b,
mp3_range=mp3,
unknown_audio_range=unknown_audio,
unknown_range=unknown,
min_seeders=min_seeders,
name_ratio=name_ratio,
title_ratio=title_ratio,
),
session,
admin_user,
)
return template_response(
"settings_page/download.html",
+6 -61
View File
@@ -11,13 +11,14 @@ from sqlmodel import Session
from app.internal.auth.authentication import ABRAuth, DetailedUser
from app.internal.indexers.abstract import SessionContainer
from app.internal.indexers.configuration import indexer_configuration_cache
from app.internal.indexers.indexer_util import IndexerContext, get_indexer_contexts
from app.internal.indexers.indexer_util import (
get_indexer_contexts,
update_single_indexer,
)
from app.internal.models import GroupEnum
from app.internal.prowlarr.prowlarr import flush_prowlarr_cache
from app.util.cache import StringConfigCache
from app.util.connection import get_connection
from app.util.db import get_session, open_session
from app.util.json_type import get_bool
from app.util.db import get_session
from app.util.log import logger
from app.util.templates import template_response
from app.util.toast import ToastException
@@ -28,7 +29,7 @@ last_modified = 0
async def check_indexer_file_changes():
with open_session() as session:
with next(get_session()) as session:
async with ClientSession() as client_session:
try:
await read_indexer_file(session, client_session)
@@ -49,62 +50,6 @@ async def lifespan(app: FastAPI):
router = APIRouter(prefix="/indexers", lifespan=lifespan)
async def update_single_indexer(
indexer_select: str,
values: Mapping[str, object],
session: Session,
client_session: ClientSession,
ignore_missing_booleans: bool = False,
):
"""
Update a single indexer with the given values.
`ignore_missing_booleans` can be set to true to ignore missing boolean values. By default, missing booleans are treated as false.
"""
session_container = SessionContainer(session=session, client_session=client_session)
contexts = await get_indexer_contexts(
session_container, check_required=False, return_disabled=True
)
updated_context: IndexerContext | None = None
for context in contexts:
if context.indexer.name == indexer_select:
updated_context = context
break
if not updated_context:
raise ValueError("Indexer not found")
for key, context in updated_context.configuration.items():
value = values.get(key)
if value is None:
# forms do not include false checkboxes, so we handle missing booleans as false
if context.type_ is bool and not ignore_missing_booleans:
value = False
else:
logger.warning("Value is missing for key", key=key)
continue
if context.type_ is bool:
indexer_configuration_cache.set_bool(session, key, value == "on")
else:
indexer_configuration_cache.set(session, key, str(value))
if "enabled" in values and (
isinstance(e := values["enabled"], str)
or isinstance(e, bool)
or isinstance(e, int)
):
logger.debug("Setting enabled state", enabled=values["enabled"])
enabled = get_bool(e) or False
await updated_context.indexer.set_enabled(
session_container,
enabled,
)
flush_prowlarr_cache()
async def read_indexer_file(
session: Session, client_session: ClientSession, *, file_path: str | None = None
):
+66 -123
View File
@@ -1,19 +1,31 @@
import json
import uuid
from typing import Annotated, cast
from typing import Annotated
from aiohttp import ClientResponseError
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, Security
from sqlmodel import Session, select
from sqlmodel import Session
from app.internal.auth.authentication import ABRAuth, DetailedUser
from app.internal.models import (
EventEnum,
GroupEnum,
Notification,
NotificationBodyTypeEnum,
)
from app.internal.notifications import send_notification
from app.routers.api.settings.notifications import (
NotificationRequest,
list_notifications,
)
from app.routers.api.settings.notifications import (
create_notification as api_create_notification,
)
from app.routers.api.settings.notifications import (
delete_notification as api_delete_notification,
)
from app.routers.api.settings.notifications import (
test_notification_id as api_test_notification_id,
)
from app.routers.api.settings.notifications import (
toggle_notification as api_toggle_notification,
)
from app.util.db import get_session
from app.util.templates import template_response
from app.util.toast import ToastException
@@ -27,7 +39,7 @@ def read_notifications(
session: Annotated[Session, Depends(get_session)],
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
):
notifications = session.exec(select(Notification)).all()
notifications = list_notifications(session, admin_user)
event_types = [e.value for e in EventEnum]
body_types = [e.value for e in NotificationBodyTypeEnum]
return template_response(
@@ -44,7 +56,7 @@ def read_notifications(
def _list_notifications(request: Request, session: Session, admin_user: DetailedUser):
notifications = session.exec(select(Notification)).all()
notifications = list_notifications(session, admin_user)
event_types = [e.value for e in EventEnum]
body_types = [e.value for e in NotificationBodyTypeEnum]
return template_response(
@@ -61,78 +73,6 @@ def _list_notifications(request: Request, session: Session, admin_user: Detailed
)
def _upsert_notification(
request: Request,
*,
name: str,
url: str,
event_type: str,
body: str,
body_type: NotificationBodyTypeEnum,
headers: str,
admin_user: DetailedUser,
session: Session,
notification_id: uuid.UUID | None = None,
):
try:
headers_json = json.loads(headers or "{}") # pyright: ignore[reportAny]
if not isinstance(headers_json, dict) or any(
not isinstance(v, str)
for v in cast(dict[str, object], headers_json).values()
):
raise ToastException(
"Invalid headers JSON. Not of type object/dict", "error"
)
headers_json = cast(dict[str, str], headers_json)
except (json.JSONDecodeError, ValueError):
raise ToastException("Invalid headers JSON", "error")
try:
if body_type == NotificationBodyTypeEnum.json:
json_body = json.loads(body, strict=False) # pyright: ignore[reportAny]
if not isinstance(json_body, dict):
raise ToastException("Invalid body. Not a JSON object", "error")
body = json.dumps(json_body, indent=2)
except (json.JSONDecodeError, ValueError):
raise ToastException("Body is invalid JSON", "error")
try:
event_enum = EventEnum(event_type)
except ValueError:
raise ToastException("Invalid event type", "error")
try:
body_enum = NotificationBodyTypeEnum(body_type)
except ValueError:
raise ToastException("Invalid notification service type", "error")
if notification_id:
notification = session.get(Notification, notification_id)
if not notification:
raise ToastException("Notification not found", "error")
notification.name = name
notification.url = url
notification.event = event_enum
notification.body_type = body_enum
notification.body = body
notification.headers = headers_json
notification.enabled = True
else:
notification = Notification(
name=name,
url=url,
event=event_enum,
body_type=body_enum,
body=body,
headers=headers_json,
enabled=True,
)
session.add(notification)
session.commit()
return _list_notifications(request, session, admin_user)
@router.post("")
def add_notification(
request: Request,
@@ -140,22 +80,28 @@ def add_notification(
url: Annotated[str, Form()],
event_type: Annotated[str, Form()],
body_type: Annotated[NotificationBodyTypeEnum, Form()],
headers: Annotated[str, Form()],
session: Annotated[Session, Depends(get_session)],
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
headers: Annotated[str, Form()] = "{}",
body: Annotated[str, Form()] = "{}",
):
return _upsert_notification(
request=request,
name=name,
url=url,
event_type=event_type,
body=body,
body_type=body_type,
headers=headers,
admin_user=admin_user,
session=session,
)
try:
api_create_notification(
NotificationRequest(
id=None,
name=name,
url=url,
event_type=event_type,
body=body,
body_type=body_type,
headers=headers,
),
session,
admin_user,
)
except HTTPException as e:
raise ToastException(e.detail, "error")
return _list_notifications(request, session, admin_user)
@router.put("/{notification_id}")
@@ -171,18 +117,23 @@ def update_notification(
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
body: Annotated[str, Form()] = "{}",
):
return _upsert_notification(
notification_id=notification_id,
request=request,
name=name,
url=url,
event_type=event_type,
body=body,
body_type=body_type,
headers=headers,
admin_user=admin_user,
session=session,
)
try:
api_create_notification(
NotificationRequest(
id=notification_id,
name=name,
url=url,
event_type=event_type,
body=body,
body_type=body_type,
headers=headers,
),
session,
admin_user,
)
except HTTPException as e:
raise ToastException(e.detail, "error")
return _list_notifications(request, session, admin_user)
@router.patch("/{notification_id}/enable")
@@ -192,12 +143,10 @@ def toggle_notification(
session: Annotated[Session, Depends(get_session)],
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
):
notification = session.get_one(Notification, notification_id)
if not notification:
raise ToastException("Notification not found", "error")
notification.enabled = not notification.enabled
session.add(notification)
session.commit()
try:
api_toggle_notification(notification_id, session, admin_user)
except HTTPException as e:
raise ToastException(e.detail, "error")
return _list_notifications(request, session, admin_user)
@@ -209,11 +158,10 @@ def delete_notification(
session: Annotated[Session, Depends(get_session)],
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
):
notification = session.get_one(Notification, notification_id)
if not notification:
raise ToastException("Notification not found", "error")
session.delete(notification)
session.commit()
try:
api_delete_notification(notification_id, session, admin_user)
except HTTPException as e:
raise ToastException(e.detail, "error")
return _list_notifications(request, session, admin_user)
@@ -224,14 +172,9 @@ async def test_notification(
session: Annotated[Session, Depends(get_session)],
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
):
_ = admin_user
notification = session.get(Notification, notification_id)
if not notification:
raise HTTPException(status_code=404, detail="Notification not found")
try:
await send_notification(session, notification)
except ClientResponseError:
raise HTTPException(status_code=500, detail="Failed to send notification")
await api_test_notification_id(notification_id, session, admin_user)
except HTTPException as e:
raise ToastException(e.detail, "error") from None
return Response(status_code=204)
+22 -12
View File
@@ -7,10 +7,21 @@ from sqlmodel import Session
from app.internal.auth.authentication import ABRAuth, DetailedUser
from app.internal.models import GroupEnum
from app.internal.prowlarr.indexer_categories import indexer_categories
from app.internal.prowlarr.prowlarr import (
flush_prowlarr_cache,
get_indexers,
prowlarr_config,
from app.internal.prowlarr.prowlarr import get_indexers
from app.internal.prowlarr.util import flush_prowlarr_cache, prowlarr_config
from app.routers.api.settings.prowlarr import (
UpdateApiKey,
UpdateBaseUrl,
UpdateCategories,
)
from app.routers.api.settings.prowlarr import (
update_indexer_categories as api_update_indexer_categories,
)
from app.routers.api.settings.prowlarr import (
update_prowlarr_api_key as api_update_prowlarr_api_key,
)
from app.routers.api.settings.prowlarr import (
update_prowlarr_base_url as api_update_prowlarr_base_url,
)
from app.util.connection import get_connection
from app.util.db import get_session
@@ -56,9 +67,7 @@ def update_prowlarr_api_key(
session: Annotated[Session, Depends(get_session)],
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
):
_ = admin_user
prowlarr_config.set_api_key(session, api_key)
flush_prowlarr_cache()
api_update_prowlarr_api_key(UpdateApiKey(api_key=api_key), session, admin_user)
return Response(status_code=204, headers={"HX-Refresh": "true"})
@@ -68,9 +77,7 @@ def update_prowlarr_base_url(
session: Annotated[Session, Depends(get_session)],
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
):
_ = admin_user
prowlarr_config.set_base_url(session, base_url)
flush_prowlarr_cache()
api_update_prowlarr_base_url(UpdateBaseUrl(base_url=base_url), session, admin_user)
return Response(status_code=204, headers={"HX-Refresh": "true"})
@@ -83,9 +90,12 @@ def update_indexer_categories(
):
if categories is None:
categories = []
prowlarr_config.set_categories(session, categories)
api_update_indexer_categories(
UpdateCategories(categories=categories), session, admin_user
)
selected = set(categories)
flush_prowlarr_cache()
return template_response(
"settings_page/prowlarr.html",
+29 -58
View File
@@ -1,21 +1,25 @@
from typing import Annotated
from aiohttp import ClientSession
from fastapi import APIRouter, Depends, Form, Request, Response, Security
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, Security
from sqlmodel import Session
from app.internal.auth.authentication import ABRAuth, DetailedUser
from app.internal.auth.config import auth_config
from app.internal.auth.login_types import LoginTypeEnum
from app.internal.auth.oidc_config import InvalidOIDCConfiguration, oidc_config
from app.internal.auth.oidc_config import oidc_config
from app.internal.env_settings import Settings
from app.internal.models import GroupEnum
from app.util.connection import get_connection
from app.util.db import get_session
from app.util.log import logger
from app.util.templates import template_response
from app.util.time import Minute
from app.util.toast import ToastException
from app.routers.api.settings.security import (
reset_auth_secret as api_reset_auth_secret,
update_security_settings as api_update_security_settings,
UpdateSecuritySettings,
)
router = APIRouter(prefix="/security")
@@ -59,8 +63,7 @@ def reset_auth_secret(
session: Annotated[Session, Depends(get_session)],
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
):
_ = admin_user
auth_config.reset_auth_secret(session)
api_reset_auth_secret(session, admin_user)
return Response(status_code=204, headers={"HX-Refresh": "true"})
@@ -82,67 +85,35 @@ async def update_security(
oidc_redirect_https: Annotated[bool | None, Form()] = None,
oidc_logout_url: Annotated[str | None, Form()] = None,
):
if (
login_type in [LoginTypeEnum.basic, LoginTypeEnum.forms]
and min_password_length is not None
):
if min_password_length < 1:
raise ToastException(
"Minimum password length can't be 0 or negative", "error"
)
else:
auth_config.set_min_password_length(session, min_password_length)
if access_token_expiry is not None:
if access_token_expiry < 1:
raise ToastException("Access token expiry can't be 0 or negative", "error")
else:
auth_config.set_access_token_expiry_minutes(
session, Minute(access_token_expiry)
)
if login_type == LoginTypeEnum.oidc:
if oidc_endpoint:
try:
await oidc_config.set_endpoint(session, client_session, oidc_endpoint)
except InvalidOIDCConfiguration as e:
raise ToastException(f"Invalid OIDC endpoint: {e.detail}", "error")
if oidc_client_id:
oidc_config.set(session, "oidc_client_id", oidc_client_id)
if oidc_client_secret:
oidc_config.set(session, "oidc_client_secret", oidc_client_secret)
if oidc_scope:
oidc_config.set(session, "oidc_scope", oidc_scope)
if oidc_username_claim:
oidc_config.set(session, "oidc_username_claim", oidc_username_claim)
if oidc_redirect_https:
oidc_config.set(
session,
"oidc_redirect_https",
"true" if oidc_redirect_https else "",
)
if oidc_logout_url:
oidc_config.set(session, "oidc_logout_url", oidc_logout_url)
if oidc_group_claim is not None:
oidc_config.set(session, "oidc_group_claim", oidc_group_claim)
error_message = await oidc_config.validate(session, client_session)
if error_message:
raise ToastException(error_message, "error")
try:
await api_update_security_settings(
UpdateSecuritySettings(
login_type=login_type,
access_token_expiry=access_token_expiry,
min_password_length=min_password_length,
oidc_endpoint=oidc_endpoint,
oidc_client_id=oidc_client_id,
oidc_client_secret=oidc_client_secret,
oidc_scope=oidc_scope,
oidc_username_claim=oidc_username_claim,
oidc_group_claim=oidc_group_claim,
oidc_redirect_https=oidc_redirect_https,
oidc_logout_url=oidc_logout_url,
),
session,
client_session,
admin_user,
)
except HTTPException as e:
raise ToastException(e.detail, "error") from None
try:
force_login_type = Settings().app.get_force_login_type()
except ValueError as e:
logger.error("Invalid force login type", exc_info=e)
force_login_type = None
if force_login_type and login_type != force_login_type:
raise ToastException(
f"Cannot change login type to '{login_type.value}' when force login type is set to '{force_login_type.value}'",
"error",
)
old = auth_config.get_login_type(session)
auth_config.set_login_type(session, login_type)
return template_response(
"settings_page/security.html",
+45 -47
View File
@@ -3,18 +3,20 @@ from typing import Annotated
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Security
from sqlmodel import Session, select
from app.internal.auth.authentication import (
ABRAuth,
DetailedUser,
create_user,
raise_for_invalid_password,
)
from app.internal.auth.authentication import ABRAuth, DetailedUser
from app.internal.auth.config import auth_config
from app.internal.auth.login_types import LoginTypeEnum
from app.internal.models import GroupEnum, User
from app.util.db import get_session
from app.util.templates import template_response
from app.util.toast import ToastException
from app.routers.api.users import (
UserUpdate,
create_new_user as api_create_new_user,
delete_user as api_delete_user,
update_user as api_update_user,
UserCreate,
)
router = APIRouter(prefix="/users")
@@ -51,23 +53,23 @@ def create_new_user(
if username.strip() == "":
raise ToastException("Invalid username", "error")
try:
raise_for_invalid_password(session, password, ignore_confirm=True)
except HTTPException as e:
raise ToastException(e.detail, "error")
if group not in GroupEnum.__members__:
raise ToastException("Invalid group selected", "error")
group = GroupEnum[group]
user = session.exec(select(User).where(User.username == username)).first()
if user:
raise ToastException("Username already exists", "error")
user = create_user(username, password, group)
session.add(user)
session.commit()
try:
api_create_new_user(
UserCreate(
username=username,
password=password,
group=GroupEnum[group],
root=False,
extra_data=None,
),
session,
admin_user,
)
except HTTPException as e:
raise ToastException(e.detail, "error")
users = session.exec(select(User)).all()
@@ -87,16 +89,10 @@ def delete_user(
session: Annotated[Session, Depends(get_session)],
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
):
if username == admin_user.username:
raise ToastException("Cannot delete own user", "error")
user = session.exec(select(User).where(User.username == username)).one_or_none()
if user and user.root:
raise ToastException("Cannot delete root user", "error")
if user:
session.delete(user)
session.commit()
try:
api_delete_user(username, session, admin_user)
except HTTPException as e:
raise ToastException(e.detail, "error")
users = session.exec(select(User)).all()
@@ -118,26 +114,28 @@ def update_user(
group: Annotated[GroupEnum | None, Form()] = None,
extra_data: Annotated[str | None, Form()] = None,
):
updated: list[str] = []
user = session.exec(select(User).where(User.username == username)).one_or_none()
if user:
if extra_data is not None:
updated.append("extra data")
user.extra_data = extra_data.strip() if extra_data.strip() != "" else None
if group is not None:
if user.root:
raise ToastException("Cannot change root user's group", "error")
user.group = group
updated.append("group")
session.add(user)
session.commit()
try:
api_update_user(
admin_user,
session=session,
username=username,
user_data=UserUpdate(
password=None,
group=group,
extra_data=extra_data,
),
)
except HTTPException as e:
raise ToastException(e.detail, "error")
if not updated:
if group is None and extra_data is None:
success_msg = "No changes made"
elif updated == ["extra data"]:
success_msg = "Updated user extra data"
elif updated == ["group"]:
elif group is not None and extra_data is not None:
success_msg = "Updated group and extra data"
elif group is not None:
success_msg = "Updated group"
elif extra_data is not None:
success_msg = "Updated extra data"
else:
success_msg = "Updated user"
+38 -218
View File
@@ -1,10 +1,5 @@
import aiohttp
from app.util.toast import ToastException
from sqlalchemy.orm import InstrumentedAttribute, selectinload
from app.internal.models import AudiobookWishlistResult
from app.internal.models import Audiobook
import uuid
from typing import Annotated, Literal, cast
from typing import Annotated
from aiohttp import ClientSession
from fastapi import (
@@ -14,129 +9,35 @@ from fastapi import (
Form,
HTTPException,
Request,
Response,
Security,
)
from pydantic import BaseModel
from sqlalchemy import func
from sqlmodel import Session, asc, col, not_, select
from sqlmodel import Session
from app.internal.auth.authentication import ABRAuth, DetailedUser
from app.internal.models import (
AudiobookRequest,
EventEnum,
GroupEnum,
ManualBookRequest,
User,
from app.internal.db_queries import (
get_all_manual_requests,
get_wishlist_counts,
get_wishlist_results,
)
from app.internal.notifications import (
send_all_manual_notifications,
send_all_notifications,
from app.internal.models import GroupEnum
from app.routers.api.requests import (
DownloadSourceBody,
delete_manual_request,
mark_manual_downloaded,
start_auto_download_endpoint,
)
from app.internal.prowlarr.prowlarr import (
ProwlarrMisconfigured,
prowlarr_config,
start_download,
)
from app.internal.query import query_sources
from app.routers.api.requests import download_book as api_download_book
from app.routers.api.requests import list_sources as api_list_sources
from app.routers.api.requests import mark_downloaded as api_mark_downloaded
from app.util.connection import get_connection
from app.util.db import get_session, open_session
from app.util.db import get_session
from app.util.redirect import BaseUrlRedirectResponse
from app.util.templates import template_response
from app.util.toast import ToastException
router = APIRouter(prefix="/wishlist")
class WishlistCounts(BaseModel):
requests: int
downloaded: int
manual: int
def get_wishlist_counts(session: Session, user: User | None = None) -> WishlistCounts:
"""
If a non-admin user is given, only count requests for that user.
Admins can see and get counts for all requests.
"""
username = None if user is None or user.is_admin() else user.username
rows = session.exec(
select(Audiobook.downloaded, func.count("*"))
.where(not username or AudiobookRequest.user_username == username)
.select_from(Audiobook)
.join(AudiobookRequest)
.group_by(col(Audiobook.downloaded))
).all()
requests = 0
downloaded = 0
for downloaded_status, count in rows:
if downloaded_status:
downloaded = count
else:
requests = count
manual = session.exec(
select(func.count())
.select_from(ManualBookRequest)
.where(
not username or ManualBookRequest.user_username == username,
col(ManualBookRequest.user_username).is_not(None),
)
).one()
return WishlistCounts(
requests=requests,
downloaded=downloaded,
manual=manual,
)
def get_wishlist_results(
session: Session,
username: str | None = None,
response_type: Literal["all", "downloaded", "not_downloaded"] = "all",
) -> list[AudiobookWishlistResult]:
"""
Gets the books that have been requested. If a username is given only the books requested by that
user are returned. If no username is given, all book requests are returned.
"""
match response_type:
case "downloaded":
clause = Audiobook.downloaded
case "not_downloaded":
clause = not_(Audiobook.downloaded)
case _:
clause = True
results = session.exec(
select(Audiobook)
.where(
clause,
col(Audiobook.asin).in_(
select(AudiobookRequest.asin).where(
not username or AudiobookRequest.user_username == username
)
),
)
.options(
selectinload(
cast(
InstrumentedAttribute[list[AudiobookRequest]],
cast(object, Audiobook.requests),
)
)
)
).all()
return [
AudiobookWishlistResult(
book=book,
requests=book.requests,
)
for book in results
]
@router.get("")
async def wishlist(
request: Request,
@@ -179,18 +80,7 @@ async def update_downloaded(
background_task: BackgroundTasks,
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
):
book = session.exec(select(Audiobook).where(Audiobook.asin == asin)).first()
if book:
book.downloaded = True
session.add(book)
session.commit()
background_task.add_task(
send_all_notifications,
event_type=EventEnum.on_successful_download,
requester=None,
book_asin=asin,
)
await api_mark_downloaded(asin, session, background_task, admin_user)
username = None if admin_user.is_admin() else admin_user.username
results = get_wishlist_results(session, username, "not_downloaded")
@@ -210,24 +100,13 @@ async def update_downloaded(
)
def _get_all_manual_requests(session: Session, user: User):
return session.exec(
select(ManualBookRequest)
.where(
user.is_admin() or ManualBookRequest.user_username == user.username,
col(ManualBookRequest.user_username).is_not(None),
)
.order_by(asc(ManualBookRequest.downloaded))
).all()
@router.get("/manual")
async def manual(
request: Request,
session: Annotated[Session, Depends(get_session)],
user: Annotated[DetailedUser, Security(ABRAuth())],
):
books = _get_all_manual_requests(session, user)
books = get_all_manual_requests(session, user)
counts = get_wishlist_counts(session, user)
return template_response(
"wishlist_page/manual.html",
@@ -245,19 +124,9 @@ async def downloaded_manual(
background_task: BackgroundTasks,
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
):
book_request = session.get(ManualBookRequest, id)
if book_request:
book_request.downloaded = True
session.add(book_request)
session.commit()
await mark_manual_downloaded(id, session, background_task, admin_user)
background_task.add_task(
send_all_manual_notifications,
event_type=EventEnum.on_successful_download,
book_request=ManualBookRequest.model_validate(book_request),
)
books = _get_all_manual_requests(session, admin_user)
books = get_all_manual_requests(session, admin_user)
counts = get_wishlist_counts(session, admin_user)
return template_response(
@@ -281,12 +150,9 @@ async def delete_manual(
session: Annotated[Session, Depends(get_session)],
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
):
book = session.get(ManualBookRequest, id)
if book:
session.delete(book)
session.commit()
await delete_manual_request(id, session, admin_user)
books = _get_all_manual_requests(session, admin_user)
books = get_all_manual_requests(session, admin_user)
counts = get_wishlist_counts(session, admin_user)
return template_response(
@@ -303,27 +169,6 @@ async def delete_manual(
)
@router.post("/refresh/{asin}")
async def refresh_source(
asin: str,
background_task: BackgroundTasks,
user: Annotated[DetailedUser, Security(ABRAuth())],
force_refresh: bool = False,
):
# causes the sources to be placed into cache once they're done
with open_session() as session:
async with ClientSession(timeout=aiohttp.ClientTimeout(30)) as client_session:
background_task.add_task(
query_sources,
asin=asin,
session=session,
client_session=client_session,
force_refresh=force_refresh,
requester=User.model_validate(user),
)
return Response(status_code=202)
@router.get("/sources/{asin}")
async def list_sources(
request: Request,
@@ -334,19 +179,19 @@ async def list_sources(
only_body: bool = False,
):
try:
prowlarr_config.raise_if_invalid(session)
except ProwlarrMisconfigured:
return BaseUrlRedirectResponse(
"/settings/prowlarr?prowlarr_misconfigured=1", status_code=302
result = await api_list_sources(
asin,
session,
client_session,
admin_user,
only_cached=not only_body,
)
result = await query_sources(
asin,
session=session,
client_session=client_session,
requester=admin_user,
only_return_if_cached=not only_body, # on initial load we want to respond quickly
)
except HTTPException as e:
if e.detail == "Prowlarr misconfigured":
return BaseUrlRedirectResponse(
"/settings/prowlarr?prowlarr_misconfigured=1", status_code=302
)
raise e
if only_body:
return template_response(
@@ -373,27 +218,8 @@ async def download_book(
client_session: Annotated[ClientSession, Depends(get_connection)],
admin_user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.admin))],
):
try:
resp = await start_download(
session=session,
client_session=client_session,
guid=guid,
indexer_id=indexer_id,
requester=admin_user,
book_asin=asin,
)
except ProwlarrMisconfigured as e:
raise HTTPException(status_code=500, detail=str(e))
if not resp.ok:
raise HTTPException(status_code=500, detail="Failed to start download")
book = session.exec(select(Audiobook).where(Audiobook.asin == asin)).first()
if book:
book.downloaded = True
session.add(book)
session.commit()
return Response(status_code=204)
body = DownloadSourceBody(guid=guid, indexer_id=indexer_id)
return await api_download_book(asin, body, session, client_session, admin_user)
@router.post("/auto-download/{asin}")
@@ -405,13 +231,7 @@ async def start_auto_download(
user: Annotated[DetailedUser, Security(ABRAuth(GroupEnum.trusted))],
):
try:
await query_sources(
asin=asin,
start_auto_download=True,
session=session,
client_session=client_session,
requester=user,
)
await start_auto_download_endpoint(asin, session, client_session, user)
except HTTPException as e:
raise ToastException(e.detail) from None
-11
View File
@@ -1,5 +1,3 @@
from contextlib import contextmanager
from sqlalchemy import create_engine
from sqlmodel import Session, text
@@ -20,12 +18,3 @@ def get_session():
if not Settings().db.use_postgres:
session.execute(text("PRAGMA foreign_keys=ON"))
yield session
# TODO: couldn't get a single function to work with FastAPI and allow for session creation wherever
@contextmanager
def open_session():
with Session(engine) as session:
if not Settings().db.use_postgres:
session.execute(text("PRAGMA foreign_keys=ON"))
yield session
+2 -1
View File
@@ -1,5 +1,6 @@
# pyright: reportUnknownMemberType=false
import html
from jinja2_htmlmin import minify_loader
from jinja2 import Environment, FileSystemLoader
from typing import Any, Mapping, overload
@@ -30,7 +31,7 @@ def _zfill(val: str | int | float, num: int) -> str:
def _to_js_string(val: str | int | float) -> str:
return f"'{str(val).replace("'", "\\'").replace('\n', '\\n')}'"
return html.escape(f"'{str(val).replace("'", "\\'").replace('\n', '\\n')}'")
templates.env.filters["zfill"] = _zfill
+3
View File
@@ -74,6 +74,9 @@
{% endif %}
{% if api_keys %}
{% if success %}
<script>toast("{{success|safe}}", "success");</script>
{% endif %}
<div class="overflow-x-auto">
<table class="table w-full">
<thead>
-10
View File
@@ -68,7 +68,6 @@
name="flac_from"
type="text"
class="border-none w-[4rem] text-center"
readonly
value="{{ flac_range.from_kbits }}"
/>
<div
@@ -80,7 +79,6 @@
name="flac_to"
type="text"
class="border-none w-[4rem] text-center"
readonly
value="{{ flac_range.to_kbits }}"
/>
<script>
@@ -102,7 +100,6 @@
name="m4b_from"
type="text"
class="border-none w-[4rem] text-center"
readonly
value="{{ m4b_range.from_kbits }}"
/>
<div id="m4b-range" class="w-full slider-no-overlap slider-round"></div>
@@ -111,7 +108,6 @@
name="m4b_to"
type="text"
class="border-none w-[4rem] text-center"
readonly
value="{{ m4b_range.to_kbits }}"
/>
<script>
@@ -133,7 +129,6 @@
name="mp3_from"
type="text"
class="border-none w-[4rem] text-center"
readonly
value="{{ mp3_range.from_kbits }}"
/>
<div id="mp3-range" class="w-full slider-no-overlap slider-round"></div>
@@ -142,7 +137,6 @@
name="mp3_to"
type="text"
class="border-none w-[4rem] text-center"
readonly
value="{{ mp3_range.to_kbits }}"
/>
<script>
@@ -164,7 +158,6 @@
name="unknown_audio_from"
type="text"
class="border-none w-[4rem] text-center"
readonly
value="{{ unknown_audio_range.from_kbits }}"
/>
<div
@@ -176,7 +169,6 @@
name="unknown_audio_to"
type="text"
class="border-none w-[4rem] text-center"
readonly
value="{{ unknown_audio_range.to_kbits }}"
/>
<script>
@@ -198,7 +190,6 @@
name="unknown_from"
type="text"
class="border-none w-[4rem] text-center"
readonly
value="{{ unknown_range.from_kbits }}"
/>
<div
@@ -210,7 +201,6 @@
name="unknown_to"
type="text"
class="border-none w-[4rem] text-center"
readonly
value="{{ unknown_range.to_kbits }}"
/>
<script>