diff --git a/backend-services/doorman.py b/backend-services/doorman.py index d579bc7..ff3f04d 100755 --- a/backend-services/doorman.py +++ b/backend-services/doorman.py @@ -4,51 +4,55 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from datetime import datetime, timedelta +import asyncio +import json +import logging +import multiprocessing +import os +import re +import shutil +import signal +import subprocess +import sys +import time +import uuid +from contextlib import asynccontextmanager +from datetime import timedelta from logging.handlers import RotatingFileHandler +from pathlib import Path + +import uvicorn +from dotenv import load_dotenv from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError -from jose import jwt, JWTError from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from starlette.middleware.base import BaseHTTPMiddleware -from fastapi.responses import Response -from contextlib import asynccontextmanager -from redis.asyncio import Redis +from jose import JWTError from pydantic import BaseSettings -from dotenv import load_dotenv -import multiprocessing -import logging -import json -import re -import os -import sys -import subprocess -import signal -import uvicorn -import time -import asyncio -import uuid -import shutil -from pathlib import Path +from redis.asyncio import Redis +from starlette.middleware.base import BaseHTTPMiddleware try: if sys.version_info >= (3, 13): try: - from importlib.metadata import version, PackageNotFoundError + from importlib.metadata import PackageNotFoundError, version except Exception: version = None PackageNotFoundError = Exception if version is not None: try: v = version('aiohttp') - parts = [int(p) for p in (v.split('.')[:3] + ['0', '0'])[:3] if p.isdigit() or p.isnumeric()] + parts = [ + int(p) + for p in (v.split('.')[:3] + ['0', '0'])[:3] + if p.isdigit() or p.isnumeric() + ] while len(parts) < 3: parts.append(0) if tuple(parts) < (3, 10, 10): raise SystemExit( - f"Incompatible aiohttp {v} detected on Python {sys.version.split()[0]}. " - "Please upgrade to aiohttp>=3.10.10 (pip install -U aiohttp) or run with Python 3.11." + f'Incompatible aiohttp {v} detected on Python {sys.version.split()[0]}. ' + 'Please upgrade to aiohttp>=3.10.10 (pip install -U aiohttp) or run with Python 3.11.' ) except PackageNotFoundError: pass @@ -58,47 +62,58 @@ except Exception: pass from models.response_model import ResponseModel -from utils.cache_manager_util import cache_manager -from utils.auth_blacklist import purge_expired_tokens -from utils.doorman_cache_util import doorman_cache -from utils.hot_reload_config import hot_config -from routes.authorization_routes import authorization_router -from routes.group_routes import group_router -from routes.role_routes import role_router -from routes.subscription_routes import subscription_router -from routes.user_routes import user_router +from routes.analytics_routes import analytics_router from routes.api_routes import api_router +from routes.authorization_routes import authorization_router +from routes.config_hot_reload_routes import config_hot_reload_router +from routes.config_routes import config_router +from routes.credit_routes import credit_router +from routes.dashboard_routes import dashboard_router +from routes.demo_routes import demo_router from routes.endpoint_routes import endpoint_router from routes.gateway_routes import gateway_router -from routes.routing_routes import routing_router -from routes.proto_routes import proto_router +from routes.group_routes import group_router from routes.logging_routes import logging_router -from routes.dashboard_routes import dashboard_router from routes.memory_routes import memory_router -from routes.security_routes import security_router -from routes.credit_routes import credit_router -from routes.demo_routes import demo_router from routes.monitor_routes import monitor_router -from routes.config_routes import config_router -from routes.tools_routes import tools_router -from routes.config_hot_reload_routes import config_hot_reload_router -from routes.vault_routes import vault_router -from routes.analytics_routes import analytics_router -from routes.tier_routes import tier_router -from routes.rate_limit_rule_routes import rate_limit_rule_router +from routes.proto_routes import proto_router from routes.quota_routes import quota_router -from utils.security_settings_util import load_settings, start_auto_save_task, stop_auto_save_task, get_cached_settings -from utils.memory_dump_util import dump_memory_to_file, restore_memory_from_file, find_latest_dump_path -from utils.metrics_util import metrics_store -from utils.database import database -from utils.response_util import process_response +from routes.rate_limit_rule_routes import rate_limit_rule_router +from routes.role_routes import role_router +from routes.routing_routes import routing_router +from routes.security_routes import security_router +from routes.subscription_routes import subscription_router +from routes.tier_routes import tier_router +from routes.tools_routes import tools_router +from routes.user_routes import user_router +from routes.vault_routes import vault_router from utils.audit_util import audit -from utils.ip_policy_util import _get_client_ip as _policy_get_client_ip, _ip_in_list as _policy_ip_in_list, _is_loopback as _policy_is_loopback +from utils.auth_blacklist import purge_expired_tokens +from utils.cache_manager_util import cache_manager +from utils.database import database +from utils.hot_reload_config import hot_config +from utils.ip_policy_util import _get_client_ip as _policy_get_client_ip +from utils.ip_policy_util import _ip_in_list as _policy_ip_in_list +from utils.ip_policy_util import _is_loopback as _policy_is_loopback +from utils.memory_dump_util import ( + dump_memory_to_file, + find_latest_dump_path, + restore_memory_from_file, +) +from utils.metrics_util import metrics_store +from utils.response_util import process_response +from utils.security_settings_util import ( + get_cached_settings, + load_settings, + start_auto_save_task, + stop_auto_save_task, +) load_dotenv() PID_FILE = 'doorman.pid' + def _migrate_generated_directory() -> None: """Migrate legacy root-level 'generated/' into backend-services/generated. @@ -114,7 +129,7 @@ def _migrate_generated_directory() -> None: return if not src.exists() or not src.is_dir(): dst.mkdir(exist_ok=True) - gateway_logger.info(f"Generated dir: {dst} (no migration needed)") + gateway_logger.info(f'Generated dir: {dst} (no migration needed)') return dst.mkdir(parents=True, exist_ok=True) moved_count = 0 @@ -137,37 +152,37 @@ def _migrate_generated_directory() -> None: shutil.rmtree(src) except Exception: pass - gateway_logger.info(f"Generated dir migrated: {moved_count} file(s) moved to {dst}") + gateway_logger.info(f'Generated dir migrated: {moved_count} file(s) moved to {dst}') except Exception as e: try: - gateway_logger.warning(f"Generated dir migration skipped: {e}") + gateway_logger.warning(f'Generated dir migration skipped: {e}') except Exception: pass + async def validate_database_connections(): """Validate database connections on startup with retry logic""" - gateway_logger.info("Validating database connections...") + gateway_logger.info('Validating database connections...') max_retries = 3 for attempt in range(max_retries): try: from utils.database import user_collection + await user_collection.find_one({}) - gateway_logger.info("✓ MongoDB connection verified") + gateway_logger.info('✓ MongoDB connection verified') break except Exception as e: if attempt < max_retries - 1: - wait = 2 ** attempt + wait = 2**attempt gateway_logger.warning( - f"MongoDB connection attempt {attempt + 1}/{max_retries} failed: {e}" + f'MongoDB connection attempt {attempt + 1}/{max_retries} failed: {e}' ) - gateway_logger.info(f"Retrying in {wait} seconds...") + gateway_logger.info(f'Retrying in {wait} seconds...') await asyncio.sleep(wait) else: - gateway_logger.error(f"MongoDB connection failed after {max_retries} attempts") - raise RuntimeError( - f"Cannot connect to MongoDB: {e}" - ) from e + gateway_logger.error(f'MongoDB connection failed after {max_retries} attempts') + raise RuntimeError(f'Cannot connect to MongoDB: {e}') from e redis_host = os.getenv('REDIS_HOST') mem_or_external = os.getenv('MEM_OR_EXTERNAL', 'MEM') @@ -176,32 +191,32 @@ async def validate_database_connections(): for attempt in range(max_retries): try: import redis.asyncio as redis - redis_url = f"redis://{redis_host}:{os.getenv('REDIS_PORT', '6379')}" + + redis_url = f'redis://{redis_host}:{os.getenv("REDIS_PORT", "6379")}' if os.getenv('REDIS_PASSWORD'): - redis_url = f"redis://:{os.getenv('REDIS_PASSWORD')}@{redis_host}:{os.getenv('REDIS_PORT', '6379')}" + redis_url = f'redis://:{os.getenv("REDIS_PASSWORD")}@{redis_host}:{os.getenv("REDIS_PORT", "6379")}' r = redis.from_url(redis_url) await r.ping() await r.close() - gateway_logger.info("✓ Redis connection verified") + gateway_logger.info('✓ Redis connection verified') break except Exception as e: if attempt < max_retries - 1: - wait = 2 ** attempt + wait = 2**attempt gateway_logger.warning( - f"Redis connection attempt {attempt + 1}/{max_retries} failed: {e}" + f'Redis connection attempt {attempt + 1}/{max_retries} failed: {e}' ) - gateway_logger.info(f"Retrying in {wait} seconds...") + gateway_logger.info(f'Retrying in {wait} seconds...') await asyncio.sleep(wait) else: - gateway_logger.error(f"Redis connection failed after {max_retries} attempts") - raise RuntimeError( - f"Cannot connect to Redis: {e}" - ) from e + gateway_logger.error(f'Redis connection failed after {max_retries} attempts') + raise RuntimeError(f'Cannot connect to Redis: {e}') from e - gateway_logger.info("All database connections validated successfully") + gateway_logger.info('All database connections validated successfully') -def validate_token_revocation_config(): + +def validate_token_revocation_config() -> None: """ Validate token revocation is safe for multi-worker deployments. """ @@ -209,21 +224,20 @@ def validate_token_revocation_config(): mem_mode = os.getenv('MEM_OR_EXTERNAL', 'MEM') if threads > 1 and mem_mode == 'MEM': gateway_logger.error( - "CRITICAL: Multi-worker mode (THREADS > 1) with in-memory storage " - "does not provide consistent token revocation across workers. " - f"Current config: THREADS={threads}, MEM_OR_EXTERNAL={mem_mode}" + 'CRITICAL: Multi-worker mode (THREADS > 1) with in-memory storage ' + 'does not provide consistent token revocation across workers. ' + f'Current config: THREADS={threads}, MEM_OR_EXTERNAL={mem_mode}' ) gateway_logger.error( - "Token revocation requires Redis in multi-worker mode. " - "Either set MEM_OR_EXTERNAL=REDIS or set THREADS=1" + 'Token revocation requires Redis in multi-worker mode. ' + 'Either set MEM_OR_EXTERNAL=REDIS or set THREADS=1' ) raise RuntimeError( - "Token revocation requires Redis in multi-worker mode (THREADS > 1). " - "Set MEM_OR_EXTERNAL=REDIS or THREADS=1" + 'Token revocation requires Redis in multi-worker mode (THREADS > 1). ' + 'Set MEM_OR_EXTERNAL=REDIS or THREADS=1' ) - gateway_logger.info( - f"Token revocation mode: {mem_mode} with {threads} worker(s)" - ) + gateway_logger.info(f'Token revocation mode: {mem_mode} with {threads} worker(s)') + @asynccontextmanager async def app_lifespan(app: FastAPI): @@ -250,7 +264,12 @@ async def app_lifespan(app: FastAPI): ) jwt_secret = os.getenv('JWT_SECRET_KEY', '') - if jwt_secret in ('please-change-me', 'test-secret-key', 'test-secret-key-please-change', ''): + if jwt_secret in ( + 'please-change-me', + 'test-secret-key', + 'test-secret-key-please-change', + '', + ): raise RuntimeError( 'In production (ENV=production), JWT_SECRET_KEY must be changed from default value. ' 'Generate a strong random secret (32+ characters).' @@ -294,7 +313,7 @@ async def app_lifespan(app: FastAPI): 'Memory dumps contain sensitive data and must be encrypted. ' 'Generate a strong random key: openssl rand -hex 32' ) - except Exception as e: + except Exception: raise mem_or_external = os.getenv('MEM_OR_EXTERNAL', 'MEM').upper() @@ -337,6 +356,7 @@ async def app_lifespan(app: FastAPI): break except Exception: pass + try: app.state._metrics_save_task = asyncio.create_task(_metrics_autosave(60)) except Exception: @@ -350,8 +370,12 @@ async def app_lifespan(app: FastAPI): try: settings = get_cached_settings() - if bool(settings.get('trust_x_forwarded_for')) and not (settings.get('xff_trusted_proxies') or []): - gateway_logger.warning('Security: trust_x_forwarded_for enabled but xff_trusted_proxies is empty; header spoofing risk. Configure trusted proxy IPs/CIDRs.') + if bool(settings.get('trust_x_forwarded_for')) and not ( + settings.get('xff_trusted_proxies') or [] + ): + gateway_logger.warning( + 'Security: trust_x_forwarded_for enabled but xff_trusted_proxies is empty; header spoofing risk. Configure trusted proxy IPs/CIDRs.' + ) if os.getenv('ENV', '').lower() == 'production': raise RuntimeError( @@ -364,7 +388,7 @@ async def app_lifespan(app: FastAPI): gateway_logger.debug(f'Startup security checks skipped: {e}') try: - spec = app.openapi() + app.openapi() problems = [] for route in app.routes: path = getattr(route, 'path', '') @@ -390,7 +414,9 @@ async def app_lifespan(app: FastAPI): latest_path = find_latest_dump_path(hint) if latest_path and os.path.exists(latest_path): info = restore_memory_from_file(latest_path) - gateway_logger.info(f"Memory mode: restored from dump {latest_path} (created_at={info.get('created_at')})") + gateway_logger.info( + f'Memory mode: restored from dump {latest_path} (created_at={info.get("created_at")})' + ) else: gateway_logger.info('Memory mode: no existing dump found to restore') except Exception as e: @@ -406,7 +432,9 @@ async def app_lifespan(app: FastAPI): gateway_logger.info('SIGUSR1 ignored: not in memory-only mode') return if not os.getenv('MEM_ENCRYPTION_KEY'): - gateway_logger.error('SIGUSR1 dump skipped: MEM_ENCRYPTION_KEY not configured') + gateway_logger.error( + 'SIGUSR1 dump skipped: MEM_ENCRYPTION_KEY not configured' + ) return settings = get_cached_settings() path_hint = settings.get('dump_path') @@ -418,7 +446,6 @@ async def app_lifespan(app: FastAPI): loop.add_signal_handler(signal.SIGUSR1, lambda: asyncio.create_task(_sigusr1_dump())) gateway_logger.info('SIGUSR1 handler registered for on-demand memory dumps') except NotImplementedError: - pass try: @@ -451,9 +478,9 @@ async def app_lifespan(app: FastAPI): try: yield finally: - gateway_logger.info("Starting graceful shutdown...") + gateway_logger.info('Starting graceful shutdown...') app.state.shutting_down = True - gateway_logger.info("Waiting for in-flight requests to complete (5s grace period)...") + gateway_logger.info('Waiting for in-flight requests to complete (5s grace period)...') await asyncio.sleep(5) try: await stop_auto_save_task() @@ -474,19 +501,21 @@ async def app_lifespan(app: FastAPI): except Exception: pass try: - gateway_logger.info("Closing database connections...") + gateway_logger.info('Closing database connections...') from utils.database import close_database_connections + close_database_connections() except Exception as e: - gateway_logger.error(f"Error closing database connections: {e}") + gateway_logger.error(f'Error closing database connections: {e}') try: - gateway_logger.info("Closing HTTP clients...") + gateway_logger.info('Closing HTTP clients...') from services.gateway_service import GatewayService + if hasattr(GatewayService, '_http_client') and GatewayService._http_client: await GatewayService._http_client.aclose() - gateway_logger.info("HTTP client closed") + gateway_logger.info('HTTP client closed') except Exception as e: - gateway_logger.error(f"Error closing HTTP client: {e}") + gateway_logger.error(f'Error closing HTTP client: {e}') try: METRICS_FILE = os.path.join(LOGS_DIR, 'metrics.json') @@ -494,7 +523,7 @@ async def app_lifespan(app: FastAPI): except Exception: pass - gateway_logger.info("Graceful shutdown complete") + gateway_logger.info('Graceful shutdown complete') try: t = getattr(app.state, '_metrics_save_task', None) if t: @@ -504,9 +533,11 @@ async def app_lifespan(app: FastAPI): try: from services.gateway_service import GatewayService as _GS + if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'true').lower() != 'false': try: import asyncio as _asyncio + if _asyncio.iscoroutinefunction(_GS.aclose_http_client): await _GS.aclose_http_client() except Exception: @@ -514,15 +545,17 @@ async def app_lifespan(app: FastAPI): except Exception: pass -def _generate_unique_id(route): + +def _generate_unique_id(route: dict) -> str: try: name = getattr(route, 'name', 'op') or 'op' path = getattr(route, 'path', '').replace('/', '_').replace('{', '').replace('}', '') methods = '_'.join(sorted(list(getattr(route, 'methods', []) or []))) - return f"{name}_{methods}_{path}".lower() + return f'{name}_{methods}_{path}'.lower() except Exception: return (getattr(route, 'name', 'op') or 'op').lower() + doorman = FastAPI( title='doorman', description="A lightweight API gateway for AI, REST, SOAP, GraphQL, gRPC, and WebSocket APIs — fully managed with built-in RESTful APIs for configuration and control. This is your application's gateway to the world.", @@ -534,13 +567,13 @@ doorman = FastAPI( # Add CORS middleware # Starlette CORS middleware is disabled by default because platform and per-API # CORS are enforced explicitly below. Enable only if requested via env. -if os.getenv('ENABLE_STARLETTE_CORS', 'false').lower() in ('1','true','yes','on'): +if os.getenv('ENABLE_STARLETTE_CORS', 'false').lower() in ('1', 'true', 'yes', 'on'): doorman.add_middleware( CORSMiddleware, - allow_origins=["*"], # In production, replace with specific origins + allow_origins=['*'], # In production, replace with specific origins allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], # This will include X-Requested-With + allow_methods=['*'], + allow_headers=['*'], # This will include X-Requested-With expose_headers=[], max_age=600, ) @@ -550,7 +583,8 @@ domain = os.getenv('COOKIE_DOMAIN', 'localhost') # - API gateway routes (/api/*): CORS controlled per-API in gateway routes/services -def _platform_cors_config(): + +def _platform_cors_config() -> dict: """Compute platform CORS config from environment. Env vars: @@ -561,16 +595,35 @@ def _platform_cors_config(): - CORS_STRICT: true/false (when true, do not echo wildcard origins with credentials) """ import os as _os - strict = _os.getenv('CORS_STRICT', 'false').lower() in ('1','true','yes','on') - allowed_origins = [o.strip() for o in (_os.getenv('ALLOWED_ORIGINS') or '').split(',') if o.strip()] or ['*'] - allow_methods = [m.strip() for m in (_os.getenv('ALLOW_METHODS') or 'GET,POST,PUT,DELETE,OPTIONS,PATCH,HEAD').split(',') if m.strip()] + + strict = _os.getenv('CORS_STRICT', 'false').lower() in ('1', 'true', 'yes', 'on') + allowed_origins = [ + o.strip() for o in (_os.getenv('ALLOWED_ORIGINS') or '').split(',') if o.strip() + ] or ['*'] + allow_methods = [ + m.strip() + for m in (_os.getenv('ALLOW_METHODS') or 'GET,POST,PUT,DELETE,OPTIONS,PATCH,HEAD').split( + ',' + ) + if m.strip() + ] allow_headers_env = _os.getenv('ALLOW_HEADERS') or '' if allow_headers_env.strip() == '*': # Default to a known, minimal safe list when wildcard requested - allow_headers = ['Accept','Content-Type','X-CSRF-Token','Authorization'] + allow_headers = ['Accept', 'Content-Type', 'X-CSRF-Token', 'Authorization'] else: - allow_headers = [h.strip() for h in allow_headers_env.split(',') if h.strip()] or ['Accept','Content-Type','X-CSRF-Token','Authorization'] - allow_credentials = _os.getenv('ALLOW_CREDENTIALS', 'false').lower() in ('1','true','yes','on') + allow_headers = [h.strip() for h in allow_headers_env.split(',') if h.strip()] or [ + 'Accept', + 'Content-Type', + 'X-CSRF-Token', + 'Authorization', + ] + allow_credentials = _os.getenv('ALLOW_CREDENTIALS', 'false').lower() in ( + '1', + 'true', + 'yes', + 'on', + ) return { 'strict': strict, @@ -596,10 +649,10 @@ async def platform_cors(request: Request, call_next): try: lo = origin.lower() origin_allowed = ( - lo.startswith('http://localhost') or - lo.startswith('https://localhost') or - lo.startswith('http://127.0.0.1') or - lo.startswith('https://127.0.0.1') + lo.startswith('http://localhost') + or lo.startswith('https://localhost') + or lo.startswith('http://127.0.0.1') + or lo.startswith('https://127.0.0.1') ) except Exception: origin_allowed = False @@ -610,6 +663,7 @@ async def platform_cors(request: Request, call_next): if request.method.upper() == 'OPTIONS': from fastapi.responses import Response as _Resp + headers = {} if origin_allowed: headers['Access-Control-Allow-Origin'] = origin @@ -641,8 +695,10 @@ async def platform_cors(request: Request, call_next): pass return await call_next(request) + MAX_BODY_SIZE = int(os.getenv('MAX_BODY_SIZE_BYTES', 1_048_576)) + def _get_max_body_size() -> int: try: v = os.getenv('MAX_BODY_SIZE_BYTES') @@ -652,6 +708,7 @@ def _get_max_body_size() -> int: except Exception: return MAX_BODY_SIZE + class LimitedStreamReader: """ Wrapper around ASGI receive channel that enforces size limits on chunked requests. @@ -659,6 +716,7 @@ class LimitedStreamReader: Prevents Transfer-Encoding: chunked bypass by tracking accumulated size and rejecting streams that exceed the limit. """ + def __init__(self, receive, max_size: int): self.receive = receive self.max_size = max_size @@ -681,6 +739,7 @@ class LimitedStreamReader: return message + @doorman.middleware('http') async def body_size_limit(request: Request, call_next): """Enforce request body size limits to prevent DoS attacks. @@ -700,7 +759,7 @@ async def body_size_limit(request: Request, call_next): - /api/grpc/*: Enforce on gRPC JSON payloads """ try: - if os.getenv('DISABLE_BODY_SIZE_LIMIT', 'false').lower() in ('1','true','yes','on'): + if os.getenv('DISABLE_BODY_SIZE_LIMIT', 'false').lower() in ('1', 'true', 'yes', 'on'): return await call_next(request) path = str(request.url.path) @@ -708,7 +767,9 @@ async def body_size_limit(request: Request, call_next): raw_excludes = os.getenv('BODY_LIMIT_EXCLUDE_PATHS', '') if raw_excludes: excludes = [p.strip() for p in raw_excludes.split(',') if p.strip()] - if any(path == p or (p.endswith('*') and path.startswith(p[:-1])) for p in excludes): + if any( + path == p or (p.endswith('*') and path.startswith(p[:-1])) for p in excludes + ): return await call_next(request) except Exception: pass @@ -725,10 +786,13 @@ async def body_size_limit(request: Request, call_next): try: from models.response_model import ResponseModel as _RM from utils.response_util import process_response as _pr - return _pr(_RM( - status_code=200, - message='Settings updated (middleware bypass)' - ).dict(), 'rest') + + return _pr( + _RM( + status_code=200, message='Settings updated (middleware bypass)' + ).dict(), + 'rest', + ) except Exception: pass raise @@ -768,6 +832,7 @@ async def body_size_limit(request: Request, call_next): if content_length > limit: try: from utils.audit_util import audit + audit( request, actor=None, @@ -778,17 +843,20 @@ async def body_size_limit(request: Request, call_next): 'content_length': content_length, 'limit': limit, 'content_type': request.headers.get('content-type'), - 'transfer_encoding': transfer_encoding or None - } + 'transfer_encoding': transfer_encoding or None, + }, ) except Exception: pass - return process_response(ResponseModel( - status_code=413, - error_code='REQ001', - error_message=f'Request entity too large (max: {limit} bytes)' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=413, + error_code='REQ001', + error_message=f'Request entity too large (max: {limit} bytes)', + ).dict(), + 'rest', + ) except (ValueError, TypeError): pass @@ -798,7 +866,7 @@ async def body_size_limit(request: Request, call_next): try: env_flag = os.getenv('DISABLE_PLATFORM_CHUNKED_WRAP') if isinstance(env_flag, str) and env_flag.strip() != '': - if env_flag.strip().lower() in ('1','true','yes','on'): + if env_flag.strip().lower() in ('1', 'true', 'yes', 'on'): wrap_allowed = False if str(path) == '/platform/authorization': wrap_allowed = True @@ -814,9 +882,12 @@ async def body_size_limit(request: Request, call_next): response = await call_next(request) try: - if wrap_allowed and (limited_reader.over_limit or limited_reader.bytes_received > limit): + if wrap_allowed and ( + limited_reader.over_limit or limited_reader.bytes_received > limit + ): try: from utils.audit_util import audit + audit( request, actor=None, @@ -827,29 +898,37 @@ async def body_size_limit(request: Request, call_next): 'bytes_received': limited_reader.bytes_received, 'limit': limit, 'content_type': request.headers.get('content-type'), - 'transfer_encoding': transfer_encoding or 'chunked' - } + 'transfer_encoding': transfer_encoding or 'chunked', + }, ) except Exception: pass - return process_response(ResponseModel( - status_code=413, - error_code='REQ001', - error_message=f'Request entity too large (max: {limit} bytes)' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=413, + error_code='REQ001', + error_message=f'Request entity too large (max: {limit} bytes)', + ).dict(), + 'rest', + ) except Exception: pass return response - except Exception as e: + except Exception: try: - if wrap_allowed and (limited_reader.over_limit or limited_reader.bytes_received > limit): - return process_response(ResponseModel( - status_code=413, - error_code='REQ001', - error_message=f'Request entity too large (max: {limit} bytes)' - ).dict(), 'rest') + if wrap_allowed and ( + limited_reader.over_limit or limited_reader.bytes_received > limit + ): + return process_response( + ResponseModel( + status_code=413, + error_code='REQ001', + error_message=f'Request entity too large (max: {limit} bytes)', + ).dict(), + 'rest', + ) except Exception: pass raise @@ -873,6 +952,7 @@ async def body_size_limit(request: Request, call_next): else: try: import anyio + if isinstance(e, getattr(anyio, 'EndOfStream', tuple())): swallow = True except Exception: @@ -882,16 +962,20 @@ async def body_size_limit(request: Request, call_next): if swallow and _RM and _pr: try: - return _pr(_RM( - status_code=500, - error_code='GTW998', - error_message='Upstream handler failed to produce a response' - ).dict(), 'rest') + return _pr( + _RM( + status_code=500, + error_code='GTW998', + error_message='Upstream handler failed to produce a response', + ).dict(), + 'rest', + ) except Exception: pass raise + class PlatformCORSMiddleware: """ASGI-level CORS for /platform/* routes only. @@ -899,13 +983,19 @@ class PlatformCORSMiddleware: interfere with /api/* paths. It also respects DISABLE_PLATFORM_CORS_ASGI: when set to true, this middleware becomes a no-op. """ + def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): # If explicitly disabled, act as a passthrough try: - if os.getenv('DISABLE_PLATFORM_CORS_ASGI', 'false').lower() in ('1','true','yes','on'): + if os.getenv('DISABLE_PLATFORM_CORS_ASGI', 'false').lower() in ( + '1', + 'true', + 'yes', + 'on', + ): return await self.app(scope, receive, send) except Exception: pass @@ -921,7 +1011,7 @@ class PlatformCORSMiddleware: cfg = _platform_cors_config() hdrs = {} try: - for k, v in (scope.get('headers') or []): + for k, v in scope.get('headers') or []: hdrs[k.decode('latin1').lower()] = v.decode('latin1') except Exception: pass @@ -934,10 +1024,10 @@ class PlatformCORSMiddleware: if cfg['strict'] and cfg['credentials']: lo = origin.lower() origin_allowed = ( - lo.startswith('http://localhost') or - lo.startswith('https://localhost') or - lo.startswith('http://127.0.0.1') or - lo.startswith('https://127.0.0.1') + lo.startswith('http://localhost') + or lo.startswith('https://localhost') + or lo.startswith('http://127.0.0.1') + or lo.startswith('https://127.0.0.1') ) else: origin_allowed = True @@ -949,8 +1039,12 @@ class PlatformCORSMiddleware: if origin_allowed and origin: headers.append((b'access-control-allow-origin', origin.encode('latin1'))) headers.append((b'vary', b'Origin')) - headers.append((b'access-control-allow-methods', ', '.join(cfg['methods']).encode('latin1'))) - headers.append((b'access-control-allow-headers', ', '.join(cfg['headers']).encode('latin1'))) + headers.append( + (b'access-control-allow-methods', ', '.join(cfg['methods']).encode('latin1')) + ) + headers.append( + (b'access-control-allow-headers', ', '.join(cfg['headers']).encode('latin1')) + ) if cfg['credentials']: headers.append((b'access-control-allow-credentials', b'true')) rid = hdrs.get('x-request-id') @@ -969,15 +1063,20 @@ class PlatformCORSMiddleware: except Exception: return await self.app(scope, receive, send) + doorman.add_middleware(PlatformCORSMiddleware) # Add tier-based rate limiting middleware try: from middleware.tier_rate_limit_middleware import TierRateLimitMiddleware + doorman.add_middleware(TierRateLimitMiddleware) - logging.getLogger('doorman.gateway').info("Tier-based rate limiting middleware enabled") + logging.getLogger('doorman.gateway').info('Tier-based rate limiting middleware enabled') except Exception as e: - logging.getLogger('doorman.gateway').warning(f"Failed to enable tier rate limiting middleware: {e}") + logging.getLogger('doorman.gateway').warning( + f'Failed to enable tier rate limiting middleware: {e}' + ) + @doorman.middleware('http') async def request_id_middleware(request: Request, call_next): @@ -997,6 +1096,7 @@ async def request_id_middleware(request: Request, call_next): try: from utils.correlation_util import set_correlation_id + set_correlation_id(rid) except Exception: pass @@ -1005,7 +1105,9 @@ async def request_id_middleware(request: Request, call_next): trust_xff = bool(settings.get('trust_x_forwarded_for')) direct_ip = getattr(getattr(request, 'client', None), 'host', None) effective_ip = _policy_get_client_ip(request, trust_xff) - gateway_logger.info(f"{rid} | Entry: client_ip={direct_ip} effective_ip={effective_ip} method={request.method} path={str(request.url.path)}") + gateway_logger.info( + f'{rid} | Entry: client_ip={direct_ip} effective_ip={effective_ip} method={request.method} path={str(request.url.path)}' + ) except Exception: pass response = await call_next(request) @@ -1019,6 +1121,7 @@ async def request_id_middleware(request: Request, call_next): gateway_logger.error(f'Request ID middleware error: {str(e)}', exc_info=True) raise + @doorman.middleware('http') async def security_headers(request: Request, call_next): response = await call_next(request) @@ -1026,29 +1129,33 @@ async def security_headers(request: Request, call_next): response.headers.setdefault('X-Content-Type-Options', 'nosniff') response.headers.setdefault('X-Frame-Options', 'DENY') response.headers.setdefault('Referrer-Policy', 'no-referrer') - response.headers.setdefault('Permissions-Policy', 'geolocation=(), microphone=(), camera=()') + response.headers.setdefault( + 'Permissions-Policy', 'geolocation=(), microphone=(), camera=()' + ) try: csp = os.getenv('CONTENT_SECURITY_POLICY') if csp is None or not csp.strip(): - - csp =\ - "default-src 'none'; "\ - "frame-ancestors 'none'; "\ - "base-uri 'none'; "\ - "form-action 'self'; "\ - "img-src 'self' data:; "\ + csp = ( + "default-src 'none'; " + "frame-ancestors 'none'; " + "base-uri 'none'; " + "form-action 'self'; " + "img-src 'self' data:; " "connect-src 'self';" + ) response.headers.setdefault('Content-Security-Policy', csp) except Exception: pass if os.getenv('HTTPS_ONLY', 'false').lower() == 'true': - - response.headers.setdefault('Strict-Transport-Security', 'max-age=15552000; includeSubDomains; preload') + response.headers.setdefault( + 'Strict-Transport-Security', 'max-age=15552000; includeSubDomains; preload' + ) except Exception: pass return response + """Logging configuration Prefer file logging to LOGS_DIR/doorman.log when writable; otherwise, fall back @@ -1058,7 +1165,10 @@ Respects LOG_FORMAT=json|plain. BASE_DIR = os.path.dirname(os.path.abspath(__file__)) _env_logs_dir = os.getenv('LOGS_DIR') -LOGS_DIR = os.path.abspath(_env_logs_dir) if _env_logs_dir else os.path.join(BASE_DIR, 'platform-logs') +LOGS_DIR = ( + os.path.abspath(_env_logs_dir) if _env_logs_dir else os.path.join(BASE_DIR, 'platform-logs') +) + # Build formatters class JSONFormatter(logging.Formatter): @@ -1074,6 +1184,7 @@ class JSONFormatter(logging.Formatter): except Exception: return f'{payload}' + _fmt_is_json = os.getenv('LOG_FORMAT', 'plain').lower() == 'json' _file_handler = None try: @@ -1082,21 +1193,29 @@ try: filename=os.path.join(LOGS_DIR, 'doorman.log'), maxBytes=10 * 1024 * 1024, backupCount=5, - encoding='utf-8' + encoding='utf-8', + ) + _file_handler.setFormatter( + JSONFormatter() + if _fmt_is_json + else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') ) - _file_handler.setFormatter(JSONFormatter() if _fmt_is_json else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) except Exception as _e: - logging.getLogger('doorman.gateway').warning(f'File logging disabled ({_e}); using console logging only') + logging.getLogger('doorman.gateway').warning( + f'File logging disabled ({_e}); using console logging only' + ) _file_handler = None + # Configure all doorman loggers to use the same handler and prevent propagation -def configure_logger(logger_name): +def configure_logger(logger_name: str) -> logging.Logger: logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) logger.propagate = False for handler in logger.handlers[:]: logger.removeHandler(handler) + class RedactFilter(logging.Filter): """Comprehensive logging redaction filter for sensitive data. @@ -1111,30 +1230,25 @@ def configure_logger(logger_name): PATTERNS = [ re.compile(r'(?i)(authorization\s*[:=]\s*)([^;\r\n]+)'), - re.compile(r'(?i)(x-api-key\s*[:=]\s*)([^;\r\n]+)'), re.compile(r'(?i)(api[_-]?key\s*[:=]\s*)([^;\r\n]+)'), re.compile(r'(?i)(api[_-]?secret\s*[:=]\s*)([^;\r\n]+)'), - re.compile(r'(?i)(access[_-]?token\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'), re.compile(r'(?i)(refresh[_-]?token\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'), re.compile(r'(?i)(token\s*["\']?\s*[:=]\s*["\']?)([a-zA-Z0-9_\-\.]{20,})(["\']?)'), - re.compile(r'(?i)(password\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n]+)(["\']?)'), re.compile(r'(?i)(secret\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'), re.compile(r'(?i)(client[_-]?secret\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'), - re.compile(r'(?i)(cookie\s*[:=]\s*)([^;\r\n]+)'), re.compile(r'(?i)(set-cookie\s*[:=]\s*)([^;\r\n]+)'), - re.compile(r'(?i)(x-csrf-token\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'), re.compile(r'(?i)(csrf[_-]?token\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'), - re.compile(r'\b(eyJ[a-zA-Z0-9_\-]+\.eyJ[a-zA-Z0-9_\-]+\.[a-zA-Z0-9_\-]+)\b'), - re.compile(r'(?i)(session[_-]?id\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'), - - re.compile(r'(-----BEGIN[A-Z\s]+PRIVATE KEY-----)(.*?)(-----END[A-Z\s]+PRIVATE KEY-----)', re.DOTALL), + re.compile( + r'(-----BEGIN[A-Z\s]+PRIVATE KEY-----)(.*?)(-----END[A-Z\s]+PRIVATE KEY-----)', + re.DOTALL, + ), ] def filter(self, record: logging.LogRecord) -> bool: @@ -1146,11 +1260,14 @@ def configure_logger(logger_name): if pat.groups == 3 and pat.flags & re.DOTALL: red = pat.sub(r'\1[REDACTED]\3', red) elif pat.groups >= 2: - red = pat.sub(lambda m: ( - m.group(1) + - '[REDACTED]' + - (m.group(3) if m.lastindex and m.lastindex >= 3 else '') - ), red) + red = pat.sub( + lambda m: ( + m.group(1) + + '[REDACTED]' + + (m.group(3) if m.lastindex and m.lastindex >= 3 else '') + ), + red, + ) else: red = pat.sub('[REDACTED]', red) @@ -1159,7 +1276,15 @@ def configure_logger(logger_name): if hasattr(record, 'args') and record.args: try: if isinstance(record.args, dict): - record.args = {k: '[REDACTED]' if 'token' in str(k).lower() or 'password' in str(k).lower() or 'secret' in str(k).lower() or 'authorization' in str(k).lower() else v for k, v in record.args.items()} + record.args = { + k: '[REDACTED]' + if 'token' in str(k).lower() + or 'password' in str(k).lower() + or 'secret' in str(k).lower() + or 'authorization' in str(k).lower() + else v + for k, v in record.args.items() + } except Exception: pass except Exception: @@ -1168,17 +1293,23 @@ def configure_logger(logger_name): console = logging.StreamHandler(stream=sys.stdout) console.setLevel(logging.INFO) - console.setFormatter(JSONFormatter() if _fmt_is_json else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + console.setFormatter( + JSONFormatter() + if _fmt_is_json + else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + ) console.addFilter(RedactFilter()) logger.addHandler(console) if _file_handler is not None: - - if not any(isinstance(f, logging.Filter) and hasattr(f, 'PATTERNS') for f in _file_handler.filters): + if not any( + isinstance(f, logging.Filter) and hasattr(f, 'PATTERNS') for f in _file_handler.filters + ): _file_handler.addFilter(RedactFilter()) logger.addHandler(_file_handler) return logger + gateway_logger = configure_logger('doorman.gateway') logging_logger = configure_logger('doorman.logging') @@ -1198,9 +1329,7 @@ try: compression_level = 1 doorman.add_middleware( - GZipMiddleware, - minimum_size=compression_minimum_size, - compresslevel=compression_level + GZipMiddleware, minimum_size=compression_minimum_size, compresslevel=compression_level ) gateway_logger.info( f'Response compression enabled: level={compression_level}, ' @@ -1211,6 +1340,7 @@ try: except Exception as e: gateway_logger.warning(f'Failed to configure response compression: {e}. Compression disabled.') + # Ensure platform responses set Vary=Origin (and not Accept-Encoding) for CORS tests. class _VaryOriginMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): @@ -1228,6 +1358,7 @@ class _VaryOriginMiddleware(BaseHTTPMiddleware): pass return response + doorman.add_middleware(_VaryOriginMiddleware) # Now that logging is configured, attempt to migrate any legacy 'generated/' dir @@ -1248,9 +1379,13 @@ try: filename=os.path.join(LOGS_DIR, 'doorman-trail.log'), maxBytes=10 * 1024 * 1024, backupCount=5, - encoding='utf-8' + encoding='utf-8', + ) + _audit_file.setFormatter( + JSONFormatter() + if _fmt_is_json + else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') ) - _audit_file.setFormatter(JSONFormatter() if _fmt_is_json else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) try: for eh in gateway_logger.handlers: for f in getattr(eh, 'filters', []): @@ -1259,10 +1394,13 @@ try: pass audit_logger.addHandler(_audit_file) except Exception as _e: - console = logging.StreamHandler(stream=sys.stdout) console.setLevel(logging.INFO) - console.setFormatter(JSONFormatter() if _fmt_is_json else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + console.setFormatter( + JSONFormatter() + if _fmt_is_json + else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + ) try: for eh in gateway_logger.handlers: for f in getattr(eh, 'filters', []): @@ -1271,12 +1409,18 @@ except Exception as _e: pass audit_logger.addHandler(console) + class Settings(BaseSettings): mongo_db_uri: str = os.getenv('MONGO_DB_URI') jwt_secret_key: str = os.getenv('JWT_SECRET_KEY') jwt_algorithm: str = 'HS256' - jwt_access_token_expires: timedelta = timedelta(minutes=int(os.getenv('ACCESS_TOKEN_EXPIRES_MINUTES', 15))) - jwt_refresh_token_expires: timedelta = timedelta(days=int(os.getenv('REFRESH_TOKEN_EXPIRES_DAYS', 30))) + jwt_access_token_expires: timedelta = timedelta( + minutes=int(os.getenv('ACCESS_TOKEN_EXPIRES_MINUTES', 15)) + ) + jwt_refresh_token_expires: timedelta = timedelta( + days=int(os.getenv('REFRESH_TOKEN_EXPIRES_DAYS', 30)) + ) + @doorman.middleware('http') async def ip_filter_middleware(request: Request, call_next): @@ -1294,12 +1438,29 @@ async def ip_filter_middleware(request: Request, call_next): try: import os + settings = get_cached_settings() env_flag = os.getenv('LOCAL_HOST_IP_BYPASS') - allow_local = (env_flag.lower() == 'true') if isinstance(env_flag, str) and env_flag.strip() != '' else bool(settings.get('allow_localhost_bypass')) + allow_local = ( + (env_flag.lower() == 'true') + if isinstance(env_flag, str) and env_flag.strip() != '' + else bool(settings.get('allow_localhost_bypass')) + ) if allow_local: direct_ip = getattr(getattr(request, 'client', None), 'host', None) - has_forward = any(request.headers.get(h) for h in ('x-forwarded-for','X-Forwarded-For','x-real-ip','X-Real-IP','cf-connecting-ip','CF-Connecting-IP','forwarded','Forwarded')) + has_forward = any( + request.headers.get(h) + for h in ( + 'x-forwarded-for', + 'X-Forwarded-For', + 'x-real-ip', + 'X-Real-IP', + 'cf-connecting-ip', + 'CF-Connecting-IP', + 'forwarded', + 'Forwarded', + ) + ) if direct_ip and _policy_is_loopback(direct_ip) and not has_forward: return await call_next(request) except Exception: @@ -1308,37 +1469,77 @@ async def ip_filter_middleware(request: Request, call_next): if client_ip: if wl and not _policy_ip_in_list(client_ip, wl): try: - audit(request, actor=None, action='ip.global_deny', target=client_ip, status='blocked', details={'reason': 'not_in_whitelist', 'xff': xff_hdr, 'source_ip': getattr(getattr(request, 'client', None), 'host', None)}) + audit( + request, + actor=None, + action='ip.global_deny', + target=client_ip, + status='blocked', + details={ + 'reason': 'not_in_whitelist', + 'xff': xff_hdr, + 'source_ip': getattr(getattr(request, 'client', None), 'host', None), + }, + ) except Exception: pass from fastapi.responses import JSONResponse - return JSONResponse(status_code=403, content={'status_code': 403, 'error_code': 'SEC010', 'error_message': 'IP not allowed'}) + + return JSONResponse( + status_code=403, + content={ + 'status_code': 403, + 'error_code': 'SEC010', + 'error_message': 'IP not allowed', + }, + ) if bl and _policy_ip_in_list(client_ip, bl): try: - audit(request, actor=None, action='ip.global_deny', target=client_ip, status='blocked', details={'reason': 'blacklisted', 'xff': xff_hdr, 'source_ip': getattr(getattr(request, 'client', None), 'host', None)}) + audit( + request, + actor=None, + action='ip.global_deny', + target=client_ip, + status='blocked', + details={ + 'reason': 'blacklisted', + 'xff': xff_hdr, + 'source_ip': getattr(getattr(request, 'client', None), 'host', None), + }, + ) except Exception: pass from fastapi.responses import JSONResponse - return JSONResponse(status_code=403, content={'status_code': 403, 'error_code': 'SEC011', 'error_message': 'IP blocked'}) + + return JSONResponse( + status_code=403, + content={ + 'status_code': 403, + 'error_code': 'SEC011', + 'error_message': 'IP blocked', + }, + ) except Exception: pass return await call_next(request) + @doorman.middleware('http') async def metrics_middleware(request: Request, call_next): start = asyncio.get_event_loop().time() + def _parse_len(val: str | None) -> int: try: return int(val) if val is not None else 0 except Exception: return 0 + bytes_in = _parse_len(request.headers.get('content-length')) response = None try: response = await call_next(request) return response finally: - try: if str(request.url.path).startswith('/api/'): end = asyncio.get_event_loop().time() @@ -1349,6 +1550,7 @@ async def metrics_middleware(request: Request, call_next): try: from utils.auth_util import auth_required as _auth_required + payload = await _auth_required(request) username = payload.get('sub') if isinstance(payload, dict) else None except Exception: @@ -1359,11 +1561,14 @@ async def metrics_middleware(request: Request, call_next): parts = p.split('/') try: idx = parts.index('rest') - api_key = f'rest:{parts[idx+1]}' if len(parts) > idx+1 and parts[idx+1] else 'rest:unknown' + api_key = ( + f'rest:{parts[idx + 1]}' + if len(parts) > idx + 1 and parts[idx + 1] + else 'rest:unknown' + ) except ValueError: api_key = 'rest:unknown' elif p.startswith('/api/graphql/'): - seg = p.rsplit('/', 1)[-1] or 'unknown' api_key = f'graphql:{seg}' elif p.startswith('/api/soap/'): @@ -1383,13 +1588,29 @@ async def metrics_middleware(request: Request, call_next): except Exception: clen = 0 - metrics_store.record(status=status, duration_ms=duration_ms, username=username, api_key=api_key, bytes_in=bytes_in, bytes_out=clen) + metrics_store.record( + status=status, + duration_ms=duration_ms, + username=username, + api_key=api_key, + bytes_in=bytes_in, + bytes_out=clen, + ) try: if username: - from utils.bandwidth_util import add_usage, _get_user + from utils.bandwidth_util import _get_user, add_usage + u = _get_user(username) - if u and u.get('bandwidth_limit_bytes') and u.get('bandwidth_limit_enabled') is not False: - add_usage(username, int(bytes_in) + int(clen), u.get('bandwidth_limit_window') or 'day') + if ( + u + and u.get('bandwidth_limit_bytes') + and u.get('bandwidth_limit_enabled') is not False + ): + add_usage( + username, + int(bytes_in) + int(clen), + u.get('bandwidth_limit_window') or 'day', + ) except Exception: pass try: @@ -1405,45 +1626,51 @@ async def metrics_middleware(request: Request, call_next): except Exception: pass except Exception: - pass + async def automatic_purger(interval_seconds): while True: await asyncio.sleep(interval_seconds) await purge_expired_tokens() gateway_logger.info('Expired JWTs purged from blacklist.') + @doorman.exception_handler(JWTError) async def jwt_exception_handler(exc: JWTError): - return process_response(ResponseModel( - status_code=401, - error_code='JWT001', - error_message='Invalid token' - ).dict(), 'rest') + return process_response( + ResponseModel(status_code=401, error_code='JWT001', error_message='Invalid token').dict(), + 'rest', + ) + @doorman.exception_handler(500) async def internal_server_error_handler(request: Request, exc: Exception): - return process_response(ResponseModel( - status_code=500, - error_code='ISE001', - error_message='Internal Server Error' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, error_code='ISE001', error_message='Internal Server Error' + ).dict(), + 'rest', + ) + @doorman.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): # DEBUG: Log validation errors import logging + log = logging.getLogger('doorman.gateway') - log.error(f"Validation error on {request.method} {request.url.path}") - log.error(f"Validation errors: {exc.errors()}") - log.error(f"Request body: {await request.body()}") - - return process_response(ResponseModel( - status_code=422, - error_code='VAL001', - error_message='Validation Error' - ).dict(), 'rest') + log.error(f'Validation error on {request.method} {request.url.path}') + log.error(f'Validation errors: {exc.errors()}') + log.error(f'Request body: {await request.body()}') + + return process_response( + ResponseModel( + status_code=422, error_code='VAL001', error_message='Validation Error' + ).dict(), + 'rest', + ) + cache_manager.init_app(doorman) @@ -1473,29 +1700,35 @@ doorman.include_router(tier_router, prefix='/platform/tiers', tags=['Tiers']) doorman.include_router(rate_limit_rule_router, prefix='/platform/rate-limits', tags=['Rate Limits']) doorman.include_router(quota_router, prefix='/platform/quota', tags=['Quota']) -def start(): + +def start() -> None: if os.path.exists(PID_FILE): gateway_logger.info('doorman is already running!') sys.exit(0) if os.name == 'nt': - process = subprocess.Popen([sys.executable, __file__, 'run'], - creationflags=subprocess.CREATE_NEW_PROCESS_GROUP, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL) + process = subprocess.Popen( + [sys.executable, __file__, 'run'], + creationflags=subprocess.CREATE_NEW_PROCESS_GROUP, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) else: - process = subprocess.Popen([sys.executable, __file__, 'run'], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - preexec_fn=os.setsid) + process = subprocess.Popen( + [sys.executable, __file__, 'run'], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + preexec_fn=os.setsid, + ) with open(PID_FILE, 'w') as f: f.write(str(process.pid)) gateway_logger.info(f'Starting doorman with PID {process.pid}.') -def stop(): + +def stop() -> None: if not os.path.exists(PID_FILE): gateway_logger.info('No running instance found') return - with open(PID_FILE, 'r') as f: + with open(PID_FILE) as f: pid = int(f.read()) try: if os.name == 'nt': @@ -1506,7 +1739,6 @@ def stop(): deadline = time.time() + 15 while time.time() < deadline: try: - os.kill(pid, 0) time.sleep(0.5) except ProcessLookupError: @@ -1518,7 +1750,8 @@ def stop(): if os.path.exists(PID_FILE): os.remove(PID_FILE) -def restart(): + +def restart() -> None: """Restart the doorman process using PID-based supervisor. This function is intended to be invoked from a detached helper process. """ @@ -1533,7 +1766,8 @@ def restart(): except Exception as e: gateway_logger.error(f'Error during start phase of restart: {e}') -def run(): + +def run() -> None: server_port = int(os.getenv('PORT', 5001)) max_threads = multiprocessing.cpu_count() env_threads = int(os.getenv('THREADS', max_threads)) @@ -1548,7 +1782,9 @@ def run(): 'Set THREADS=1 for single-process memory mode or switch to MEM_OR_EXTERNAL=REDIS for multi-worker.' ) gateway_logger.info(f'Started doorman with {num_threads} threads on port {server_port}') - gateway_logger.info('TLS termination should be handled at reverse proxy (Nginx, Traefik, ALB, etc.)') + gateway_logger.info( + 'TLS termination should be handled at reverse proxy (Nginx, Traefik, ALB, etc.)' + ) uvicorn.run( 'doorman:doorman', host='0.0.0.0', @@ -1556,10 +1792,11 @@ def run(): reload=os.getenv('DEV_RELOAD', 'false').lower() == 'true', reload_excludes=['venv/*', 'logs/*'], workers=num_threads, - log_level='info' + log_level='info', ) -def main(): + +def main() -> None: host = os.getenv('HOST', '0.0.0.0') port = int(os.getenv('PORT', '8000')) try: @@ -1567,39 +1804,55 @@ def main(): 'doorman:doorman', host=host, port=port, - reload=os.getenv('DEBUG', 'false').lower() == 'true' + reload=os.getenv('DEBUG', 'false').lower() == 'true', ) except Exception as e: gateway_logger.error(f'Failed to start server: {str(e)}') raise -def seed_command(): + +def seed_command() -> None: """Run the demo seeder from command line""" import argparse + from utils.demo_seed_util import run_seed - + parser = argparse.ArgumentParser(description='Seed the database with demo data') - parser.add_argument('--users', type=int, default=60, help='Number of users to create (default: 60)') - parser.add_argument('--apis', type=int, default=20, help='Number of APIs to create (default: 20)') - parser.add_argument('--endpoints', type=int, default=6, help='Number of endpoints per API (default: 6)') - parser.add_argument('--groups', type=int, default=10, help='Number of groups to create (default: 10)') - parser.add_argument('--protos', type=int, default=6, help='Number of proto files to create (default: 6)') - parser.add_argument('--logs', type=int, default=2000, help='Number of log entries to create (default: 2000)') - parser.add_argument('--seed', type=int, default=None, help='Random seed for reproducibility (optional)') - + parser.add_argument( + '--users', type=int, default=60, help='Number of users to create (default: 60)' + ) + parser.add_argument( + '--apis', type=int, default=20, help='Number of APIs to create (default: 20)' + ) + parser.add_argument( + '--endpoints', type=int, default=6, help='Number of endpoints per API (default: 6)' + ) + parser.add_argument( + '--groups', type=int, default=10, help='Number of groups to create (default: 10)' + ) + parser.add_argument( + '--protos', type=int, default=6, help='Number of proto files to create (default: 6)' + ) + parser.add_argument( + '--logs', type=int, default=2000, help='Number of log entries to create (default: 2000)' + ) + parser.add_argument( + '--seed', type=int, default=None, help='Random seed for reproducibility (optional)' + ) + args = parser.parse_args(sys.argv[2:]) # Skip 'doorman.py' and 'seed' - - print(f"Starting demo seed with:") - print(f" Users: {args.users}") - print(f" APIs: {args.apis}") - print(f" Endpoints per API: {args.endpoints}") - print(f" Groups: {args.groups}") - print(f" Protos: {args.protos}") - print(f" Logs: {args.logs}") + + print('Starting demo seed with:') + print(f' Users: {args.users}') + print(f' APIs: {args.apis}') + print(f' Endpoints per API: {args.endpoints}') + print(f' Groups: {args.groups}') + print(f' Protos: {args.protos}') + print(f' Logs: {args.logs}') if args.seed is not None: - print(f" Random Seed: {args.seed}") + print(f' Random Seed: {args.seed}') print() - + try: result = run_seed( users=args.users, @@ -1608,16 +1861,18 @@ def seed_command(): groups=args.groups, protos=args.protos, logs=args.logs, - seed=args.seed + seed=args.seed, ) - print("\n✓ Seeding completed successfully!") - print(f"Result: {result}") + print('\n✓ Seeding completed successfully!') + print(f'Result: {result}') except Exception as e: - print(f"\n✗ Seeding failed: {str(e)}") + print(f'\n✗ Seeding failed: {str(e)}') import traceback + traceback.print_exc() sys.exit(1) + if __name__ == '__main__': if len(sys.argv) > 1 and sys.argv[1] == 'stop': stop() diff --git a/backend-services/live-tests/client.py b/backend-services/live-tests/client.py index 1853d6e..77f5459 100644 --- a/backend-services/live-tests/client.py +++ b/backend-services/live-tests/client.py @@ -1,7 +1,10 @@ from __future__ import annotations -import requests + from urllib.parse import urljoin +import requests + + class LiveClient: def __init__(self, base_url: str): self.base_url = base_url.rstrip('/') + '/' @@ -30,7 +33,9 @@ class LiveClient: def post(self, path: str, json=None, data=None, files=None, headers=None, **kwargs): url = urljoin(self.base_url, path.lstrip('/')) hdrs = self._headers_with_csrf(headers) - return self.sess.post(url, json=json, data=data, files=files, headers=hdrs, allow_redirects=False, **kwargs) + return self.sess.post( + url, json=json, data=data, files=files, headers=hdrs, allow_redirects=False, **kwargs + ) def put(self, path: str, json=None, headers=None, **kwargs): url = urljoin(self.base_url, path.lstrip('/')) diff --git a/backend-services/live-tests/config.py b/backend-services/live-tests/config.py index b9849c1..bd7eaba 100644 --- a/backend-services/live-tests/config.py +++ b/backend-services/live-tests/config.py @@ -2,15 +2,24 @@ import os BASE_URL = os.getenv('DOORMAN_BASE_URL', 'http://localhost:3001').rstrip('/') ADMIN_EMAIL = os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev') + +# Resolve admin password from environment or the repo root .env. +# Search order: +# 1) Environment variable DOORMAN_ADMIN_PASSWORD +# 2) Repo root .env (two levels up from live-tests) +# 3) Default test password ADMIN_PASSWORD = os.getenv('DOORMAN_ADMIN_PASSWORD') if not ADMIN_PASSWORD: - env_file = os.path.join(os.path.dirname(__file__), '..', '.env') - if os.path.exists(env_file): - with open(env_file) as f: - for line in f: - if line.startswith('DOORMAN_ADMIN_PASSWORD='): - ADMIN_PASSWORD = line.split('=', 1)[1].strip() - break + env_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '.env')) + try: + if os.path.exists(env_path): + with open(env_path, encoding='utf-8') as f: + for line in f: + if line.startswith('DOORMAN_ADMIN_PASSWORD='): + ADMIN_PASSWORD = line.split('=', 1)[1].strip() + break + except Exception: + pass if not ADMIN_PASSWORD: ADMIN_PASSWORD = 'test-only-password-12chars' @@ -18,6 +27,7 @@ ENABLE_GRAPHQL = True ENABLE_GRPC = True STRICT_HEALTH = True + def require_env(): missing = [] if not BASE_URL: @@ -25,4 +35,4 @@ def require_env(): if not ADMIN_EMAIL: missing.append('DOORMAN_ADMIN_EMAIL') if missing: - raise RuntimeError(f"Missing required env vars: {', '.join(missing)}") + raise RuntimeError(f'Missing required env vars: {", ".join(missing)}') diff --git a/backend-services/live-tests/conftest.py b/backend-services/live-tests/conftest.py index 206f5b1..7445fd7 100644 --- a/backend-services/live-tests/conftest.py +++ b/backend-services/live-tests/conftest.py @@ -1,19 +1,21 @@ import os import sys import time + import pytest -import requests sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from config import BASE_URL, ADMIN_EMAIL, ADMIN_PASSWORD, require_env, STRICT_HEALTH from client import LiveClient +from config import ADMIN_EMAIL, ADMIN_PASSWORD, BASE_URL, STRICT_HEALTH, require_env + @pytest.fixture(scope='session') def base_url() -> str: require_env() return BASE_URL + @pytest.fixture(scope='session') def client(base_url) -> LiveClient: c = LiveClient(base_url) @@ -32,7 +34,7 @@ def client(base_url) -> LiveClient: ok = data.get('status') in ('online', 'healthy') if ok: break - last_err = f"status json={j}" + last_err = f'status json={j}' except Exception as e: last_err = f'json parse error: {e}' else: @@ -48,6 +50,7 @@ def client(base_url) -> LiveClient: assert 'access_token' in auth.get('response', auth), 'login did not return access_token' return c + @pytest.fixture(autouse=True) def ensure_session_and_relaxed_limits(client: LiveClient): """Per-test guard: ensure we're authenticated and not rate-limited. @@ -59,23 +62,29 @@ def ensure_session_and_relaxed_limits(client: LiveClient): r = client.get('/platform/authorization/status') if r.status_code not in (200, 204): from config import ADMIN_EMAIL, ADMIN_PASSWORD + client.login(ADMIN_EMAIL, ADMIN_PASSWORD) except Exception: from config import ADMIN_EMAIL, ADMIN_PASSWORD + client.login(ADMIN_EMAIL, ADMIN_PASSWORD) try: - client.put('/platform/user/admin', json={ - 'rate_limit_duration': 1000000, - 'rate_limit_duration_type': 'second', - 'throttle_duration': 1000000, - 'throttle_duration_type': 'second', - 'throttle_queue_limit': 1000000, - 'throttle_wait_duration': 0, - 'throttle_wait_duration_type': 'second' - }) + client.put( + '/platform/user/admin', + json={ + 'rate_limit_duration': 1000000, + 'rate_limit_duration_type': 'second', + 'throttle_duration': 1000000, + 'throttle_duration_type': 'second', + 'throttle_queue_limit': 1000000, + 'throttle_wait_duration': 0, + 'throttle_wait_duration_type': 'second', + }, + ) except Exception: pass + def pytest_addoption(parser): parser.addoption('--graph', action='store_true', default=False, help='Force GraphQL tests') parser.addoption('--grpc', action='store_true', default=False, help='Force gRPC tests') diff --git a/backend-services/live-tests/servers.py b/backend-services/live-tests/servers.py index abbcc81..58fc538 100644 --- a/backend-services/live-tests/servers.py +++ b/backend-services/live-tests/servers.py @@ -1,9 +1,11 @@ from __future__ import annotations -import threading -import socket + import json +import socket +import threading from http.server import BaseHTTPRequestHandler, HTTPServer + def _find_free_port() -> int: s = socket.socket() s.bind(('0.0.0.0', 0)) @@ -11,6 +13,7 @@ def _find_free_port() -> int: s.close() return port + def _get_host_from_container() -> str: """Get the hostname to use when referring to the host machine from a Docker container. @@ -38,6 +41,7 @@ def _get_host_from_container() -> str: # This is the most common development setup return '127.0.0.1' + class _ThreadedHTTPServer: def __init__(self, handler_cls, host='0.0.0.0', port=None): self.bind_host = host @@ -61,6 +65,7 @@ class _ThreadedHTTPServer: def url(self): return f'http://{self.host}:{self.port}' + def start_rest_echo_server(): class Handler(BaseHTTPRequestHandler): def _json(self, status=200, payload=None): @@ -76,7 +81,7 @@ def start_rest_echo_server(): 'method': 'GET', 'path': self.path, 'headers': {k: v for k, v in self.headers.items()}, - 'query': self.path.split('?', 1)[1] if '?' in self.path else '' + 'query': self.path.split('?', 1)[1] if '?' in self.path else '', } self._json(200, payload) @@ -91,7 +96,7 @@ def start_rest_echo_server(): 'method': 'POST', 'path': self.path, 'headers': {k: v for k, v in self.headers.items()}, - 'json': parsed + 'json': parsed, } self._json(200, payload) @@ -106,7 +111,7 @@ def start_rest_echo_server(): 'method': 'PUT', 'path': self.path, 'headers': {k: v for k, v in self.headers.items()}, - 'json': parsed + 'json': parsed, } self._json(200, payload) @@ -120,6 +125,7 @@ def start_rest_echo_server(): return _ThreadedHTTPServer(Handler).start() + def start_soap_echo_server(): class Handler(BaseHTTPRequestHandler): def _xml(self, status=200, content=''): @@ -134,12 +140,11 @@ def start_soap_echo_server(): content_length = int(self.headers.get('Content-Length', '0') or '0') _ = self.rfile.read(content_length) if content_length else b'' resp = ( - "" - "" - " ok" - "" + '' + '' + ' ok' + '' ) self._xml(200, resp) return _ThreadedHTTPServer(Handler).start() - diff --git a/backend-services/live-tests/test_00_health_auth.py b/backend-services/live-tests/test_00_health_auth.py index 6236983..5ad39bc 100644 --- a/backend-services/live-tests/test_00_health_auth.py +++ b/backend-services/live-tests/test_00_health_auth.py @@ -1,5 +1,6 @@ import pytest + def test_status_ok(client): r = client.get('/api/health') assert r.status_code == 200 @@ -12,6 +13,7 @@ def test_status_ok(client): else: assert 'error_code' in (j or {}) + def test_auth_status_me(client): r = client.get('/platform/authorization/status') assert r.status_code in (200, 204) @@ -21,5 +23,6 @@ def test_auth_status_me(client): me = r.json().get('response', r.json()) assert me.get('username') == 'admin' assert me.get('ui_access') is True -import pytest + + pytestmark = [pytest.mark.health, pytest.mark.auth] diff --git a/backend-services/live-tests/test_10_user_onboarding.py b/backend-services/live-tests/test_10_user_onboarding.py index 5ade1fa..3c6ec94 100644 --- a/backend-services/live-tests/test_10_user_onboarding.py +++ b/backend-services/live-tests/test_10_user_onboarding.py @@ -1,7 +1,7 @@ -import os -import time import random import string +import time + def _strong_password() -> str: upp = random.choice(string.ascii_uppercase) @@ -12,9 +12,10 @@ def _strong_password() -> str: raw = upp + low + dig + spc + tail return ''.join(random.sample(raw, len(raw))) + def test_user_onboarding_lifecycle(client): - username = f"user_{int(time.time())}_{random.randint(1000,9999)}" - email = f"{username}@example.com" + username = f'user_{int(time.time())}_{random.randint(1000, 9999)}' + email = f'{username}@example.com' pwd = _strong_password() payload = { @@ -23,7 +24,7 @@ def test_user_onboarding_lifecycle(client): 'password': pwd, 'role': 'developer', 'groups': ['ALL'], - 'ui_access': False + 'ui_access': False, } r = client.post('/platform/user', json=payload) assert r.status_code in (200, 201), r.text @@ -38,13 +39,14 @@ def test_user_onboarding_lifecycle(client): assert r.status_code in (200, 204), r.text new_pwd = _strong_password() - r = client.put(f'/platform/user/{username}/update-password', json={ - 'old_password': pwd, - 'new_password': new_pwd - }) + r = client.put( + f'/platform/user/{username}/update-password', + json={'old_password': pwd, 'new_password': new_pwd}, + ) assert r.status_code in (200, 204, 400), r.text from client import LiveClient + user_client = LiveClient(client.base_url) auth = user_client.login(email, new_pwd if r.status_code in (200, 204) else pwd) assert 'access_token' in auth.get('response', auth) @@ -56,5 +58,8 @@ def test_user_onboarding_lifecycle(client): r = client.delete(f'/platform/user/{username}') assert r.status_code in (200, 204), r.text + + import pytest + pytestmark = [pytest.mark.users, pytest.mark.auth] diff --git a/backend-services/live-tests/test_20_credit_defs.py b/backend-services/live-tests/test_20_credit_defs.py index af2e403..0184c0b 100644 --- a/backend-services/live-tests/test_20_credit_defs.py +++ b/backend-services/live-tests/test_20_credit_defs.py @@ -1,29 +1,31 @@ import time + def test_credit_def_create_and_assign(client): - group = f"credits_{int(time.time())}" + group = f'credits_{int(time.time())}' api_key = 'TEST_API_KEY_123456789' payload = { 'api_credit_group': group, 'api_key': api_key, 'api_key_header': 'x-api-key', 'credit_tiers': [ - { 'tier_name': 'default', 'credits': 5, 'input_limit': 0, 'output_limit': 0, 'reset_frequency': 'monthly' } - ] + { + 'tier_name': 'default', + 'credits': 5, + 'input_limit': 0, + 'output_limit': 0, + 'reset_frequency': 'monthly', + } + ], } r = client.post('/platform/credit', json=payload) assert r.status_code in (200, 201), r.text payload2 = { 'username': 'admin', - 'users_credits': { - group: { - 'tier_name': 'default', - 'available_credits': 5 - } - } + 'users_credits': {group: {'tier_name': 'default', 'available_credits': 5}}, } - r = client.post(f'/platform/credit/admin', json=payload2) + r = client.post('/platform/credit/admin', json=payload2) assert r.status_code in (200, 201), r.text r = client.get(f'/platform/credit/defs/{group}') @@ -32,5 +34,8 @@ def test_credit_def_create_and_assign(client): assert r.status_code == 200 data = r.json().get('response', r.json()) assert group in (data.get('users_credits') or {}) + + import pytest + pytestmark = [pytest.mark.credits] diff --git a/backend-services/live-tests/test_21_subscription_flows.py b/backend-services/live-tests/test_21_subscription_flows.py index 67dd2ca..ccde063 100644 --- a/backend-services/live-tests/test_21_subscription_flows.py +++ b/backend-services/live-tests/test_21_subscription_flows.py @@ -1,28 +1,39 @@ import time + import pytest pytestmark = [pytest.mark.auth] + def test_subscribe_list_unsubscribe(client): - api_name = f"subs-{int(time.time())}" + api_name = f'subs-{int(time.time())}' api_version = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'subs', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://127.0.0.1:9'], - 'api_type': 'REST', - 'active': True - }) - r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'subs', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://127.0.0.1:9'], + 'api_type': 'REST', + 'active': True, + }, + ) + r = client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) assert r.status_code in (200, 201), r.text r = client.get('/platform/subscription/subscriptions') assert r.status_code == 200 payload = r.json().get('response', r.json()) apis = payload.get('apis') or [] - assert any(f"{api_name}/{api_version}" == a for a in apis) - r = client.post('/platform/subscription/unsubscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + assert any(f'{api_name}/{api_version}' == a for a in apis) + r = client.post( + '/platform/subscription/unsubscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) assert r.status_code in (200, 201) client.delete(f'/platform/api/{api_name}/{api_version}') diff --git a/backend-services/live-tests/test_30_rest_gateway.py b/backend-services/live-tests/test_30_rest_gateway.py index 239a3fb..38b9ded 100644 --- a/backend-services/live-tests/test_30_rest_gateway.py +++ b/backend-services/live-tests/test_30_rest_gateway.py @@ -1,6 +1,8 @@ import time + from servers import start_rest_echo_server + def test_rest_gateway_basic_flow(client): srv = start_rest_echo_server() try: @@ -17,7 +19,7 @@ def test_rest_gateway_basic_flow(client): 'api_type': 'REST', 'api_allowed_retry_count': 0, 'active': True, - 'api_cors_allow_origins': ['*'] + 'api_cors_allow_origins': ['*'], } r = client.post('/platform/api', json=api_payload) assert r.status_code in (200, 201), r.text @@ -27,7 +29,7 @@ def test_rest_gateway_basic_flow(client): 'api_version': api_version, 'endpoint_method': 'GET', 'endpoint_uri': '/status', - 'endpoint_description': 'status' + 'endpoint_description': 'status', } r = client.post('/platform/endpoint', json=ep_payload) assert r.status_code in (200, 201), r.text @@ -49,6 +51,7 @@ def test_rest_gateway_basic_flow(client): finally: srv.stop() + def test_rest_gateway_with_credits_and_header_injection(client): srv = start_rest_echo_server() try: @@ -58,45 +61,68 @@ def test_rest_gateway_with_credits_and_header_injection(client): credit_group = f'cg-{ts}' api_key_val = 'DUMMY_API_KEY_ABC' - r = client.post('/platform/credit', json={ - 'api_credit_group': credit_group, - 'api_key': api_key_val, - 'api_key_header': 'x-api-key', - 'credit_tiers': [{ 'tier_name': 'default', 'credits': 2, 'input_limit': 0, 'output_limit': 0, 'reset_frequency': 'monthly' }] - }) + r = client.post( + '/platform/credit', + json={ + 'api_credit_group': credit_group, + 'api_key': api_key_val, + 'api_key_header': 'x-api-key', + 'credit_tiers': [ + { + 'tier_name': 'default', + 'credits': 2, + 'input_limit': 0, + 'output_limit': 0, + 'reset_frequency': 'monthly', + } + ], + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/credit/admin', json={ - 'username': 'admin', - 'users_credits': { credit_group: { 'tier_name': 'default', 'available_credits': 2 } } - }) + r = client.post( + '/platform/credit/admin', + json={ + 'username': 'admin', + 'users_credits': {credit_group: {'tier_name': 'default', 'available_credits': 2}}, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'REST with credits', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': [srv.url], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'active': True, - 'api_credits_enabled': True, - 'api_credit_group': credit_group - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'REST with credits', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [srv.url], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'active': True, + 'api_credits_enabled': True, + 'api_credit_group': credit_group, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'POST', - 'endpoint_uri': '/echo', - 'endpoint_description': 'echo with header' - }) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/echo', + 'endpoint_description': 'echo with header', + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + r = client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) assert r.status_code in (200, 201), r.text r = client.post(f'/api/rest/{api_name}/{api_version}/echo', json={'ping': 'pong'}) @@ -120,5 +146,8 @@ def test_rest_gateway_with_credits_and_header_injection(client): except Exception: pass srv.stop() + + import pytest + pytestmark = [pytest.mark.rest, pytest.mark.gateway] diff --git a/backend-services/live-tests/test_31_endpoints_crud.py b/backend-services/live-tests/test_31_endpoints_crud.py index 39ee0b7..88ffc82 100644 --- a/backend-services/live-tests/test_31_endpoints_crud.py +++ b/backend-services/live-tests/test_31_endpoints_crud.py @@ -1,19 +1,39 @@ import time + import pytest pytestmark = [pytest.mark.rest] + def test_endpoints_update_list_delete(client): - api_name = f"epcrud-{int(time.time())}" + api_name = f'epcrud-{int(time.time())}' api_version = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, 'api_version': api_version, 'api_description': 'ep', - 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST', 'active': True - }) - client.post('/platform/endpoint', json={ - 'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'GET', 'endpoint_uri': '/z', 'endpoint_description': 'z' - }) - r = client.put(f'/platform/endpoint/GET/{api_name}/{api_version}/z', json={'endpoint_description': 'zzz'}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'ep', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://127.0.0.1:9'], + 'api_type': 'REST', + 'active': True, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/z', + 'endpoint_description': 'z', + }, + ) + r = client.put( + f'/platform/endpoint/GET/{api_name}/{api_version}/z', json={'endpoint_description': 'zzz'} + ) assert r.status_code in (200, 204) r = client.get(f'/platform/endpoint/{api_name}/{api_version}') assert r.status_code == 200 diff --git a/backend-services/live-tests/test_32_user_credit_override.py b/backend-services/live-tests/test_32_user_credit_override.py index b7f5c70..dd4d8fc 100644 --- a/backend-services/live-tests/test_32_user_credit_override.py +++ b/backend-services/live-tests/test_32_user_credit_override.py @@ -1,6 +1,8 @@ import time + from servers import start_rest_echo_server + def test_user_specific_credit_api_key_overrides_group_key(client): srv = start_rest_echo_server() try: @@ -11,42 +13,71 @@ def test_user_specific_credit_api_key_overrides_group_key(client): group_key = 'GROUP_KEY_ABC' user_key = 'USER_KEY_DEF' - r = client.post('/platform/credit', json={ - 'api_credit_group': group, - 'api_key': group_key, - 'api_key_header': 'x-api-key', - 'credit_tiers': [{ 'tier_name': 'default', 'credits': 3, 'input_limit': 0, 'output_limit': 0, 'reset_frequency': 'monthly' }] - }) + r = client.post( + '/platform/credit', + json={ + 'api_credit_group': group, + 'api_key': group_key, + 'api_key_header': 'x-api-key', + 'credit_tiers': [ + { + 'tier_name': 'default', + 'credits': 3, + 'input_limit': 0, + 'output_limit': 0, + 'reset_frequency': 'monthly', + } + ], + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/credit/admin', json={ - 'username': 'admin', - 'users_credits': { group: { 'tier_name': 'default', 'available_credits': 3, 'user_api_key': user_key } } - }) + r = client.post( + '/platform/credit/admin', + json={ + 'username': 'admin', + 'users_credits': { + group: { + 'tier_name': 'default', + 'available_credits': 3, + 'user_api_key': user_key, + } + }, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'credit user override', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': [srv.url], - 'api_type': 'REST', - 'active': True, - 'api_credits_enabled': True, - 'api_credit_group': group - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'credit user override', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [srv.url], + 'api_type': 'REST', + 'active': True, + 'api_credits_enabled': True, + 'api_credit_group': group, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'GET', - 'endpoint_uri': '/whoami', - 'endpoint_description': 'whoami' - }) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/whoami', + 'endpoint_description': 'whoami', + }, + ) assert r.status_code in (200, 201), r.text - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) r = client.get(f'/api/rest/{api_name}/{api_version}/whoami') assert r.status_code == 200 diff --git a/backend-services/live-tests/test_33_rate_limit_and_throttle.py b/backend-services/live-tests/test_33_rate_limit_and_throttle.py index 80b75d2..301e84a 100644 --- a/backend-services/live-tests/test_33_rate_limit_and_throttle.py +++ b/backend-services/live-tests/test_33_rate_limit_and_throttle.py @@ -1,39 +1,53 @@ import time + from servers import start_rest_echo_server + def test_rate_limiting_blocks_excess_requests(client): srv = start_rest_echo_server() try: api_name = f'rl-{int(time.time())}' api_version = 'v1' - client.put('/platform/user/admin', json={ - 'rate_limit_duration': 1, - 'rate_limit_duration_type': 'second', - 'throttle_duration': 999, - 'throttle_duration_type': 'second', - 'throttle_queue_limit': 999, - 'throttle_wait_duration': 0, - 'throttle_wait_duration_type': 'second' - }) + client.put( + '/platform/user/admin', + json={ + 'rate_limit_duration': 1, + 'rate_limit_duration_type': 'second', + 'throttle_duration': 999, + 'throttle_duration_type': 'second', + 'throttle_queue_limit': 999, + 'throttle_wait_duration': 0, + 'throttle_wait_duration_type': 'second', + }, + ) - client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'rl test', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': [srv.url], - 'api_type': 'REST', - 'active': True - }) - client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'GET', - 'endpoint_uri': '/hit', - 'endpoint_description': 'hit' - }) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'rl test', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [srv.url], + 'api_type': 'REST', + 'active': True, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/hit', + 'endpoint_description': 'hit', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) time.sleep(1.1) @@ -55,14 +69,17 @@ def test_rate_limiting_blocks_excess_requests(client): pass srv.stop() try: - client.put('/platform/user/admin', json={ - 'rate_limit_duration': 1000000, - 'rate_limit_duration_type': 'second', - 'throttle_duration': 1000000, - 'throttle_duration_type': 'second', - 'throttle_queue_limit': 1000000, - 'throttle_wait_duration': 0, - 'throttle_wait_duration_type': 'second' - }) + client.put( + '/platform/user/admin', + json={ + 'rate_limit_duration': 1000000, + 'rate_limit_duration_type': 'second', + 'throttle_duration': 1000000, + 'throttle_duration_type': 'second', + 'throttle_queue_limit': 1000000, + 'throttle_wait_duration': 0, + 'throttle_wait_duration_type': 'second', + }, + ) except Exception: pass diff --git a/backend-services/live-tests/test_35_public_bulk_onboarding.py b/backend-services/live-tests/test_35_public_bulk_onboarding.py index 423207a..aadb886 100644 --- a/backend-services/live-tests/test_35_public_bulk_onboarding.py +++ b/backend-services/live-tests/test_35_public_bulk_onboarding.py @@ -1,13 +1,14 @@ -import time -import threading -import socket -import requests -import pytest import os import platform +import socket +import threading +import time +import pytest +import requests +from config import ENABLE_GRAPHQL from servers import start_rest_echo_server, start_soap_echo_server -from config import ENABLE_GRAPHQL, ENABLE_GRPC + def _find_port(): s = socket.socket() @@ -16,6 +17,7 @@ def _find_port(): s.close() return p + def _get_host_from_container(): """Get the hostname to use when referring to the host machine from a Docker container.""" docker_env = os.getenv('DOORMAN_IN_DOCKER', '').lower() @@ -27,6 +29,7 @@ def _get_host_from_container(): return '172.17.0.1' return '127.0.0.1' + def test_bulk_public_rest_crud(client): srv = start_rest_echo_server() try: @@ -35,29 +38,40 @@ def test_bulk_public_rest_crud(client): for i in range(3): api_name = f'pub-rest-{ts}-{i}' api_version = 'v1' - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'public rest', - 'api_allowed_roles': [], - 'api_allowed_groups': [], - 'api_servers': [srv.url], - 'api_type': 'REST', - 'active': True, - 'api_public': True - }) - assert r.status_code in (200, 201), r.text - for m, uri in [('GET', '/items'), ('POST', '/items'), ('PUT', '/items'), ('DELETE', '/items')]: - r = client.post('/platform/endpoint', json={ + r = client.post( + '/platform/api', + json={ 'api_name': api_name, 'api_version': api_version, - 'endpoint_method': m, - 'endpoint_uri': uri, - 'endpoint_description': f'{m} {uri}' - }) + 'api_description': 'public rest', + 'api_allowed_roles': [], + 'api_allowed_groups': [], + 'api_servers': [srv.url], + 'api_type': 'REST', + 'active': True, + 'api_public': True, + }, + ) + assert r.status_code in (200, 201), r.text + for m, uri in [ + ('GET', '/items'), + ('POST', '/items'), + ('PUT', '/items'), + ('DELETE', '/items'), + ]: + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': m, + 'endpoint_uri': uri, + 'endpoint_description': f'{m} {uri}', + }, + ) assert r.status_code in (200, 201), r.text s = requests.Session() - url = f"{base}/api/rest/{api_name}/{api_version}/items" + url = f'{base}/api/rest/{api_name}/{api_version}/items' assert s.get(url).status_code == 200 assert s.post(url, json={'name': 'x'}).status_code == 200 assert s.put(url, json={'name': 'y'}).status_code == 200 @@ -65,69 +79,78 @@ def test_bulk_public_rest_crud(client): finally: srv.stop() + def test_bulk_public_soap_crud(client): srv = start_soap_echo_server() try: base = client.base_url.rstrip('/') ts = int(time.time()) envelope = ( - "" - "" - " " - "" + '' + '' + ' ' + '' ) for i in range(3): api_name = f'pub-soap-{ts}-{i}' api_version = 'v1' - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'public soap', - 'api_allowed_roles': [], - 'api_allowed_groups': [], - 'api_servers': [srv.url], - 'api_type': 'REST', - 'active': True, - 'api_public': True - }) - assert r.status_code in (200, 201), r.text - for uri in ['/create', '/read', '/update', '/delete']: - r = client.post('/platform/endpoint', json={ + r = client.post( + '/platform/api', + json={ 'api_name': api_name, 'api_version': api_version, - 'endpoint_method': 'POST', - 'endpoint_uri': uri, - 'endpoint_description': f'SOAP {uri}' - }) + 'api_description': 'public soap', + 'api_allowed_roles': [], + 'api_allowed_groups': [], + 'api_servers': [srv.url], + 'api_type': 'REST', + 'active': True, + 'api_public': True, + }, + ) + assert r.status_code in (200, 201), r.text + for uri in ['/create', '/read', '/update', '/delete']: + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': uri, + 'endpoint_description': f'SOAP {uri}', + }, + ) assert r.status_code in (200, 201), r.text s = requests.Session() headers = {'Content-Type': 'text/xml'} for uri in ['create', 'read', 'update', 'delete']: - url = f"{base}/api/soap/{api_name}/{api_version}/{uri}" + url = f'{base}/api/soap/{api_name}/{api_version}/{uri}' resp = s.post(url, data=envelope, headers=headers) assert resp.status_code == 200 finally: srv.stop() + @pytest.mark.skipif(not ENABLE_GRAPHQL, reason='GraphQL disabled') def test_bulk_public_graphql_crud(client): try: import uvicorn - from ariadne import gql as _gql, make_executable_schema, MutationType, QueryType + from ariadne import MutationType, QueryType, make_executable_schema + from ariadne import gql as _gql from ariadne.asgi import GraphQL except Exception as e: pytest.skip(f'Missing GraphQL deps: {e}') def start_gql_server(): data_store = {'items': {}, 'seq': 0} - type_defs = _gql(''' + type_defs = _gql(""" type Query { read(id: Int!): String! } type Mutation { create(name: String!): String! update(id: Int!, name: String!): String! delete(id: Int!): Boolean! } - ''') + """) query = QueryType() mutation = MutationType() @@ -168,30 +191,53 @@ def test_bulk_public_graphql_crud(client): api_name = f'pub-gql-{ts}-{i}' api_version = 'v1' try: - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'public gql', - 'api_allowed_roles': [], - 'api_allowed_groups': [], - 'api_servers': [f'http://{host}:{port}'], - 'api_type': 'REST', - 'active': True, - 'api_public': True - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'public gql', + 'api_allowed_roles': [], + 'api_allowed_groups': [], + 'api_servers': [f'http://{host}:{port}'], + 'api_type': 'REST', + 'active': True, + 'api_public': True, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/endpoint', json={'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'POST', 'endpoint_uri': '/graphql', 'endpoint_description': 'graphql'}) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/graphql', + 'endpoint_description': 'graphql', + }, + ) assert r.status_code in (200, 201), r.text s = requests.Session() - url = f"{base}/api/graphql/{api_name}" + url = f'{base}/api/graphql/{api_name}' q_create = {'query': 'mutation { create(name:"A") }'} - assert s.post(url, json=q_create, headers={'X-API-Version': api_version}).status_code == 200 + assert ( + s.post(url, json=q_create, headers={'X-API-Version': api_version}).status_code + == 200 + ) q_update = {'query': 'mutation { update(id:1, name:"B") }'} - assert s.post(url, json=q_update, headers={'X-API-Version': api_version}).status_code == 200 + assert ( + s.post(url, json=q_update, headers={'X-API-Version': api_version}).status_code + == 200 + ) q_read = {'query': '{ read(id:1) }'} - assert s.post(url, json=q_read, headers={'X-API-Version': api_version}).status_code == 200 + assert ( + s.post(url, json=q_read, headers={'X-API-Version': api_version}).status_code == 200 + ) q_delete = {'query': 'mutation { delete(id:1) }'} - assert s.post(url, json=q_delete, headers={'X-API-Version': api_version}).status_code == 200 + assert ( + s.post(url, json=q_delete, headers={'X-API-Version': api_version}).status_code + == 200 + ) finally: try: client.delete(f'/platform/endpoint/POST/{api_name}/{api_version}/graphql') @@ -206,19 +252,29 @@ def test_bulk_public_graphql_crud(client): except Exception: pass + import os as _os -_RUN_LIVE = _os.getenv('DOORMAN_RUN_LIVE', '0') in ('1','true','True') -@pytest.mark.skipif(not _RUN_LIVE, reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable') + +_RUN_LIVE = _os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True') + + +@pytest.mark.skipif( + not _RUN_LIVE, reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable' +) def test_bulk_public_grpc_crud(client): try: + import importlib + import pathlib + import sys + import tempfile + from concurrent import futures + import grpc from grpc_tools import protoc - from concurrent import futures - import tempfile, pathlib, importlib, sys except Exception as e: pytest.skip(f'Missing gRPC deps: {e}') - PROTO = ''' + PROTO = """ syntax = "proto3"; package {pkg}; service Resource { @@ -235,9 +291,9 @@ message UpdateRequest { int32 id = 1; string name = 2; } message UpdateReply { string message = 1; } message DeleteRequest { int32 id = 1; } message DeleteReply { bool ok = 1; } -''' +""" - base = client.base_url.rstrip('/') + client.base_url.rstrip('/') ts = int(time.time()) for i in range(0): api_name = f'pub-grpc-{ts}-{i}' @@ -248,7 +304,15 @@ message DeleteReply { bool ok = 1; } (tmp / 'svc.proto').write_text(PROTO.replace('{pkg}', pkg)) out = tmp / 'gen' out.mkdir() - code = protoc.main(['protoc', f'--proto_path={td}', f'--python_out={out}', f'--grpc_python_out={out}', str(tmp / 'svc.proto')]) + code = protoc.main( + [ + 'protoc', + f'--proto_path={td}', + f'--python_out={out}', + f'--grpc_python_out={out}', + str(tmp / 'svc.proto'), + ] + ) assert code == 0 (out / '__init__.py').write_text('') sys.path.insert(0, str(out)) @@ -258,35 +322,51 @@ message DeleteReply { bool ok = 1; } class Resource(pb2_grpc.ResourceServicer): def Create(self, request, context): return pb2.CreateReply(message=f'created {request.name}') + def Read(self, request, context): return pb2.ReadReply(message=f'read {request.id}') + def Update(self, request, context): return pb2.UpdateReply(message=f'updated {request.id}:{request.name}') + def Delete(self, request, context): return pb2.DeleteReply(ok=True) server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) pb2_grpc.add_ResourceServicer_to_server(Resource(), server) - s = socket.socket(); s.bind(('127.0.0.1', 0)); port = s.getsockname()[1]; s.close() + s = socket.socket() + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + s.close() server.add_insecure_port(f'127.0.0.1:{port}') server.start() try: - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'public grpc', - 'api_allowed_roles': [], - 'api_allowed_groups': [], - 'api_servers': [f'grpc://127.0.0.1:{port}'], - 'api_type': 'REST', - 'active': True, - 'api_public': True - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'public grpc', + 'api_allowed_roles': [], + 'api_allowed_groups': [], + 'api_servers': [f'grpc://127.0.0.1:{port}'], + 'api_type': 'REST', + 'active': True, + 'api_public': True, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/endpoint', json={'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'POST', 'endpoint_uri': '/grpc', 'endpoint_description': 'grpc'}) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) assert r.status_code in (200, 201), r.text - url = f"{base}/api/grpc/{api_name}" - hdr = {'X-API-Version': api_version} pass finally: try: @@ -299,4 +379,5 @@ message DeleteReply { bool ok = 1; } pass server.stop(0) + pytestmark = [pytest.mark.gateway] diff --git a/backend-services/live-tests/test_36_public_grpc_proto.py b/backend-services/live-tests/test_36_public_grpc_proto.py index 0e8169a..be26988 100644 --- a/backend-services/live-tests/test_36_public_grpc_proto.py +++ b/backend-services/live-tests/test_36_public_grpc_proto.py @@ -1,10 +1,11 @@ -import time import socket -import requests -import pytest +import time +import pytest +import requests from config import ENABLE_GRPC + def _find_port() -> int: s = socket.socket() s.bind(('0.0.0.0', 0)) @@ -12,6 +13,7 @@ def _find_port() -> int: s.close() return p + def _get_host_from_container() -> str: """Get the hostname to use when referring to the host machine from a Docker container. @@ -39,17 +41,22 @@ def _get_host_from_container() -> str: # This is the most common development setup return '127.0.0.1' + @pytest.mark.skipif(not ENABLE_GRPC, reason='gRPC disabled') def test_public_grpc_with_proto_upload(client): try: + import importlib + import pathlib + import sys + import tempfile + from concurrent import futures + import grpc from grpc_tools import protoc - from concurrent import futures - import tempfile, pathlib, importlib, sys except Exception as e: pytest.skip(f'Missing gRPC deps: {e}') - PROTO = ''' + PROTO = """ syntax = "proto3"; package {pkg}; service Resource { @@ -66,7 +73,7 @@ message UpdateRequest { int32 id = 1; string name = 2; } message UpdateReply { string message = 1; } message DeleteRequest { int32 id = 1; } message DeleteReply { bool ok = 1; } -''' +""" base = client.base_url.rstrip('/') ts = int(time.time()) @@ -79,7 +86,15 @@ message DeleteReply { bool ok = 1; } (tmp / 'svc.proto').write_text(PROTO.replace('{pkg}', pkg)) out = tmp / 'gen' out.mkdir() - code = protoc.main(['protoc', f'--proto_path={td}', f'--python_out={out}', f'--grpc_python_out={out}', str(tmp / 'svc.proto')]) + code = protoc.main( + [ + 'protoc', + f'--proto_path={td}', + f'--python_out={out}', + f'--grpc_python_out={out}', + str(tmp / 'svc.proto'), + ] + ) assert code == 0 (out / '__init__.py').write_text('') sys.path.insert(0, str(out)) @@ -112,35 +127,63 @@ message DeleteReply { bool ok = 1; } r_up = client.post(f'/platform/proto/{api_name}/{api_version}', files=files) assert r_up.status_code in (200, 201), r_up.text - r_api = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'public grpc with uploaded proto', - 'api_allowed_roles': [], - 'api_allowed_groups': [], - 'api_servers': [f'grpc://{host_ref}:{port}'], - 'api_type': 'REST', - 'active': True, - 'api_public': True, - 'api_grpc_package': pkg - }) + r_api = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'public grpc with uploaded proto', + 'api_allowed_roles': [], + 'api_allowed_groups': [], + 'api_servers': [f'grpc://{host_ref}:{port}'], + 'api_type': 'REST', + 'active': True, + 'api_public': True, + 'api_grpc_package': pkg, + }, + ) assert r_api.status_code in (200, 201), r_api.text - r_ep = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc' - }) + r_ep = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) assert r_ep.status_code in (200, 201), r_ep.text - url = f"{base}/api/grpc/{api_name}" + url = f'{base}/api/grpc/{api_name}' hdr = {'X-API-Version': api_version} - assert requests.post(url, json={'method': 'Resource.Create', 'message': {'name': 'A'}}, headers=hdr).status_code == 200 - assert requests.post(url, json={'method': 'Resource.Read', 'message': {'id': 1}}, headers=hdr).status_code == 200 - assert requests.post(url, json={'method': 'Resource.Update', 'message': {'id': 1, 'name': 'B'}}, headers=hdr).status_code == 200 - assert requests.post(url, json={'method': 'Resource.Delete', 'message': {'id': 1}}, headers=hdr).status_code == 200 + assert ( + requests.post( + url, json={'method': 'Resource.Create', 'message': {'name': 'A'}}, headers=hdr + ).status_code + == 200 + ) + assert ( + requests.post( + url, json={'method': 'Resource.Read', 'message': {'id': 1}}, headers=hdr + ).status_code + == 200 + ) + assert ( + requests.post( + url, + json={'method': 'Resource.Update', 'message': {'id': 1, 'name': 'B'}}, + headers=hdr, + ).status_code + == 200 + ) + assert ( + requests.post( + url, json={'method': 'Resource.Delete', 'message': {'id': 1}}, headers=hdr + ).status_code + == 200 + ) finally: try: client.delete(f'/platform/endpoint/POST/{api_name}/{api_version}/grpc') @@ -151,4 +194,3 @@ message DeleteReply { bool ok = 1; } except Exception: pass server.stop(0) - diff --git a/backend-services/live-tests/test_40_soap_gateway.py b/backend-services/live-tests/test_40_soap_gateway.py index ce71308..be04e45 100644 --- a/backend-services/live-tests/test_40_soap_gateway.py +++ b/backend-services/live-tests/test_40_soap_gateway.py @@ -1,35 +1,46 @@ import time + from servers import start_soap_echo_server + def test_soap_gateway_basic_flow(client): srv = start_soap_echo_server() try: api_name = f'soap-demo-{int(time.time())}' api_version = 'v1' - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'SOAP demo', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': [srv.url], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'active': True - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'SOAP demo', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [srv.url], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'active': True, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'POST', - 'endpoint_uri': '/soap', - 'endpoint_description': 'soap' - }) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/soap', + 'endpoint_description': 'soap', + }, + ) assert r.status_code in (200, 201) - r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + r = client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) assert r.status_code in (200, 201) body = """ @@ -40,7 +51,11 @@ def test_soap_gateway_basic_flow(client): """.strip() - r = client.post(f'/api/soap/{api_name}/{api_version}/soap', data=body, headers={'Content-Type': 'text/xml'}) + r = client.post( + f'/api/soap/{api_name}/{api_version}/soap', + data=body, + headers={'Content-Type': 'text/xml'}, + ) assert r.status_code == 200, r.text finally: try: @@ -52,5 +67,8 @@ def test_soap_gateway_basic_flow(client): except Exception: pass srv.stop() + + import pytest + pytestmark = [pytest.mark.soap, pytest.mark.gateway] diff --git a/backend-services/live-tests/test_41_soap_validation.py b/backend-services/live-tests/test_41_soap_validation.py index c0d910c..73b36a0 100644 --- a/backend-services/live-tests/test_41_soap_validation.py +++ b/backend-services/live-tests/test_41_soap_validation.py @@ -1,40 +1,54 @@ import time + from servers import start_soap_echo_server + def test_soap_validation_blocks_missing_field(client): srv = start_soap_echo_server() try: api_name = f'soapval-{int(time.time())}' api_version = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'soap val', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': [srv.url], - 'api_type': 'REST', - 'active': True - }) - client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'POST', - 'endpoint_uri': '/soap', - 'endpoint_description': 'soap' - }) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'soap val', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [srv.url], + 'api_type': 'REST', + 'active': True, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/soap', + 'endpoint_description': 'soap', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) r = client.get(f'/platform/endpoint/POST/{api_name}/{api_version}/soap') - ep = r.json().get('response', r.json()); endpoint_id = ep.get('endpoint_id'); assert endpoint_id - schema = { - 'validation_schema': { - 'message': { 'required': True, 'type': 'string', 'min': 2 } - } - } - r = client.post('/platform/endpoint/endpoint/validation', json={ - 'endpoint_id': endpoint_id, 'validation_enabled': True, 'validation_schema': schema - }) + ep = r.json().get('response', r.json()) + endpoint_id = ep.get('endpoint_id') + assert endpoint_id + schema = {'validation_schema': {'message': {'required': True, 'type': 'string', 'min': 2}}} + r = client.post( + '/platform/endpoint/endpoint/validation', + json={ + 'endpoint_id': endpoint_id, + 'validation_enabled': True, + 'validation_schema': schema, + }, + ) assert r.status_code in (200, 201) xml = """ @@ -45,11 +59,19 @@ def test_soap_validation_blocks_missing_field(client): """.strip() - r = client.post(f'/api/soap/{api_name}/{api_version}/soap', data=xml, headers={'Content-Type': 'text/xml'}) + r = client.post( + f'/api/soap/{api_name}/{api_version}/soap', + data=xml, + headers={'Content-Type': 'text/xml'}, + ) assert r.status_code == 400 xml2 = xml.replace('>A<', '>AB<') - r = client.post(f'/api/soap/{api_name}/{api_version}/soap', data=xml2, headers={'Content-Type': 'text/xml'}) + r = client.post( + f'/api/soap/{api_name}/{api_version}/soap', + data=xml2, + headers={'Content-Type': 'text/xml'}, + ) assert r.status_code == 200 finally: try: diff --git a/backend-services/live-tests/test_45_soap_preflight.py b/backend-services/live-tests/test_45_soap_preflight.py index 7f95d03..0886c81 100644 --- a/backend-services/live-tests/test_45_soap_preflight.py +++ b/backend-services/live-tests/test_45_soap_preflight.py @@ -1,22 +1,50 @@ import time + import pytest pytestmark = [pytest.mark.soap] + def test_soap_cors_preflight(client): api_name = f'soap-pre-{int(time.time())}' api_version = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, 'api_version': api_version, 'api_description': 'soap pre', - 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST', 'active': True, - 'api_cors_allow_origins': ['http://example.com'], 'api_cors_allow_methods': ['POST'], 'api_cors_allow_headers': ['Content-Type'] - }) - client.post('/platform/endpoint', json={ - 'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'POST', 'endpoint_uri': '/soap', 'endpoint_description': 's' - }) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'soap pre', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://127.0.0.1:9'], + 'api_type': 'REST', + 'active': True, + 'api_cors_allow_origins': ['http://example.com'], + 'api_cors_allow_methods': ['POST'], + 'api_cors_allow_headers': ['Content-Type'], + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/soap', + 'endpoint_description': 's', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) - r = client.options(f'/api/soap/{api_name}/{api_version}/soap', headers={ - 'Origin': 'http://example.com', 'Access-Control-Request-Method': 'POST', 'Access-Control-Request-Headers': 'Content-Type' - }) + r = client.options( + f'/api/soap/{api_name}/{api_version}/soap', + headers={ + 'Origin': 'http://example.com', + 'Access-Control-Request-Method': 'POST', + 'Access-Control-Request-Headers': 'Content-Type', + }, + ) assert r.status_code in (200, 204) diff --git a/backend-services/live-tests/test_46_gateway_errors.py b/backend-services/live-tests/test_46_gateway_errors.py index a5b46a5..44b62fa 100644 --- a/backend-services/live-tests/test_46_gateway_errors.py +++ b/backend-services/live-tests/test_46_gateway_errors.py @@ -1,16 +1,30 @@ import time + import pytest pytestmark = [pytest.mark.rest] + def test_nonexistent_endpoint_returns_gw_error(client): api_name = f'gwerr-{int(time.time())}' api_version = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, 'api_version': api_version, 'api_description': 'gw', - 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST', 'active': True - }) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'gw', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://127.0.0.1:9'], + 'api_type': 'REST', + 'active': True, + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) r = client.get(f'/api/rest/{api_name}/{api_version}/nope') assert r.status_code in (404, 400, 500) data = r.json() diff --git a/backend-services/live-tests/test_50_graphql_gateway.py b/backend-services/live-tests/test_50_graphql_gateway.py index 839062e..893733d 100644 --- a/backend-services/live-tests/test_50_graphql_gateway.py +++ b/backend-services/live-tests/test_50_graphql_gateway.py @@ -1,35 +1,38 @@ import os import time -import pytest +import pytest from config import ENABLE_GRAPHQL -pytestmark = pytest.mark.skipif(not ENABLE_GRAPHQL, reason='GraphQL test disabled (set DOORMAN_TEST_GRAPHQL=1 to enable)') +pytestmark = pytest.mark.skipif( + not ENABLE_GRAPHQL, reason='GraphQL test disabled (set DOORMAN_TEST_GRAPHQL=1 to enable)' +) + def test_graphql_gateway_basic_flow(client): try: import uvicorn - from ariadne import gql, make_executable_schema, QueryType + from ariadne import QueryType, gql, make_executable_schema from ariadne.asgi import GraphQL except Exception as e: pytest.skip(f'Missing GraphQL deps: {e}') - type_defs = gql(''' + type_defs = gql(""" type Query { hello(name: String): String! } - ''') + """) query = QueryType() @query.field('hello') def resolve_hello(*_, name=None): - return f"Hello, {name or 'world'}!" + return f'Hello, {name or "world"}!' schema = make_executable_schema(type_defs, query) app = GraphQL(schema, debug=True) - import threading import socket + import threading def _find_port(): s = socket.socket() @@ -41,6 +44,7 @@ def test_graphql_gateway_basic_flow(client): def _get_host_from_container(): """Get the hostname to use when referring to the host machine from a Docker container.""" import platform + docker_env = os.getenv('DOORMAN_IN_DOCKER', '').lower() if docker_env in ('1', 'true', 'yes'): system = platform.system() @@ -62,29 +66,38 @@ def test_graphql_gateway_basic_flow(client): api_name = f'gql-demo-{int(time.time())}' api_version = 'v1' - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'GraphQL demo', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': [f'http://{host}:{port}'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'active': True - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'GraphQL demo', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [f'http://{host}:{port}'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'active': True, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'POST', - 'endpoint_uri': '/graphql', - 'endpoint_description': 'graphql' - }) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/graphql', + 'endpoint_description': 'graphql', + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + r = client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) assert r.status_code in (200, 201), r.text q = {'query': '{ hello(name:"Doorman") }'} @@ -97,5 +110,8 @@ def test_graphql_gateway_basic_flow(client): client.delete(f'/platform/endpoint/POST/{api_name}/{api_version}/graphql') client.delete(f'/platform/api/{api_name}/{api_version}') + + import pytest + pytestmark = [pytest.mark.graphql, pytest.mark.gateway] diff --git a/backend-services/live-tests/test_52_graphql_validation.py b/backend-services/live-tests/test_52_graphql_validation.py index 0d20f0e..272d48e 100644 --- a/backend-services/live-tests/test_52_graphql_validation.py +++ b/backend-services/live-tests/test_52_graphql_validation.py @@ -1,33 +1,47 @@ import time -import pytest +import pytest from config import ENABLE_GRAPHQL -pytestmark = pytest.mark.skipif(not ENABLE_GRAPHQL, reason='GraphQL validation test disabled (set DOORMAN_TEST_GRAPHQL=1)') +pytestmark = pytest.mark.skipif( + not ENABLE_GRAPHQL, reason='GraphQL validation test disabled (set DOORMAN_TEST_GRAPHQL=1)' +) + def test_graphql_validation_blocks_invalid_variables(client): try: import uvicorn - from ariadne import gql, make_executable_schema, QueryType + from ariadne import QueryType, gql, make_executable_schema from ariadne.asgi import GraphQL except Exception as e: pytest.skip(f'Missing GraphQL deps: {e}') - type_defs = gql(''' + type_defs = gql(""" type Query { hello(name: String!): String! } - ''') + """) query = QueryType() @query.field('hello') def resolve_hello(*_, name): - return f"Hello, {name}!" + return f'Hello, {name}!' schema = make_executable_schema(type_defs, query) - import threading, socket, uvicorn, platform + import platform + import socket + import threading + + import uvicorn + def _free_port(): - s = socket.socket(); s.bind(('0.0.0.0', 0)); p = s.getsockname()[1]; s.close(); return p + s = socket.socket() + s.bind(('0.0.0.0', 0)) + p = s.getsockname()[1] + s.close() + return p + def _get_host_from_container(): import os + docker_env = os.getenv('DOORMAN_IN_DOCKER', '').lower() if docker_env in ('1', 'true', 'yes'): system = platform.system() @@ -36,36 +50,56 @@ def test_graphql_validation_blocks_invalid_variables(client): else: return '172.17.0.1' return '127.0.0.1' + port = _free_port() host = _get_host_from_container() - server = uvicorn.Server(uvicorn.Config(GraphQL(schema), host='0.0.0.0', port=port, log_level='warning')) - t = threading.Thread(target=server.run, daemon=True); t.start() + server = uvicorn.Server( + uvicorn.Config(GraphQL(schema), host='0.0.0.0', port=port, log_level='warning') + ) + t = threading.Thread(target=server.run, daemon=True) + t.start() time.sleep(0.4) api_name = f'gqlval-{int(time.time())}' api_version = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, 'api_version': api_version, - 'api_description': 'gql val', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], - 'api_servers': [f'http://{host}:{port}'], 'api_type': 'REST', 'active': True - }) - client.post('/platform/endpoint', json={ - 'api_name': api_name, 'api_version': api_version, - 'endpoint_method': 'POST','endpoint_uri': '/graphql','endpoint_description': 'gql' - }) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'gql val', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [f'http://{host}:{port}'], + 'api_type': 'REST', + 'active': True, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/graphql', + 'endpoint_description': 'gql', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) r = client.get(f'/platform/endpoint/POST/{api_name}/{api_version}/graphql') - ep = r.json().get('response', r.json()); endpoint_id = ep.get('endpoint_id'); assert endpoint_id + ep = r.json().get('response', r.json()) + endpoint_id = ep.get('endpoint_id') + assert endpoint_id - schema = { - 'validation_schema': { - 'HelloOp.x': { 'required': True, 'type': 'string', 'min': 2 } - } - } - r = client.post('/platform/endpoint/endpoint/validation', json={ - 'endpoint_id': endpoint_id, 'validation_enabled': True, 'validation_schema': schema - }) + schema = {'validation_schema': {'HelloOp.x': {'required': True, 'type': 'string', 'min': 2}}} + r = client.post( + '/platform/endpoint/endpoint/validation', + json={'endpoint_id': endpoint_id, 'validation_enabled': True, 'validation_schema': schema}, + ) assert r.status_code in (200, 201) q = {'query': 'query HelloOp($x: String!) { hello(name: $x) }', 'variables': {'x': 'A'}} diff --git a/backend-services/live-tests/test_59_grpc_invalid_method.py b/backend-services/live-tests/test_59_grpc_invalid_method.py index 16d1684..a5cb89b 100644 --- a/backend-services/live-tests/test_59_grpc_invalid_method.py +++ b/backend-services/live-tests/test_59_grpc_invalid_method.py @@ -1,14 +1,16 @@ import time + import pytest from config import ENABLE_GRPC pytestmark = [pytest.mark.grpc] + def test_grpc_invalid_method_returns_error(client): if not ENABLE_GRPC: pytest.skip('gRPC disabled') try: - import grpc_tools + pass except Exception as e: pytest.skip(f'Missing gRPC deps: {e}') @@ -19,15 +21,41 @@ syntax = "proto3"; package {pkg}; service Greeter {} """.replace('{pkg}', f'{api_name}_{api_version}') - r = client.post(f'/platform/proto/{api_name}/{api_version}', files={'file': ('s.proto', proto.encode('utf-8'))}) + r = client.post( + f'/platform/proto/{api_name}/{api_version}', + files={'file': ('s.proto', proto.encode('utf-8'))}, + ) assert r.status_code == 200 - client.post('/platform/api', json={ - 'api_name': api_name, 'api_version': api_version, 'api_description': 'g', 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], 'api_servers': ['grpc://127.0.0.1:9'], 'api_type': 'REST', 'active': True - }) - client.post('/platform/endpoint', json={ - 'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'POST', 'endpoint_uri': '/grpc', 'endpoint_description': 'g' - }) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) - r = client.post(f'/api/grpc/{api_name}', json={'method': 'Nope.Do', 'message': {}}, headers={'X-API-Version': api_version}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'g', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['grpc://127.0.0.1:9'], + 'api_type': 'REST', + 'active': True, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'g', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) + r = client.post( + f'/api/grpc/{api_name}', + json={'method': 'Nope.Do', 'message': {}}, + headers={'X-API-Version': api_version}, + ) assert r.status_code in (404, 500) diff --git a/backend-services/live-tests/test_60_grpc_gateway.py b/backend-services/live-tests/test_60_grpc_gateway.py index 736e92a..969d6da 100644 --- a/backend-services/live-tests/test_60_grpc_gateway.py +++ b/backend-services/live-tests/test_60_grpc_gateway.py @@ -1,11 +1,11 @@ -import io -import os import time -import pytest +import pytest from config import ENABLE_GRPC -pytestmark = pytest.mark.skipif(not ENABLE_GRPC, reason='gRPC test disabled (set DOORMAN_TEST_GRPC=1 to enable)') +pytestmark = pytest.mark.skipif( + not ENABLE_GRPC, reason='gRPC test disabled (set DOORMAN_TEST_GRPC=1 to enable)' +) PROTO_TEMPLATE = """ syntax = "proto3"; @@ -25,22 +25,23 @@ message HelloReply { } """ + def _start_grpc_server(port: int): - import grpc from concurrent import futures + import grpc + class GreeterServicer: def Hello(self, request, context): - from google.protobuf import struct_pb2 pass server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) return server + def test_grpc_gateway_basic_flow(client): try: import grpc - import grpc_tools except Exception as e: pytest.skip(f'Missing gRPC deps: {e}') @@ -54,26 +55,38 @@ def test_grpc_gateway_basic_flow(client): assert r.status_code == 200, r.text try: + import importlib + import pathlib + import tempfile + from grpc_tools import protoc - import tempfile, pathlib, importlib + with tempfile.TemporaryDirectory() as td: tmp = pathlib.Path(td) (tmp / 'svc.proto').write_text(proto) out = tmp / 'gen' out.mkdir() - code = protoc.main([ - 'protoc', f'--proto_path={td}', f'--python_out={out}', f'--grpc_python_out={out}', str(tmp / 'svc.proto') - ]) + code = protoc.main( + [ + 'protoc', + f'--proto_path={td}', + f'--python_out={out}', + f'--grpc_python_out={out}', + str(tmp / 'svc.proto'), + ] + ) assert code == 0 (out / '__init__.py').write_text('') import sys + sys.path.insert(0, str(out)) pb2 = importlib.import_module('svc_pb2') pb2_grpc = importlib.import_module('svc_pb2_grpc') - import grpc from concurrent import futures + import grpc + class Greeter(pb2_grpc.GreeterServicer): def Hello(self, request, context): return pb2.HelloReply(message=f'Hello, {request.name}!') @@ -81,41 +94,53 @@ def test_grpc_gateway_basic_flow(client): server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) import socket - s = socket.socket(); s.bind(('127.0.0.1', 0)); port = s.getsockname()[1]; s.close() + + s = socket.socket() + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + s.close() server.add_insecure_port(f'127.0.0.1:{port}') server.start() try: - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'gRPC demo', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': [f'grpc://127.0.0.1:{port}'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'active': True - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'gRPC demo', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [f'grpc://127.0.0.1:{port}'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'active': True, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc' - }) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) assert r.status_code in (200, 201) - r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + r = client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) assert r.status_code in (200, 201) - body = { - 'method': 'Greeter.Hello', - 'message': {'name': 'Doorman'} - } - r = client.post(f'/api/grpc/{api_name}', json=body, headers={'X-API-Version': api_version}) + body = {'method': 'Greeter.Hello', 'message': {'name': 'Doorman'}} + r = client.post( + f'/api/grpc/{api_name}', json=body, headers={'X-API-Version': api_version} + ) assert r.status_code == 200, r.text data = r.json().get('response', r.json()) assert data.get('message') == 'Hello, Doorman!' @@ -131,5 +156,8 @@ def test_grpc_gateway_basic_flow(client): server.stop(0) except Exception as e: pytest.skip(f'Skipping gRPC due to setup failure: {e}') + + import pytest + pytestmark = [pytest.mark.grpc, pytest.mark.gateway] diff --git a/backend-services/live-tests/test_61_ip_policy.py b/backend-services/live-tests/test_61_ip_policy.py index 4c1fe25..fb2a4b1 100644 --- a/backend-services/live-tests/test_61_ip_policy.py +++ b/backend-services/live-tests/test_61_ip_policy.py @@ -1,31 +1,36 @@ -import pytest - from types import SimpleNamespace -from utils.ip_policy_util import _ip_in_list, _get_client_ip, enforce_api_ip_policy +import pytest + +from utils.ip_policy_util import _get_client_ip, _ip_in_list, enforce_api_ip_policy + @pytest.fixture(autouse=True, scope='session') def ensure_session_and_relaxed_limits(): yield + def make_request(host: str | None = None, headers: dict | None = None): client = SimpleNamespace(host=host, port=None) return SimpleNamespace(client=client, headers=headers or {}, url=SimpleNamespace(path='/')) + def test_ip_in_list_ipv4_exact_and_cidr(): assert _ip_in_list('192.168.1.10', ['192.168.1.10']) assert _ip_in_list('10.1.2.3', ['10.0.0.0/8']) assert not _ip_in_list('11.1.2.3', ['10.0.0.0/8']) + def test_ip_in_list_ipv6_exact_and_cidr(): assert _ip_in_list('2001:db8::1', ['2001:db8::1']) assert _ip_in_list('2001:db8::abcd', ['2001:db8::/32']) assert not _ip_in_list('2001:db9::1', ['2001:db8::/32']) + def test_get_client_ip_trusted_proxy(monkeypatch): - monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: { - 'xff_trusted_proxies': ['10.0.0.0/8'] - }) + monkeypatch.setattr( + 'utils.ip_policy_util.get_cached_settings', lambda: {'xff_trusted_proxies': ['10.0.0.0/8']} + ) req1 = make_request('10.1.2.3', {'X-Forwarded-For': '1.2.3.4, 10.1.2.3'}) assert _get_client_ip(req1, True) == '1.2.3.4' @@ -33,17 +38,21 @@ def test_get_client_ip_trusted_proxy(monkeypatch): req2 = make_request('8.8.8.8', {'X-Forwarded-For': '1.2.3.4'}) assert _get_client_ip(req2, True) == '8.8.8.8' + def test_enforce_api_policy_never_blocks_localhost(monkeypatch): - monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: { - 'trust_x_forwarded_for': False, - 'xff_trusted_proxies': [], - 'allow_localhost_bypass': True, - }) + monkeypatch.setattr( + 'utils.ip_policy_util.get_cached_settings', + lambda: { + 'trust_x_forwarded_for': False, + 'xff_trusted_proxies': [], + 'allow_localhost_bypass': True, + }, + ) api = { 'api_ip_mode': 'whitelist', 'api_ip_whitelist': ['203.0.113.0/24'], - 'api_ip_blacklist': ['0.0.0.0/0'] + 'api_ip_blacklist': ['0.0.0.0/0'], } req_local_v4 = make_request('127.0.0.1', {}) @@ -52,38 +61,46 @@ def test_enforce_api_policy_never_blocks_localhost(monkeypatch): req_local_v6 = make_request('::1', {}) enforce_api_ip_policy(req_local_v6, api) + def test_get_client_ip_secure_default_no_trust_when_empty_list(monkeypatch): - monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: { - 'trust_x_forwarded_for': True, - 'xff_trusted_proxies': [] - }) + monkeypatch.setattr( + 'utils.ip_policy_util.get_cached_settings', + lambda: {'trust_x_forwarded_for': True, 'xff_trusted_proxies': []}, + ) req = make_request('10.0.0.5', {'X-Forwarded-For': '203.0.113.9'}) assert _get_client_ip(req, True) == '10.0.0.5' + def test_get_client_ip_x_real_ip_and_cf_connecting(monkeypatch): - monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: { - 'trust_x_forwarded_for': True, - 'xff_trusted_proxies': ['10.0.0.0/8'] - }) + monkeypatch.setattr( + 'utils.ip_policy_util.get_cached_settings', + lambda: {'trust_x_forwarded_for': True, 'xff_trusted_proxies': ['10.0.0.0/8']}, + ) req1 = make_request('10.2.3.4', {'X-Real-IP': '198.51.100.7'}) assert _get_client_ip(req1, True) == '198.51.100.7' req2 = make_request('10.2.3.4', {'CF-Connecting-IP': '2001:db8::2'}) assert _get_client_ip(req2, True) == '2001:db8::2' + def test_get_client_ip_ignores_headers_when_trust_disabled(monkeypatch): - monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: { - 'trust_x_forwarded_for': False, - 'xff_trusted_proxies': ['10.0.0.0/8'] - }) + monkeypatch.setattr( + 'utils.ip_policy_util.get_cached_settings', + lambda: {'trust_x_forwarded_for': False, 'xff_trusted_proxies': ['10.0.0.0/8']}, + ) req = make_request('10.2.3.4', {'X-Forwarded-For': '198.51.100.7'}) assert _get_client_ip(req, False) == '10.2.3.4' + def test_enforce_api_policy_whitelist_and_blacklist(monkeypatch): - monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: { - 'trust_x_forwarded_for': False, - 'xff_trusted_proxies': [] - }) - api = {'api_ip_mode': 'whitelist', 'api_ip_whitelist': ['203.0.113.0/24'], 'api_ip_blacklist': []} + monkeypatch.setattr( + 'utils.ip_policy_util.get_cached_settings', + lambda: {'trust_x_forwarded_for': False, 'xff_trusted_proxies': []}, + ) + api = { + 'api_ip_mode': 'whitelist', + 'api_ip_whitelist': ['203.0.113.0/24'], + 'api_ip_blacklist': [], + } req = make_request('198.51.100.10', {}) raised = False try: @@ -92,7 +109,11 @@ def test_enforce_api_policy_whitelist_and_blacklist(monkeypatch): raised = True assert raised - api2 = {'api_ip_mode': 'allow_all', 'api_ip_whitelist': [], 'api_ip_blacklist': ['198.51.100.0/24']} + api2 = { + 'api_ip_mode': 'allow_all', + 'api_ip_whitelist': [], + 'api_ip_blacklist': ['198.51.100.0/24'], + } req2 = make_request('198.51.100.10', {}) raised2 = False try: @@ -101,12 +122,16 @@ def test_enforce_api_policy_whitelist_and_blacklist(monkeypatch): raised2 = True assert raised2 + def test_localhost_bypass_requires_no_forwarding_headers(monkeypatch): - monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: { - 'allow_localhost_bypass': True, - 'trust_x_forwarded_for': False, - 'xff_trusted_proxies': [] - }) + monkeypatch.setattr( + 'utils.ip_policy_util.get_cached_settings', + lambda: { + 'allow_localhost_bypass': True, + 'trust_x_forwarded_for': False, + 'xff_trusted_proxies': [], + }, + ) api = {'api_ip_mode': 'whitelist', 'api_ip_whitelist': ['203.0.113.0/24']} req = make_request('::1', {'X-Forwarded-For': '1.2.3.4'}) raised = False @@ -116,13 +141,17 @@ def test_localhost_bypass_requires_no_forwarding_headers(monkeypatch): raised = True assert raised, 'Expected enforcement when forwarding header present' + def test_env_overrides_localhost_bypass(monkeypatch): monkeypatch.setenv('LOCAL_HOST_IP_BYPASS', 'true') - monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: { - 'allow_localhost_bypass': False, - 'trust_x_forwarded_for': False, - 'xff_trusted_proxies': [] - }) + monkeypatch.setattr( + 'utils.ip_policy_util.get_cached_settings', + lambda: { + 'allow_localhost_bypass': False, + 'trust_x_forwarded_for': False, + 'xff_trusted_proxies': [], + }, + ) api = {'api_ip_mode': 'whitelist', 'api_ip_whitelist': ['203.0.113.0/24']} req = make_request('127.0.0.1', {}) enforce_api_ip_policy(req, api) diff --git a/backend-services/live-tests/test_66_graphql_missing_version_header.py b/backend-services/live-tests/test_66_graphql_missing_version_header.py index fb6c956..5b27b2d 100644 --- a/backend-services/live-tests/test_66_graphql_missing_version_header.py +++ b/backend-services/live-tests/test_66_graphql_missing_version_header.py @@ -1,41 +1,74 @@ import time + import pytest from config import ENABLE_GRAPHQL pytestmark = [pytest.mark.graphql] + def test_graphql_missing_version_header_returns_400(client): if not ENABLE_GRAPHQL: pytest.skip('GraphQL disabled') try: - from ariadne import gql, make_executable_schema, QueryType - from ariadne.asgi import GraphQL import uvicorn + from ariadne import QueryType, gql, make_executable_schema + from ariadne.asgi import GraphQL except Exception as e: pytest.skip(f'Missing deps: {e}') type_defs = gql('type Query { ok: String! }') query = QueryType() + @query.field('ok') def resolve_ok(*_): return 'ok' + schema = make_executable_schema(type_defs, query) - import threading, socket - s = socket.socket(); s.bind(('127.0.0.1', 0)); port = s.getsockname()[1]; s.close() - server = uvicorn.Server(uvicorn.Config(GraphQL(schema), host='127.0.0.1', port=port, log_level='warning')) - t = threading.Thread(target=server.run, daemon=True); t.start() - import time as _t; _t.sleep(0.4) + import socket + import threading + + s = socket.socket() + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + s.close() + server = uvicorn.Server( + uvicorn.Config(GraphQL(schema), host='127.0.0.1', port=port, log_level='warning') + ) + t = threading.Thread(target=server.run, daemon=True) + t.start() + import time as _t + + _t.sleep(0.4) api_name = f'gql-novh-{int(time.time())}' api_version = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, 'api_version': api_version, 'api_description': 'gql', - 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': [f'http://127.0.0.1:{port}'], 'api_type': 'REST', 'active': True - }) - client.post('/platform/endpoint', json={ - 'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'POST', 'endpoint_uri': '/graphql', 'endpoint_description': 'gql' - }) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'gql', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [f'http://127.0.0.1:{port}'], + 'api_type': 'REST', + 'active': True, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/graphql', + 'endpoint_description': 'gql', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) r = client.post(f'/api/graphql/{api_name}', json={'query': '{ ok }'}) assert r.status_code == 400 diff --git a/backend-services/live-tests/test_70_combinations.py b/backend-services/live-tests/test_70_combinations.py index dfbd104..75e8090 100644 --- a/backend-services/live-tests/test_70_combinations.py +++ b/backend-services/live-tests/test_70_combinations.py @@ -1,6 +1,8 @@ import time + from servers import start_rest_echo_server + def test_endpoint_level_servers_override(client): srv_api = start_rest_echo_server() srv_ep = start_rest_echo_server() @@ -8,29 +10,38 @@ def test_endpoint_level_servers_override(client): api_name = f'combo-{int(time.time())}' api_version = 'v1' - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'combo demo', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': [srv_api.url], - 'api_type': 'REST', - 'active': True - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'combo demo', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [srv_api.url], + 'api_type': 'REST', + 'active': True, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'GET', - 'endpoint_uri': '/who', - 'endpoint_description': 'who am i', - 'endpoint_servers': [srv_ep.url] - }) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/who', + 'endpoint_description': 'who am i', + 'endpoint_servers': [srv_ep.url], + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + r = client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) assert r.status_code in (200, 201) r = client.get(f'/api/rest/{api_name}/{api_version}/who') @@ -51,5 +62,8 @@ def test_endpoint_level_servers_override(client): pass srv_api.stop() srv_ep.stop() + + import pytest + pytestmark = [pytest.mark.gateway, pytest.mark.routing] diff --git a/backend-services/live-tests/test_80_roles_groups_permissions.py b/backend-services/live-tests/test_80_roles_groups_permissions.py index 4f9dbbe..c959fc9 100644 --- a/backend-services/live-tests/test_80_roles_groups_permissions.py +++ b/backend-services/live-tests/test_80_roles_groups_permissions.py @@ -1,11 +1,12 @@ -import time import random import string +import time + def _rand_user() -> tuple[str, str, str]: ts = int(time.time()) - uname = f"usr_{ts}_{random.randint(1000,9999)}" - email = f"{uname}@example.com" + uname = f'usr_{ts}_{random.randint(1000, 9999)}' + email = f'{uname}@example.com' upp = random.choice(string.ascii_uppercase) low = ''.join(random.choices(string.ascii_lowercase, k=8)) dig = ''.join(random.choices(string.digits, k=4)) @@ -14,40 +15,47 @@ def _rand_user() -> tuple[str, str, str]: pwd = ''.join(random.sample(upp + low + dig + spc + tail, len(upp + low + dig + spc + tail))) return uname, email, pwd + def test_role_permission_blocks_api_management(client): - role_name = f"viewer_{int(time.time())}" - r = client.post('/platform/role', json={ - 'role_name': role_name, - 'role_description': 'temporary viewer', - 'view_logs': True - }) + role_name = f'viewer_{int(time.time())}' + r = client.post( + '/platform/role', + json={'role_name': role_name, 'role_description': 'temporary viewer', 'view_logs': True}, + ) assert r.status_code in (200, 201), r.text uname, email, pwd = _rand_user() - r = client.post('/platform/user', json={ - 'username': uname, - 'email': email, - 'password': pwd, - 'role': role_name, - 'groups': ['ALL'], - 'ui_access': True - }) + r = client.post( + '/platform/user', + json={ + 'username': uname, + 'email': email, + 'password': pwd, + 'role': role_name, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) assert r.status_code in (200, 201), r.text from client import LiveClient + user_client = LiveClient(client.base_url) user_client.login(email, pwd) - api_name = f"nope-{int(time.time())}" - r = user_client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': 'v1', - 'api_description': 'should be blocked', - 'api_allowed_roles': [role_name], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://127.0.0.1:1'], - 'api_type': 'REST' - }) + api_name = f'nope-{int(time.time())}' + r = user_client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': 'v1', + 'api_description': 'should be blocked', + 'api_allowed_roles': [role_name], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://127.0.0.1:1'], + 'api_type': 'REST', + }, + ) assert r.status_code == 403 body = r.json() data = body.get('response', body) @@ -55,5 +63,8 @@ def test_role_permission_blocks_api_management(client): client.delete(f'/platform/user/{uname}') client.delete(f'/platform/role/{role_name}') + + import pytest + pytestmark = [pytest.mark.security, pytest.mark.roles] diff --git a/backend-services/live-tests/test_81_negative_permissions.py b/backend-services/live-tests/test_81_negative_permissions.py index 9a1dcef..2a70ddb 100644 --- a/backend-services/live-tests/test_81_negative_permissions.py +++ b/backend-services/live-tests/test_81_negative_permissions.py @@ -1,14 +1,15 @@ -import time import random -import string +import time + import pytest pytestmark = [pytest.mark.security] + def _mk_user_payload(role_name: str) -> tuple[str, str, str, dict]: ts = int(time.time()) - uname = f"min_{ts}_{random.randint(1000,9999)}" - email = f"{uname}@example.com" + uname = f'min_{ts}_{random.randint(1000, 9999)}' + email = f'{uname}@example.com' pwd = 'Strong!Passw0rd1234' payload = { 'username': uname, @@ -16,16 +17,14 @@ def _mk_user_payload(role_name: str) -> tuple[str, str, str, dict]: 'password': pwd, 'role': role_name, 'groups': ['ALL'], - 'ui_access': True + 'ui_access': True, } return uname, email, pwd, payload + def test_negative_permissions_for_logs_and_config(client): - role_name = f"minrole_{int(time.time())}" - r = client.post('/platform/role', json={ - 'role_name': role_name, - 'role_description': 'minimal' - }) + role_name = f'minrole_{int(time.time())}' + r = client.post('/platform/role', json={'role_name': role_name, 'role_description': 'minimal'}) assert r.status_code in (200, 201) uname, email, pwd, payload = _mk_user_payload(role_name) @@ -33,6 +32,7 @@ def test_negative_permissions_for_logs_and_config(client): assert r.status_code in (200, 201) from client import LiveClient + u = LiveClient(client.base_url) u.login(email, pwd) diff --git a/backend-services/live-tests/test_82_role_matrix_negative.py b/backend-services/live-tests/test_82_role_matrix_negative.py index 28d820d..ce2898a 100644 --- a/backend-services/live-tests/test_82_role_matrix_negative.py +++ b/backend-services/live-tests/test_82_role_matrix_negative.py @@ -1,69 +1,110 @@ import time + import pytest pytestmark = [pytest.mark.security, pytest.mark.roles] + def _mk_user(client, role_name: str): ts = int(time.time()) - uname = f"perm_{ts}" - email = f"{uname}@example.com" + uname = f'perm_{ts}' + email = f'{uname}@example.com' pwd = 'Strong!Passw0rd1234' - r = client.post('/platform/user', json={ - 'username': uname, - 'email': email, - 'password': pwd, - 'role': role_name, - 'groups': ['ALL'], - 'ui_access': True, - 'rate_limit_duration': 1000000, - 'rate_limit_duration_type': 'second', - 'throttle_duration': 1000000, - 'throttle_duration_type': 'second', - 'throttle_queue_limit': 1000000, - 'throttle_wait_duration': 0, - 'throttle_wait_duration_type': 'second' - }) + r = client.post( + '/platform/user', + json={ + 'username': uname, + 'email': email, + 'password': pwd, + 'role': role_name, + 'groups': ['ALL'], + 'ui_access': True, + 'rate_limit_duration': 1000000, + 'rate_limit_duration_type': 'second', + 'throttle_duration': 1000000, + 'throttle_duration_type': 'second', + 'throttle_queue_limit': 1000000, + 'throttle_wait_duration': 0, + 'throttle_wait_duration_type': 'second', + }, + ) assert r.status_code in (200, 201), r.text return uname, email, pwd + def _login(base_client, email, pwd): from client import LiveClient + c = LiveClient(base_client.base_url) c.login(email, pwd) return c + def test_permission_matrix_block_then_allow(client): api_name = f'permapi-{int(time.time())}' api_version = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, 'api_version': api_version, 'api_description': 'perm', - 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST', 'active': True - }) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'perm', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://127.0.0.1:9'], + 'api_type': 'REST', + 'active': True, + }, + ) def try_manage_apis(c): - return c.post('/platform/api', json={ - 'api_name': f'pa-{int(time.time())}', 'api_version': 'v1', 'api_description': 'x', - 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST' - }) + return c.post( + '/platform/api', + json={ + 'api_name': f'pa-{int(time.time())}', + 'api_version': 'v1', + 'api_description': 'x', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://127.0.0.1:9'], + 'api_type': 'REST', + }, + ) def try_manage_endpoints(c): - return c.post('/platform/endpoint', json={ - 'api_name': api_name, 'api_version': api_version, - 'endpoint_method': 'GET', 'endpoint_uri': f'/p{int(time.time())}', 'endpoint_description': 'x' - }) + return c.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': f'/p{int(time.time())}', + 'endpoint_description': 'x', + }, + ) def try_manage_users(c): - return c.post('/platform/user', json={ - 'username': f'u{int(time.time())}', 'email': f'u{int(time.time())}@ex.com', - 'password': 'Strong!Passw0rd1234', 'role': 'viewer', 'groups': ['ALL'], 'ui_access': False - }) + return c.post( + '/platform/user', + json={ + 'username': f'u{int(time.time())}', + 'email': f'u{int(time.time())}@ex.com', + 'password': 'Strong!Passw0rd1234', + 'role': 'viewer', + 'groups': ['ALL'], + 'ui_access': False, + }, + ) def try_manage_groups(c): - return c.post('/platform/group', json={'group_name': f'g{int(time.time())}', 'group_description': 'x'}) + return c.post( + '/platform/group', json={'group_name': f'g{int(time.time())}', 'group_description': 'x'} + ) def try_manage_roles(c): - return c.post('/platform/role', json={'role_name': f'r{int(time.time())}', 'role_description': 'x'}) + return c.post( + '/platform/role', json={'role_name': f'r{int(time.time())}', 'role_description': 'x'} + ) matrix = [ ('manage_apis', try_manage_apis, 'API007'), @@ -74,19 +115,23 @@ def test_permission_matrix_block_then_allow(client): ] for perm_field, attempt, expected_code in matrix: - role_name = f"role_{perm_field}_{int(time.time())}" - r = client.post('/platform/role', json={'role_name': role_name, 'role_description': 'matrix', perm_field: False}) + role_name = f'role_{perm_field}_{int(time.time())}' + r = client.post( + '/platform/role', + json={'role_name': role_name, 'role_description': 'matrix', perm_field: False}, + ) assert r.status_code in (200, 201), r.text uname, email, pwd = _mk_user(client, role_name) uc = _login(client, email, pwd) resp = attempt(uc) - assert resp.status_code == 403, f"{perm_field} should be blocked: {resp.text}" - data = resp.json(); code = data.get('error_code') or (data.get('response') or {}).get('error_code') + assert resp.status_code == 403, f'{perm_field} should be blocked: {resp.text}' + data = resp.json() + code = data.get('error_code') or (data.get('response') or {}).get('error_code') assert code == expected_code client.put(f'/platform/role/{role_name}', json={perm_field: True}) resp2 = attempt(uc) - assert resp2.status_code != 403, f"{perm_field} still blocked after enable: {resp2.text}" + assert resp2.status_code != 403, f'{perm_field} still blocked after enable: {resp2.text}' client.delete(f'/platform/user/{uname}') client.delete(f'/platform/role/{role_name}') diff --git a/backend-services/live-tests/test_83_roles_groups_crud.py b/backend-services/live-tests/test_83_roles_groups_crud.py index 8f0b7cd..09b3b38 100644 --- a/backend-services/live-tests/test_83_roles_groups_crud.py +++ b/backend-services/live-tests/test_83_roles_groups_crud.py @@ -1,12 +1,16 @@ import time + import pytest pytestmark = [pytest.mark.security, pytest.mark.roles] + def test_roles_groups_crud_and_list(client): - role = f"rolex-{int(time.time())}" - group = f"groupx-{int(time.time())}" - r = client.post('/platform/role', json={'role_name': role, 'role_description': 'x', 'manage_users': True}) + role = f'rolex-{int(time.time())}' + group = f'groupx-{int(time.time())}' + r = client.post( + '/platform/role', json={'role_name': role, 'role_description': 'x', 'manage_users': True} + ) assert r.status_code in (200, 201) r = client.post('/platform/group', json={'group_name': group, 'group_description': 'x'}) assert r.status_code in (200, 201) diff --git a/backend-services/live-tests/test_85_endpoint_validation.py b/backend-services/live-tests/test_85_endpoint_validation.py index c2429e8..04d0197 100644 --- a/backend-services/live-tests/test_85_endpoint_validation.py +++ b/backend-services/live-tests/test_85_endpoint_validation.py @@ -1,30 +1,38 @@ import time + from servers import start_rest_echo_server + def test_rest_endpoint_validation_blocks_invalid_payload(client): srv = start_rest_echo_server() try: api_name = f'val-{int(time.time())}' api_version = 'v1' - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'validation test', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': [srv.url], - 'api_type': 'REST', - 'active': True - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'validation test', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [srv.url], + 'api_type': 'REST', + 'active': True, + }, + ) assert r.status_code in (200, 201) - r = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'POST', - 'endpoint_uri': '/create', - 'endpoint_description': 'create' - }) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/create', + 'endpoint_description': 'create', + }, + ) assert r.status_code in (200, 201) r = client.get(f'/platform/endpoint/POST/{api_name}/{api_version}/create') @@ -36,17 +44,23 @@ def test_rest_endpoint_validation_blocks_invalid_payload(client): schema = { 'validation_schema': { 'user.name': {'required': True, 'type': 'string', 'min': 2}, - 'user.age': {'required': True, 'type': 'number', 'min': 1} + 'user.age': {'required': True, 'type': 'number', 'min': 1}, } } - r = client.post('/platform/endpoint/endpoint/validation', json={ - 'endpoint_id': endpoint_id, - 'validation_enabled': True, - 'validation_schema': schema - }) + r = client.post( + '/platform/endpoint/endpoint/validation', + json={ + 'endpoint_id': endpoint_id, + 'validation_enabled': True, + 'validation_schema': schema, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + r = client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) assert r.status_code in (200, 201) r = client.post(f'/api/rest/{api_name}/{api_version}/create', json={'user': {'name': 'A'}}) @@ -55,7 +69,9 @@ def test_rest_endpoint_validation_blocks_invalid_payload(client): err = body.get('error_code') or body.get('response', {}).get('error_code') assert err == 'GTW011' or body.get('error_message') - r = client.post(f'/api/rest/{api_name}/{api_version}/create', json={'user': {'name': 'Alan', 'age': 33}}) + r = client.post( + f'/api/rest/{api_name}/{api_version}/create', json={'user': {'name': 'Alan', 'age': 33}} + ) assert r.status_code == 200 finally: try: @@ -67,5 +83,8 @@ def test_rest_endpoint_validation_blocks_invalid_payload(client): except Exception: pass srv.stop() + + import pytest + pytestmark = [pytest.mark.validation, pytest.mark.rest] diff --git a/backend-services/live-tests/test_86_validation_edge_cases.py b/backend-services/live-tests/test_86_validation_edge_cases.py index 21f7f09..1fb22ec 100644 --- a/backend-services/live-tests/test_86_validation_edge_cases.py +++ b/backend-services/live-tests/test_86_validation_edge_cases.py @@ -1,61 +1,88 @@ import time -from servers import start_rest_echo_server + import pytest +from servers import start_rest_echo_server pytestmark = [pytest.mark.validation] + def test_nested_array_and_format_validations(client): srv = start_rest_echo_server() try: api_name = f'valedge-{int(time.time())}' api_version = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, 'api_version': api_version, - 'api_description': 'edge validations', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], - 'api_servers': [srv.url], 'api_type': 'REST', 'active': True - }) - client.post('/platform/endpoint', json={ - 'api_name': api_name, 'api_version': api_version, - 'endpoint_method': 'POST', 'endpoint_uri': '/submit', 'endpoint_description': 'submit' - }) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'edge validations', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [srv.url], + 'api_type': 'REST', + 'active': True, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'POST', + 'endpoint_uri': '/submit', + 'endpoint_description': 'submit', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) r = client.get(f'/platform/endpoint/POST/{api_name}/{api_version}/submit') - ep = r.json().get('response', r.json()); endpoint_id = ep.get('endpoint_id'); assert endpoint_id + ep = r.json().get('response', r.json()) + endpoint_id = ep.get('endpoint_id') + assert endpoint_id schema = { 'validation_schema': { 'user.email': {'required': True, 'type': 'string', 'format': 'email'}, 'items': { - 'required': True, 'type': 'array', 'min': 1, + 'required': True, + 'type': 'array', + 'min': 1, 'array_items': { 'type': 'object', 'nested_schema': { 'id': {'required': True, 'type': 'string', 'format': 'uuid'}, - 'quantity': {'required': True, 'type': 'number', 'min': 1} - } - } - } + 'quantity': {'required': True, 'type': 'number', 'min': 1}, + }, + }, + }, } } - r = client.post('/platform/endpoint/endpoint/validation', json={ - 'endpoint_id': endpoint_id, 'validation_enabled': True, 'validation_schema': schema - }) + r = client.post( + '/platform/endpoint/endpoint/validation', + json={ + 'endpoint_id': endpoint_id, + 'validation_enabled': True, + 'validation_schema': schema, + }, + ) if r.status_code == 422: import pytest + pytest.skip('Validation schema shape not accepted by server (422)') assert r.status_code in (200, 201) - bad = { - 'user': {'email': 'not-an-email'}, - 'items': [{'id': '123', 'quantity': 0}] - } + bad = {'user': {'email': 'not-an-email'}, 'items': [{'id': '123', 'quantity': 0}]} r = client.post(f'/api/rest/{api_name}/{api_version}/submit', json=bad) assert r.status_code == 400 import uuid + ok = { 'user': {'email': 'u@example.com'}, - 'items': [{'id': str(uuid.uuid4()), 'quantity': 2}] + 'items': [{'id': str(uuid.uuid4()), 'quantity': 2}], } r = client.post(f'/api/rest/{api_name}/{api_version}/submit', json=ok) assert r.status_code == 200 diff --git a/backend-services/live-tests/test_90_security_tools_logging.py b/backend-services/live-tests/test_90_security_tools_logging.py index dd19477..9487d0d 100644 --- a/backend-services/live-tests/test_90_security_tools_logging.py +++ b/backend-services/live-tests/test_90_security_tools_logging.py @@ -10,21 +10,29 @@ def test_security_settings_get_put(client): updated = r.json().get('response', r.json()) assert bool(updated.get('enable_auto_save') or False) == desired + def test_tools_cors_check(client): - r = client.post('/platform/tools/cors/check', json={ - 'origin': 'http://localhost:3000', - 'method': 'GET', - 'request_headers': ['Content-Type'] - }) + r = client.post( + '/platform/tools/cors/check', + json={ + 'origin': 'http://localhost:3000', + 'method': 'GET', + 'request_headers': ['Content-Type'], + }, + ) assert r.status_code == 200 payload = r.json().get('response', r.json()) assert 'config' in payload and 'preflight' in payload + def test_clear_all_caches(client): r = client.delete('/api/caches') assert r.status_code == 200 body = r.json().get('response', r.json()) - assert 'All caches cleared' in (body.get('message') or body.get('error_message') or 'All caches cleared') + assert 'All caches cleared' in ( + body.get('message') or body.get('error_message') or 'All caches cleared' + ) + def test_logging_endpoints(client): r = client.get('/platform/logging/logs?limit=10') @@ -36,5 +44,8 @@ def test_logging_endpoints(client): assert r.status_code == 200 files = r.json().get('response', r.json()) assert 'count' in files + + import pytest + pytestmark = [pytest.mark.security, pytest.mark.tools, pytest.mark.logging] diff --git a/backend-services/live-tests/test_91_memory_dump_restore.py b/backend-services/live-tests/test_91_memory_dump_restore.py index 84a052e..9fe3b6a 100644 --- a/backend-services/live-tests/test_91_memory_dump_restore.py +++ b/backend-services/live-tests/test_91_memory_dump_restore.py @@ -1,10 +1,14 @@ import os + import pytest pytestmark = [pytest.mark.security] + def test_memory_dump_restore_conditionally(client): - mem_mode = os.environ.get('MEM_OR_EXTERNAL', os.environ.get('MEM_OR_REDIS', 'MEM')).upper() == 'MEM' + mem_mode = ( + os.environ.get('MEM_OR_EXTERNAL', os.environ.get('MEM_OR_REDIS', 'MEM')).upper() == 'MEM' + ) key = os.environ.get('MEM_ENCRYPTION_KEY') if not mem_mode or not key: pytest.skip('Memory dump/restore only in memory mode with MEM_ENCRYPTION_KEY set') diff --git a/backend-services/live-tests/test_92_authorization_flows.py b/backend-services/live-tests/test_92_authorization_flows.py index 187d2ab..19e571c 100644 --- a/backend-services/live-tests/test_92_authorization_flows.py +++ b/backend-services/live-tests/test_92_authorization_flows.py @@ -10,24 +10,32 @@ def test_token_refresh_and_invalidate(client): assert r.status_code == 401 from config import ADMIN_EMAIL, ADMIN_PASSWORD + client.login(ADMIN_EMAIL, ADMIN_PASSWORD) + def test_admin_revoke_tokens_for_user(client): - import time, random, string + import random + import time + ts = int(time.time()) - uname = f'revoke_{ts}_{random.randint(1000,9999)}' + uname = f'revoke_{ts}_{random.randint(1000, 9999)}' email = f'{uname}@example.com' pwd = 'Strong!Passw0rd1234' - r = client.post('/platform/user', json={ - 'username': uname, - 'email': email, - 'password': pwd, - 'role': 'admin', - 'groups': ['ALL'], - 'ui_access': True - }) + r = client.post( + '/platform/user', + json={ + 'username': uname, + 'email': email, + 'password': pwd, + 'role': 'admin', + 'groups': ['ALL'], + 'ui_access': True, + }, + ) assert r.status_code in (200, 201), r.text from client import LiveClient + user_client = LiveClient(client.base_url) user_client.login(email, pwd) r = client.post(f'/platform/authorization/admin/revoke/{uname}', json={}) @@ -35,5 +43,8 @@ def test_admin_revoke_tokens_for_user(client): r = user_client.get('/platform/user/me') assert r.status_code == 401 client.delete(f'/platform/user/{uname}') + + import pytest + pytestmark = [pytest.mark.auth] diff --git a/backend-services/live-tests/test_93_public_and_auth_optional.py b/backend-services/live-tests/test_93_public_and_auth_optional.py index 9327a41..6a296c6 100644 --- a/backend-services/live-tests/test_93_public_and_auth_optional.py +++ b/backend-services/live-tests/test_93_public_and_auth_optional.py @@ -1,31 +1,39 @@ import time -from servers import start_rest_echo_server + import requests +from servers import start_rest_echo_server + def test_public_api_no_auth_required(client): srv = start_rest_echo_server() try: api_name = f'public-{int(time.time())}' api_version = 'v1' - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'public', - 'api_allowed_roles': [], - 'api_allowed_groups': [], - 'api_servers': [srv.url], - 'api_type': 'REST', - 'active': True, - 'api_public': True - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'public', + 'api_allowed_roles': [], + 'api_allowed_groups': [], + 'api_servers': [srv.url], + 'api_type': 'REST', + 'active': True, + 'api_public': True, + }, + ) assert r.status_code in (200, 201) - r = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'GET', - 'endpoint_uri': '/status', - 'endpoint_description': 'status' - }) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/status', + 'endpoint_description': 'status', + }, + ) assert r.status_code in (200, 201) s = requests.Session() url = client.base_url.rstrip('/') + f'/api/rest/{api_name}/{api_version}/status' @@ -42,33 +50,41 @@ def test_public_api_no_auth_required(client): pass srv.stop() + def test_auth_not_required_but_not_public_allows_unauthenticated(client): srv = start_rest_echo_server() try: api_name = f'authopt-{int(time.time())}' api_version = 'v1' - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'auth optional', - 'api_allowed_roles': [], - 'api_allowed_groups': [], - 'api_servers': [srv.url], - 'api_type': 'REST', - 'active': True, - 'api_public': False, - 'api_auth_required': False - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'auth optional', + 'api_allowed_roles': [], + 'api_allowed_groups': [], + 'api_servers': [srv.url], + 'api_type': 'REST', + 'active': True, + 'api_public': False, + 'api_auth_required': False, + }, + ) assert r.status_code in (200, 201) - r = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'GET', - 'endpoint_uri': '/ping', - 'endpoint_description': 'ping' - }) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/ping', + 'endpoint_description': 'ping', + }, + ) assert r.status_code in (200, 201) import requests + s = requests.Session() url = client.base_url.rstrip('/') + f'/api/rest/{api_name}/{api_version}/ping' r = s.get(url) @@ -83,5 +99,8 @@ def test_auth_not_required_but_not_public_allows_unauthenticated(client): except Exception: pass srv.stop() + + import pytest + pytestmark = [pytest.mark.rest, pytest.mark.auth] diff --git a/backend-services/live-tests/test_94_routing_and_header_swap.py b/backend-services/live-tests/test_94_routing_and_header_swap.py index c17b4dc..f950dff 100644 --- a/backend-services/live-tests/test_94_routing_and_header_swap.py +++ b/backend-services/live-tests/test_94_routing_and_header_swap.py @@ -1,6 +1,8 @@ import time + from servers import start_rest_echo_server + def test_client_routing_overrides_api_servers(client): srv_a = start_rest_echo_server() srv_b = start_rest_echo_server() @@ -9,38 +11,52 @@ def test_client_routing_overrides_api_servers(client): api_version = 'v1' client_key = f'ck-{int(time.time())}' - r = client.post('/platform/routing', json={ - 'routing_name': 'test-routing', - 'routing_servers': [srv_b.url], - 'routing_description': 'test', - 'client_key': client_key, - 'server_index': 0 - }) + r = client.post( + '/platform/routing', + json={ + 'routing_name': 'test-routing', + 'routing_servers': [srv_b.url], + 'routing_description': 'test', + 'client_key': client_key, + 'server_index': 0, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'routing demo', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': [srv_a.url], - 'api_type': 'REST', - 'active': True - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'routing demo', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [srv_a.url], + 'api_type': 'REST', + 'active': True, + }, + ) assert r.status_code in (200, 201) - r = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'GET', - 'endpoint_uri': '/where', - 'endpoint_description': 'where' - }) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/where', + 'endpoint_description': 'where', + }, + ) assert r.status_code in (200, 201) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) - r = client.get(f'/api/rest/{api_name}/{api_version}/where', headers={'client-key': client_key}) + r = client.get( + f'/api/rest/{api_name}/{api_version}/where', headers={'client-key': client_key} + ) assert r.status_code == 200 data = r.json().get('response', r.json()) hdrs = {k.lower(): v for k, v in (data.get('headers') or {}).items()} @@ -55,10 +71,12 @@ def test_client_routing_overrides_api_servers(client): except Exception: pass try: - client.delete(f"/platform/routing/{client_key}") + client.delete(f'/platform/routing/{client_key}') except Exception: pass - srv_a.stop(); srv_b.stop() + srv_a.stop() + srv_b.stop() + def test_authorization_field_swap_sets_auth_header(client): srv = start_rest_echo_server() @@ -67,30 +85,41 @@ def test_authorization_field_swap_sets_auth_header(client): api_version = 'v1' swap_header = 'x-up-auth' token_value = 'Bearer SHHH_TOKEN' - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'auth swap', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': [srv.url], - 'api_type': 'REST', - 'active': True, - 'api_allowed_headers': [swap_header], - 'api_authorization_field_swap': swap_header - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'auth swap', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': [srv.url], + 'api_type': 'REST', + 'active': True, + 'api_allowed_headers': [swap_header], + 'api_authorization_field_swap': swap_header, + }, + ) assert r.status_code in (200, 201) - r = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'GET', - 'endpoint_uri': '/secure', - 'endpoint_description': 'secure' - }) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/secure', + 'endpoint_description': 'secure', + }, + ) assert r.status_code in (200, 201) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) - r = client.get(f'/api/rest/{api_name}/{api_version}/secure', headers={swap_header: token_value}) + r = client.get( + f'/api/rest/{api_name}/{api_version}/secure', headers={swap_header: token_value} + ) assert r.status_code == 200 data = r.json().get('response', r.json()) hdrs = {k.lower(): v for k, v in (data.get('headers') or {}).items()} @@ -105,5 +134,8 @@ def test_authorization_field_swap_sets_auth_header(client): except Exception: pass srv.stop() + + import pytest + pytestmark = [pytest.mark.routing, pytest.mark.gateway] diff --git a/backend-services/live-tests/test_95_api_export_import_roundtrip.py b/backend-services/live-tests/test_95_api_export_import_roundtrip.py index 141c226..362f9a5 100644 --- a/backend-services/live-tests/test_95_api_export_import_roundtrip.py +++ b/backend-services/live-tests/test_95_api_export_import_roundtrip.py @@ -1,30 +1,38 @@ import time + def test_single_api_export_import_roundtrip(client): api_name = f'cfg-{int(time.time())}' api_version = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'cfg demo', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://127.0.0.1:9'], - 'api_type': 'REST', - 'active': True - }) - client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'GET', - 'endpoint_uri': '/x', - 'endpoint_description': 'x' - }) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'cfg demo', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://127.0.0.1:9'], + 'api_type': 'REST', + 'active': True, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/x', + 'endpoint_description': 'x', + }, + ) r = client.get(f'/platform/config/export/apis?api_name={api_name}&api_version={api_version}') assert r.status_code == 200 payload = r.json().get('response', r.json()) - exported_api = payload.get('api'); exported_eps = payload.get('endpoints') + exported_api = payload.get('api') + exported_eps = payload.get('endpoints') assert exported_api and exported_api.get('api_name') == api_name assert any(ep.get('endpoint_uri') == '/x' for ep in (exported_eps or [])) @@ -39,5 +47,8 @@ def test_single_api_export_import_roundtrip(client): assert r.status_code == 200 r = client.get(f'/platform/endpoint/GET/{api_name}/{api_version}/x') assert r.status_code == 200 + + import pytest + pytestmark = [pytest.mark.config] diff --git a/backend-services/live-tests/test_95_chaos_backends.py b/backend-services/live-tests/test_95_chaos_backends.py index 4d7e9a2..0ad7e88 100644 --- a/backend-services/live-tests/test_95_chaos_backends.py +++ b/backend-services/live-tests/test_95_chaos_backends.py @@ -1,12 +1,17 @@ import time + import pytest + @pytest.mark.order(-10) def test_redis_outage_during_requests(client): r = client.get('/platform/authorization/status') assert r.status_code in (200, 204) - r = client.post('/platform/tools/chaos/toggle', json={'backend': 'redis', 'enabled': True, 'duration_ms': 1500}) + r = client.post( + '/platform/tools/chaos/toggle', + json={'backend': 'redis', 'enabled': True, 'duration_ms': 1500}, + ) assert r.status_code == 200 t0 = time.time() @@ -25,13 +30,17 @@ def test_redis_outage_during_requests(client): data = js.get('response', js) assert isinstance(data.get('error_budget_burn'), int) + @pytest.mark.order(-9) def test_mongo_outage_during_requests(client): - t0 = time.time() + time.time() r0 = client.get('/platform/user/me') assert r0.status_code in (200, 204) - r = client.post('/platform/tools/chaos/toggle', json={'backend': 'mongo', 'enabled': True, 'duration_ms': 1500}) + r = client.post( + '/platform/tools/chaos/toggle', + json={'backend': 'mongo', 'enabled': True, 'duration_ms': 1500}, + ) assert r.status_code == 200 t1 = time.time() @@ -49,4 +58,3 @@ def test_mongo_outage_during_requests(client): js = s.json() data = js.get('response', js) assert isinstance(data.get('error_budget_burn'), int) - diff --git a/backend-services/live-tests/test_96_config_import_export.py b/backend-services/live-tests/test_96_config_import_export.py index 7cf43cd..57da1f0 100644 --- a/backend-services/live-tests/test_96_config_import_export.py +++ b/backend-services/live-tests/test_96_config_import_export.py @@ -7,5 +7,8 @@ def test_config_export_import_roundtrip(client): assert r.status_code == 200 data = r.json().get('response', r.json()) assert 'imported' in data + + import pytest + pytestmark = [pytest.mark.config] diff --git a/backend-services/live-tests/test_97_monitor_and_readiness.py b/backend-services/live-tests/test_97_monitor_and_readiness.py index 6d5b980..d4a3e95 100644 --- a/backend-services/live-tests/test_97_monitor_and_readiness.py +++ b/backend-services/live-tests/test_97_monitor_and_readiness.py @@ -12,5 +12,8 @@ def test_monitor_endpoints(client): assert r.status_code == 200 metrics = r.json().get('response', r.json()) assert isinstance(metrics, dict) + + import pytest + pytestmark = [pytest.mark.monitor] diff --git a/backend-services/live-tests/test_98_cors_preflight.py b/backend-services/live-tests/test_98_cors_preflight.py index 8ae4ebf..f7a2a8e 100644 --- a/backend-services/live-tests/test_98_cors_preflight.py +++ b/backend-services/live-tests/test_98_cors_preflight.py @@ -1,38 +1,51 @@ def test_api_cors_preflight_and_response_headers(client): import time + api_name = f'cors-{int(time.time())}' api_version = 'v1' - r = client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': api_version, - 'api_description': 'cors test', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://127.0.0.1:9'], - 'api_type': 'REST', - 'active': True, - 'api_cors_allow_origins': ['http://example.com'], - 'api_cors_allow_methods': ['GET','POST'], - 'api_cors_allow_headers': ['Content-Type','X-CSRF-Token'], - 'api_cors_allow_credentials': True - }) + r = client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'cors test', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://127.0.0.1:9'], + 'api_type': 'REST', + 'active': True, + 'api_cors_allow_origins': ['http://example.com'], + 'api_cors_allow_methods': ['GET', 'POST'], + 'api_cors_allow_headers': ['Content-Type', 'X-CSRF-Token'], + 'api_cors_allow_credentials': True, + }, + ) assert r.status_code in (200, 201), r.text - r = client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'GET', - 'endpoint_uri': '/ok', - 'endpoint_description': 'ok' - }) + r = client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/ok', + 'endpoint_description': 'ok', + }, + ) assert r.status_code in (200, 201) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) path = f'/api/rest/{api_name}/{api_version}/ok' - r = client.options(path, headers={ - 'Origin': 'http://example.com', - 'Access-Control-Request-Method': 'GET', - 'Access-Control-Request-Headers': 'Content-Type' - }) + r = client.options( + path, + headers={ + 'Origin': 'http://example.com', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'Content-Type', + }, + ) assert r.status_code in (200, 204) acao = r.headers.get('Access-Control-Allow-Origin') assert acao in (None, 'http://example.com') or True @@ -43,5 +56,8 @@ def test_api_cors_preflight_and_response_headers(client): client.delete(f'/platform/endpoint/GET/{api_name}/{api_version}/ok') client.delete(f'/platform/api/{api_name}/{api_version}') + + import pytest + pytestmark = [pytest.mark.cors] diff --git a/backend-services/live-tests/test_99_cors_matrices.py b/backend-services/live-tests/test_99_cors_matrices.py index 1b68afe..3aa69f8 100644 --- a/backend-services/live-tests/test_99_cors_matrices.py +++ b/backend-services/live-tests/test_99_cors_matrices.py @@ -1,58 +1,116 @@ import time + import pytest pytestmark = [pytest.mark.cors] + def test_cors_wildcard_with_credentials_true_sets_origin(client): api_name = f'corsw-{int(time.time())}' api_version = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, 'api_version': api_version, 'api_description': 'cors wild', - 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST', 'active': True, - 'api_cors_allow_origins': ['*'], 'api_cors_allow_methods': ['GET','OPTIONS'], 'api_cors_allow_headers': ['Content-Type'], - 'api_cors_allow_credentials': True - }) - client.post('/platform/endpoint', json={ - 'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'GET', 'endpoint_uri': '/c', 'endpoint_description': 'c' - }) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'cors wild', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://127.0.0.1:9'], + 'api_type': 'REST', + 'active': True, + 'api_cors_allow_origins': ['*'], + 'api_cors_allow_methods': ['GET', 'OPTIONS'], + 'api_cors_allow_headers': ['Content-Type'], + 'api_cors_allow_credentials': True, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/c', + 'endpoint_description': 'c', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) path = f'/api/rest/{api_name}/{api_version}/c' - r = client.options(path, headers={ - 'Origin': 'http://foo.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'Content-Type' - }) + r = client.options( + path, + headers={ + 'Origin': 'http://foo.example', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'Content-Type', + }, + ) assert r.status_code in (200, 204) assert r.headers.get('Access-Control-Allow-Origin') in (None, 'http://foo.example') or True client.delete(f'/platform/endpoint/GET/{api_name}/{api_version}/c') client.delete(f'/platform/api/{api_name}/{api_version}') + def test_cors_specific_origin_and_headers(client): api_name = f'corss-{int(time.time())}' api_version = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, 'api_version': api_version, 'api_description': 'cors spec', - 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST', 'active': True, - 'api_cors_allow_origins': ['http://ok.example'], 'api_cors_allow_methods': ['GET','POST','OPTIONS'], - 'api_cors_allow_headers': ['Content-Type','X-CSRF-Token'], 'api_cors_allow_credentials': False - }) - client.post('/platform/endpoint', json={ - 'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'GET', 'endpoint_uri': '/d', 'endpoint_description': 'd' - }) - client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'api_description': 'cors spec', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://127.0.0.1:9'], + 'api_type': 'REST', + 'active': True, + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['GET', 'POST', 'OPTIONS'], + 'api_cors_allow_headers': ['Content-Type', 'X-CSRF-Token'], + 'api_cors_allow_credentials': False, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/d', + 'endpoint_description': 'd', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'}, + ) path = f'/api/rest/{api_name}/{api_version}/d' - r = client.options(path, headers={ - 'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-CSRF-Token' - }) + r = client.options( + path, + headers={ + 'Origin': 'http://ok.example', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'X-CSRF-Token', + }, + ) assert r.status_code in (200, 204) assert r.headers.get('Access-Control-Allow-Origin') in (None, 'http://ok.example') or True - r = client.options(path, headers={ - 'Origin': 'http://bad.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-CSRF-Token' - }) + r = client.options( + path, + headers={ + 'Origin': 'http://bad.example', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'X-CSRF-Token', + }, + ) assert r.status_code in (200, 204) client.delete(f'/platform/endpoint/GET/{api_name}/{api_version}/d') diff --git a/backend-services/live-tests/test_api_cors_headers_matrix_live.py b/backend-services/live-tests/test_api_cors_headers_matrix_live.py index afe2c37..5032069 100644 --- a/backend-services/live-tests/test_api_cors_headers_matrix_live.py +++ b/backend-services/live-tests/test_api_cors_headers_matrix_live.py @@ -1,32 +1,59 @@ import os + import pytest _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True') if not _RUN_LIVE: - pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable') + pytestmark = pytest.mark.skip( + reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable' + ) + def test_api_cors_allow_origins_allow_methods_headers_credentials_expose_live(client): import time + api_name = f'corslive-{int(time.time())}' ver = 'v1' - client.post('/platform/api', json={ - 'api_name': api_name, - 'api_version': ver, - 'api_description': 'cors live', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://upstream.example'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'api_cors_allow_origins': ['http://ok.example'], - 'api_cors_allow_methods': ['GET','POST'], - 'api_cors_allow_headers': ['Content-Type','X-CSRF-Token'], - 'api_cors_allow_credentials': True, - 'api_cors_expose_headers': ['X-Resp-Id'], - }) - client.post('/platform/endpoint', json={'api_name': api_name, 'api_version': ver, 'endpoint_method': 'GET', 'endpoint_uri': '/q', 'endpoint_description': 'q'}) - client.post('/platform/subscription/subscribe', json={'username': 'admin', 'api_name': api_name, 'api_version': ver}) - r = client.options(f'/api/rest/{api_name}/{ver}/q', headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-CSRF-Token'}) + client.post( + '/platform/api', + json={ + 'api_name': api_name, + 'api_version': ver, + 'api_description': 'cors live', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://upstream.example'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['GET', 'POST'], + 'api_cors_allow_headers': ['Content-Type', 'X-CSRF-Token'], + 'api_cors_allow_credentials': True, + 'api_cors_expose_headers': ['X-Resp-Id'], + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/q', + 'endpoint_description': 'q', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'username': 'admin', 'api_name': api_name, 'api_version': ver}, + ) + r = client.options( + f'/api/rest/{api_name}/{ver}/q', + headers={ + 'Origin': 'http://ok.example', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'X-CSRF-Token', + }, + ) assert r.status_code == 204 assert r.headers.get('Access-Control-Allow-Origin') == 'http://ok.example' assert 'GET' in (r.headers.get('Access-Control-Allow-Methods') or '') diff --git a/backend-services/live-tests/test_bandwidth_limit_live.py b/backend-services/live-tests/test_bandwidth_limit_live.py index 4c31883..2d91c44 100644 --- a/backend-services/live-tests/test_bandwidth_limit_live.py +++ b/backend-services/live-tests/test_bandwidth_limit_live.py @@ -1,30 +1,57 @@ import os + import pytest _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True') if not _RUN_LIVE: - pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable') + pytestmark = pytest.mark.skip( + reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable' + ) + def test_bandwidth_limit_enforced_and_window_resets_live(client): name, ver = 'bwlive', 'v1' - client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': 'bw live', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://up.example'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) - client.post('/platform/endpoint', json={'api_name': name, 'api_version': ver, 'endpoint_method': 'GET', 'endpoint_uri': '/p', 'endpoint_description': 'p'}) - client.post('/platform/subscription/subscribe', json={'username': 'admin', 'api_name': name, 'api_version': ver}) - client.put('/platform/user/admin', json={'bandwidth_limit_bytes': 1, 'bandwidth_limit_window': 'second', 'bandwidth_limit_enabled': True}) + client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'bw live', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up.example'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/p', + 'endpoint_description': 'p', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'username': 'admin', 'api_name': name, 'api_version': ver}, + ) + client.put( + '/platform/user/admin', + json={ + 'bandwidth_limit_bytes': 1, + 'bandwidth_limit_window': 'second', + 'bandwidth_limit_enabled': True, + }, + ) client.delete('/api/caches') r1 = client.get(f'/api/rest/{name}/{ver}/p') r2 = client.get(f'/api/rest/{name}/{ver}/p') assert r1.status_code == 200 and r2.status_code == 429 import time + time.sleep(1.1) r3 = client.get(f'/api/rest/{name}/{ver}/p') assert r3.status_code == 200 diff --git a/backend-services/live-tests/test_graphql_fallback_live.py b/backend-services/live-tests/test_graphql_fallback_live.py index 8beac52..24df5bb 100644 --- a/backend-services/live-tests/test_graphql_fallback_live.py +++ b/backend-services/live-tests/test_graphql_fallback_live.py @@ -1,81 +1,123 @@ import os + import pytest _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True') if not _RUN_LIVE: - pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable') + pytestmark = pytest.mark.skip( + reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable' + ) + async def _setup(client, name='gllive', ver='v1'): - await client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://gql.up'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'api_public': True, - }) - await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/graphql', - 'endpoint_description': 'gql' - }) + await client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://gql.up'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'api_public': True, + }, + ) + await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/graphql', + 'endpoint_description': 'gql', + }, + ) return name, ver + @pytest.mark.asyncio async def test_graphql_client_fallback_to_httpx_live(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = await _setup(authed_client, name='gll1') + class Dummy: pass + class FakeHTTPResp: def __init__(self, payload): self._p = payload + def json(self): return self._p + class H: def __init__(self, *args, **kwargs): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, json=None, headers=None): return FakeHTTPResp({'ok': True}) + monkeypatch.setattr(gs, 'Client', Dummy) monkeypatch.setattr(gs.httpx, 'AsyncClient', H) - r = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ ping }', 'variables': {}}) + r = await authed_client.post( + f'/api/graphql/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'query': '{ ping }', 'variables': {}}, + ) assert r.status_code == 200 and r.json().get('ok') is True + @pytest.mark.asyncio async def test_graphql_errors_live_strict_and_loose(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = await _setup(authed_client, name='gll2') + class Dummy: pass + class FakeHTTPResp: def __init__(self, payload): self._p = payload + def json(self): return self._p + class H: def __init__(self, *args, **kwargs): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, json=None, headers=None): return FakeHTTPResp({'errors': [{'message': 'boom'}]}) + monkeypatch.setattr(gs, 'Client', Dummy) monkeypatch.setattr(gs.httpx, 'AsyncClient', H) monkeypatch.delenv('STRICT_RESPONSE_ENVELOPE', raising=False) - r1 = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ err }', 'variables': {}}) + r1 = await authed_client.post( + f'/api/graphql/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'query': '{ err }', 'variables': {}}, + ) assert r1.status_code == 200 and isinstance(r1.json().get('errors'), list) monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true') - r2 = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ err }', 'variables': {}}) + r2 = await authed_client.post( + f'/api/graphql/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'query': '{ err }', 'variables': {}}, + ) assert r2.status_code == 200 and r2.json().get('status_code') == 200 diff --git a/backend-services/live-tests/test_grpc_pkg_override_live.py b/backend-services/live-tests/test_grpc_pkg_override_live.py index c7813ef..278dbc7 100644 --- a/backend-services/live-tests/test_grpc_pkg_override_live.py +++ b/backend-services/live-tests/test_grpc_pkg_override_live.py @@ -1,24 +1,33 @@ import os + import pytest _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True') if not _RUN_LIVE: - pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable') + pytestmark = pytest.mark.skip( + reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable' + ) + def _fake_pb2_module(method_name='M'): class Req: pass + class Reply: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() + def __init__(self, ok=True): self.ok = ok + @staticmethod def FromString(b): return Reply(True) - setattr(Req, '__name__', f'{method_name}Request') - setattr(Reply, '__name__', f'{method_name}Reply') + + Req.__name__ = f'{method_name}Request' + Reply.__name__ = f'{method_name}Reply' return Req, Reply + def _make_import_module_recorder(record, pb2_map): def _imp(name): record.append(name) @@ -27,31 +36,47 @@ def _make_import_module_recorder(record, pb2_map): mapping = pb2_map.get(name) if mapping is None: req_cls, rep_cls = _fake_pb2_module('M') - setattr(mod, 'MRequest', req_cls) - setattr(mod, 'MReply', rep_cls) + mod.MRequest = req_cls + mod.MReply = rep_cls else: req_cls, rep_cls = mapping if req_cls: - setattr(mod, 'MRequest', req_cls) + mod.MRequest = req_cls if rep_cls: - setattr(mod, 'MReply', rep_cls) + mod.MReply = rep_cls return mod if name.endswith('_pb2_grpc'): + class Stub: def __init__(self, ch): self._ch = ch + async def M(self, req): - return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})() + return type( + 'R', + (), + { + 'DESCRIPTOR': type( + 'D', (), {'fields': [type('F', (), {'name': 'ok'})()]} + )(), + 'ok': True, + }, + )() + mod = type('SVC', (), {'SvcStub': Stub}) return mod raise ImportError(name) + return _imp + def _make_fake_grpc_unary(sequence_codes, grpc_mod): counter = {'i': 0} + class AioChan: async def channel_ready(self): return True + class Chan(AioChan): def unary_unary(self, method, request_serializer=None, response_deserializer=None): async def _call(req): @@ -59,143 +84,214 @@ def _make_fake_grpc_unary(sequence_codes, grpc_mod): code = sequence_codes[idx] counter['i'] += 1 if code is None: - return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})() + return type( + 'R', + (), + { + 'DESCRIPTOR': type( + 'D', (), {'fields': [type('F', (), {'name': 'ok'})()]} + )(), + 'ok': True, + }, + )() + class E(Exception): def code(self): return code + def details(self): return 'err' + raise E() + return _call + class aio: @staticmethod def insecure_channel(url): return Chan() + fake = type('G', (), {'aio': aio, 'StatusCode': grpc_mod.StatusCode, 'RpcError': Exception}) return fake + @pytest.mark.asyncio async def test_grpc_with_api_grpc_package_config(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gplive1', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': 'g', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['grpc://127.0.0.1:9'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'api_grpc_package': 'api.pkg' - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc' - }) + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'g', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['grpc://127.0.0.1:9'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'api_grpc_package': 'api.pkg', + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) record = [] req_cls, rep_cls = _fake_pb2_module('M') pb2_map = {'api.pkg_pb2': (req_cls, rep_cls)} - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)) + monkeypatch.setattr( + gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map) + ) monkeypatch.setattr(gs.os.path, 'exists', lambda p: True) monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc)) - r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'}) + r = await authed_client.post( + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'}, + ) assert r.status_code == 200 assert any(n == 'api.pkg_pb2' for n in record) + @pytest.mark.asyncio async def test_grpc_with_request_package_override(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gplive2', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': 'g', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['grpc://127.0.0.1:9'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc' - }) + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'g', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['grpc://127.0.0.1:9'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) record = [] req_cls, rep_cls = _fake_pb2_module('M') pb2_map = {'req.pkg_pb2': (req_cls, rep_cls)} - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)) + monkeypatch.setattr( + gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map) + ) monkeypatch.setattr(gs.os.path, 'exists', lambda p: True) monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc)) - r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'}) + r = await authed_client.post( + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'}, + ) assert r.status_code == 200 assert any(n == 'req.pkg_pb2' for n in record) + @pytest.mark.asyncio async def test_grpc_without_package_server_uses_fallback_path(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gplive3', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': 'g', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['grpc://127.0.0.1:9'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc' - }) + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'g', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['grpc://127.0.0.1:9'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) record = [] req_cls, rep_cls = _fake_pb2_module('M') default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2' pb2_map = {default_pkg: (req_cls, rep_cls)} - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)) + monkeypatch.setattr( + gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map) + ) monkeypatch.setattr(gs.os.path, 'exists', lambda p: True) monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc)) - r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}) + r = await authed_client.post( + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, + ) assert r.status_code == 200 assert any(n.endswith(default_pkg) for n in record) + @pytest.mark.asyncio async def test_grpc_unavailable_then_success_with_retry_live(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gplive4', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': 'g', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['grpc://127.0.0.1:9'], - 'api_type': 'REST', - 'api_allowed_retry_count': 1, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc' - }) + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'g', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['grpc://127.0.0.1:9'], + 'api_type': 'REST', + 'api_allowed_retry_count': 1, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) record = [] req_cls, rep_cls = _fake_pb2_module('M') default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2' pb2_map = {default_pkg: (req_cls, rep_cls)} - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)) + monkeypatch.setattr( + gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map) + ) fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNAVAILABLE, None], gs.grpc) monkeypatch.setattr(gs, 'grpc', fake_grpc) - r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}) + r = await authed_client.post( + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, + ) assert r.status_code == 200 diff --git a/backend-services/live-tests/test_memory_dump_sigusr1_live.py b/backend-services/live-tests/test_memory_dump_sigusr1_live.py index f992f3c..716a8bb 100644 --- a/backend-services/live-tests/test_memory_dump_sigusr1_live.py +++ b/backend-services/live-tests/test_memory_dump_sigusr1_live.py @@ -1,16 +1,22 @@ -import pytest import os import platform +import pytest + _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True') if not _RUN_LIVE: - pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable') + pytestmark = pytest.mark.skip( + reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable' + ) + @pytest.mark.skipif(platform.system() == 'Windows', reason='SIGUSR1 not available on Windows') def test_sigusr1_dump_in_memory_mode_live(client, monkeypatch, tmp_path): monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'live-secret-xyz') monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'live' / 'memory_dump.bin')) - import signal, time + import signal + import time + os.kill(os.getpid(), signal.SIGUSR1) time.sleep(0.5) assert True diff --git a/backend-services/live-tests/test_platform_cors_env_edges_live.py b/backend-services/live-tests/test_platform_cors_env_edges_live.py index 454b6e0..78497df 100644 --- a/backend-services/live-tests/test_platform_cors_env_edges_live.py +++ b/backend-services/live-tests/test_platform_cors_env_edges_live.py @@ -1,22 +1,41 @@ import os + import pytest _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True') if not _RUN_LIVE: - pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable') + pytestmark = pytest.mark.skip( + reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable' + ) + def test_platform_cors_strict_wildcard_credentials_edges_live(client, monkeypatch): monkeypatch.setenv('ALLOWED_ORIGINS', '*') monkeypatch.setenv('ALLOW_CREDENTIALS', 'true') monkeypatch.setenv('CORS_STRICT', 'true') - r = client.options('/platform/api', headers={'Origin': 'http://evil.example', 'Access-Control-Request-Method': 'GET'}) + r = client.options( + '/platform/api', + headers={'Origin': 'http://evil.example', 'Access-Control-Request-Method': 'GET'}, + ) assert r.status_code == 204 assert r.headers.get('Access-Control-Allow-Origin') in (None, '') + def test_platform_cors_methods_headers_defaults_live(client, monkeypatch): monkeypatch.setenv('ALLOW_METHODS', '') monkeypatch.setenv('ALLOW_HEADERS', '*') - r = client.options('/platform/api', headers={'Origin': 'http://localhost:3000', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-Rand'}) + r = client.options( + '/platform/api', + headers={ + 'Origin': 'http://localhost:3000', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'X-Rand', + }, + ) assert r.status_code == 204 - methods = [m.strip() for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',') if m.strip()] - assert set(methods) == {'GET','POST','PUT','DELETE','OPTIONS','PATCH','HEAD'} + methods = [ + m.strip() + for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',') + if m.strip() + ] + assert set(methods) == {'GET', 'POST', 'PUT', 'DELETE', 'OPTIONS', 'PATCH', 'HEAD'} diff --git a/backend-services/live-tests/test_rest_header_forwarding_live.py b/backend-services/live-tests/test_rest_header_forwarding_live.py index 81e55b9..bf52a07 100644 --- a/backend-services/live-tests/test_rest_header_forwarding_live.py +++ b/backend-services/live-tests/test_rest_header_forwarding_live.py @@ -1,14 +1,20 @@ import os + import pytest _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True') if not _RUN_LIVE: - pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable') + pytestmark = pytest.mark.skip( + reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable' + ) + @pytest.mark.asyncio async def test_forward_allowed_headers_only(monkeypatch, authed_client): - from conftest import create_api, create_endpoint, subscribe_self + from conftest import create_endpoint, subscribe_self + import services.gateway_service as gs + name, ver = 'hforw', 'v1' payload = { 'api_name': name, @@ -19,7 +25,7 @@ async def test_forward_allowed_headers_only(monkeypatch, authed_client): 'api_servers': ['http://up'], 'api_type': 'REST', 'api_allowed_retry_count': 0, - 'api_allowed_headers': ['x-allowed', 'content-type'] + 'api_allowed_headers': ['x-allowed', 'content-type'], } await authed_client.post('/platform/api', json=payload) await create_endpoint(authed_client, name, ver, 'GET', '/p') @@ -31,26 +37,37 @@ async def test_forward_allowed_headers_only(monkeypatch, authed_client): self._p = {'ok': True} self.headers = {'Content-Type': 'application/json'} self.text = '' + def json(self): return self._p + captured = {} + class CapClient: async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def get(self, url, params=None, headers=None): captured['headers'] = headers or {} return Resp() + monkeypatch.setattr(gs.httpx, 'AsyncClient', CapClient) - await authed_client.get(f'/api/rest/{name}/{ver}/p', headers={'X-Allowed': 'yes', 'X-Blocked': 'no'}) + await authed_client.get( + f'/api/rest/{name}/{ver}/p', headers={'X-Allowed': 'yes', 'X-Blocked': 'no'} + ) ch = {k.lower(): v for k, v in (captured.get('headers') or {}).items()} assert 'x-allowed' in ch and 'x-blocked' not in ch + @pytest.mark.asyncio async def test_response_headers_filtered_by_allowlist(monkeypatch, authed_client): - from conftest import create_api, create_endpoint, subscribe_self + from conftest import create_endpoint, subscribe_self + import services.gateway_service as gs + name, ver = 'hresp', 'v1' payload = { 'api_name': name, @@ -61,7 +78,7 @@ async def test_response_headers_filtered_by_allowlist(monkeypatch, authed_client 'api_servers': ['http://up'], 'api_type': 'REST', 'api_allowed_retry_count': 0, - 'api_allowed_headers': ['x-upstream'] + 'api_allowed_headers': ['x-upstream'], } await authed_client.post('/platform/api', json=payload) await create_endpoint(authed_client, name, ver, 'GET', '/p') @@ -71,17 +88,26 @@ async def test_response_headers_filtered_by_allowlist(monkeypatch, authed_client def __init__(self): self.status_code = 200 self._p = {'ok': True} - self.headers = {'Content-Type': 'application/json', 'X-Upstream': 'yes', 'X-Secret': 'no'} + self.headers = { + 'Content-Type': 'application/json', + 'X-Upstream': 'yes', + 'X-Secret': 'no', + } self.text = '' + def json(self): return self._p + class HC: async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def get(self, url, params=None, headers=None): return Resp() + monkeypatch.setattr(gs.httpx, 'AsyncClient', HC) r = await authed_client.get(f'/api/rest/{name}/{ver}/p') assert r.status_code == 200 diff --git a/backend-services/live-tests/test_rest_retries_live.py b/backend-services/live-tests/test_rest_retries_live.py index d390a0b..18f135a 100644 --- a/backend-services/live-tests/test_rest_retries_live.py +++ b/backend-services/live-tests/test_rest_retries_live.py @@ -1,23 +1,30 @@ import os + import pytest _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True') if not _RUN_LIVE: - pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable') + pytestmark = pytest.mark.skip( + reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable' + ) -from tests.test_gateway_routing_limits import _FakeAsyncClient @pytest.mark.asyncio async def test_rest_retries_on_500_then_success(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + import services.gateway_service as gs + name, ver = 'rlive500', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/r') await subscribe_self(authed_client, name, ver) from utils.database import api_collection - api_collection.update_one({'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}}) + + api_collection.update_one( + {'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}} + ) await authed_client.delete('/api/caches') class Resp: @@ -26,30 +33,42 @@ async def test_rest_retries_on_500_then_success(monkeypatch, authed_client): self._json = body or {} self.text = '' self.headers = headers or {'Content-Type': 'application/json'} + def json(self): return self._json + seq = [Resp(500), Resp(200, {'ok': True})] + class SeqClient: async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def get(self, url, params=None, headers=None): return seq.pop(0) + monkeypatch.setattr(gs.httpx, 'AsyncClient', SeqClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/r') assert r.status_code == 200 and r.json().get('ok') is True + @pytest.mark.asyncio async def test_rest_retries_on_503_then_success(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + import services.gateway_service as gs + name, ver = 'rlive503', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/r') await subscribe_self(authed_client, name, ver) from utils.database import api_collection - api_collection.update_one({'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}}) + + api_collection.update_one( + {'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}} + ) await authed_client.delete('/api/caches') class Resp: @@ -57,43 +76,58 @@ async def test_rest_retries_on_503_then_success(monkeypatch, authed_client): self.status_code = status self.headers = {'Content-Type': 'application/json'} self.text = '' + def json(self): return {} + seq = [Resp(503), Resp(200)] + class SeqClient: async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def get(self, url, params=None, headers=None): return seq.pop(0) + monkeypatch.setattr(gs.httpx, 'AsyncClient', SeqClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/r') assert r.status_code == 200 + @pytest.mark.asyncio async def test_rest_no_retry_when_retry_count_zero(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + import services.gateway_service as gs + name, ver = 'rlivez0', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/r') await subscribe_self(authed_client, name, ver) await authed_client.delete('/api/caches') + class Resp: def __init__(self, status): self.status_code = status self.headers = {'Content-Type': 'application/json'} self.text = '' + def json(self): return {} + class OneClient: async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def get(self, url, params=None, headers=None): return Resp(500) + monkeypatch.setattr(gs.httpx, 'AsyncClient', OneClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/r') assert r.status_code == 500 diff --git a/backend-services/live-tests/test_soap_content_type_and_retries_live.py b/backend-services/live-tests/test_soap_content_type_and_retries_live.py index ba97b98..2c387b3 100644 --- a/backend-services/live-tests/test_soap_content_type_and_retries_live.py +++ b/backend-services/live-tests/test_soap_content_type_and_retries_live.py @@ -1,14 +1,20 @@ import os + import pytest _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True') if not _RUN_LIVE: - pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable') + pytestmark = pytest.mark.skip( + reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable' + ) + @pytest.mark.asyncio async def test_soap_content_types_matrix(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + import services.gateway_service as gs + name, ver = 'soapct', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/s') @@ -19,30 +25,43 @@ async def test_soap_content_types_matrix(monkeypatch, authed_client): self.status_code = 200 self.headers = {'Content-Type': 'application/xml'} self.text = '' + def json(self): return {'ok': True} + class HC: async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, json=None, params=None, headers=None, content=None): return Resp() + monkeypatch.setattr(gs.httpx, 'AsyncClient', HC) for ct in ['application/xml', 'text/xml']: - r = await authed_client.post(f'/api/soap/{name}/{ver}/s', headers={'Content-Type': ct}, content='') + r = await authed_client.post( + f'/api/soap/{name}/{ver}/s', headers={'Content-Type': ct}, content='' + ) assert r.status_code == 200 + @pytest.mark.asyncio async def test_soap_retries_then_success(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + import services.gateway_service as gs + name, ver = 'soaprt', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/s') await subscribe_self(authed_client, name, ver) from utils.database import api_collection - api_collection.update_one({'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}}) + + api_collection.update_one( + {'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}} + ) await authed_client.delete('/api/caches') class Resp: @@ -50,16 +69,24 @@ async def test_soap_retries_then_success(monkeypatch, authed_client): self.status_code = status self.headers = {'Content-Type': 'application/xml'} self.text = '' + def json(self): return {'ok': True} + seq = [Resp(503), Resp(200)] + class HC: async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, json=None, params=None, headers=None, content=None): return seq.pop(0) + monkeypatch.setattr(gs.httpx, 'AsyncClient', HC) - r = await authed_client.post(f'/api/soap/{name}/{ver}/s', headers={'Content-Type': 'application/xml'}, content='') + r = await authed_client.post( + f'/api/soap/{name}/{ver}/s', headers={'Content-Type': 'application/xml'}, content='' + ) assert r.status_code == 200 diff --git a/backend-services/live-tests/test_throttle_queue_and_wait_live.py b/backend-services/live-tests/test_throttle_queue_and_wait_live.py index 3d4454c..64bd7e3 100644 --- a/backend-services/live-tests/test_throttle_queue_and_wait_live.py +++ b/backend-services/live-tests/test_throttle_queue_and_wait_live.py @@ -1,68 +1,94 @@ import os + import pytest _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True') if not _RUN_LIVE: - pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable') + pytestmark = pytest.mark.skip( + reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable' + ) + def test_throttle_queue_limit_exceeded_429_live(client): - from config import ADMIN_EMAIL name, ver = 'throtq', 'v1' - client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': 'live throttle', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://up.example'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) - client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'GET', - 'endpoint_uri': '/t', - 'endpoint_description': 't' - }) - client.post('/platform/subscription/subscribe', json={'username': 'admin', 'api_name': name, 'api_version': ver}) + client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'live throttle', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up.example'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/t', + 'endpoint_description': 't', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'username': 'admin', 'api_name': name, 'api_version': ver}, + ) client.put('/platform/user/admin', json={'throttle_queue_limit': 1}) client.delete('/api/caches') - r1 = client.get(f'/api/rest/{name}/{ver}/t') + client.get(f'/api/rest/{name}/{ver}/t') r2 = client.get(f'/api/rest/{name}/{ver}/t') assert r2.status_code == 429 + def test_throttle_dynamic_wait_live(client): name, ver = 'throtw', 'v1' - client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': 'live throttle wait', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://up.example'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) - client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'GET', - 'endpoint_uri': '/w', - 'endpoint_description': 'w' - }) - client.post('/platform/subscription/subscribe', json={'username': 'admin', 'api_name': name, 'api_version': ver}) - client.put('/platform/user/admin', json={ - 'throttle_duration': 1, - 'throttle_duration_type': 'second', - 'throttle_queue_limit': 10, - 'throttle_wait_duration': 0.1, - 'throttle_wait_duration_type': 'second', - 'rate_limit_duration': 1000, - 'rate_limit_duration_type': 'second', - }) + client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'live throttle wait', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up.example'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) + client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/w', + 'endpoint_description': 'w', + }, + ) + client.post( + '/platform/subscription/subscribe', + json={'username': 'admin', 'api_name': name, 'api_version': ver}, + ) + client.put( + '/platform/user/admin', + json={ + 'throttle_duration': 1, + 'throttle_duration_type': 'second', + 'throttle_queue_limit': 10, + 'throttle_wait_duration': 0.1, + 'throttle_wait_duration_type': 'second', + 'rate_limit_duration': 1000, + 'rate_limit_duration_type': 'second', + }, + ) client.delete('/api/caches') import time + t0 = time.perf_counter() r1 = client.get(f'/api/rest/{name}/{ver}/w') t1 = time.perf_counter() diff --git a/backend-services/middleware/analytics_middleware.py b/backend-services/middleware/analytics_middleware.py index 734e884..e95798d 100644 --- a/backend-services/middleware/analytics_middleware.py +++ b/backend-services/middleware/analytics_middleware.py @@ -5,12 +5,13 @@ Automatically records detailed metrics for every request passing through the gateway, including per-endpoint tracking and full performance data. """ -import time import logging +import time +from collections.abc import Callable + from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp -from typing import Callable from utils.enhanced_metrics_util import enhanced_metrics_store @@ -20,7 +21,7 @@ logger = logging.getLogger('doorman.analytics') class AnalyticsMiddleware(BaseHTTPMiddleware): """ Middleware to capture comprehensive request/response metrics. - + Records: - Response time - Status code @@ -29,21 +30,21 @@ class AnalyticsMiddleware(BaseHTTPMiddleware): - Endpoint URI and method - Request/response sizes """ - + def __init__(self, app: ASGIApp): super().__init__(app) - + async def dispatch(self, request: Request, call_next: Callable) -> Response: """ Process request and record metrics. """ # Start timing start_time = time.time() - + # Extract request metadata method = request.method path = str(request.url.path) - + # Estimate request size (headers + body) request_size = 0 try: @@ -54,16 +55,16 @@ class AnalyticsMiddleware(BaseHTTPMiddleware): request_size += int(request.headers['content-length']) except Exception: pass - + # Process request response = await call_next(request) - + # Calculate duration duration_ms = (time.time() - start_time) * 1000 - + # Extract response metadata status_code = response.status_code - + # Estimate response size response_size = 0 try: @@ -74,18 +75,20 @@ class AnalyticsMiddleware(BaseHTTPMiddleware): response_size += int(response.headers['content-length']) except Exception: pass - + # Extract user from request state (set by auth middleware) username = None try: if hasattr(request.state, 'user'): - username = request.state.user.get('sub') if isinstance(request.state.user, dict) else None + username = ( + request.state.user.get('sub') if isinstance(request.state.user, dict) else None + ) except Exception: pass - + # Parse API and endpoint from path api_key, endpoint_uri = self._parse_api_endpoint(path) - + # Record metrics try: enhanced_metrics_store.record( @@ -96,17 +99,17 @@ class AnalyticsMiddleware(BaseHTTPMiddleware): endpoint_uri=endpoint_uri, method=method, bytes_in=request_size, - bytes_out=response_size + bytes_out=response_size, ) except Exception as e: - logger.error(f"Failed to record analytics: {str(e)}") - + logger.error(f'Failed to record analytics: {str(e)}') + return response - + def _parse_api_endpoint(self, path: str) -> tuple[str | None, str | None]: """ Parse API key and endpoint URI from request path. - + Examples: - /api/rest/customer/v1/users -> ("rest:customer", "/customer/v1/users") - /platform/analytics/overview -> (None, "/platform/analytics/overview") @@ -118,34 +121,34 @@ class AnalyticsMiddleware(BaseHTTPMiddleware): parts = path.split('/') if len(parts) >= 5: api_name = parts[3] - api_version = parts[4] + parts[4] endpoint_uri = '/' + '/'.join(parts[3:]) - return f"rest:{api_name}", endpoint_uri - + return f'rest:{api_name}', endpoint_uri + elif path.startswith('/api/graphql/'): # GraphQL API parts = path.split('/') if len(parts) >= 4: api_name = parts[3] - return f"graphql:{api_name}", path - + return f'graphql:{api_name}', path + elif path.startswith('/api/soap/'): # SOAP API parts = path.split('/') if len(parts) >= 4: api_name = parts[3] - return f"soap:{api_name}", path - + return f'soap:{api_name}', path + elif path.startswith('/api/grpc/'): # gRPC API parts = path.split('/') if len(parts) >= 4: api_name = parts[3] - return f"grpc:{api_name}", path - + return f'grpc:{api_name}', path + # Platform endpoints (not API requests) return None, path - + except Exception: return None, path @@ -153,8 +156,8 @@ class AnalyticsMiddleware(BaseHTTPMiddleware): def setup_analytics_middleware(app): """ Add analytics middleware to FastAPI app. - + Should be called during app initialization. """ app.add_middleware(AnalyticsMiddleware) - logger.info("Analytics middleware initialized") + logger.info('Analytics middleware initialized') diff --git a/backend-services/middleware/rate_limit_middleware.py b/backend-services/middleware/rate_limit_middleware.py index 2d9668d..6202a6a 100644 --- a/backend-services/middleware/rate_limit_middleware.py +++ b/backend-services/middleware/rate_limit_middleware.py @@ -6,16 +6,16 @@ Checks rate limits and quotas, adds headers, and returns 429 when exceeded. """ import logging -import time -from typing import Optional, List, Callable +from collections.abc import Callable + from fastapi import Request, Response, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp -from models.rate_limit_models import RateLimitRule, RuleType, TimeWindow, TierLimits -from utils.rate_limiter import get_rate_limiter, RateLimiter -from utils.quota_tracker import get_quota_tracker, QuotaTracker, QuotaType +from models.rate_limit_models import RateLimitRule, RuleType, TierLimits, TimeWindow +from utils.quota_tracker import QuotaTracker, QuotaType, get_quota_tracker +from utils.rate_limiter import RateLimiter, get_rate_limiter logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) class RateLimitMiddleware(BaseHTTPMiddleware): """ Middleware for rate limiting requests - + Features: - Applies rate limit rules based on user, API, endpoint, IP - Checks quotas (monthly, daily) @@ -31,18 +31,18 @@ class RateLimitMiddleware(BaseHTTPMiddleware): - Returns 429 Too Many Requests when limits exceeded - Supports tier-based limits """ - + def __init__( self, app: ASGIApp, - rate_limiter: Optional[RateLimiter] = None, - quota_tracker: Optional[QuotaTracker] = None, - get_rules_func: Optional[Callable] = None, - get_user_tier_func: Optional[Callable] = None + rate_limiter: RateLimiter | None = None, + quota_tracker: QuotaTracker | None = None, + get_rules_func: Callable | None = None, + get_user_tier_func: Callable | None = None, ): """ Initialize rate limit middleware - + Args: app: FastAPI application rate_limiter: Rate limiter instance @@ -55,134 +55,134 @@ class RateLimitMiddleware(BaseHTTPMiddleware): self.quota_tracker = quota_tracker or get_quota_tracker() self.get_rules_func = get_rules_func or self._default_get_rules self.get_user_tier_func = get_user_tier_func or self._default_get_user_tier - + async def dispatch(self, request: Request, call_next): """ Process request through rate limiting - + Args: request: Incoming request call_next: Next middleware/handler - + Returns: Response (possibly 429 if rate limited) """ # Skip rate limiting for certain paths if self._should_skip(request): return await call_next(request) - + # Extract identifiers user_id = self._get_user_id(request) api_name = self._get_api_name(request) endpoint_uri = str(request.url.path) ip_address = self._get_client_ip(request) - + # Get applicable rules rules = await self.get_rules_func(request, user_id, api_name, endpoint_uri, ip_address) - + # Check rate limits for rule in rules: identifier = self._get_identifier(rule, user_id, api_name, endpoint_uri, ip_address) - + if identifier: result = self.rate_limiter.check_rate_limit(rule, identifier) - + if not result.allowed: # Rate limit exceeded return self._create_rate_limit_response(result, rule) - + # Check quotas if user identified if user_id: tier_limits = await self.get_user_tier_func(user_id) - + if tier_limits: quota_result = await self._check_quotas(user_id, tier_limits) - + if not quota_result.allowed: return self._create_quota_exceeded_response(quota_result) - + # Process request response = await call_next(request) - + # Add rate limit headers if rules: # Use first rule for headers (highest priority) rule = rules[0] identifier = self._get_identifier(rule, user_id, api_name, endpoint_uri, ip_address) - + if identifier: usage = self.rate_limiter.get_current_usage(rule, identifier) self._add_rate_limit_headers(response, usage.limit, usage.remaining, usage.reset_at) - + # Increment quota (async, don't block response) if user_id: try: self.quota_tracker.increment_quota(user_id, QuotaType.REQUESTS, 1, 'month') except Exception as e: - logger.error(f"Error incrementing quota: {e}") - + logger.error(f'Error incrementing quota: {e}') + return response - + def _should_skip(self, request: Request) -> bool: """ Check if rate limiting should be skipped for this request - + Args: request: Incoming request - + Returns: True if should skip """ # Skip health checks, metrics, etc. skip_paths = ['/health', '/metrics', '/docs', '/redoc', '/openapi.json'] - + return any(request.url.path.startswith(path) for path in skip_paths) - - def _get_user_id(self, request: Request) -> Optional[str]: + + def _get_user_id(self, request: Request) -> str | None: """ Extract user ID from request - + Args: request: Incoming request - + Returns: User ID or None """ # Try to get from request state (set by auth middleware) if hasattr(request.state, 'user'): return getattr(request.state.user, 'username', None) - + # Try to get from headers return request.headers.get('X-User-ID') - - def _get_api_name(self, request: Request) -> Optional[str]: + + def _get_api_name(self, request: Request) -> str | None: """ Extract API name from request - + Args: request: Incoming request - + Returns: API name or None """ # Try to get from request state (set by routing) if hasattr(request.state, 'api_name'): return request.state.api_name - + # Try to extract from path path_parts = request.url.path.strip('/').split('/') if len(path_parts) > 0: return path_parts[0] - + return None - + def _get_client_ip(self, request: Request) -> str: """ Extract client IP address from request - + Args: request: Incoming request - + Returns: IP address """ @@ -191,36 +191,36 @@ class RateLimitMiddleware(BaseHTTPMiddleware): if forwarded_for: # Take first IP (original client) return forwarded_for.split(',')[0].strip() - + # Check X-Real-IP header real_ip = request.headers.get('X-Real-IP') if real_ip: return real_ip - + # Fall back to direct connection if request.client: return request.client.host - + return 'unknown' - + def _get_identifier( self, rule: RateLimitRule, - user_id: Optional[str], - api_name: Optional[str], + user_id: str | None, + api_name: str | None, endpoint_uri: str, - ip_address: str - ) -> Optional[str]: + ip_address: str, + ) -> str | None: """ Get identifier for rate limit rule - + Args: rule: Rate limit rule user_id: User ID api_name: API name endpoint_uri: Endpoint URI ip_address: IP address - + Returns: Identifier string or None """ @@ -233,188 +233,183 @@ class RateLimitMiddleware(BaseHTTPMiddleware): elif rule.rule_type == RuleType.PER_IP: return ip_address elif rule.rule_type == RuleType.PER_USER_API: - return f"{user_id}:{api_name}" if user_id and api_name else None + return f'{user_id}:{api_name}' if user_id and api_name else None elif rule.rule_type == RuleType.PER_USER_ENDPOINT: - return f"{user_id}:{endpoint_uri}" if user_id else None + return f'{user_id}:{endpoint_uri}' if user_id else None elif rule.rule_type == RuleType.GLOBAL: return 'global' - + return None - + async def _default_get_rules( self, request: Request, - user_id: Optional[str], - api_name: Optional[str], + user_id: str | None, + api_name: str | None, endpoint_uri: str, - ip_address: str - ) -> List[RateLimitRule]: + ip_address: str, + ) -> list[RateLimitRule]: """ Default function to get applicable rules - + Priority order: 1. If user has tier assigned → Use tier limits ONLY 2. If user has NO tier → Use per-user rate limit rules 3. Fall back to global rules - + In production, this should query MongoDB for rules. This is a placeholder that returns default rules. - + Args: request: Incoming request user_id: User ID api_name: API name endpoint_uri: Endpoint URI ip_address: IP address - + Returns: List of applicable rules """ # TODO: Query MongoDB for rules # For now, return default rules - + rules = [] - + # Check if user has a tier assigned user_tier = None if user_id: user_tier = await self.get_user_tier_func(user_id) - + if user_tier: # User has tier → Use tier limits ONLY (priority) # Convert tier limits to rate limit rules if user_tier.requests_per_minute: - rules.append(RateLimitRule( - rule_id=f'tier_{user_id}', - rule_type=RuleType.PER_USER, - time_window=TimeWindow.MINUTE, - limit=user_tier.requests_per_minute, - burst_allowance=user_tier.burst_allowance or 0, - priority=100, # Highest priority - enabled=True, - description=f"Tier-based limit for {user_id}" - )) + rules.append( + RateLimitRule( + rule_id=f'tier_{user_id}', + rule_type=RuleType.PER_USER, + time_window=TimeWindow.MINUTE, + limit=user_tier.requests_per_minute, + burst_allowance=user_tier.burst_allowance or 0, + priority=100, # Highest priority + enabled=True, + description=f'Tier-based limit for {user_id}', + ) + ) else: # User has NO tier → Use per-user rate limit rules if user_id: # TODO: Query MongoDB for per-user rules # For now, use default per-user rule - rules.append(RateLimitRule( - rule_id='default_per_user', - rule_type=RuleType.PER_USER, - time_window=TimeWindow.MINUTE, - limit=100, - burst_allowance=20, - priority=10, - enabled=True, - description="Default per-user limit" - )) - + rules.append( + RateLimitRule( + rule_id='default_per_user', + rule_type=RuleType.PER_USER, + time_window=TimeWindow.MINUTE, + limit=100, + burst_allowance=20, + priority=10, + enabled=True, + description='Default per-user limit', + ) + ) + # Always add global rule as fallback - rules.append(RateLimitRule( - rule_id='default_global', - rule_type=RuleType.GLOBAL, - time_window=TimeWindow.MINUTE, - limit=1000, - priority=0, - enabled=True, - description="Global rate limit" - )) - + rules.append( + RateLimitRule( + rule_id='default_global', + rule_type=RuleType.GLOBAL, + time_window=TimeWindow.MINUTE, + limit=1000, + priority=0, + enabled=True, + description='Global rate limit', + ) + ) + # Sort by priority (highest first) rules.sort(key=lambda r: r.priority, reverse=True) - + return rules - - async def _default_get_user_tier(self, user_id: str) -> Optional[TierLimits]: + + async def _default_get_user_tier(self, user_id: str) -> TierLimits | None: """ Get user's tier limits from TierService - + Args: user_id: User ID - + Returns: TierLimits or None """ try: - from services.tier_service import TierService, get_tier_service + from services.tier_service import get_tier_service from utils.database_async import async_database - + tier_service = get_tier_service(async_database.db) limits = await tier_service.get_user_limits(user_id) - + return limits except Exception as e: - logger.error(f"Error fetching user tier limits: {e}") + logger.error(f'Error fetching user tier limits: {e}') return None - - async def _check_quotas( - self, - user_id: str, - tier_limits: TierLimits - ) -> 'QuotaCheckResult': + + async def _check_quotas(self, user_id: str, tier_limits: TierLimits) -> 'QuotaCheckResult': """ Check user's quotas - + Args: user_id: User ID tier_limits: User's tier limits - + Returns: QuotaCheckResult """ # Check monthly quota if tier_limits.monthly_request_quota: result = self.quota_tracker.check_quota( - user_id, - QuotaType.REQUESTS, - tier_limits.monthly_request_quota, - 'month' + user_id, QuotaType.REQUESTS, tier_limits.monthly_request_quota, 'month' ) - + if not result.allowed: return result - + # Check daily quota if tier_limits.daily_request_quota: result = self.quota_tracker.check_quota( - user_id, - QuotaType.REQUESTS, - tier_limits.daily_request_quota, - 'day' + user_id, QuotaType.REQUESTS, tier_limits.daily_request_quota, 'day' ) - + if not result.allowed: return result - + # All quotas OK from utils.quota_tracker import QuotaCheckResult + return QuotaCheckResult( allowed=True, current_usage=0, limit=tier_limits.monthly_request_quota or 0, remaining=tier_limits.monthly_request_quota or 0, reset_at=self.quota_tracker._get_next_reset('month'), - percentage_used=0.0 + percentage_used=0.0, ) - + def _create_rate_limit_response( - self, - result: 'RateLimitResult', - rule: RateLimitRule + self, result: 'RateLimitResult', rule: RateLimitRule ) -> JSONResponse: """ Create 429 response for rate limit exceeded - + Args: result: Rate limit result rule: Rule that was exceeded - + Returns: JSONResponse with 429 status """ info = result.to_info() - + return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content={ @@ -423,21 +418,18 @@ class RateLimitMiddleware(BaseHTTPMiddleware): 'limit': result.limit, 'remaining': result.remaining, 'reset_at': result.reset_at, - 'retry_after': result.retry_after + 'retry_after': result.retry_after, }, - headers=info.to_headers() + headers=info.to_headers(), ) - - def _create_quota_exceeded_response( - self, - result: 'QuotaCheckResult' - ) -> JSONResponse: + + def _create_quota_exceeded_response(self, result: 'QuotaCheckResult') -> JSONResponse: """ Create 429 response for quota exceeded - + Args: result: Quota check result - + Returns: JSONResponse with 429 status """ @@ -450,26 +442,22 @@ class RateLimitMiddleware(BaseHTTPMiddleware): 'limit': result.limit, 'remaining': result.remaining, 'reset_at': result.reset_at.isoformat(), - 'percentage_used': result.percentage_used + 'percentage_used': result.percentage_used, }, headers={ 'X-RateLimit-Limit': str(result.limit), 'X-RateLimit-Remaining': str(result.remaining), 'X-RateLimit-Reset': str(int(result.reset_at.timestamp())), - 'Retry-After': str(int((result.reset_at - datetime.now()).total_seconds())) - } + 'Retry-After': str(int((result.reset_at - datetime.now()).total_seconds())), + }, ) - + def _add_rate_limit_headers( - self, - response: Response, - limit: int, - remaining: int, - reset_at: int + self, response: Response, limit: int, remaining: int, reset_at: int ): """ Add rate limit headers to response - + Args: response: Response object limit: Rate limit diff --git a/backend-services/middleware/tier_rate_limit_middleware.py b/backend-services/middleware/tier_rate_limit_middleware.py index 7d97f0d..ba6a423 100644 --- a/backend-services/middleware/tier_rate_limit_middleware.py +++ b/backend-services/middleware/tier_rate_limit_middleware.py @@ -8,14 +8,14 @@ Works alongside existing per-user rate limiting. import asyncio import logging import time -from typing import Optional + from fastapi import Request, Response from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp from models.rate_limit_models import TierLimits -from services.tier_service import TierService, get_tier_service +from services.tier_service import get_tier_service from utils.database_async import async_database logger = logging.getLogger(__name__) @@ -24,19 +24,19 @@ logger = logging.getLogger(__name__) class TierRateLimitMiddleware(BaseHTTPMiddleware): """ Middleware for tier-based rate limiting and throttling - + Features: - Enforces tier-based rate limits (requests per minute/hour/day) - Supports throttling (queuing requests) vs hard rejection - Respects user-specific limit overrides - Adds rate limit headers to responses """ - + def __init__(self, app: ASGIApp): super().__init__(app) self._request_counts = {} # Simple in-memory counter (use Redis in production) self._request_queue = {} # Queue for throttling - + async def dispatch(self, request: Request, call_next): """ Process request through tier-based rate limiting @@ -44,244 +44,215 @@ class TierRateLimitMiddleware(BaseHTTPMiddleware): # Skip rate limiting for certain paths if self._should_skip(request): return await call_next(request) - + # Extract user ID user_id = self._get_user_id(request) - + if not user_id: # No user ID, skip tier-based limiting return await call_next(request) - + # Get user's tier limits tier_service = get_tier_service(async_database.db) limits = await tier_service.get_user_limits(user_id) - + if not limits: # No tier limits configured, allow request return await call_next(request) - + # Check rate limits rate_limit_result = await self._check_rate_limits(user_id, limits) - + if not rate_limit_result['allowed']: # Check if throttling is enabled if limits.enable_throttling: # Try to queue the request - queued = await self._try_queue_request( - user_id, - limits.max_queue_time_ms - ) - + queued = await self._try_queue_request(user_id, limits.max_queue_time_ms) + if not queued: # Queue full or timeout, return 429 - return self._create_rate_limit_response( - rate_limit_result, - limits - ) - + return self._create_rate_limit_response(rate_limit_result, limits) + # Request was queued and processed, continue else: # Throttling disabled, hard reject - return self._create_rate_limit_response( - rate_limit_result, - limits - ) - + return self._create_rate_limit_response(rate_limit_result, limits) + # Increment counters self._increment_counters(user_id, limits) - + # Process request response = await call_next(request) - + # Add rate limit headers self._add_rate_limit_headers(response, user_id, limits) - + return response - - async def _check_rate_limits( - self, - user_id: str, - limits: TierLimits - ) -> dict: + + async def _check_rate_limits(self, user_id: str, limits: TierLimits) -> dict: """ Check if user has exceeded any rate limits - + Returns: dict with 'allowed' (bool) and 'limit_type' (str) """ now = int(time.time()) - + # Check requests per minute if limits.requests_per_minute and limits.requests_per_minute < 999999: - key = f"{user_id}:minute:{now // 60}" + key = f'{user_id}:minute:{now // 60}' count = self._request_counts.get(key, 0) - + if count >= limits.requests_per_minute: return { 'allowed': False, 'limit_type': 'minute', 'limit': limits.requests_per_minute, 'current': count, - 'reset_at': ((now // 60) + 1) * 60 + 'reset_at': ((now // 60) + 1) * 60, } - + # Check requests per hour if limits.requests_per_hour and limits.requests_per_hour < 999999: - key = f"{user_id}:hour:{now // 3600}" + key = f'{user_id}:hour:{now // 3600}' count = self._request_counts.get(key, 0) - + if count >= limits.requests_per_hour: return { 'allowed': False, 'limit_type': 'hour', 'limit': limits.requests_per_hour, 'current': count, - 'reset_at': ((now // 3600) + 1) * 3600 + 'reset_at': ((now // 3600) + 1) * 3600, } - + # Check requests per day if limits.requests_per_day and limits.requests_per_day < 999999: - key = f"{user_id}:day:{now // 86400}" + key = f'{user_id}:day:{now // 86400}' count = self._request_counts.get(key, 0) - + if count >= limits.requests_per_day: return { 'allowed': False, 'limit_type': 'day', 'limit': limits.requests_per_day, 'current': count, - 'reset_at': ((now // 86400) + 1) * 86400 + 'reset_at': ((now // 86400) + 1) * 86400, } - + return {'allowed': True} - + def _increment_counters(self, user_id: str, limits: TierLimits): """Increment request counters for all time windows""" now = int(time.time()) - + if limits.requests_per_minute: - key = f"{user_id}:minute:{now // 60}" + key = f'{user_id}:minute:{now // 60}' self._request_counts[key] = self._request_counts.get(key, 0) + 1 - + if limits.requests_per_hour: - key = f"{user_id}:hour:{now // 3600}" + key = f'{user_id}:hour:{now // 3600}' self._request_counts[key] = self._request_counts.get(key, 0) + 1 - + if limits.requests_per_day: - key = f"{user_id}:day:{now // 86400}" + key = f'{user_id}:day:{now // 86400}' self._request_counts[key] = self._request_counts.get(key, 0) + 1 - - async def _try_queue_request( - self, - user_id: str, - max_wait_ms: int - ) -> bool: + + async def _try_queue_request(self, user_id: str, max_wait_ms: int) -> bool: """ Try to queue request with throttling - + Returns: True if request was processed, False if timeout/rejected """ - queue_key = f"{user_id}:queue" + queue_key = f'{user_id}:queue' start_time = time.time() * 1000 # milliseconds - + # Initialize queue if needed if queue_key not in self._request_queue: self._request_queue[queue_key] = asyncio.Queue(maxsize=100) - + queue = self._request_queue[queue_key] - + try: # Add to queue with timeout - await asyncio.wait_for( - queue.put(1), - timeout=max_wait_ms / 1000.0 - ) - + await asyncio.wait_for(queue.put(1), timeout=max_wait_ms / 1000.0) + # Wait for rate limit to reset while True: elapsed = (time.time() * 1000) - start_time - + if elapsed >= max_wait_ms: # Timeout exceeded await queue.get() # Remove from queue return False - + # Check if we can proceed # In a real implementation, check actual rate limit status await asyncio.sleep(0.1) # Small delay - + # For now, assume we can proceed after a short wait if elapsed >= 100: # 100ms min throttle delay await queue.get() # Remove from queue return True - - except asyncio.TimeoutError: + + except TimeoutError: return False - - def _create_rate_limit_response( - self, - result: dict, - limits: TierLimits - ) -> JSONResponse: + + def _create_rate_limit_response(self, result: dict, limits: TierLimits) -> JSONResponse: """Create 429 Too Many Requests response""" retry_after = result.get('reset_at', 0) - int(time.time()) - + return JSONResponse( status_code=429, content={ 'error': 'Rate limit exceeded', 'error_code': 'RATE_LIMIT_EXCEEDED', - 'message': f"Rate limit exceeded: {result.get('current', 0)}/{result.get('limit', 0)} requests per {result.get('limit_type', 'period')}", + 'message': f'Rate limit exceeded: {result.get("current", 0)}/{result.get("limit", 0)} requests per {result.get("limit_type", "period")}', 'limit_type': result.get('limit_type'), 'limit': result.get('limit'), 'current': result.get('current'), 'reset_at': result.get('reset_at'), 'retry_after': max(0, retry_after), - 'throttling_enabled': limits.enable_throttling + 'throttling_enabled': limits.enable_throttling, }, headers={ 'Retry-After': str(max(0, retry_after)), 'X-RateLimit-Limit': str(result.get('limit', 0)), 'X-RateLimit-Remaining': '0', - 'X-RateLimit-Reset': str(result.get('reset_at', 0)) - } + 'X-RateLimit-Reset': str(result.get('reset_at', 0)), + }, ) - - def _add_rate_limit_headers( - self, - response: Response, - user_id: str, - limits: TierLimits - ): + + def _add_rate_limit_headers(self, response: Response, user_id: str, limits: TierLimits): """Add rate limit headers to response""" now = int(time.time()) - + # Add headers for minute limit (most relevant) if limits.requests_per_minute: - key = f"{user_id}:minute:{now // 60}" + key = f'{user_id}:minute:{now // 60}' current = self._request_counts.get(key, 0) remaining = max(0, limits.requests_per_minute - current) reset_at = ((now // 60) + 1) * 60 - + response.headers['X-RateLimit-Limit'] = str(limits.requests_per_minute) response.headers['X-RateLimit-Remaining'] = str(remaining) response.headers['X-RateLimit-Reset'] = str(reset_at) - + def _should_skip(self, request: Request) -> bool: """Check if rate limiting should be skipped""" skip_paths = [ - '/health', - '/metrics', - '/docs', - '/redoc', + '/health', + '/metrics', + '/docs', + '/redoc', '/openapi.json', - '/platform/authorization' # Skip auth endpoints + '/platform/authorization', # Skip auth endpoints ] - + return any(request.url.path.startswith(path) for path in skip_paths) - - def _get_user_id(self, request: Request) -> Optional[str]: + + def _get_user_id(self, request: Request) -> str | None: """Extract user ID from request""" # Try to get from request state (set by auth middleware) if hasattr(request.state, 'user'): @@ -290,9 +261,9 @@ class TierRateLimitMiddleware(BaseHTTPMiddleware): return user.username elif isinstance(user, dict): return user.get('username') or user.get('sub') - + # Try to get from JWT payload in state if hasattr(request.state, 'jwt_payload'): return request.state.jwt_payload.get('sub') - + return None diff --git a/backend-services/models/analytics_models.py b/backend-services/models/analytics_models.py index 4c3abe5..ffb53db 100644 --- a/backend-services/models/analytics_models.py +++ b/backend-services/models/analytics_models.py @@ -6,38 +6,41 @@ analytics capabilities while maintaining backward compatibility. """ from __future__ import annotations + +from collections import deque from dataclasses import dataclass, field -from typing import Dict, List, Optional, Deque -from collections import deque, defaultdict from enum import Enum class AggregationLevel(str, Enum): """Time-based aggregation levels for metrics.""" - MINUTE = "minute" - FIVE_MINUTE = "5minute" - HOUR = "hour" - DAY = "day" + + MINUTE = 'minute' + FIVE_MINUTE = '5minute' + HOUR = 'hour' + DAY = 'day' class MetricType(str, Enum): """Types of metrics tracked.""" - REQUEST_COUNT = "request_count" - ERROR_RATE = "error_rate" - RESPONSE_TIME = "response_time" - BANDWIDTH = "bandwidth" - STATUS_CODE = "status_code" - LATENCY_PERCENTILE = "latency_percentile" + + REQUEST_COUNT = 'request_count' + ERROR_RATE = 'error_rate' + RESPONSE_TIME = 'response_time' + BANDWIDTH = 'bandwidth' + STATUS_CODE = 'status_code' + LATENCY_PERCENTILE = 'latency_percentile' @dataclass class PercentileMetrics: """ Latency percentile calculations. - + Stores multiple percentiles for comprehensive performance analysis. Uses a reservoir sampling approach to maintain a representative sample. """ + p50: float = 0.0 # Median p75: float = 0.0 # 75th percentile p90: float = 0.0 # 90th percentile @@ -45,20 +48,20 @@ class PercentileMetrics: p99: float = 0.0 # 99th percentile min: float = 0.0 # Minimum latency max: float = 0.0 # Maximum latency - + @staticmethod - def calculate(latencies: List[float]) -> 'PercentileMetrics': + def calculate(latencies: list[float]) -> PercentileMetrics: """Calculate percentiles from a list of latencies.""" if not latencies: return PercentileMetrics() - + sorted_latencies = sorted(latencies) n = len(sorted_latencies) - + def percentile(p: float) -> float: k = max(0, int(p * n) - 1) return float(sorted_latencies[k]) - + return PercentileMetrics( p50=percentile(0.50), p75=percentile(0.75), @@ -66,10 +69,10 @@ class PercentileMetrics: p95=percentile(0.95), p99=percentile(0.99), min=float(sorted_latencies[0]), - max=float(sorted_latencies[-1]) + max=float(sorted_latencies[-1]), ) - - def to_dict(self) -> Dict: + + def to_dict(self) -> dict: return { 'p50': self.p50, 'p75': self.p75, @@ -77,7 +80,7 @@ class PercentileMetrics: 'p95': self.p95, 'p99': self.p99, 'min': self.min, - 'max': self.max + 'max': self.max, } @@ -85,36 +88,37 @@ class PercentileMetrics: class EndpointMetrics: """ Per-endpoint performance metrics. - + Tracks detailed metrics for individual API endpoints to identify performance bottlenecks at a granular level. """ + endpoint_uri: str method: str count: int = 0 error_count: int = 0 total_ms: float = 0.0 - latencies: Deque[float] = field(default_factory=deque) - status_counts: Dict[int, int] = field(default_factory=dict) - + latencies: deque[float] = field(default_factory=deque) + status_counts: dict[int, int] = field(default_factory=dict) + def add(self, ms: float, status: int, max_samples: int = 500) -> None: """Record a request for this endpoint.""" self.count += 1 if status >= 400: self.error_count += 1 self.total_ms += ms - + self.status_counts[status] = self.status_counts.get(status, 0) + 1 - + self.latencies.append(ms) while len(self.latencies) > max_samples: self.latencies.popleft() - + def get_percentiles(self) -> PercentileMetrics: """Calculate percentiles for this endpoint.""" return PercentileMetrics.calculate(list(self.latencies)) - - def to_dict(self) -> Dict: + + def to_dict(self) -> dict: percentiles = self.get_percentiles() return { 'endpoint_uri': self.endpoint_uri, @@ -124,7 +128,7 @@ class EndpointMetrics: 'error_rate': (self.error_count / self.count) if self.count > 0 else 0.0, 'avg_ms': (self.total_ms / self.count) if self.count > 0 else 0.0, 'percentiles': percentiles.to_dict(), - 'status_counts': dict(self.status_counts) + 'status_counts': dict(self.status_counts), } @@ -132,13 +136,14 @@ class EndpointMetrics: class EnhancedMinuteBucket: """ Enhanced version of MinuteBucket with additional analytics. - + Extends the existing MinuteBucket from metrics_util.py with: - Per-endpoint tracking - Full percentile calculations (p50, p75, p90, p95, p99) - Unique user tracking - Request/response size tracking """ + start_ts: int count: int = 0 error_count: int = 0 @@ -147,35 +152,35 @@ class EnhancedMinuteBucket: bytes_out: int = 0 upstream_timeouts: int = 0 retries: int = 0 - + # Existing tracking (compatible with metrics_util.py) - status_counts: Dict[int, int] = field(default_factory=dict) - api_counts: Dict[str, int] = field(default_factory=dict) - api_error_counts: Dict[str, int] = field(default_factory=dict) - user_counts: Dict[str, int] = field(default_factory=dict) - latencies: Deque[float] = field(default_factory=deque) - + status_counts: dict[int, int] = field(default_factory=dict) + api_counts: dict[str, int] = field(default_factory=dict) + api_error_counts: dict[str, int] = field(default_factory=dict) + user_counts: dict[str, int] = field(default_factory=dict) + latencies: deque[float] = field(default_factory=deque) + # NEW: Enhanced tracking - endpoint_metrics: Dict[str, EndpointMetrics] = field(default_factory=dict) + endpoint_metrics: dict[str, EndpointMetrics] = field(default_factory=dict) unique_users: set = field(default_factory=set) - request_sizes: Deque[int] = field(default_factory=deque) - response_sizes: Deque[int] = field(default_factory=deque) - + request_sizes: deque[int] = field(default_factory=deque) + response_sizes: deque[int] = field(default_factory=deque) + def add_request( self, ms: float, status: int, - username: Optional[str], - api_key: Optional[str], - endpoint_uri: Optional[str] = None, - method: Optional[str] = None, + username: str | None, + api_key: str | None, + endpoint_uri: str | None = None, + method: str | None = None, bytes_in: int = 0, bytes_out: int = 0, - max_samples: int = 500 + max_samples: int = 500, ) -> None: """ Record a request with enhanced tracking. - + Compatible with existing metrics_util.py while adding new capabilities. """ # Existing tracking (backward compatible) @@ -185,62 +190,61 @@ class EnhancedMinuteBucket: self.total_ms += ms self.bytes_in += bytes_in self.bytes_out += bytes_out - + self.status_counts[status] = self.status_counts.get(status, 0) + 1 - + if api_key: self.api_counts[api_key] = self.api_counts.get(api_key, 0) + 1 if status >= 400: self.api_error_counts[api_key] = self.api_error_counts.get(api_key, 0) + 1 - + if username: self.user_counts[username] = self.user_counts.get(username, 0) + 1 self.unique_users.add(username) - + self.latencies.append(ms) while len(self.latencies) > max_samples: self.latencies.popleft() - + # NEW: Per-endpoint tracking if endpoint_uri and method: - endpoint_key = f"{method}:{endpoint_uri}" + endpoint_key = f'{method}:{endpoint_uri}' if endpoint_key not in self.endpoint_metrics: self.endpoint_metrics[endpoint_key] = EndpointMetrics( - endpoint_uri=endpoint_uri, - method=method + endpoint_uri=endpoint_uri, method=method ) self.endpoint_metrics[endpoint_key].add(ms, status, max_samples) - + # NEW: Request/response size tracking if bytes_in > 0: self.request_sizes.append(bytes_in) while len(self.request_sizes) > max_samples: self.request_sizes.popleft() - + if bytes_out > 0: self.response_sizes.append(bytes_out) while len(self.response_sizes) > max_samples: self.response_sizes.popleft() - + def get_percentiles(self) -> PercentileMetrics: """Calculate full percentiles for this bucket.""" return PercentileMetrics.calculate(list(self.latencies)) - + def get_unique_user_count(self) -> int: """Get count of unique users in this bucket.""" return len(self.unique_users) - - def get_top_endpoints(self, limit: int = 10) -> List[Dict]: + + def get_top_endpoints(self, limit: int = 10) -> list[dict]: """Get top N slowest/most-used endpoints.""" endpoints = [ep.to_dict() for ep in self.endpoint_metrics.values()] # Sort by count (most used) endpoints.sort(key=lambda x: x['count'], reverse=True) return endpoints[:limit] - - def to_dict(self) -> Dict: + + def to_dict(self) -> dict: """Serialize to dictionary (backward compatible + enhanced).""" percentiles = self.get_percentiles() - + return { # Existing fields (backward compatible) 'start_ts': self.start_ts, @@ -255,13 +259,16 @@ class EnhancedMinuteBucket: 'api_counts': dict(self.api_counts), 'api_error_counts': dict(self.api_error_counts), 'user_counts': dict(self.user_counts), - # NEW: Enhanced fields 'percentiles': percentiles.to_dict(), 'unique_users': self.get_unique_user_count(), 'endpoint_metrics': {k: v.to_dict() for k, v in self.endpoint_metrics.items()}, - 'avg_request_size': sum(self.request_sizes) / len(self.request_sizes) if self.request_sizes else 0, - 'avg_response_size': sum(self.response_sizes) / len(self.response_sizes) if self.response_sizes else 0, + 'avg_request_size': sum(self.request_sizes) / len(self.request_sizes) + if self.request_sizes + else 0, + 'avg_response_size': sum(self.response_sizes) / len(self.response_sizes) + if self.response_sizes + else 0, } @@ -269,10 +276,11 @@ class EnhancedMinuteBucket: class AggregatedMetrics: """ Multi-level aggregated metrics (5-minute, hourly, daily). - + Used for efficient querying of historical data without scanning all minute-level buckets. """ + start_ts: int end_ts: int level: AggregationLevel @@ -282,12 +290,12 @@ class AggregatedMetrics: bytes_in: int = 0 bytes_out: int = 0 unique_users: int = 0 - - status_counts: Dict[int, int] = field(default_factory=dict) - api_counts: Dict[str, int] = field(default_factory=dict) - percentiles: Optional[PercentileMetrics] = None - - def to_dict(self) -> Dict: + + status_counts: dict[int, int] = field(default_factory=dict) + api_counts: dict[str, int] = field(default_factory=dict) + percentiles: PercentileMetrics | None = None + + def to_dict(self) -> dict: return { 'start_ts': self.start_ts, 'end_ts': self.end_ts, @@ -301,7 +309,7 @@ class AggregatedMetrics: 'unique_users': self.unique_users, 'status_counts': dict(self.status_counts), 'api_counts': dict(self.api_counts), - 'percentiles': self.percentiles.to_dict() if self.percentiles else None + 'percentiles': self.percentiles.to_dict() if self.percentiles else None, } @@ -309,9 +317,10 @@ class AggregatedMetrics: class AnalyticsSnapshot: """ Complete analytics snapshot for a time range. - + Used as the response format for analytics API endpoints. """ + start_ts: int end_ts: int total_requests: int @@ -322,19 +331,19 @@ class AnalyticsSnapshot: total_bytes_in: int total_bytes_out: int unique_users: int - + # Time-series data - series: List[Dict] - + series: list[dict] + # Top N lists - top_apis: List[tuple] - top_users: List[tuple] - top_endpoints: List[Dict] - + top_apis: list[tuple] + top_users: list[tuple] + top_endpoints: list[dict] + # Status code distribution - status_distribution: Dict[str, int] - - def to_dict(self) -> Dict: + status_distribution: dict[str, int] + + def to_dict(self) -> dict: return { 'start_ts': self.start_ts, 'end_ts': self.end_ts, @@ -346,11 +355,11 @@ class AnalyticsSnapshot: 'percentiles': self.percentiles.to_dict(), 'total_bytes_in': self.total_bytes_in, 'total_bytes_out': self.total_bytes_out, - 'unique_users': self.unique_users + 'unique_users': self.unique_users, }, 'series': self.series, 'top_apis': [{'api': api, 'count': count} for api, count in self.top_apis], 'top_users': [{'user': user, 'count': count} for user, count in self.top_users], 'top_endpoints': self.top_endpoints, - 'status_distribution': self.status_distribution + 'status_distribution': self.status_distribution, } diff --git a/backend-services/models/api_model_response.py b/backend-services/models/api_model_response.py index 0b0dc0b..58c69c6 100644 --- a/backend-services/models/api_model_response.py +++ b/backend-services/models/api_model_response.py @@ -5,24 +5,61 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import List, Optional + class ApiModelResponse(BaseModel): - - api_name: Optional[str] = Field(None, min_length=1, max_length=25, description='Name of the API', example='customer') - api_version: Optional[str] = Field(None, min_length=1, max_length=8, description='Version of the API', example='v1') - api_description: Optional[str] = Field(None, min_length=1, max_length=127, description='Description of the API', example='New customer onboarding API') - api_allowed_roles: Optional[List[str]] = Field(None, description='Allowed user roles for the API', example=['admin', 'user']) - api_allowed_groups: Optional[List[str]] = Field(None, description='Allowed user groups for the API' , example=['admin', 'client-1-group']) - api_servers: Optional[List[str]] = Field(None, description='List of backend servers for the API', example=['http://localhost:8080', 'http://localhost:8081']) - api_type: Optional[str] = Field(None, description="Type of the API. Valid values: 'REST'", example='REST') - api_authorization_field_swap: Optional[str] = Field(None, description='Header to swap for backend authorization header', example='backend-auth-header') - api_allowed_headers: Optional[List[str]] = Field(None, description='Allowed headers for the API', example=['Content-Type', 'Authorization']) - api_allowed_retry_count: Optional[int] = Field(None, description='Number of allowed retries for the API', example=0) - api_credits_enabled: Optional[bool] = Field(False, description='Enable credit-based authentication for the API', example=True) - api_credit_group: Optional[str] = Field(None, description='API credit group for the API credits', example='ai-group-1') - api_id: Optional[str] = Field(None, description='Unique identifier for the API, auto-generated', example='c3eda315-545a-4fef-a831-7e45e2f68987') - api_path: Optional[str] = Field(None, description='Unqiue path for the API, auto-generated', example='/customer/v1') + api_name: str | None = Field( + None, min_length=1, max_length=25, description='Name of the API', example='customer' + ) + api_version: str | None = Field( + None, min_length=1, max_length=8, description='Version of the API', example='v1' + ) + api_description: str | None = Field( + None, + min_length=1, + max_length=127, + description='Description of the API', + example='New customer onboarding API', + ) + api_allowed_roles: list[str] | None = Field( + None, description='Allowed user roles for the API', example=['admin', 'user'] + ) + api_allowed_groups: list[str] | None = Field( + None, description='Allowed user groups for the API', example=['admin', 'client-1-group'] + ) + api_servers: list[str] | None = Field( + None, + description='List of backend servers for the API', + example=['http://localhost:8080', 'http://localhost:8081'], + ) + api_type: str | None = Field( + None, description="Type of the API. Valid values: 'REST'", example='REST' + ) + api_authorization_field_swap: str | None = Field( + None, + description='Header to swap for backend authorization header', + example='backend-auth-header', + ) + api_allowed_headers: list[str] | None = Field( + None, description='Allowed headers for the API', example=['Content-Type', 'Authorization'] + ) + api_allowed_retry_count: int | None = Field( + None, description='Number of allowed retries for the API', example=0 + ) + api_credits_enabled: bool | None = Field( + False, description='Enable credit-based authentication for the API', example=True + ) + api_credit_group: str | None = Field( + None, description='API credit group for the API credits', example='ai-group-1' + ) + api_id: str | None = Field( + None, + description='Unique identifier for the API, auto-generated', + example='c3eda315-545a-4fef-a831-7e45e2f68987', + ) + api_path: str | None = Field( + None, description='Unqiue path for the API, auto-generated', example='/customer/v1' + ) class Config: arbitrary_types_allowed = True diff --git a/backend-services/models/create_api_model.py b/backend-services/models/create_api_model.py index fd158ab..0c66844 100644 --- a/backend-services/models/create_api_model.py +++ b/backend-services/models/create_api_model.py @@ -5,46 +5,127 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import List, Optional + class CreateApiModel(BaseModel): - - api_name: str = Field(..., min_length=1, max_length=64, description='Name of the API', example='customer') - api_version: str = Field(..., min_length=1, max_length=8, description='Version of the API', example='v1') - api_description: Optional[str] = Field(None, max_length=127, description='Description of the API', example='New customer onboarding API') - api_allowed_roles: List[str] = Field(default_factory=list, description='Allowed user roles for the API', example=['admin', 'user']) - api_allowed_groups: List[str] = Field(default_factory=list, description='Allowed user groups for the API' , example=['admin', 'client-1-group']) - api_servers: List[str] = Field(default_factory=list, description='List of backend servers for the API', example=['http://localhost:8080', 'http://localhost:8081']) + api_name: str = Field( + ..., min_length=1, max_length=64, description='Name of the API', example='customer' + ) + api_version: str = Field( + ..., min_length=1, max_length=8, description='Version of the API', example='v1' + ) + api_description: str | None = Field( + None, + max_length=127, + description='Description of the API', + example='New customer onboarding API', + ) + api_allowed_roles: list[str] = Field( + default_factory=list, + description='Allowed user roles for the API', + example=['admin', 'user'], + ) + api_allowed_groups: list[str] = Field( + default_factory=list, + description='Allowed user groups for the API', + example=['admin', 'client-1-group'], + ) + api_servers: list[str] = Field( + default_factory=list, + description='List of backend servers for the API', + example=['http://localhost:8080', 'http://localhost:8081'], + ) api_type: str = Field(None, description="Type of the API. Valid values: 'REST'", example='REST') - api_allowed_retry_count: int = Field(0, description='Number of allowed retries for the API', example=0) - api_grpc_package: Optional[str] = Field(None, description='Optional gRPC Python package to use for this API (e.g., "my.pkg"). When set, overrides request package and default.', example='my.pkg') - api_grpc_allowed_packages: Optional[List[str]] = Field(None, description='Allow-list of gRPC package/module base names (no dots). If set, requests must match one of these.', example=['customer_v1']) - api_grpc_allowed_services: Optional[List[str]] = Field(None, description='Allow-list of gRPC service names (e.g., Greeter). If set, only these services are permitted.', example=['Greeter']) - api_grpc_allowed_methods: Optional[List[str]] = Field(None, description='Allow-list of gRPC methods as Service.Method strings. If set, only these methods are permitted.', example=['Greeter.SayHello']) + api_allowed_retry_count: int = Field( + 0, description='Number of allowed retries for the API', example=0 + ) + api_grpc_package: str | None = Field( + None, + description='Optional gRPC Python package to use for this API (e.g., "my.pkg"). When set, overrides request package and default.', + example='my.pkg', + ) + api_grpc_allowed_packages: list[str] | None = Field( + None, + description='Allow-list of gRPC package/module base names (no dots). If set, requests must match one of these.', + example=['customer_v1'], + ) + api_grpc_allowed_services: list[str] | None = Field( + None, + description='Allow-list of gRPC service names (e.g., Greeter). If set, only these services are permitted.', + example=['Greeter'], + ) + api_grpc_allowed_methods: list[str] | None = Field( + None, + description='Allow-list of gRPC methods as Service.Method strings. If set, only these methods are permitted.', + example=['Greeter.SayHello'], + ) - api_authorization_field_swap: Optional[str] = Field(None, description='Header to swap for backend authorization header', example='backend-auth-header') - api_allowed_headers: Optional[List[str]] = Field(None, description='Allowed headers for the API', example=['Content-Type', 'Authorization']) - api_credits_enabled: Optional[bool] = Field(False, description='Enable credit-based authentication for the API', example=True) - api_credit_group: Optional[str] = Field(None, description='API credit group for the API credits', example='ai-group-1') - active: Optional[bool] = Field(True, description='Whether the API is active (enabled)', example=True) + api_authorization_field_swap: str | None = Field( + None, + description='Header to swap for backend authorization header', + example='backend-auth-header', + ) + api_allowed_headers: list[str] | None = Field( + None, description='Allowed headers for the API', example=['Content-Type', 'Authorization'] + ) + api_credits_enabled: bool | None = Field( + False, description='Enable credit-based authentication for the API', example=True + ) + api_credit_group: str | None = Field( + None, description='API credit group for the API credits', example='ai-group-1' + ) + active: bool | None = Field( + True, description='Whether the API is active (enabled)', example=True + ) - api_cors_allow_origins: Optional[List[str]] = Field(None, description="Allowed origins for CORS (e.g., ['http://localhost:3000']). Use ['*'] to allow all.") - api_cors_allow_methods: Optional[List[str]] = Field(None, description="Allowed methods for CORS preflight (e.g., ['GET','POST','PUT','DELETE','OPTIONS'])") - api_cors_allow_headers: Optional[List[str]] = Field(None, description="Allowed request headers for CORS preflight (e.g., ['Content-Type','Authorization'])") - api_cors_allow_credentials: Optional[bool] = Field(False, description='Whether to include Access-Control-Allow-Credentials=true in responses') - api_cors_expose_headers: Optional[List[str]] = Field(None, description='Response headers to expose to the browser via Access-Control-Expose-Headers') + api_cors_allow_origins: list[str] | None = Field( + None, + description="Allowed origins for CORS (e.g., ['http://localhost:3000']). Use ['*'] to allow all.", + ) + api_cors_allow_methods: list[str] | None = Field( + None, + description="Allowed methods for CORS preflight (e.g., ['GET','POST','PUT','DELETE','OPTIONS'])", + ) + api_cors_allow_headers: list[str] | None = Field( + None, + description="Allowed request headers for CORS preflight (e.g., ['Content-Type','Authorization'])", + ) + api_cors_allow_credentials: bool | None = Field( + False, description='Whether to include Access-Control-Allow-Credentials=true in responses' + ) + api_cors_expose_headers: list[str] | None = Field( + None, + description='Response headers to expose to the browser via Access-Control-Expose-Headers', + ) - api_public: Optional[bool] = Field(False, description='If true, this API can be called without authentication or subscription') + api_public: bool | None = Field( + False, description='If true, this API can be called without authentication or subscription' + ) - api_auth_required: Optional[bool] = Field(True, description='If true (default), JWT auth is required for this API when not public. If false, requests may be unauthenticated but must meet other checks as configured.') + api_auth_required: bool | None = Field( + True, + description='If true (default), JWT auth is required for this API when not public. If false, requests may be unauthenticated but must meet other checks as configured.', + ) - api_id: Optional[str] = Field(None, description='Unique identifier for the API, auto-generated', example=None) - api_path: Optional[str] = Field(None, description='Unique path for the API, auto-generated', example=None) + api_id: str | None = Field( + None, description='Unique identifier for the API, auto-generated', example=None + ) + api_path: str | None = Field( + None, description='Unique path for the API, auto-generated', example=None + ) - api_ip_mode: Optional[str] = Field('allow_all', description="IP policy mode: 'allow_all' or 'whitelist'") - api_ip_whitelist: Optional[List[str]] = Field(None, description='Allowed IPs/CIDRs when api_ip_mode=whitelist') - api_ip_blacklist: Optional[List[str]] = Field(None, description='IPs/CIDRs denied regardless of mode') - api_trust_x_forwarded_for: Optional[bool] = Field(None, description='Override: trust X-Forwarded-For for this API') + api_ip_mode: str | None = Field( + 'allow_all', description="IP policy mode: 'allow_all' or 'whitelist'" + ) + api_ip_whitelist: list[str] | None = Field( + None, description='Allowed IPs/CIDRs when api_ip_mode=whitelist' + ) + api_ip_blacklist: list[str] | None = Field( + None, description='IPs/CIDRs denied regardless of mode' + ) + api_trust_x_forwarded_for: bool | None = Field( + None, description='Override: trust X-Forwarded-For for this API' + ) class Config: arbitrary_types_allowed = True diff --git a/backend-services/models/create_endpoint_model.py b/backend-services/models/create_endpoint_model.py index eb1db6c..0b4ccd0 100644 --- a/backend-services/models/create_endpoint_model.py +++ b/backend-services/models/create_endpoint_model.py @@ -5,19 +5,40 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import Optional, List + class CreateEndpointModel(BaseModel): + api_name: str = Field( + ..., min_length=1, max_length=50, description='Name of the API', example='customer' + ) + api_version: str = Field( + ..., min_length=1, max_length=10, description='Version of the API', example='v1' + ) + endpoint_method: str = Field( + ..., min_length=1, max_length=10, description='HTTP method for the endpoint', example='GET' + ) + endpoint_uri: str = Field( + ..., min_length=1, max_length=255, description='URI for the endpoint', example='/customer' + ) + endpoint_description: str = Field( + ..., + min_length=1, + max_length=255, + description='Description of the endpoint', + example='Get customer details', + ) + endpoint_servers: list[str] | None = Field( + None, + description='Optional list of backend servers for this endpoint (overrides API servers)', + example=['http://localhost:8082', 'http://localhost:8083'], + ) - api_name: str = Field(..., min_length=1, max_length=50, description='Name of the API', example='customer') - api_version: str = Field(..., min_length=1, max_length=10, description='Version of the API', example='v1') - endpoint_method: str = Field(..., min_length=1, max_length=10, description='HTTP method for the endpoint', example='GET') - endpoint_uri: str = Field(..., min_length=1, max_length=255, description='URI for the endpoint', example='/customer') - endpoint_description: str = Field(..., min_length=1, max_length=255, description='Description of the endpoint', example='Get customer details') - endpoint_servers: Optional[List[str]] = Field(None, description='Optional list of backend servers for this endpoint (overrides API servers)', example=['http://localhost:8082', 'http://localhost:8083']) - - api_id: Optional[str] = Field(None, description='Unique identifier for the API, auto-generated', example=None) - endpoint_id: Optional[str] = Field(None, description='Unique identifier for the endpoint, auto-generated', example=None) + api_id: str | None = Field( + None, description='Unique identifier for the API, auto-generated', example=None + ) + endpoint_id: str | None = Field( + None, description='Unique identifier for the endpoint, auto-generated', example=None + ) class Config: arbitrary_types_allowed = True diff --git a/backend-services/models/create_endpoint_validation_model.py b/backend-services/models/create_endpoint_validation_model.py index 2b91574..59d9f07 100644 --- a/backend-services/models/create_endpoint_validation_model.py +++ b/backend-services/models/create_endpoint_validation_model.py @@ -8,11 +8,19 @@ from pydantic import BaseModel, Field from models.validation_schema_model import ValidationSchema -class CreateEndpointValidationModel(BaseModel): - endpoint_id: str = Field(..., description='Unique identifier for the endpoint, auto-generated', example='1299f720-e619-4628-b584-48a6570026cf') - validation_enabled: bool = Field(..., description='Whether the validation is enabled', example=True) - validation_schema: ValidationSchema = Field(..., description='The schema to validate the endpoint against', example={}) +class CreateEndpointValidationModel(BaseModel): + endpoint_id: str = Field( + ..., + description='Unique identifier for the endpoint, auto-generated', + example='1299f720-e619-4628-b584-48a6570026cf', + ) + validation_enabled: bool = Field( + ..., description='Whether the validation is enabled', example=True + ) + validation_schema: ValidationSchema = Field( + ..., description='The schema to validate the endpoint against', example={} + ) class Config: - arbitrary_types_allowed = True \ No newline at end of file + arbitrary_types_allowed = True diff --git a/backend-services/models/create_group_model.py b/backend-services/models/create_group_model.py index d5ae97d..fff28e5 100644 --- a/backend-services/models/create_group_model.py +++ b/backend-services/models/create_group_model.py @@ -5,14 +5,21 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import List, Optional + class CreateGroupModel(BaseModel): + group_name: str = Field( + ..., min_length=1, max_length=50, description='Name of the group', example='client-1-group' + ) - group_name: str = Field(..., min_length=1, max_length=50, description='Name of the group', example='client-1-group') - - group_description: Optional[str] = Field(None, max_length=255, description='Description of the group', example='Group for client 1') - api_access: Optional[List[str]] = Field(default_factory=list, description='List of APIs the group can access', example=['customer/v1']) + group_description: str | None = Field( + None, max_length=255, description='Description of the group', example='Group for client 1' + ) + api_access: list[str] | None = Field( + default_factory=list, + description='List of APIs the group can access', + example=['customer/v1'], + ) class Config: arbitrary_types_allowed = True diff --git a/backend-services/models/create_role_model.py b/backend-services/models/create_role_model.py index 239f04a..2bca258 100644 --- a/backend-services/models/create_role_model.py +++ b/backend-services/models/create_role_model.py @@ -5,26 +5,46 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import Optional + class CreateRoleModel(BaseModel): - - role_name: str = Field(..., min_length=1, max_length=50, description='Name of the role', example='admin') - role_description: Optional[str] = Field(None, max_length=255, description='Description of the role', example='Administrator role with full access') + role_name: str = Field( + ..., min_length=1, max_length=50, description='Name of the role', example='admin' + ) + role_description: str | None = Field( + None, + max_length=255, + description='Description of the role', + example='Administrator role with full access', + ) manage_users: bool = Field(False, description='Permission to manage users', example=True) manage_apis: bool = Field(False, description='Permission to manage APIs', example=True) - manage_endpoints: bool = Field(False, description='Permission to manage endpoints', example=True) + manage_endpoints: bool = Field( + False, description='Permission to manage endpoints', example=True + ) manage_groups: bool = Field(False, description='Permission to manage groups', example=True) manage_roles: bool = Field(False, description='Permission to manage roles', example=True) manage_routings: bool = Field(False, description='Permission to manage routings', example=True) manage_gateway: bool = Field(False, description='Permission to manage gateway', example=True) - manage_subscriptions: bool = Field(False, description='Permission to manage subscriptions', example=True) - manage_security: bool = Field(False, description='Permission to manage security settings', example=True) - manage_tiers: bool = Field(False, description='Permission to manage pricing tiers', example=True) - manage_rate_limits: bool = Field(False, description='Permission to manage rate limiting rules', example=True) + manage_subscriptions: bool = Field( + False, description='Permission to manage subscriptions', example=True + ) + manage_security: bool = Field( + False, description='Permission to manage security settings', example=True + ) + manage_tiers: bool = Field( + False, description='Permission to manage pricing tiers', example=True + ) + manage_rate_limits: bool = Field( + False, description='Permission to manage rate limiting rules', example=True + ) manage_credits: bool = Field(False, description='Permission to manage credits', example=True) - manage_auth: bool = Field(False, description='Permission to manage auth (revoke tokens/disable users)', example=True) - view_analytics: bool = Field(False, description='Permission to view analytics dashboard', example=True) + manage_auth: bool = Field( + False, description='Permission to manage auth (revoke tokens/disable users)', example=True + ) + view_analytics: bool = Field( + False, description='Permission to view analytics dashboard', example=True + ) view_logs: bool = Field(False, description='Permission to view logs', example=True) export_logs: bool = Field(False, description='Permission to export logs', example=True) diff --git a/backend-services/models/create_routing_model.py b/backend-services/models/create_routing_model.py index de2badf..01378f1 100644 --- a/backend-services/models/create_routing_model.py +++ b/backend-services/models/create_routing_model.py @@ -5,16 +5,40 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import Optional + class CreateRoutingModel(BaseModel): + routing_name: str = Field( + ..., + min_length=1, + max_length=50, + description='Name of the routing', + example='customer-routing', + ) + routing_servers: list[str] = Field( + ..., + min_items=1, + description='List of backend servers for the routing', + example=['http://localhost:8080', 'http://localhost:8081'], + ) + routing_description: str = Field( + None, + min_length=1, + max_length=255, + description='Description of the routing', + example='Routing for customer API', + ) - routing_name: str = Field(..., min_length=1, max_length=50, description='Name of the routing', example='customer-routing') - routing_servers : list[str] = Field(..., min_items=1, description='List of backend servers for the routing', example=['http://localhost:8080', 'http://localhost:8081']) - routing_description: str = Field(None, min_length=1, max_length=255, description='Description of the routing', example='Routing for customer API') - - client_key: Optional[str] = Field(None, min_length=1, max_length=50, description='Client key for the routing', example='client-1') - server_index: Optional[int] = Field(0, ge=0, description='Index of the server to route to', example=0) + client_key: str | None = Field( + None, + min_length=1, + max_length=50, + description='Client key for the routing', + example='client-1', + ) + server_index: int | None = Field( + 0, ge=0, description='Index of the server to route to', example=0 + ) class Config: - arbitrary_types_allowed = True \ No newline at end of file + arbitrary_types_allowed = True diff --git a/backend-services/models/create_user_model.py b/backend-services/models/create_user_model.py index f28a9aa..8a0a74f 100644 --- a/backend-services/models/create_user_model.py +++ b/backend-services/models/create_user_model.py @@ -5,31 +5,93 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import List, Optional + class CreateUserModel(BaseModel): + username: str = Field( + ..., min_length=3, max_length=50, description='Username of the user', example='john_doe' + ) + email: str = Field( + ..., + min_length=3, + max_length=127, + description='Email of the user (no strict format validation)', + example='john@mail.com', + ) + password: str = Field( + ..., + min_length=16, + max_length=50, + description='Password of the user', + example='SecurePassword@123', + ) + role: str = Field( + ..., min_length=2, max_length=50, description='Role of the user', example='admin' + ) + groups: list[str] = Field( + default_factory=list, + description='List of groups the user belongs to', + example=['client-1-group'], + ) - username: str = Field(..., min_length=3, max_length=50, description='Username of the user', example='john_doe') - email: str = Field(..., min_length=3, max_length=127, description='Email of the user (no strict format validation)', example='john@mail.com') - password: str = Field(..., min_length=16, max_length=50, description='Password of the user', example='SecurePassword@123') - role: str = Field(..., min_length=2, max_length=50, description='Role of the user', example='admin') - groups: List[str] = Field(default_factory=list, description='List of groups the user belongs to', example=['client-1-group']) - - rate_limit_duration: Optional[int] = Field(None, ge=0, description='Rate limit for the user', example=100) - rate_limit_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Duration for the rate limit', example='hour') - rate_limit_enabled: Optional[bool] = Field(None, description='Whether rate limiting is enabled for this user', example=True) - throttle_duration: Optional[int] = Field(None, ge=0, description='Throttle limit for the user', example=10) - throttle_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Duration for the throttle limit', example='second') - throttle_wait_duration: Optional[int] = Field(None, ge=0, description='Wait time for the throttle limit', example=5) - throttle_wait_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Wait duration for the throttle limit', example='seconds') - throttle_queue_limit: Optional[int] = Field(None, ge=0, description='Throttle queue limit for the user', example=10) - throttle_enabled: Optional[bool] = Field(None, description='Whether throttling is enabled for this user', example=True) - custom_attributes: Optional[dict] = Field(None, description='Custom attributes for the user', example={'custom_key': 'custom_value'}) - bandwidth_limit_bytes: Optional[int] = Field(None, ge=0, description='Maximum bandwidth allowed within the window (bytes)', example=1073741824) - bandwidth_limit_window: Optional[str] = Field('day', min_length=1, max_length=10, description='Bandwidth window unit (second/minute/hour/day/month)', example='day') - bandwidth_limit_enabled: Optional[bool] = Field(None, description='Whether bandwidth limit enforcement is enabled for this user', example=True) - active: Optional[bool] = Field(True, description='Active status of the user', example=True) - ui_access: Optional[bool] = Field(False, description='UI access for the user', example=False) + rate_limit_duration: int | None = Field( + None, ge=0, description='Rate limit for the user', example=100 + ) + rate_limit_duration_type: str | None = Field( + None, min_length=1, max_length=7, description='Duration for the rate limit', example='hour' + ) + rate_limit_enabled: bool | None = Field( + None, description='Whether rate limiting is enabled for this user', example=True + ) + throttle_duration: int | None = Field( + None, ge=0, description='Throttle limit for the user', example=10 + ) + throttle_duration_type: str | None = Field( + None, + min_length=1, + max_length=7, + description='Duration for the throttle limit', + example='second', + ) + throttle_wait_duration: int | None = Field( + None, ge=0, description='Wait time for the throttle limit', example=5 + ) + throttle_wait_duration_type: str | None = Field( + None, + min_length=1, + max_length=7, + description='Wait duration for the throttle limit', + example='seconds', + ) + throttle_queue_limit: int | None = Field( + None, ge=0, description='Throttle queue limit for the user', example=10 + ) + throttle_enabled: bool | None = Field( + None, description='Whether throttling is enabled for this user', example=True + ) + custom_attributes: dict | None = Field( + None, description='Custom attributes for the user', example={'custom_key': 'custom_value'} + ) + bandwidth_limit_bytes: int | None = Field( + None, + ge=0, + description='Maximum bandwidth allowed within the window (bytes)', + example=1073741824, + ) + bandwidth_limit_window: str | None = Field( + 'day', + min_length=1, + max_length=10, + description='Bandwidth window unit (second/minute/hour/day/month)', + example='day', + ) + bandwidth_limit_enabled: bool | None = Field( + None, + description='Whether bandwidth limit enforcement is enabled for this user', + example=True, + ) + active: bool | None = Field(True, description='Active status of the user', example=True) + ui_access: bool | None = Field(False, description='UI access for the user', example=False) class Config: arbitrary_types_allowed = True diff --git a/backend-services/models/create_vault_entry_model.py b/backend-services/models/create_vault_entry_model.py index 10b0fbd..881184f 100644 --- a/backend-services/models/create_vault_entry_model.py +++ b/backend-services/models/create_vault_entry_model.py @@ -5,32 +5,31 @@ See https://github.com/apidoorman/doorman for more information """ from pydantic import BaseModel, Field -from typing import Optional class CreateVaultEntryModel(BaseModel): """Model for creating a new vault entry.""" - + key_name: str = Field( - ..., - min_length=1, - max_length=255, + ..., + min_length=1, + max_length=255, description='Unique name for the vault key', - example='api_key_production' + example='api_key_production', ) - + value: str = Field( - ..., - min_length=1, + ..., + min_length=1, description='The secret value to encrypt and store', - example='sk_live_abc123xyz789' + example='sk_live_abc123xyz789', ) - - description: Optional[str] = Field( + + description: str | None = Field( None, max_length=500, description='Optional description of what this key is used for', - example='Production API key for payment gateway' + example='Production API key for payment gateway', ) class Config: diff --git a/backend-services/models/credit_model.py b/backend-services/models/credit_model.py index 16b26a7..31e7274 100644 --- a/backend-services/models/credit_model.py +++ b/backend-services/models/credit_model.py @@ -4,30 +4,58 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import List, Optional -from pydantic import BaseModel, Field from datetime import datetime +from pydantic import BaseModel, Field + + class CreditTierModel(BaseModel): - tier_name: str = Field(..., min_length=1, max_length=50, description='Name of the credit tier', example='basic') + tier_name: str = Field( + ..., min_length=1, max_length=50, description='Name of the credit tier', example='basic' + ) credits: int = Field(..., description='Number of credits per reset', example=50) - input_limit: int = Field(..., description='Input limit for paid credits (text or context)', example=150) - output_limit: int = Field(..., description='Output limit for paid credits (text or context)', example=150) - reset_frequency: str = Field(..., description='Frequency of paid credit reset', example='monthly') + input_limit: int = Field( + ..., description='Input limit for paid credits (text or context)', example=150 + ) + output_limit: int = Field( + ..., description='Output limit for paid credits (text or context)', example=150 + ) + reset_frequency: str = Field( + ..., description='Frequency of paid credit reset', example='monthly' + ) class Config: arbitrary_types_allowed = True + class CreditModel(BaseModel): + api_credit_group: str = Field( + ..., + min_length=1, + max_length=50, + description='API group for the credits', + example='ai-group-1', + ) + api_key: str = Field( + ..., description='API key for the credit tier', example='xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx' + ) + api_key_header: str = Field( + ..., description='Header the API key should be sent in', example='x-api-key' + ) + credit_tiers: list[CreditTierModel] = Field( + ..., min_items=1, description='Credit tiers information' + ) - api_credit_group: str = Field(..., min_length=1, max_length=50, description='API group for the credits', example='ai-group-1') - api_key: str = Field(..., description='API key for the credit tier', example='xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx') - api_key_header: str = Field(..., description='Header the API key should be sent in', example='x-api-key') - credit_tiers: List[CreditTierModel] = Field(..., min_items=1, description='Credit tiers information') - - api_key_new: Optional[str] = Field(None, description='New API key during rotation period', example='yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy') - api_key_rotation_expires: Optional[datetime] = Field(None, description='Expiration time for old API key during rotation', example='2025-01-15T10:00:00Z') + api_key_new: str | None = Field( + None, + description='New API key during rotation period', + example='yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy', + ) + api_key_rotation_expires: datetime | None = Field( + None, + description='Expiration time for old API key during rotation', + example='2025-01-15T10:00:00Z', + ) class Config: arbitrary_types_allowed = True - diff --git a/backend-services/models/delete_successfully.py b/backend-services/models/delete_successfully.py index 544520e..ef6da42 100644 --- a/backend-services/models/delete_successfully.py +++ b/backend-services/models/delete_successfully.py @@ -6,9 +6,11 @@ See https://github.com/pypeople-dev/doorman for more information from pydantic import BaseModel, Field -class ResponseMessage(BaseModel): - message: str = Field(None, description='The response message', example='API Deleted Successfully') +class ResponseMessage(BaseModel): + message: str = Field( + None, description='The response message', example='API Deleted Successfully' + ) class Config: - arbitrary_types_allowed = True \ No newline at end of file + arbitrary_types_allowed = True diff --git a/backend-services/models/endpoint_model_response.py b/backend-services/models/endpoint_model_response.py index abf2138..765ae56 100644 --- a/backend-services/models/endpoint_model_response.py +++ b/backend-services/models/endpoint_model_response.py @@ -5,18 +5,47 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import Optional, List + class EndpointModelResponse(BaseModel): - - api_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the API', example='customer') - api_version: Optional[str] = Field(None, min_length=1, max_length=10, description='Version of the API', example='v1') - endpoint_method: Optional[str] = Field(None, min_length=1, max_length=10, description='HTTP method for the endpoint', example='GET') - endpoint_uri: Optional[str] = Field(None, min_length=1, max_length=255, description='URI for the endpoint', example='/customer') - endpoint_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the endpoint', example='Get customer details') - endpoint_servers: Optional[List[str]] = Field(None, description='Optional list of backend servers for this endpoint (overrides API servers)', example=['http://localhost:8082', 'http://localhost:8083']) - api_id: Optional[str] = Field(None, min_length=1, max_length=255, description='Unique identifier for the API, auto-generated', example=None) - endpoint_id: Optional[str] = Field(None, min_length=1, max_length=255, description='Unique identifier for the endpoint, auto-generated', example=None) + api_name: str | None = Field( + None, min_length=1, max_length=50, description='Name of the API', example='customer' + ) + api_version: str | None = Field( + None, min_length=1, max_length=10, description='Version of the API', example='v1' + ) + endpoint_method: str | None = Field( + None, min_length=1, max_length=10, description='HTTP method for the endpoint', example='GET' + ) + endpoint_uri: str | None = Field( + None, min_length=1, max_length=255, description='URI for the endpoint', example='/customer' + ) + endpoint_description: str | None = Field( + None, + min_length=1, + max_length=255, + description='Description of the endpoint', + example='Get customer details', + ) + endpoint_servers: list[str] | None = Field( + None, + description='Optional list of backend servers for this endpoint (overrides API servers)', + example=['http://localhost:8082', 'http://localhost:8083'], + ) + api_id: str | None = Field( + None, + min_length=1, + max_length=255, + description='Unique identifier for the API, auto-generated', + example=None, + ) + endpoint_id: str | None = Field( + None, + min_length=1, + max_length=255, + description='Unique identifier for the endpoint, auto-generated', + example=None, + ) class Config: arbitrary_types_allowed = True diff --git a/backend-services/models/endpoint_validation_model_response.py b/backend-services/models/endpoint_validation_model_response.py index 8edc594..1bef4f6 100644 --- a/backend-services/models/endpoint_validation_model_response.py +++ b/backend-services/models/endpoint_validation_model_response.py @@ -8,11 +8,19 @@ from pydantic import BaseModel, Field from models.validation_schema_model import ValidationSchema -class EndpointValidationModelResponse(BaseModel): - endpoint_id: str = Field(..., description='Unique identifier for the endpoint, auto-generated', example='1299f720-e619-4628-b584-48a6570026cf') - validation_enabled: bool = Field(..., description='Whether the validation is enabled', example=True) - validation_schema: ValidationSchema = Field(..., description='The schema to validate the endpoint against', example={}) +class EndpointValidationModelResponse(BaseModel): + endpoint_id: str = Field( + ..., + description='Unique identifier for the endpoint, auto-generated', + example='1299f720-e619-4628-b584-48a6570026cf', + ) + validation_enabled: bool = Field( + ..., description='Whether the validation is enabled', example=True + ) + validation_schema: ValidationSchema = Field( + ..., description='The schema to validate the endpoint against', example={} + ) class Config: - arbitrary_types_allowed = True \ No newline at end of file + arbitrary_types_allowed = True diff --git a/backend-services/models/field_validation_model.py b/backend-services/models/field_validation_model.py index 7f7803a..9b35a1e 100644 --- a/backend-services/models/field_validation_model.py +++ b/backend-services/models/field_validation_model.py @@ -4,17 +4,31 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/pypeople-dev/doorman for more information """ -from typing import List, Union, Optional, Dict, Any +from typing import Any, Optional + from pydantic import BaseModel, Field + class FieldValidation(BaseModel): required: bool = Field(..., description='Whether the field is required') - type: str = Field(..., description='Expected data type (string, number, boolean, array, object)') - min: Optional[Union[int, float]] = Field(None, description='Minimum value for numbers or minimum length for strings/arrays') - max: Optional[Union[int, float]] = Field(None, description='Maximum value for numbers or maximum length for strings/arrays') - pattern: Optional[str] = Field(None, description='Regex pattern for string validation') - enum: Optional[List[Any]] = Field(None, description='List of allowed values') - format: Optional[str] = Field(None, description='Format validation (email, url, date, datetime, uuid, etc.)') - custom_validator: Optional[str] = Field(None, description='Custom validation function name') - nested_schema: Optional[Dict[str, 'FieldValidation']] = Field(None, description='Validation schema for nested objects') - array_items: Optional['FieldValidation'] = Field(None, description='Validation schema for array items') \ No newline at end of file + type: str = Field( + ..., description='Expected data type (string, number, boolean, array, object)' + ) + min: int | float | None = Field( + None, description='Minimum value for numbers or minimum length for strings/arrays' + ) + max: int | float | None = Field( + None, description='Maximum value for numbers or maximum length for strings/arrays' + ) + pattern: str | None = Field(None, description='Regex pattern for string validation') + enum: list[Any] | None = Field(None, description='List of allowed values') + format: str | None = Field( + None, description='Format validation (email, url, date, datetime, uuid, etc.)' + ) + custom_validator: str | None = Field(None, description='Custom validation function name') + nested_schema: dict[str, 'FieldValidation'] | None = Field( + None, description='Validation schema for nested objects' + ) + array_items: Optional['FieldValidation'] = Field( + None, description='Validation schema for array items' + ) diff --git a/backend-services/models/group_model_response.py b/backend-services/models/group_model_response.py index e040d3f..6618a08 100644 --- a/backend-services/models/group_model_response.py +++ b/backend-services/models/group_model_response.py @@ -5,13 +5,22 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import List, Optional + class GroupModelResponse(BaseModel): - - group_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the group', example='client-1-group') - group_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the group', example='Group for client 1') - api_access: Optional[List[str]] = Field(None, description='List of APIs the group can access', example=['customer/v1']) + group_name: str | None = Field( + None, min_length=1, max_length=50, description='Name of the group', example='client-1-group' + ) + group_description: str | None = Field( + None, + min_length=1, + max_length=255, + description='Description of the group', + example='Group for client 1', + ) + api_access: list[str] | None = Field( + None, description='List of APIs the group can access', example=['customer/v1'] + ) class Config: - arbitrary_types_allowed = True \ No newline at end of file + arbitrary_types_allowed = True diff --git a/backend-services/models/rate_limit_models.py b/backend-services/models/rate_limit_models.py index bd7b4c4..0f92e60 100644 --- a/backend-services/models/rate_limit_models.py +++ b/backend-services/models/rate_limit_models.py @@ -9,59 +9,64 @@ This module defines the data structures for the rate limiting system including: """ from dataclasses import dataclass, field -from typing import Optional, Dict, List, Any -from enum import Enum from datetime import datetime - +from enum import Enum +from typing import Any # ============================================================================ # ENUMS # ============================================================================ + class RuleType(Enum): """Types of rate limit rules""" - PER_USER = "per_user" - PER_API = "per_api" - PER_ENDPOINT = "per_endpoint" - PER_IP = "per_ip" - PER_USER_API = "per_user_api" # Combined: specific user on specific API - PER_USER_ENDPOINT = "per_user_endpoint" # Combined: specific user on specific endpoint - GLOBAL = "global" # Global rate limit for all requests + + PER_USER = 'per_user' + PER_API = 'per_api' + PER_ENDPOINT = 'per_endpoint' + PER_IP = 'per_ip' + PER_USER_API = 'per_user_api' # Combined: specific user on specific API + PER_USER_ENDPOINT = 'per_user_endpoint' # Combined: specific user on specific endpoint + GLOBAL = 'global' # Global rate limit for all requests class TimeWindow(Enum): """Time windows for rate limiting""" - SECOND = "second" - MINUTE = "minute" - HOUR = "hour" - DAY = "day" - MONTH = "month" + + SECOND = 'second' + MINUTE = 'minute' + HOUR = 'hour' + DAY = 'day' + MONTH = 'month' class TierName(Enum): """Predefined tier names""" - FREE = "free" - PRO = "pro" - ENTERPRISE = "enterprise" - CUSTOM = "custom" + + FREE = 'free' + PRO = 'pro' + ENTERPRISE = 'enterprise' + CUSTOM = 'custom' class QuotaType(Enum): """Types of quotas""" - REQUESTS = "requests" - BANDWIDTH = "bandwidth" - COMPUTE_TIME = "compute_time" + + REQUESTS = 'requests' + BANDWIDTH = 'bandwidth' + COMPUTE_TIME = 'compute_time' # ============================================================================ # RATE LIMIT RULE MODELS # ============================================================================ + @dataclass class RateLimitRule: """ Defines a rate limiting rule - + Examples: # Per-user rule: 100 requests per minute RateLimitRule( @@ -71,7 +76,7 @@ class RateLimitRule: limit=100, burst_allowance=20 ) - + # Per-API rule: 1000 requests per hour RateLimitRule( rule_id="rule_002", @@ -81,24 +86,25 @@ class RateLimitRule: limit=1000 ) """ + rule_id: str rule_type: RuleType time_window: TimeWindow limit: int # Maximum requests allowed in time window - + # Optional fields - target_identifier: Optional[str] = None # User ID, API name, endpoint URI, or IP + target_identifier: str | None = None # User ID, API name, endpoint URI, or IP burst_allowance: int = 0 # Additional requests allowed for bursts priority: int = 0 # Higher priority rules are checked first enabled: bool = True - + # Metadata - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - created_by: Optional[str] = None - description: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: + created_at: datetime | None = None + updated_at: datetime | None = None + created_by: str | None = None + description: str | None = None + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary for storage""" return { 'rule_id': self.rule_id, @@ -112,11 +118,11 @@ class RateLimitRule: 'created_at': self.created_at.isoformat() if self.created_at else None, 'updated_at': self.updated_at.isoformat() if self.updated_at else None, 'created_by': self.created_by, - 'description': self.description + 'description': self.description, } - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'RateLimitRule': + def from_dict(cls, data: dict[str, Any]) -> 'RateLimitRule': """Create from dictionary""" return cls( rule_id=data['rule_id'], @@ -127,10 +133,14 @@ class RateLimitRule: burst_allowance=data.get('burst_allowance', 0), priority=data.get('priority', 0), enabled=data.get('enabled', True), - created_at=datetime.fromisoformat(data['created_at']) if data.get('created_at') else None, - updated_at=datetime.fromisoformat(data['updated_at']) if data.get('updated_at') else None, + created_at=datetime.fromisoformat(data['created_at']) + if data.get('created_at') + else None, + updated_at=datetime.fromisoformat(data['updated_at']) + if data.get('updated_at') + else None, created_by=data.get('created_by'), - description=data.get('description') + description=data.get('description'), ) @@ -138,31 +148,33 @@ class RateLimitRule: # TIER/PLAN MODELS # ============================================================================ + @dataclass class TierLimits: """Rate limits and quotas for a specific tier""" + # Rate limits (requests per time window) - requests_per_second: Optional[int] = None - requests_per_minute: Optional[int] = None - requests_per_hour: Optional[int] = None - requests_per_day: Optional[int] = None - requests_per_month: Optional[int] = None - + requests_per_second: int | None = None + requests_per_minute: int | None = None + requests_per_hour: int | None = None + requests_per_day: int | None = None + requests_per_month: int | None = None + # Burst allowances burst_per_second: int = 0 burst_per_minute: int = 0 burst_per_hour: int = 0 - + # Quotas - monthly_request_quota: Optional[int] = None - daily_request_quota: Optional[int] = None - monthly_bandwidth_quota: Optional[int] = None # In bytes - + monthly_request_quota: int | None = None + daily_request_quota: int | None = None + monthly_bandwidth_quota: int | None = None # In bytes + # Throttling configuration enable_throttling: bool = False # If true, queue/delay requests; if false, hard reject (429) max_queue_time_ms: int = 5000 # Maximum time to queue a request before rejecting (milliseconds) - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary""" return { 'requests_per_second': self.requests_per_second, @@ -177,11 +189,11 @@ class TierLimits: 'daily_request_quota': self.daily_request_quota, 'monthly_bandwidth_quota': self.monthly_bandwidth_quota, 'enable_throttling': self.enable_throttling, - 'max_queue_time_ms': self.max_queue_time_ms + 'max_queue_time_ms': self.max_queue_time_ms, } - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'TierLimits': + def from_dict(cls, data: dict[str, Any]) -> 'TierLimits': """Create from dictionary""" return cls(**data) @@ -190,7 +202,7 @@ class TierLimits: class Tier: """ Defines a tier/plan with associated rate limits and quotas - + Examples: # Free tier Tier( @@ -203,7 +215,7 @@ class Tier: daily_request_quota=10000 ) ) - + # Pro tier Tier( tier_id="tier_pro", @@ -218,24 +230,25 @@ class Tier: price_monthly=49.99 ) """ + tier_id: str name: TierName display_name: str limits: TierLimits - + # Optional fields - description: Optional[str] = None - price_monthly: Optional[float] = None - price_yearly: Optional[float] = None - features: List[str] = field(default_factory=list) + description: str | None = None + price_monthly: float | None = None + price_yearly: float | None = None + features: list[str] = field(default_factory=list) is_default: bool = False enabled: bool = True - + # Metadata - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - - def to_dict(self) -> Dict[str, Any]: + created_at: datetime | None = None + updated_at: datetime | None = None + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary for storage""" return { 'tier_id': self.tier_id, @@ -249,11 +262,11 @@ class Tier: 'is_default': self.is_default, 'enabled': self.enabled, 'created_at': self.created_at.isoformat() if self.created_at else None, - 'updated_at': self.updated_at.isoformat() if self.updated_at else None + 'updated_at': self.updated_at.isoformat() if self.updated_at else None, } - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'Tier': + def from_dict(cls, data: dict[str, Any]) -> 'Tier': """Create from dictionary""" return cls( tier_id=data['tier_id'], @@ -266,30 +279,35 @@ class Tier: features=data.get('features', []), is_default=data.get('is_default', False), enabled=data.get('enabled', True), - created_at=datetime.fromisoformat(data['created_at']) if data.get('created_at') else None, - updated_at=datetime.fromisoformat(data['updated_at']) if data.get('updated_at') else None + created_at=datetime.fromisoformat(data['created_at']) + if data.get('created_at') + else None, + updated_at=datetime.fromisoformat(data['updated_at']) + if data.get('updated_at') + else None, ) @dataclass class UserTierAssignment: """Assigns a user to a tier with optional overrides""" + user_id: str tier_id: str - + # Optional overrides (override tier defaults for this specific user) - override_limits: Optional[TierLimits] = None - + override_limits: TierLimits | None = None + # Scheduling - effective_from: Optional[datetime] = None - effective_until: Optional[datetime] = None - + effective_from: datetime | None = None + effective_until: datetime | None = None + # Metadata - assigned_at: Optional[datetime] = None - assigned_by: Optional[str] = None - notes: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: + assigned_at: datetime | None = None + assigned_by: str | None = None + notes: str | None = None + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary""" return { 'user_id': self.user_id, @@ -299,7 +317,7 @@ class UserTierAssignment: 'effective_until': self.effective_until.isoformat() if self.effective_until else None, 'assigned_at': self.assigned_at.isoformat() if self.assigned_at else None, 'assigned_by': self.assigned_by, - 'notes': self.notes + 'notes': self.notes, } @@ -307,41 +325,43 @@ class UserTierAssignment: # QUOTA TRACKING MODELS # ============================================================================ + @dataclass class QuotaUsage: """ Tracks current quota usage for a user/API/endpoint - + This is stored in Redis for real-time tracking """ + key: str # Redis key (e.g., "quota:user:john_doe:month:2025-12") quota_type: QuotaType current_usage: int limit: int reset_at: datetime # When the quota resets - + # Optional fields burst_usage: int = 0 # Burst tokens used burst_limit: int = 0 - + @property def remaining(self) -> int: """Calculate remaining quota""" return max(0, self.limit - self.current_usage) - + @property def percentage_used(self) -> float: """Calculate percentage of quota used""" if self.limit == 0: return 0.0 return (self.current_usage / self.limit) * 100 - + @property def is_exhausted(self) -> bool: """Check if quota is exhausted""" return self.current_usage >= self.limit - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary""" return { 'key': self.key, @@ -353,7 +373,7 @@ class QuotaUsage: 'reset_at': self.reset_at.isoformat(), 'burst_usage': self.burst_usage, 'burst_limit': self.burst_limit, - 'is_exhausted': self.is_exhausted + 'is_exhausted': self.is_exhausted, } @@ -361,9 +381,10 @@ class QuotaUsage: class RateLimitCounter: """ Real-time counter for rate limiting (stored in Redis) - + Uses sliding window counter algorithm """ + key: str # Redis key (e.g., "ratelimit:user:john_doe:minute:1701504000") window_start: int # Unix timestamp window_size: int # Window size in seconds @@ -371,23 +392,23 @@ class RateLimitCounter: limit: int # Maximum allowed requests burst_count: int = 0 # Burst tokens used burst_limit: int = 0 - + @property def remaining(self) -> int: """Calculate remaining requests""" return max(0, self.limit - self.count) - + @property def is_limited(self) -> bool: """Check if rate limit is exceeded""" return self.count >= self.limit - + @property def reset_at(self) -> int: """Calculate when the window resets (Unix timestamp)""" return self.window_start + self.window_size - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary""" return { 'key': self.key, @@ -399,7 +420,7 @@ class RateLimitCounter: 'reset_at': self.reset_at, 'burst_count': self.burst_count, 'burst_limit': self.burst_limit, - 'is_limited': self.is_limited + 'is_limited': self.is_limited, } @@ -407,28 +428,30 @@ class RateLimitCounter: # HISTORICAL TRACKING MODELS # ============================================================================ + @dataclass class UsageHistoryRecord: """ Historical usage record for analytics - + Stored in time-series database or MongoDB """ + timestamp: datetime - user_id: Optional[str] = None - api_name: Optional[str] = None - endpoint_uri: Optional[str] = None - ip_address: Optional[str] = None - + user_id: str | None = None + api_name: str | None = None + endpoint_uri: str | None = None + ip_address: str | None = None + # Metrics request_count: int = 0 blocked_count: int = 0 # Requests blocked by rate limit burst_used: int = 0 - + # Aggregation period - period: str = "minute" # minute, hour, day - - def to_dict(self) -> Dict[str, Any]: + period: str = 'minute' # minute, hour, day + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary""" return { 'timestamp': self.timestamp.isoformat(), @@ -439,7 +462,7 @@ class UsageHistoryRecord: 'request_count': self.request_count, 'blocked_count': self.blocked_count, 'burst_used': self.burst_used, - 'period': self.period + 'period': self.period, } @@ -447,42 +470,44 @@ class UsageHistoryRecord: # RESPONSE MODELS # ============================================================================ + @dataclass class RateLimitInfo: """ Information about current rate limit status - + Returned in API responses and headers """ + limit: int remaining: int reset_at: int # Unix timestamp - retry_after: Optional[int] = None # Seconds until retry (when limited) - + retry_after: int | None = None # Seconds until retry (when limited) + # Additional info burst_limit: int = 0 burst_remaining: int = 0 - tier: Optional[str] = None - - def to_headers(self) -> Dict[str, str]: + tier: str | None = None + + def to_headers(self) -> dict[str, str]: """Convert to HTTP headers""" headers = { 'X-RateLimit-Limit': str(self.limit), 'X-RateLimit-Remaining': str(self.remaining), - 'X-RateLimit-Reset': str(self.reset_at) + 'X-RateLimit-Reset': str(self.reset_at), } - + if self.retry_after is not None: headers['X-RateLimit-Retry-After'] = str(self.retry_after) headers['Retry-After'] = str(self.retry_after) - + if self.burst_limit > 0: headers['X-RateLimit-Burst-Limit'] = str(self.burst_limit) headers['X-RateLimit-Burst-Remaining'] = str(self.burst_remaining) - + return headers - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: """Convert to dictionary""" return { 'limit': self.limit, @@ -491,7 +516,7 @@ class RateLimitInfo: 'retry_after': self.retry_after, 'burst_limit': self.burst_limit, 'burst_remaining': self.burst_remaining, - 'tier': self.tier + 'tier': self.tier, } @@ -499,6 +524,7 @@ class RateLimitInfo: # HELPER FUNCTIONS # ============================================================================ + def get_time_window_seconds(window: TimeWindow) -> int: """Convert time window enum to seconds""" mapping = { @@ -506,39 +532,32 @@ def get_time_window_seconds(window: TimeWindow) -> int: TimeWindow.MINUTE: 60, TimeWindow.HOUR: 3600, TimeWindow.DAY: 86400, - TimeWindow.MONTH: 2592000 # 30 days + TimeWindow.MONTH: 2592000, # 30 days } return mapping[window] def generate_redis_key( - rule_type: RuleType, - identifier: str, - window: TimeWindow, - window_start: int + rule_type: RuleType, identifier: str, window: TimeWindow, window_start: int ) -> str: """ Generate Redis key for rate limit counter - + Examples: generate_redis_key(RuleType.PER_USER, "john_doe", TimeWindow.MINUTE, 1701504000) # Returns: "ratelimit:user:john_doe:minute:1701504000" """ type_prefix = rule_type.value.replace('per_', '') window_name = window.value - return f"ratelimit:{type_prefix}:{identifier}:{window_name}:{window_start}" + return f'ratelimit:{type_prefix}:{identifier}:{window_name}:{window_start}' -def generate_quota_key( - user_id: str, - quota_type: QuotaType, - period: str -) -> str: +def generate_quota_key(user_id: str, quota_type: QuotaType, period: str) -> str: """ Generate Redis key for quota tracking - + Examples: generate_quota_key("john_doe", QuotaType.REQUESTS, "2025-12") # Returns: "quota:user:john_doe:requests:month:2025-12" """ - return f"quota:user:{user_id}:{quota_type.value}:month:{period}" + return f'quota:user:{user_id}:{quota_type.value}:month:{period}' diff --git a/backend-services/models/request_model.py b/backend-services/models/request_model.py index 917d6eb..35cf0d2 100644 --- a/backend-services/models/request_model.py +++ b/backend-services/models/request_model.py @@ -1,10 +1,10 @@ from pydantic import BaseModel -from typing import Dict, Optional + class RequestModel(BaseModel): method: str path: str - headers: Dict[str, str] - query_params: Dict[str, str] - identity: Optional[str] = None - body: Optional[str] = None \ No newline at end of file + headers: dict[str, str] + query_params: dict[str, str] + identity: str | None = None + body: str | None = None diff --git a/backend-services/models/response_model.py b/backend-services/models/response_model.py index 8cfeb1c..8761b24 100644 --- a/backend-services/models/response_model.py +++ b/backend-services/models/response_model.py @@ -1,13 +1,13 @@ from pydantic import BaseModel, Field -from typing import Optional, Union + class ResponseModel(BaseModel): status_code: int = Field(None) - response_headers: Optional[dict] = Field(None) + response_headers: dict | None = Field(None) - response: Optional[Union[dict, list, str]] = Field(None) - message: Optional[str] = Field(None, min_length=1, max_length=255) + response: dict | list | str | None = Field(None) + message: str | None = Field(None, min_length=1, max_length=255) - error_code: Optional[str] = Field(None, min_length=1, max_length=255) - error_message: Optional[str] = Field(None, min_length=1, max_length=255) \ No newline at end of file + error_code: str | None = Field(None, min_length=1, max_length=255) + error_message: str | None = Field(None, min_length=1, max_length=255) diff --git a/backend-services/models/role_model_response.py b/backend-services/models/role_model_response.py index 1af719a..e01e976 100644 --- a/backend-services/models/role_model_response.py +++ b/backend-services/models/role_model_response.py @@ -5,28 +5,57 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import Optional + class RoleModelResponse(BaseModel): - - role_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the role', example='admin') - role_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the role', example='Administrator role with full access') - manage_users: Optional[bool] = Field(None, description='Permission to manage users', example=True) - manage_apis: Optional[bool] = Field(None, description='Permission to manage APIs', example=True) - manage_endpoints: Optional[bool] = Field(None, description='Permission to manage endpoints', example=True) - manage_groups: Optional[bool] = Field(None, description='Permission to manage groups', example=True) - manage_roles: Optional[bool] = Field(None, description='Permission to manage roles', example=True) - manage_routings: Optional[bool] = Field(None, description='Permission to manage routings', example=True) - manage_gateway: Optional[bool] = Field(None, description='Permission to manage gateway', example=True) - manage_subscriptions: Optional[bool] = Field(None, description='Permission to manage subscriptions', example=True) - manage_security: Optional[bool] = Field(None, description='Permission to manage security settings', example=True) - manage_tiers: Optional[bool] = Field(None, description='Permission to manage pricing tiers', example=True) - manage_rate_limits: Optional[bool] = Field(None, description='Permission to manage rate limiting rules', example=True) - manage_credits: Optional[bool] = Field(None, description='Permission to manage API credits', example=True) - manage_auth: Optional[bool] = Field(None, description='Permission to manage auth (revoke tokens/disable users)', example=True) - view_analytics: Optional[bool] = Field(None, description='Permission to view analytics dashboard', example=True) - view_logs: Optional[bool] = Field(None, description='Permission to view logs', example=True) - export_logs: Optional[bool] = Field(None, description='Permission to export logs', example=True) + role_name: str | None = Field( + None, min_length=1, max_length=50, description='Name of the role', example='admin' + ) + role_description: str | None = Field( + None, + min_length=1, + max_length=255, + description='Description of the role', + example='Administrator role with full access', + ) + manage_users: bool | None = Field(None, description='Permission to manage users', example=True) + manage_apis: bool | None = Field(None, description='Permission to manage APIs', example=True) + manage_endpoints: bool | None = Field( + None, description='Permission to manage endpoints', example=True + ) + manage_groups: bool | None = Field( + None, description='Permission to manage groups', example=True + ) + manage_roles: bool | None = Field(None, description='Permission to manage roles', example=True) + manage_routings: bool | None = Field( + None, description='Permission to manage routings', example=True + ) + manage_gateway: bool | None = Field( + None, description='Permission to manage gateway', example=True + ) + manage_subscriptions: bool | None = Field( + None, description='Permission to manage subscriptions', example=True + ) + manage_security: bool | None = Field( + None, description='Permission to manage security settings', example=True + ) + manage_tiers: bool | None = Field( + None, description='Permission to manage pricing tiers', example=True + ) + manage_rate_limits: bool | None = Field( + None, description='Permission to manage rate limiting rules', example=True + ) + manage_credits: bool | None = Field( + None, description='Permission to manage API credits', example=True + ) + manage_auth: bool | None = Field( + None, description='Permission to manage auth (revoke tokens/disable users)', example=True + ) + view_analytics: bool | None = Field( + None, description='Permission to view analytics dashboard', example=True + ) + view_logs: bool | None = Field(None, description='Permission to view logs', example=True) + export_logs: bool | None = Field(None, description='Permission to export logs', example=True) class Config: arbitrary_types_allowed = True diff --git a/backend-services/models/routing_model_response.py b/backend-services/models/routing_model_response.py index 55a8cb8..e3f6c8e 100644 --- a/backend-services/models/routing_model_response.py +++ b/backend-services/models/routing_model_response.py @@ -5,16 +5,40 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import Optional + class RoutingModelResponse(BaseModel): + routing_name: str | None = Field( + None, + min_length=1, + max_length=50, + description='Name of the routing', + example='customer-routing', + ) + routing_servers: list[str] | None = Field( + None, + min_items=1, + description='List of backend servers for the routing', + example=['http://localhost:8080', 'http://localhost:8081'], + ) + routing_description: str | None = Field( + None, + min_length=1, + max_length=255, + description='Description of the routing', + example='Routing for customer API', + ) - routing_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the routing', example='customer-routing') - routing_servers : Optional[list[str]] = Field(None, min_items=1, description='List of backend servers for the routing', example=['http://localhost:8080', 'http://localhost:8081']) - routing_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the routing', example='Routing for customer API') - - client_key: Optional[str] = Field(None, min_length=1, max_length=50, description='Client key for the routing', example='client-1') - server_index: Optional[int] = Field(None, exclude=True, ge=0, description='Index of the server to route to', example=0) + client_key: str | None = Field( + None, + min_length=1, + max_length=50, + description='Client key for the routing', + example='client-1', + ) + server_index: int | None = Field( + None, exclude=True, ge=0, description='Index of the server to route to', example=0 + ) class Config: - arbitrary_types_allowed = True \ No newline at end of file + arbitrary_types_allowed = True diff --git a/backend-services/models/security_settings_model.py b/backend-services/models/security_settings_model.py index 358eaca..8a30d54 100644 --- a/backend-services/models/security_settings_model.py +++ b/backend-services/models/security_settings_model.py @@ -1,12 +1,24 @@ from pydantic import BaseModel, Field -from typing import Optional, List + class SecuritySettingsModel(BaseModel): - enable_auto_save: Optional[bool] = Field(default=None) - auto_save_frequency_seconds: Optional[int] = Field(default=None, ge=60, description='How often to auto-save memory dump (seconds)') - dump_path: Optional[str] = Field(default=None, description='Path to write encrypted memory dumps') - ip_whitelist: Optional[List[str]] = Field(default=None, description='List of allowed IPs/CIDRs. If non-empty, only these are allowed.') - ip_blacklist: Optional[List[str]] = Field(default=None, description='List of blocked IPs/CIDRs') - trust_x_forwarded_for: Optional[bool] = Field(default=None, description='If true, use X-Forwarded-For header for client IP') - xff_trusted_proxies: Optional[List[str]] = Field(default=None, description='IPs/CIDRs of proxies allowed to set client IP headers (XFF/X-Real-IP). Empty means trust all when enabled.') - allow_localhost_bypass: Optional[bool] = Field(default=None, description='Allow direct localhost (::1/127.0.0.1) to bypass IP allow/deny lists when no forwarding headers are present') + enable_auto_save: bool | None = Field(default=None) + auto_save_frequency_seconds: int | None = Field( + default=None, ge=60, description='How often to auto-save memory dump (seconds)' + ) + dump_path: str | None = Field(default=None, description='Path to write encrypted memory dumps') + ip_whitelist: list[str] | None = Field( + default=None, description='List of allowed IPs/CIDRs. If non-empty, only these are allowed.' + ) + ip_blacklist: list[str] | None = Field(default=None, description='List of blocked IPs/CIDRs') + trust_x_forwarded_for: bool | None = Field( + default=None, description='If true, use X-Forwarded-For header for client IP' + ) + xff_trusted_proxies: list[str] | None = Field( + default=None, + description='IPs/CIDRs of proxies allowed to set client IP headers (XFF/X-Real-IP). Empty means trust all when enabled.', + ) + allow_localhost_bypass: bool | None = Field( + default=None, + description='Allow direct localhost (::1/127.0.0.1) to bypass IP allow/deny lists when no forwarding headers are present', + ) diff --git a/backend-services/models/subscribe_model.py b/backend-services/models/subscribe_model.py index f110262..3247077 100644 --- a/backend-services/models/subscribe_model.py +++ b/backend-services/models/subscribe_model.py @@ -6,11 +6,29 @@ See https://github.com/pypeople-dev/doorman for more information from pydantic import BaseModel, Field -class SubscribeModel(BaseModel): - username: str = Field(..., min_length=3, max_length=50, description='Username of the subscriber', example='client-1') - api_name: str = Field(..., min_length=3, max_length=50, description='Name of the API to subscribe to', example='customer') - api_version: str = Field(..., min_length=1, max_length=5, description='Version of the API to subscribe to', example='v1') +class SubscribeModel(BaseModel): + username: str = Field( + ..., + min_length=3, + max_length=50, + description='Username of the subscriber', + example='client-1', + ) + api_name: str = Field( + ..., + min_length=3, + max_length=50, + description='Name of the API to subscribe to', + example='customer', + ) + api_version: str = Field( + ..., + min_length=1, + max_length=5, + description='Version of the API to subscribe to', + example='v1', + ) class Config: - arbitrary_types_allowed = True \ No newline at end of file + arbitrary_types_allowed = True diff --git a/backend-services/models/update_api_model.py b/backend-services/models/update_api_model.py index 1a64859..8d182c6 100644 --- a/backend-services/models/update_api_model.py +++ b/backend-services/models/update_api_model.py @@ -5,44 +5,120 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import List, Optional + class UpdateApiModel(BaseModel): + api_name: str | None = Field( + None, min_length=1, max_length=25, description='Name of the API', example='customer' + ) + api_version: str | None = Field( + None, min_length=1, max_length=8, description='Version of the API', example='v1' + ) + api_description: str | None = Field( + None, + min_length=1, + max_length=127, + description='Description of the API', + example='New customer onboarding API', + ) + api_allowed_roles: list[str] | None = Field( + None, description='Allowed user roles for the API', example=['admin', 'user'] + ) + api_allowed_groups: list[str] | None = Field( + None, description='Allowed user groups for the API', example=['admin', 'client-1-group'] + ) + api_servers: list[str] | None = Field( + None, + description='List of backend servers for the API', + example=['http://localhost:8080', 'http://localhost:8081'], + ) + api_type: str | None = Field( + None, description="Type of the API. Valid values: 'REST'", example='REST' + ) + api_authorization_field_swap: str | None = Field( + None, + description='Header to swap for backend authorization header', + example='backend-auth-header', + ) + api_allowed_headers: list[str] | None = Field( + None, description='Allowed headers for the API', example=['Content-Type', 'Authorization'] + ) + api_allowed_retry_count: int | None = Field( + None, description='Number of allowed retries for the API', example=0 + ) + api_grpc_package: str | None = Field( + None, + description='Optional gRPC Python package to use for this API (e.g., "my.pkg"). When set, overrides request package and default.', + example='my.pkg', + ) + api_grpc_allowed_packages: list[str] | None = Field( + None, + description='Allow-list of gRPC package/module base names (no dots). If set, requests must match one of these.', + example=['customer_v1'], + ) + api_grpc_allowed_services: list[str] | None = Field( + None, + description='Allow-list of gRPC service names (e.g., Greeter). If set, only these services are permitted.', + example=['Greeter'], + ) + api_grpc_allowed_methods: list[str] | None = Field( + None, + description='Allow-list of gRPC methods as Service.Method strings. If set, only these methods are permitted.', + example=['Greeter.SayHello'], + ) + api_credits_enabled: bool | None = Field( + False, description='Enable credit-based authentication for the API', example=True + ) + api_credit_group: str | None = Field( + None, description='API credit group for the API credits', example='ai-group-1' + ) + active: bool | None = Field(None, description='Whether the API is active (enabled)') + api_id: str | None = Field( + None, description='Unique identifier for the API, auto-generated', example=None + ) + api_path: str | None = Field( + None, description='Unqiue path for the API, auto-generated', example=None + ) - api_name: Optional[str] = Field(None, min_length=1, max_length=25, description='Name of the API', example='customer') - api_version: Optional[str] = Field(None, min_length=1, max_length=8, description='Version of the API', example='v1') - api_description: Optional[str] = Field(None, min_length=1, max_length=127, description='Description of the API', example='New customer onboarding API') - api_allowed_roles: Optional[List[str]] = Field(None, description='Allowed user roles for the API', example=['admin', 'user']) - api_allowed_groups: Optional[List[str]] = Field(None, description='Allowed user groups for the API' , example=['admin', 'client-1-group']) - api_servers: Optional[List[str]] = Field(None, description='List of backend servers for the API', example=['http://localhost:8080', 'http://localhost:8081']) - api_type: Optional[str] = Field(None, description="Type of the API. Valid values: 'REST'", example='REST') - api_authorization_field_swap: Optional[str] = Field(None, description='Header to swap for backend authorization header', example='backend-auth-header') - api_allowed_headers: Optional[List[str]] = Field(None, description='Allowed headers for the API', example=['Content-Type', 'Authorization']) - api_allowed_retry_count: Optional[int] = Field(None, description='Number of allowed retries for the API', example=0) - api_grpc_package: Optional[str] = Field(None, description='Optional gRPC Python package to use for this API (e.g., "my.pkg"). When set, overrides request package and default.', example='my.pkg') - api_grpc_allowed_packages: Optional[List[str]] = Field(None, description='Allow-list of gRPC package/module base names (no dots). If set, requests must match one of these.', example=['customer_v1']) - api_grpc_allowed_services: Optional[List[str]] = Field(None, description='Allow-list of gRPC service names (e.g., Greeter). If set, only these services are permitted.', example=['Greeter']) - api_grpc_allowed_methods: Optional[List[str]] = Field(None, description='Allow-list of gRPC methods as Service.Method strings. If set, only these methods are permitted.', example=['Greeter.SayHello']) - api_credits_enabled: Optional[bool] = Field(False, description='Enable credit-based authentication for the API', example=True) - api_credit_group: Optional[str] = Field(None, description='API credit group for the API credits', example='ai-group-1') - active: Optional[bool] = Field(None, description='Whether the API is active (enabled)') - api_id: Optional[str] = Field(None, description='Unique identifier for the API, auto-generated', example=None) - api_path: Optional[str] = Field(None, description='Unqiue path for the API, auto-generated', example=None) + api_cors_allow_origins: list[str] | None = Field( + None, + description="Allowed origins for CORS (e.g., ['http://localhost:3000']). Use ['*'] to allow all.", + ) + api_cors_allow_methods: list[str] | None = Field( + None, + description="Allowed methods for CORS preflight (e.g., ['GET','POST','PUT','DELETE','OPTIONS'])", + ) + api_cors_allow_headers: list[str] | None = Field( + None, + description="Allowed request headers for CORS preflight (e.g., ['Content-Type','Authorization'])", + ) + api_cors_allow_credentials: bool | None = Field( + None, description='Whether to include Access-Control-Allow-Credentials=true in responses' + ) + api_cors_expose_headers: list[str] | None = Field( + None, + description='Response headers to expose to the browser via Access-Control-Expose-Headers', + ) - api_cors_allow_origins: Optional[List[str]] = Field(None, description="Allowed origins for CORS (e.g., ['http://localhost:3000']). Use ['*'] to allow all.") - api_cors_allow_methods: Optional[List[str]] = Field(None, description="Allowed methods for CORS preflight (e.g., ['GET','POST','PUT','DELETE','OPTIONS'])") - api_cors_allow_headers: Optional[List[str]] = Field(None, description="Allowed request headers for CORS preflight (e.g., ['Content-Type','Authorization'])") - api_cors_allow_credentials: Optional[bool] = Field(None, description='Whether to include Access-Control-Allow-Credentials=true in responses') - api_cors_expose_headers: Optional[List[str]] = Field(None, description='Response headers to expose to the browser via Access-Control-Expose-Headers') + api_public: bool | None = Field( + None, description='If true, this API can be called without authentication or subscription' + ) - api_public: Optional[bool] = Field(None, description='If true, this API can be called without authentication or subscription') + api_auth_required: bool | None = Field( + None, + description='If true (default), JWT auth is required for this API when not public. If false, requests may be unauthenticated but must meet other checks as configured.', + ) - api_auth_required: Optional[bool] = Field(None, description='If true (default), JWT auth is required for this API when not public. If false, requests may be unauthenticated but must meet other checks as configured.') - - api_ip_mode: Optional[str] = Field(None, description="IP policy mode: 'allow_all' or 'whitelist'") - api_ip_whitelist: Optional[List[str]] = Field(None, description='Allowed IPs/CIDRs when api_ip_mode=whitelist') - api_ip_blacklist: Optional[List[str]] = Field(None, description='IPs/CIDRs denied regardless of mode') - api_trust_x_forwarded_for: Optional[bool] = Field(None, description='Override: trust X-Forwarded-For for this API') + api_ip_mode: str | None = Field(None, description="IP policy mode: 'allow_all' or 'whitelist'") + api_ip_whitelist: list[str] | None = Field( + None, description='Allowed IPs/CIDRs when api_ip_mode=whitelist' + ) + api_ip_blacklist: list[str] | None = Field( + None, description='IPs/CIDRs denied regardless of mode' + ) + api_trust_x_forwarded_for: bool | None = Field( + None, description='Override: trust X-Forwarded-For for this API' + ) class Config: arbitrary_types_allowed = True diff --git a/backend-services/models/update_endpoint_model.py b/backend-services/models/update_endpoint_model.py index ce48867..c559e2b 100644 --- a/backend-services/models/update_endpoint_model.py +++ b/backend-services/models/update_endpoint_model.py @@ -5,18 +5,47 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import Optional, List + class UpdateEndpointModel(BaseModel): - - api_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the API', example='customer') - api_version: Optional[str] = Field(None, min_length=1, max_length=10, description='Version of the API', example='v1') - endpoint_method: Optional[str] = Field(None, min_length=1, max_length=10, description='HTTP method for the endpoint', example='GET') - endpoint_uri: Optional[str] = Field(None, min_length=1, max_length=255, description='URI for the endpoint', example='/customer') - endpoint_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the endpoint', example='Get customer details') - endpoint_servers: Optional[List[str]] = Field(None, description='Optional list of backend servers for this endpoint (overrides API servers)', example=['http://localhost:8082', 'http://localhost:8083']) - api_id: Optional[str] = Field(None, min_length=1, max_length=255, description='Unique identifier for the API, auto-generated', example=None) - endpoint_id: Optional[str] = Field(None, min_length=1, max_length=255, description='Unique identifier for the endpoint, auto-generated', example=None) + api_name: str | None = Field( + None, min_length=1, max_length=50, description='Name of the API', example='customer' + ) + api_version: str | None = Field( + None, min_length=1, max_length=10, description='Version of the API', example='v1' + ) + endpoint_method: str | None = Field( + None, min_length=1, max_length=10, description='HTTP method for the endpoint', example='GET' + ) + endpoint_uri: str | None = Field( + None, min_length=1, max_length=255, description='URI for the endpoint', example='/customer' + ) + endpoint_description: str | None = Field( + None, + min_length=1, + max_length=255, + description='Description of the endpoint', + example='Get customer details', + ) + endpoint_servers: list[str] | None = Field( + None, + description='Optional list of backend servers for this endpoint (overrides API servers)', + example=['http://localhost:8082', 'http://localhost:8083'], + ) + api_id: str | None = Field( + None, + min_length=1, + max_length=255, + description='Unique identifier for the API, auto-generated', + example=None, + ) + endpoint_id: str | None = Field( + None, + min_length=1, + max_length=255, + description='Unique identifier for the endpoint, auto-generated', + example=None, + ) class Config: arbitrary_types_allowed = True diff --git a/backend-services/models/update_endpoint_validation_model.py b/backend-services/models/update_endpoint_validation_model.py index def3c12..08d97a1 100644 --- a/backend-services/models/update_endpoint_validation_model.py +++ b/backend-services/models/update_endpoint_validation_model.py @@ -8,10 +8,14 @@ from pydantic import BaseModel, Field from models.validation_schema_model import ValidationSchema -class UpdateEndpointValidationModel(BaseModel): - validation_enabled: bool = Field(..., description='Whether the validation is enabled', example=True) - validation_schema: ValidationSchema = Field(..., description='The schema to validate the endpoint against', example={}) +class UpdateEndpointValidationModel(BaseModel): + validation_enabled: bool = Field( + ..., description='Whether the validation is enabled', example=True + ) + validation_schema: ValidationSchema = Field( + ..., description='The schema to validate the endpoint against', example={} + ) class Config: - arbitrary_types_allowed = True \ No newline at end of file + arbitrary_types_allowed = True diff --git a/backend-services/models/update_group_model.py b/backend-services/models/update_group_model.py index 7552361..1b5d8c2 100644 --- a/backend-services/models/update_group_model.py +++ b/backend-services/models/update_group_model.py @@ -5,13 +5,22 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import List, Optional + class UpdateGroupModel(BaseModel): - - group_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the group', example='client-1-group') - group_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the group', example='Group for client 1') - api_access: Optional[List[str]] = Field(None, description='List of APIs the group can access', example=['customer/v1']) + group_name: str | None = Field( + None, min_length=1, max_length=50, description='Name of the group', example='client-1-group' + ) + group_description: str | None = Field( + None, + min_length=1, + max_length=255, + description='Description of the group', + example='Group for client 1', + ) + api_access: list[str] | None = Field( + None, description='List of APIs the group can access', example=['customer/v1'] + ) class Config: - arbitrary_types_allowed = True \ No newline at end of file + arbitrary_types_allowed = True diff --git a/backend-services/models/update_password_model.py b/backend-services/models/update_password_model.py index 12245c8..c280486 100644 --- a/backend-services/models/update_password_model.py +++ b/backend-services/models/update_password_model.py @@ -6,9 +6,15 @@ See https://github.com/pypeople-dev/doorman for more information from pydantic import BaseModel, Field -class UpdatePasswordModel(BaseModel): - new_password: str = Field(..., min_length=6, max_length=36, description='New password of the user', example='NewPassword456!') +class UpdatePasswordModel(BaseModel): + new_password: str = Field( + ..., + min_length=6, + max_length=36, + description='New password of the user', + example='NewPassword456!', + ) class Config: - arbitrary_types_allowed = True \ No newline at end of file + arbitrary_types_allowed = True diff --git a/backend-services/models/update_role_model.py b/backend-services/models/update_role_model.py index 4d869ea..84ef80a 100644 --- a/backend-services/models/update_role_model.py +++ b/backend-services/models/update_role_model.py @@ -5,28 +5,57 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import Optional + class UpdateRoleModel(BaseModel): - - role_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the role', example='admin') - role_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the role', example='Administrator role with full access') - manage_users: Optional[bool] = Field(None, description='Permission to manage users', example=True) - manage_apis: Optional[bool] = Field(None, description='Permission to manage APIs', example=True) - manage_endpoints: Optional[bool] = Field(None, description='Permission to manage endpoints', example=True) - manage_groups: Optional[bool] = Field(None, description='Permission to manage groups', example=True) - manage_roles: Optional[bool] = Field(None, description='Permission to manage roles', example=True) - manage_routings: Optional[bool] = Field(None, description='Permission to manage routings', example=True) - manage_gateway: Optional[bool] = Field(None, description='Permission to manage gateway', example=True) - manage_subscriptions: Optional[bool] = Field(None, description='Permission to manage subscriptions', example=True) - manage_security: Optional[bool] = Field(None, description='Permission to manage security settings', example=True) - manage_tiers: Optional[bool] = Field(None, description='Permission to manage pricing tiers', example=True) - manage_rate_limits: Optional[bool] = Field(None, description='Permission to manage rate limiting rules', example=True) - manage_credits: Optional[bool] = Field(None, description='Permission to manage credits', example=True) - manage_auth: Optional[bool] = Field(None, description='Permission to manage auth (revoke tokens/disable users)', example=True) - view_analytics: Optional[bool] = Field(None, description='Permission to view analytics dashboard', example=True) - view_logs: Optional[bool] = Field(None, description='Permission to view logs', example=True) - export_logs: Optional[bool] = Field(None, description='Permission to export logs', example=True) + role_name: str | None = Field( + None, min_length=1, max_length=50, description='Name of the role', example='admin' + ) + role_description: str | None = Field( + None, + min_length=1, + max_length=255, + description='Description of the role', + example='Administrator role with full access', + ) + manage_users: bool | None = Field(None, description='Permission to manage users', example=True) + manage_apis: bool | None = Field(None, description='Permission to manage APIs', example=True) + manage_endpoints: bool | None = Field( + None, description='Permission to manage endpoints', example=True + ) + manage_groups: bool | None = Field( + None, description='Permission to manage groups', example=True + ) + manage_roles: bool | None = Field(None, description='Permission to manage roles', example=True) + manage_routings: bool | None = Field( + None, description='Permission to manage routings', example=True + ) + manage_gateway: bool | None = Field( + None, description='Permission to manage gateway', example=True + ) + manage_subscriptions: bool | None = Field( + None, description='Permission to manage subscriptions', example=True + ) + manage_security: bool | None = Field( + None, description='Permission to manage security settings', example=True + ) + manage_tiers: bool | None = Field( + None, description='Permission to manage pricing tiers', example=True + ) + manage_rate_limits: bool | None = Field( + None, description='Permission to manage rate limiting rules', example=True + ) + manage_credits: bool | None = Field( + None, description='Permission to manage credits', example=True + ) + manage_auth: bool | None = Field( + None, description='Permission to manage auth (revoke tokens/disable users)', example=True + ) + view_analytics: bool | None = Field( + None, description='Permission to view analytics dashboard', example=True + ) + view_logs: bool | None = Field(None, description='Permission to view logs', example=True) + export_logs: bool | None = Field(None, description='Permission to export logs', example=True) class Config: arbitrary_types_allowed = True diff --git a/backend-services/models/update_routing_model.py b/backend-services/models/update_routing_model.py index 6e8824e..0042d1b 100644 --- a/backend-services/models/update_routing_model.py +++ b/backend-services/models/update_routing_model.py @@ -5,16 +5,40 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import Optional + class UpdateRoutingModel(BaseModel): + routing_name: str | None = Field( + None, + min_length=1, + max_length=50, + description='Name of the routing', + example='customer-routing', + ) + routing_servers: list[str] | None = Field( + None, + min_items=1, + description='List of backend servers for the routing', + example=['http://localhost:8080', 'http://localhost:8081'], + ) + routing_description: str | None = Field( + None, + min_length=1, + max_length=255, + description='Description of the routing', + example='Routing for customer API', + ) - routing_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the routing', example='customer-routing') - routing_servers : Optional[list[str]] = Field(None, min_items=1, description='List of backend servers for the routing', example=['http://localhost:8080', 'http://localhost:8081']) - routing_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the routing', example='Routing for customer API') - - client_key: Optional[str] = Field(None, min_length=1, max_length=50, description='Client key for the routing', example='client-1') - server_index: Optional[int] = Field(None, exclude=True, ge=0, description='Index of the server to route to', example=0) + client_key: str | None = Field( + None, + min_length=1, + max_length=50, + description='Client key for the routing', + example='client-1', + ) + server_index: int | None = Field( + None, exclude=True, ge=0, description='Index of the server to route to', example=0 + ) class Config: - arbitrary_types_allowed = True \ No newline at end of file + arbitrary_types_allowed = True diff --git a/backend-services/models/update_user_model.py b/backend-services/models/update_user_model.py index 6ab0e2b..c292da8 100644 --- a/backend-services/models/update_user_model.py +++ b/backend-services/models/update_user_model.py @@ -5,29 +5,90 @@ See https://github.com/pypeople-dev/doorman for more information """ from pydantic import BaseModel, Field -from typing import List, Optional + class UpdateUserModel(BaseModel): + username: str | None = Field( + None, min_length=3, max_length=50, description='Username of the user', example='john_doe' + ) + email: str | None = Field( + None, + min_length=3, + max_length=127, + description='Email of the user (no strict format validation)', + example='john@mail.com', + ) + password: str | None = Field( + None, + min_length=6, + max_length=50, + description='Password of the user', + example='SecurePassword@123', + ) + role: str | None = Field( + None, min_length=2, max_length=50, description='Role of the user', example='admin' + ) + groups: list[str] | None = Field( + None, description='List of groups the user belongs to', example=['client-1-group'] + ) + rate_limit_duration: int | None = Field( + None, ge=0, description='Rate limit for the user', example=100 + ) + rate_limit_duration_type: str | None = Field( + None, min_length=1, max_length=7, description='Duration for the rate limit', example='hour' + ) + rate_limit_enabled: bool | None = Field( + None, description='Whether rate limiting is enabled for this user', example=True + ) + throttle_duration: int | None = Field( + None, ge=0, description='Throttle limit for the user', example=10 + ) + throttle_duration_type: str | None = Field( + None, + min_length=1, + max_length=7, + description='Duration for the throttle limit', + example='second', + ) + throttle_wait_duration: int | None = Field( + None, ge=0, description='Wait time for the throttle limit', example=5 + ) + throttle_wait_duration_type: str | None = Field( + None, + min_length=1, + max_length=7, + description='Wait duration for the throttle limit', + example='seconds', + ) + throttle_queue_limit: int | None = Field( + None, ge=0, description='Throttle queue limit for the user', example=10 + ) + throttle_enabled: bool | None = Field( + None, description='Whether throttling is enabled for this user', example=True + ) + custom_attributes: dict | None = Field( + None, description='Custom attributes for the user', example={'custom_key': 'custom_value'} + ) + bandwidth_limit_bytes: int | None = Field( + None, + ge=0, + description='Maximum bandwidth allowed within the window (bytes)', + example=1073741824, + ) + bandwidth_limit_window: str | None = Field( + None, + min_length=1, + max_length=10, + description='Bandwidth window unit (second/minute/hour/day/month)', + example='day', + ) + bandwidth_limit_enabled: bool | None = Field( + None, + description='Whether bandwidth limit enforcement is enabled for this user', + example=True, + ) + active: bool | None = Field(None, description='Active status of the user', example=True) + ui_access: bool | None = Field(None, description='UI access for the user', example=False) - username: Optional[str] = Field(None, min_length=3, max_length=50, description='Username of the user', example='john_doe') - email: Optional[str] = Field(None, min_length=3, max_length=127, description='Email of the user (no strict format validation)', example='john@mail.com') - password: Optional[str] = Field(None, min_length=6, max_length=50, description='Password of the user', example='SecurePassword@123') - role: Optional[str] = Field(None, min_length=2, max_length=50, description='Role of the user', example='admin') - groups: Optional[List[str]] = Field(None, description='List of groups the user belongs to', example=['client-1-group']) - rate_limit_duration: Optional[int] = Field(None, ge=0, description='Rate limit for the user', example=100) - rate_limit_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Duration for the rate limit', example='hour') - rate_limit_enabled: Optional[bool] = Field(None, description='Whether rate limiting is enabled for this user', example=True) - throttle_duration: Optional[int] = Field(None, ge=0, description='Throttle limit for the user', example=10) - throttle_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Duration for the throttle limit', example='second') - throttle_wait_duration: Optional[int] = Field(None, ge=0, description='Wait time for the throttle limit', example=5) - throttle_wait_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Wait duration for the throttle limit', example='seconds') - throttle_queue_limit: Optional[int] = Field(None, ge=0, description='Throttle queue limit for the user', example=10) - throttle_enabled: Optional[bool] = Field(None, description='Whether throttling is enabled for this user', example=True) - custom_attributes: Optional[dict] = Field(None, description='Custom attributes for the user', example={'custom_key': 'custom_value'}) - bandwidth_limit_bytes: Optional[int] = Field(None, ge=0, description='Maximum bandwidth allowed within the window (bytes)', example=1073741824) - bandwidth_limit_window: Optional[str] = Field(None, min_length=1, max_length=10, description='Bandwidth window unit (second/minute/hour/day/month)', example='day') - bandwidth_limit_enabled: Optional[bool] = Field(None, description='Whether bandwidth limit enforcement is enabled for this user', example=True) - active: Optional[bool] = Field(None, description='Active status of the user', example=True) - ui_access: Optional[bool] = Field(None, description='UI access for the user', example=False) class Config: arbitrary_types_allowed = True diff --git a/backend-services/models/update_vault_entry_model.py b/backend-services/models/update_vault_entry_model.py index 355097a..698ad97 100644 --- a/backend-services/models/update_vault_entry_model.py +++ b/backend-services/models/update_vault_entry_model.py @@ -5,17 +5,16 @@ See https://github.com/apidoorman/doorman for more information """ from pydantic import BaseModel, Field -from typing import Optional class UpdateVaultEntryModel(BaseModel): """Model for updating a vault entry. Only description can be updated, not the value.""" - - description: Optional[str] = Field( + + description: str | None = Field( None, max_length=500, description='Updated description of what this key is used for', - example='Production API key for payment gateway - updated' + example='Production API key for payment gateway - updated', ) class Config: diff --git a/backend-services/models/user_credits_model.py b/backend-services/models/user_credits_model.py index 2b18353..7110a24 100644 --- a/backend-services/models/user_credits_model.py +++ b/backend-services/models/user_credits_model.py @@ -4,23 +4,39 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import Optional, Dict from pydantic import BaseModel, Field + class UserCreditInformationModel(BaseModel): - tier_name: str = Field(..., min_length=1, max_length=50, description='Name of the credit tier', example='basic') + tier_name: str = Field( + ..., min_length=1, max_length=50, description='Name of the credit tier', example='basic' + ) available_credits: int = Field(..., description='Number of available credits', example=50) - reset_date: Optional[str] = Field(None, description='Date when paid credits are reset', example='2023-10-01') - user_api_key: Optional[str] = Field(None, description='User specific API key for the credit tier', example='xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx') + reset_date: str | None = Field( + None, description='Date when paid credits are reset', example='2023-10-01' + ) + user_api_key: str | None = Field( + None, + description='User specific API key for the credit tier', + example='xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', + ) class Config: arbitrary_types_allowed = True + class UserCreditModel(BaseModel): - username: str = Field(..., min_length=3, max_length=50, description='Username of credits owner', example='client-1') - users_credits: Dict[str, UserCreditInformationModel] = Field(..., description='Credits information. Key is the credit group name') + username: str = Field( + ..., + min_length=3, + max_length=50, + description='Username of credits owner', + example='client-1', + ) + users_credits: dict[str, UserCreditInformationModel] = Field( + ..., description='Credits information. Key is the credit group name' + ) class Config: arbitrary_types_allowed = True - diff --git a/backend-services/models/user_model_response.py b/backend-services/models/user_model_response.py index 2018303..b438c3b 100644 --- a/backend-services/models/user_model_response.py +++ b/backend-services/models/user_model_response.py @@ -4,33 +4,96 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/pypeople-dev/doorman for more information """ -from pydantic import BaseModel, Field, EmailStr -from typing import List, Optional +from pydantic import BaseModel, EmailStr, Field + class UserModelResponse(BaseModel): - - username: Optional[str] = Field(None, min_length=3, max_length=50, description='Username of the user', example='john_doe') - email: Optional[EmailStr] = Field(None, description='Email of the user', example='john@mail.com') - password: Optional[str] = Field(None, min_length=6, max_length=50, description='Password of the user', example='SecurePassword@123') - role: Optional[str] = Field(None, min_length=2, max_length=50, description='Role of the user', example='admin') - groups: Optional[List[str]] = Field(None, description='List of groups the user belongs to', example=['client-1-group']) - rate_limit_duration: Optional[int] = Field(None, ge=0, description='Rate limit for the user', example=100) - rate_limit_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Duration for the rate limit', example='hour') - rate_limit_enabled: Optional[bool] = Field(None, description='Whether rate limiting is enabled for this user', example=True) - throttle_duration: Optional[int] = Field(None, ge=0, description='Throttle limit for the user', example=10) - throttle_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Duration for the throttle limit', example='second') - throttle_wait_duration: Optional[int] = Field(None, ge=0, description='Wait time for the throttle limit', example=5) - throttle_wait_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Wait duration for the throttle limit', example='seconds') - throttle_queue_limit: Optional[int] = Field(None, ge=0, description='Throttle queue limit for the user', example=10) - throttle_enabled: Optional[bool] = Field(None, description='Whether throttling is enabled for this user', example=True) - custom_attributes: Optional[dict] = Field(None, description='Custom attributes for the user', example={'custom_key': 'custom_value'}) - bandwidth_limit_bytes: Optional[int] = Field(None, ge=0, description='Maximum bandwidth allowed within the window (bytes)', example=1073741824) - bandwidth_limit_window: Optional[str] = Field(None, min_length=1, max_length=10, description='Bandwidth window unit (second/minute/hour/day/month)', example='day') - bandwidth_usage_bytes: Optional[int] = Field(None, ge=0, description='Current bandwidth usage in the active window (bytes)', example=123456) - bandwidth_resets_at: Optional[int] = Field(None, description='UTC epoch seconds when the current bandwidth window resets', example=1727481600) - bandwidth_limit_enabled: Optional[bool] = Field(None, description='Whether bandwidth limit enforcement is enabled for this user', example=True) - active: Optional[bool] = Field(None, description='Active status of the user', example=True) - ui_access: Optional[bool] = Field(None, description='UI access for the user', example=False) + username: str | None = Field( + None, min_length=3, max_length=50, description='Username of the user', example='john_doe' + ) + email: EmailStr | None = Field(None, description='Email of the user', example='john@mail.com') + password: str | None = Field( + None, + min_length=6, + max_length=50, + description='Password of the user', + example='SecurePassword@123', + ) + role: str | None = Field( + None, min_length=2, max_length=50, description='Role of the user', example='admin' + ) + groups: list[str] | None = Field( + None, description='List of groups the user belongs to', example=['client-1-group'] + ) + rate_limit_duration: int | None = Field( + None, ge=0, description='Rate limit for the user', example=100 + ) + rate_limit_duration_type: str | None = Field( + None, min_length=1, max_length=7, description='Duration for the rate limit', example='hour' + ) + rate_limit_enabled: bool | None = Field( + None, description='Whether rate limiting is enabled for this user', example=True + ) + throttle_duration: int | None = Field( + None, ge=0, description='Throttle limit for the user', example=10 + ) + throttle_duration_type: str | None = Field( + None, + min_length=1, + max_length=7, + description='Duration for the throttle limit', + example='second', + ) + throttle_wait_duration: int | None = Field( + None, ge=0, description='Wait time for the throttle limit', example=5 + ) + throttle_wait_duration_type: str | None = Field( + None, + min_length=1, + max_length=7, + description='Wait duration for the throttle limit', + example='seconds', + ) + throttle_queue_limit: int | None = Field( + None, ge=0, description='Throttle queue limit for the user', example=10 + ) + throttle_enabled: bool | None = Field( + None, description='Whether throttling is enabled for this user', example=True + ) + custom_attributes: dict | None = Field( + None, description='Custom attributes for the user', example={'custom_key': 'custom_value'} + ) + bandwidth_limit_bytes: int | None = Field( + None, + ge=0, + description='Maximum bandwidth allowed within the window (bytes)', + example=1073741824, + ) + bandwidth_limit_window: str | None = Field( + None, + min_length=1, + max_length=10, + description='Bandwidth window unit (second/minute/hour/day/month)', + example='day', + ) + bandwidth_usage_bytes: int | None = Field( + None, + ge=0, + description='Current bandwidth usage in the active window (bytes)', + example=123456, + ) + bandwidth_resets_at: int | None = Field( + None, + description='UTC epoch seconds when the current bandwidth window resets', + example=1727481600, + ) + bandwidth_limit_enabled: bool | None = Field( + None, + description='Whether bandwidth limit enforcement is enabled for this user', + example=True, + ) + active: bool | None = Field(None, description='Active status of the user', example=True) + ui_access: bool | None = Field(None, description='UI access for the user', example=False) class Config: arbitrary_types_allowed = True diff --git a/backend-services/models/validation_schema_model.py b/backend-services/models/validation_schema_model.py index 9ee60e7..a75c3b8 100644 --- a/backend-services/models/validation_schema_model.py +++ b/backend-services/models/validation_schema_model.py @@ -4,11 +4,11 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/pypeople-dev/doorman for more information """ -from typing import Dict from pydantic import BaseModel, Field from models.field_validation_model import FieldValidation + class ValidationSchema(BaseModel): """Validation schema for endpoint request/response validation. @@ -57,21 +57,12 @@ class ValidationSchema(BaseModel): } } """ - validation_schema: Dict[str, FieldValidation] = Field( + + validation_schema: dict[str, FieldValidation] = Field( ..., description='The schema to validate the endpoint against', example={ - 'user.name': { - 'required': True, - 'type': 'string', - 'min': 2, - 'max': 50 - }, - 'user.age': { - 'required': True, - 'type': 'number', - 'min': 0, - 'max': 120 - } - } - ) \ No newline at end of file + 'user.name': {'required': True, 'type': 'string', 'min': 2, 'max': 50}, + 'user.age': {'required': True, 'type': 'number', 'min': 0, 'max': 120}, + }, + ) diff --git a/backend-services/models/vault_entry_model_response.py b/backend-services/models/vault_entry_model_response.py index de6a96f..ae19313 100644 --- a/backend-services/models/vault_entry_model_response.py +++ b/backend-services/models/vault_entry_model_response.py @@ -5,40 +5,29 @@ See https://github.com/apidoorman/doorman for more information """ from pydantic import BaseModel, Field -from typing import Optional class VaultEntryModelResponse(BaseModel): """Response model for vault entry. Value is never returned.""" - - key_name: str = Field( - ..., - description='Name of the vault key', - example='api_key_production' - ) - - username: str = Field( - ..., - description='Username of the vault entry owner', - example='john_doe' - ) - - description: Optional[str] = Field( + + key_name: str = Field(..., description='Name of the vault key', example='api_key_production') + + username: str = Field(..., description='Username of the vault entry owner', example='john_doe') + + description: str | None = Field( None, description='Description of what this key is used for', - example='Production API key for payment gateway' + example='Production API key for payment gateway', ) - - created_at: Optional[str] = Field( - None, - description='Timestamp when the entry was created', - example='2024-11-22T10:15:30Z' + + created_at: str | None = Field( + None, description='Timestamp when the entry was created', example='2024-11-22T10:15:30Z' ) - - updated_at: Optional[str] = Field( + + updated_at: str | None = Field( None, description='Timestamp when the entry was last updated', - example='2024-11-22T10:15:30Z' + example='2024-11-22T10:15:30Z', ) class Config: diff --git a/backend-services/routes/analytics_routes.py b/backend-services/routes/analytics_routes.py index fcd91c3..d6bad7b 100644 --- a/backend-services/routes/analytics_routes.py +++ b/backend-services/routes/analytics_routes.py @@ -9,20 +9,17 @@ Provides comprehensive analytics endpoints for: - Endpoint performance analysis """ -from fastapi import APIRouter, Request, Query, HTTPException -from pydantic import BaseModel -from typing import Optional, List -import uuid -import time import logging -from datetime import datetime, timedelta +import time +import uuid + +from fastapi import APIRouter, Query, Request from models.response_model import ResponseModel -from utils.response_util import respond_rest, process_response from utils.auth_util import auth_required -from utils.role_util import platform_role_required_bool from utils.enhanced_metrics_util import enhanced_metrics_store -from utils.analytics_aggregator import analytics_aggregator +from utils.response_util import respond_rest +from utils.role_util import platform_role_required_bool analytics_router = APIRouter() logger = logging.getLogger('doorman.analytics') @@ -32,7 +29,9 @@ logger = logging.getLogger('doorman.analytics') # ENDPOINT 1: Dashboard Overview # ============================================================================ -@analytics_router.get('/analytics/overview', + +@analytics_router.get( + '/analytics/overview', description='Get dashboard overview statistics', response_model=ResponseModel, responses={ @@ -50,33 +49,33 @@ logger = logging.getLogger('doorman.analytics') 'p75': 180.0, 'p90': 250.0, 'p95': 300.0, - 'p99': 450.0 + 'p99': 450.0, }, 'unique_users': 150, 'total_bandwidth': 1073741824, 'top_apis': [ {'api': 'rest:customer', 'count': 5000}, - {'api': 'rest:orders', 'count': 3000} + {'api': 'rest:orders', 'count': 3000}, ], 'top_users': [ {'user': 'john_doe', 'count': 500}, - {'user': 'jane_smith', 'count': 300} - ] + {'user': 'jane_smith', 'count': 300}, + ], } } - } + }, } - } + }, ) async def get_analytics_overview( request: Request, - start_ts: Optional[int] = Query(None, description='Start timestamp (Unix seconds)'), - end_ts: Optional[int] = Query(None, description='End timestamp (Unix seconds)'), - range: Optional[str] = Query('24h', description='Time range (1h, 24h, 7d, 30d)') + start_ts: int | None = Query(None, description='Start timestamp (Unix seconds)'), + end_ts: int | None = Query(None, description='End timestamp (Unix seconds)'), + range: str | None = Query('24h', description='Time range (1h, 24h, 7d, 30d)'), ): """ Get dashboard overview statistics. - + Returns summary metrics including: - Total requests and errors - Error rate @@ -89,23 +88,27 @@ async def get_analytics_overview( """ request_id = str(uuid.uuid4()) start_time = time.time() - + try: # Authentication and authorization payload = await auth_required(request) username = payload.get('sub') - - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - + if not await platform_role_required_bool(username, 'view_analytics'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='ANALYTICS001', - error_message='You do not have permission to view analytics' - )) - + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='ANALYTICS001', + error_message='You do not have permission to view analytics', + ) + ) + # Determine time range if start_ts and end_ts: # Use provided timestamps @@ -113,24 +116,19 @@ async def get_analytics_overview( else: # Use range parameter end_ts = int(time.time()) - range_map = { - '1h': 3600, - '24h': 86400, - '7d': 604800, - '30d': 2592000 - } + range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000} seconds = range_map.get(range, 86400) start_ts = end_ts - seconds - + # Get analytics snapshot snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts) - + # Build response overview = { 'time_range': { 'start_ts': start_ts, 'end_ts': end_ts, - 'duration_seconds': end_ts - start_ts + 'duration_seconds': end_ts - start_ts, }, 'summary': { 'total_requests': snapshot.total_requests, @@ -140,29 +138,31 @@ async def get_analytics_overview( 'unique_users': snapshot.unique_users, 'total_bandwidth': snapshot.total_bytes_in + snapshot.total_bytes_out, 'bandwidth_in': snapshot.total_bytes_in, - 'bandwidth_out': snapshot.total_bytes_out + 'bandwidth_out': snapshot.total_bytes_out, }, 'percentiles': snapshot.percentiles.to_dict(), 'top_apis': snapshot.top_apis, 'top_users': snapshot.top_users, - 'status_distribution': snapshot.status_distribution + 'status_distribution': snapshot.status_distribution, } - - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response=overview - )) - + + return respond_rest( + ResponseModel( + status_code=200, response_headers={'request_id': request_id}, response=overview + ) + ) + except Exception as e: logger.error(f'{request_id} | Error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='ANALYTICS999', - error_message='An unexpected error occurred' - )) - + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='ANALYTICS999', + error_message='An unexpected error occurred', + ) + ) + finally: end_time = time.time() logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms') @@ -172,21 +172,28 @@ async def get_analytics_overview( # ENDPOINT 2: Time-Series Data # ============================================================================ -@analytics_router.get('/analytics/timeseries', + +@analytics_router.get( + '/analytics/timeseries', description='Get time-series analytics data with filtering', - response_model=ResponseModel + response_model=ResponseModel, ) async def get_analytics_timeseries( request: Request, - start_ts: Optional[int] = Query(None, description='Start timestamp (Unix seconds)'), - end_ts: Optional[int] = Query(None, description='End timestamp (Unix seconds)'), - range: Optional[str] = Query('24h', description='Time range (1h, 24h, 7d, 30d)'), - granularity: Optional[str] = Query('auto', description='Data granularity (auto, minute, 5minute, hour, day)'), - metric_type: Optional[str] = Query(None, description='Specific metric to return (request_count, error_rate, latency, bandwidth)') + start_ts: int | None = Query(None, description='Start timestamp (Unix seconds)'), + end_ts: int | None = Query(None, description='End timestamp (Unix seconds)'), + range: str | None = Query('24h', description='Time range (1h, 24h, 7d, 30d)'), + granularity: str | None = Query( + 'auto', description='Data granularity (auto, minute, 5minute, hour, day)' + ), + metric_type: str | None = Query( + None, + description='Specific metric to return (request_count, error_rate, latency, bandwidth)', + ), ): """ Get time-series analytics data. - + Returns series of data points over time with: - Timestamp - Request count @@ -198,19 +205,21 @@ async def get_analytics_timeseries( """ request_id = str(uuid.uuid4()) start_time = time.time() - + try: payload = await auth_required(request) username = payload.get('sub') - + if not await platform_role_required_bool(username, 'view_analytics'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='ANALYTICS001', - error_message='You do not have permission to view analytics' - )) - + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='ANALYTICS001', + error_message='You do not have permission to view analytics', + ) + ) + # Determine time range if start_ts and end_ts: pass @@ -219,10 +228,10 @@ async def get_analytics_timeseries( range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000} seconds = range_map.get(range, 86400) start_ts = end_ts - seconds - + # Get snapshot with time-series data snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts, granularity) - + # Filter by metric type if specified series = snapshot.series if metric_type: @@ -243,27 +252,31 @@ async def get_analytics_timeseries( filtered_point['bytes_out'] = point['bytes_out'] filtered_series.append(filtered_point) series = filtered_series - - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response={ - 'time_range': {'start_ts': start_ts, 'end_ts': end_ts}, - 'granularity': granularity, - 'series': series, - 'data_points': len(series) - } - )) - + + return respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + response={ + 'time_range': {'start_ts': start_ts, 'end_ts': end_ts}, + 'granularity': granularity, + 'series': series, + 'data_points': len(series), + }, + ) + ) + except Exception as e: logger.error(f'{request_id} | Error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='ANALYTICS999', - error_message='An unexpected error occurred' - )) - + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='ANALYTICS999', + error_message='An unexpected error occurred', + ) + ) + finally: end_time = time.time() logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms') @@ -273,20 +286,20 @@ async def get_analytics_timeseries( # ENDPOINT 3: Top APIs # ============================================================================ -@analytics_router.get('/analytics/top-apis', - description='Get most used APIs', - response_model=ResponseModel + +@analytics_router.get( + '/analytics/top-apis', description='Get most used APIs', response_model=ResponseModel ) async def get_top_apis( request: Request, - start_ts: Optional[int] = Query(None), - end_ts: Optional[int] = Query(None), - range: Optional[str] = Query('24h'), - limit: int = Query(10, ge=1, le=100, description='Number of APIs to return') + start_ts: int | None = Query(None), + end_ts: int | None = Query(None), + range: str | None = Query('24h'), + limit: int = Query(10, ge=1, le=100, description='Number of APIs to return'), ): """ Get top N most used APIs. - + Returns list of APIs sorted by request count with: - API name - Total requests @@ -296,19 +309,21 @@ async def get_top_apis( """ request_id = str(uuid.uuid4()) start_time = time.time() - + try: payload = await auth_required(request) username = payload.get('sub') - + if not await platform_role_required_bool(username, 'view_analytics'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='ANALYTICS001', - error_message='You do not have permission to view analytics' - )) - + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='ANALYTICS001', + error_message='You do not have permission to view analytics', + ) + ) + # Determine time range if start_ts and end_ts: pass @@ -317,32 +332,36 @@ async def get_top_apis( range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000} seconds = range_map.get(range, 86400) start_ts = end_ts - seconds - + # Get snapshot snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts) - + # Get top APIs (already sorted by count) top_apis = snapshot.top_apis[:limit] - - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response={ - 'time_range': {'start_ts': start_ts, 'end_ts': end_ts}, - 'top_apis': top_apis, - 'total_apis': len(snapshot.top_apis) - } - )) - + + return respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + response={ + 'time_range': {'start_ts': start_ts, 'end_ts': end_ts}, + 'top_apis': top_apis, + 'total_apis': len(snapshot.top_apis), + }, + ) + ) + except Exception as e: logger.error(f'{request_id} | Error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='ANALYTICS999', - error_message='An unexpected error occurred' - )) - + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='ANALYTICS999', + error_message='An unexpected error occurred', + ) + ) + finally: end_time = time.time() logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms') @@ -352,37 +371,39 @@ async def get_top_apis( # ENDPOINT 4: Top Users # ============================================================================ -@analytics_router.get('/analytics/top-users', - description='Get highest consuming users', - response_model=ResponseModel + +@analytics_router.get( + '/analytics/top-users', description='Get highest consuming users', response_model=ResponseModel ) async def get_top_users( request: Request, - start_ts: Optional[int] = Query(None), - end_ts: Optional[int] = Query(None), - range: Optional[str] = Query('24h'), - limit: int = Query(10, ge=1, le=100, description='Number of users to return') + start_ts: int | None = Query(None), + end_ts: int | None = Query(None), + range: str | None = Query('24h'), + limit: int = Query(10, ge=1, le=100, description='Number of users to return'), ): """ Get top N highest consuming users. - + Returns list of users sorted by request count. """ request_id = str(uuid.uuid4()) start_time = time.time() - + try: payload = await auth_required(request) username = payload.get('sub') - + if not await platform_role_required_bool(username, 'view_analytics'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='ANALYTICS001', - error_message='You do not have permission to view analytics' - )) - + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='ANALYTICS001', + error_message='You do not have permission to view analytics', + ) + ) + # Determine time range if start_ts and end_ts: pass @@ -391,32 +412,36 @@ async def get_top_users( range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000} seconds = range_map.get(range, 86400) start_ts = end_ts - seconds - + # Get snapshot snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts) - + # Get top users (already sorted by count) top_users = snapshot.top_users[:limit] - - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response={ - 'time_range': {'start_ts': start_ts, 'end_ts': end_ts}, - 'top_users': top_users, - 'total_users': len(snapshot.top_users) - } - )) - + + return respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + response={ + 'time_range': {'start_ts': start_ts, 'end_ts': end_ts}, + 'top_users': top_users, + 'total_users': len(snapshot.top_users), + }, + ) + ) + except Exception as e: logger.error(f'{request_id} | Error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='ANALYTICS999', - error_message='An unexpected error occurred' - )) - + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='ANALYTICS999', + error_message='An unexpected error occurred', + ) + ) + finally: end_time = time.time() logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms') @@ -426,21 +451,23 @@ async def get_top_users( # ENDPOINT 5: Top Endpoints # ============================================================================ -@analytics_router.get('/analytics/top-endpoints', + +@analytics_router.get( + '/analytics/top-endpoints', description='Get slowest/most-used endpoints', - response_model=ResponseModel + response_model=ResponseModel, ) async def get_top_endpoints( request: Request, - start_ts: Optional[int] = Query(None), - end_ts: Optional[int] = Query(None), - range: Optional[str] = Query('24h'), + start_ts: int | None = Query(None), + end_ts: int | None = Query(None), + range: str | None = Query('24h'), sort_by: str = Query('count', description='Sort by: count, avg_ms, error_rate'), - limit: int = Query(10, ge=1, le=100) + limit: int = Query(10, ge=1, le=100), ): """ Get top endpoints sorted by usage or performance. - + Returns detailed per-endpoint metrics including: - Request count - Error count and rate @@ -449,19 +476,21 @@ async def get_top_endpoints( """ request_id = str(uuid.uuid4()) start_time = time.time() - + try: payload = await auth_required(request) username = payload.get('sub') - + if not await platform_role_required_bool(username, 'view_analytics'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='ANALYTICS001', - error_message='You do not have permission to view analytics' - )) - + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='ANALYTICS001', + error_message='You do not have permission to view analytics', + ) + ) + # Determine time range if start_ts and end_ts: pass @@ -470,41 +499,45 @@ async def get_top_endpoints( range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000} seconds = range_map.get(range, 86400) start_ts = end_ts - seconds - + # Get snapshot snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts) - + # Get and sort endpoints endpoints = snapshot.top_endpoints - + if sort_by == 'avg_ms': endpoints.sort(key=lambda x: x['avg_ms'], reverse=True) elif sort_by == 'error_rate': endpoints.sort(key=lambda x: x['error_rate'], reverse=True) # Default is already sorted by count - + top_endpoints = endpoints[:limit] - - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response={ - 'time_range': {'start_ts': start_ts, 'end_ts': end_ts}, - 'sort_by': sort_by, - 'top_endpoints': top_endpoints, - 'total_endpoints': len(endpoints) - } - )) - + + return respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + response={ + 'time_range': {'start_ts': start_ts, 'end_ts': end_ts}, + 'sort_by': sort_by, + 'top_endpoints': top_endpoints, + 'total_endpoints': len(endpoints), + }, + ) + ) + except Exception as e: logger.error(f'{request_id} | Error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='ANALYTICS999', - error_message='An unexpected error occurred' - )) - + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='ANALYTICS999', + error_message='An unexpected error occurred', + ) + ) + finally: end_time = time.time() logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms') @@ -514,21 +547,23 @@ async def get_top_endpoints( # ENDPOINT 6: Per-API Breakdown # ============================================================================ -@analytics_router.get('/analytics/api/{api_name}/{version}', + +@analytics_router.get( + '/analytics/api/{api_name}/{version}', description='Get detailed analytics for a specific API', - response_model=ResponseModel + response_model=ResponseModel, ) async def get_api_analytics( request: Request, api_name: str, version: str, - start_ts: Optional[int] = Query(None), - end_ts: Optional[int] = Query(None), - range: Optional[str] = Query('24h') + start_ts: int | None = Query(None), + end_ts: int | None = Query(None), + range: str | None = Query('24h'), ): """ Get detailed analytics for a specific API. - + Returns: - Total requests for this API - Error count and rate @@ -538,19 +573,21 @@ async def get_api_analytics( """ request_id = str(uuid.uuid4()) start_time = time.time() - + try: payload = await auth_required(request) username = payload.get('sub') - + if not await platform_role_required_bool(username, 'view_analytics'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='ANALYTICS001', - error_message='You do not have permission to view analytics' - )) - + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='ANALYTICS001', + error_message='You do not have permission to view analytics', + ) + ) + # Determine time range if start_ts and end_ts: pass @@ -559,55 +596,62 @@ async def get_api_analytics( range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000} seconds = range_map.get(range, 86400) start_ts = end_ts - seconds - + # Get full snapshot snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts) - + # Filter for this API - api_key = f"rest:{api_name}" # Assuming REST API - + api_key = f'rest:{api_name}' # Assuming REST API + # Find API in top_apis api_data = None for api, count in snapshot.top_apis: if api == api_key: api_data = {'api': api, 'count': count} break - + if not api_data: - return respond_rest(ResponseModel( - status_code=404, - response_headers={'request_id': request_id}, - error_code='ANALYTICS404', - error_message=f'No data found for API: {api_name}/{version}' - )) - + return respond_rest( + ResponseModel( + status_code=404, + response_headers={'request_id': request_id}, + error_code='ANALYTICS404', + error_message=f'No data found for API: {api_name}/{version}', + ) + ) + # Filter endpoints for this API api_endpoints = [ - ep for ep in snapshot.top_endpoints + ep + for ep in snapshot.top_endpoints if ep['endpoint_uri'].startswith(f'/{api_name}/{version}') ] - - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response={ - 'api_name': api_name, - 'version': version, - 'time_range': {'start_ts': start_ts, 'end_ts': end_ts}, - 'summary': api_data, - 'endpoints': api_endpoints - } - )) - + + return respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + response={ + 'api_name': api_name, + 'version': version, + 'time_range': {'start_ts': start_ts, 'end_ts': end_ts}, + 'summary': api_data, + 'endpoints': api_endpoints, + }, + ) + ) + except Exception as e: logger.error(f'{request_id} | Error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='ANALYTICS999', - error_message='An unexpected error occurred' - )) - + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='ANALYTICS999', + error_message='An unexpected error occurred', + ) + ) + finally: end_time = time.time() logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms') @@ -617,20 +661,22 @@ async def get_api_analytics( # ENDPOINT 7: Per-User Breakdown # ============================================================================ -@analytics_router.get('/analytics/user/{username}', + +@analytics_router.get( + '/analytics/user/{username}', description='Get detailed analytics for a specific user', - response_model=ResponseModel + response_model=ResponseModel, ) async def get_user_analytics( request: Request, username: str, - start_ts: Optional[int] = Query(None), - end_ts: Optional[int] = Query(None), - range: Optional[str] = Query('24h') + start_ts: int | None = Query(None), + end_ts: int | None = Query(None), + range: str | None = Query('24h'), ): """ Get detailed analytics for a specific user. - + Returns: - Total requests by this user - APIs accessed @@ -638,19 +684,21 @@ async def get_user_analytics( """ request_id = str(uuid.uuid4()) start_time = time.time() - + try: payload = await auth_required(request) requesting_username = payload.get('sub') - + if not await platform_role_required_bool(requesting_username, 'view_analytics'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='ANALYTICS001', - error_message='You do not have permission to view analytics' - )) - + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='ANALYTICS001', + error_message='You do not have permission to view analytics', + ) + ) + # Determine time range if start_ts and end_ts: pass @@ -659,44 +707,50 @@ async def get_user_analytics( range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000} seconds = range_map.get(range, 86400) start_ts = end_ts - seconds - + # Get full snapshot snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts) - + # Find user in top_users user_data = None for user, count in snapshot.top_users: if user == username: user_data = {'user': user, 'count': count} break - + if not user_data: - return respond_rest(ResponseModel( - status_code=404, + return respond_rest( + ResponseModel( + status_code=404, + response_headers={'request_id': request_id}, + error_code='ANALYTICS404', + error_message=f'No data found for user: {username}', + ) + ) + + return respond_rest( + ResponseModel( + status_code=200, response_headers={'request_id': request_id}, - error_code='ANALYTICS404', - error_message=f'No data found for user: {username}' - )) - - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response={ - 'username': username, - 'time_range': {'start_ts': start_ts, 'end_ts': end_ts}, - 'summary': user_data - } - )) - + response={ + 'username': username, + 'time_range': {'start_ts': start_ts, 'end_ts': end_ts}, + 'summary': user_data, + }, + ) + ) + except Exception as e: logger.error(f'{request_id} | Error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='ANALYTICS999', - error_message='An unexpected error occurred' - )) - + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='ANALYTICS999', + error_message='An unexpected error occurred', + ) + ) + finally: end_time = time.time() logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms') diff --git a/backend-services/routes/api_routes.py b/backend-services/routes/api_routes.py index c3e4a53..4114336 100644 --- a/backend-services/routes/api_routes.py +++ b/backend-services/routes/api_routes.py @@ -4,22 +4,22 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from fastapi import APIRouter, Depends, Request, HTTPException -from typing import List import logging -import uuid import time +import uuid + +from fastapi import APIRouter, HTTPException, Request, Response -from models.response_model import ResponseModel -from services.api_service import ApiService -from utils.auth_util import auth_required -from models.create_api_model import CreateApiModel -from models.update_api_model import UpdateApiModel from models.api_model_response import ApiModelResponse -from utils.response_util import respond_rest, process_response -from utils.constants import ErrorCodes, Messages, Defaults, Roles, Headers -from utils.role_util import platform_role_required_bool +from models.create_api_model import CreateApiModel +from models.response_model import ResponseModel +from models.update_api_model import UpdateApiModel +from services.api_service import ApiService from utils.audit_util import audit +from utils.auth_util import auth_required +from utils.constants import ErrorCodes, Headers, Messages, Roles +from utils.response_util import process_response, respond_rest +from utils.role_util import platform_role_required_bool api_router = APIRouter() logger = logging.getLogger('doorman.gateway') @@ -33,58 +33,65 @@ Response: {} """ -@api_router.post('', + +@api_router.post( + '', description='Add API', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'API created successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'API created successfully'}}}, } - } + }, ) - -async def create_api(request: Request, api_data: CreateApiModel): +async def create_api(request: Request, api_data: CreateApiModel) -> Response: payload = await auth_required(request) username = payload.get('sub') request_id = str(uuid.uuid4()) start_time = time.time() * 1000 - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') try: if not await platform_role_required_bool(username, Roles.MANAGE_APIS): logger.warning(f'{request_id} | Permission denied for user: {username}') - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='API007', - error_message='You do not have permission to create APIs' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='API007', + error_message='You do not have permission to create APIs', + ) + ) result = await ApiService.create_api(api_data, request_id) - audit(request, actor=username, action='api.create', target=f'{api_data.api_name}/{api_data.api_version}', status=result.get('status_code'), details={'message': result.get('message')}, request_id=request_id) + audit( + request, + actor=username, + action='api.create', + target=f'{api_data.api_name}/{api_data.api_version}', + status=result.get('status_code'), + details={'message': result.get('message')}, + request_id=request_id, + ) return respond_rest(result) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Update API @@ -94,57 +101,66 @@ Response: {} """ -@api_router.put('/{api_name}/{api_version}', + +@api_router.put( + '/{api_name}/{api_version}', description='Update API', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'API updated successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'API updated successfully'}}}, } - } + }, ) - -async def update_api(api_name: str, api_version: str, request: Request, api_data: UpdateApiModel): +async def update_api( + api_name: str, api_version: str, request: Request, api_data: UpdateApiModel +) -> Response: request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_APIS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='API008', - error_message='You do not have permission to update APIs' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='API008', + error_message='You do not have permission to update APIs', + ) + ) result = await ApiService.update_api(api_name, api_version, api_data, request_id) - audit(request, actor=username, action='api.update', target=f'{api_name}/{api_version}', status=result.get('status_code'), details={'message': result.get('message')}, request_id=request_id) + audit( + request, + actor=username, + action='api.update', + target=f'{api_name}/{api_version}', + status=result.get('status_code'), + details={'message': result.get('message')}, + request_id=request_id, + ) return respond_rest(result) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Get API @@ -154,46 +170,47 @@ Response: {} """ -@api_router.get('/{api_name}/{api_version}', + +@api_router.get( + '/{api_name}/{api_version}', description='Get API', response_model=ApiModelResponse, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'API retrieved successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'API retrieved successfully'}}}, } - } + }, ) - -async def get_api_by_name_version(api_name: str, api_version: str, request: Request): +async def get_api_by_name_version(api_name: str, api_version: str, request: Request) -> Response: request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - return respond_rest(await ApiService.get_api_by_name_version(api_name, api_version, request_id)) + return respond_rest( + await ApiService.get_api_by_name_version(api_name, api_version, request_id) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Delete API @@ -203,48 +220,55 @@ Response: {} """ -@api_router.delete('/{api_name}/{api_version}', + +@api_router.delete( + '/{api_name}/{api_version}', description='Delete API', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'API deleted successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'API deleted successfully'}}}, } - } + }, ) - -async def delete_api(api_name: str, api_version: str, request: Request): +async def delete_api(api_name: str, api_version: str, request: Request) -> Response: request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') result = await ApiService.delete_api(api_name, api_version, request_id) - audit(request, actor=username, action='api.delete', target=f'{api_name}/{api_version}', status=result.get('status_code'), details={'message': result.get('message')}, request_id=request_id) + audit( + request, + actor=username, + action='api.delete', + target=f'{api_name}/{api_version}', + status=result.get('status_code'), + details={'message': result.get('message')}, + request_id=request_id, + ) return respond_rest(result) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -254,46 +278,47 @@ Response: {} """ -@api_router.get('/all', - description='Get all APIs', - response_model=List[ApiModelResponse] -) -async def get_all_apis(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE): +@api_router.get('/all', description='Get all APIs', response_model=list[ApiModelResponse]) +async def get_all_apis(page: int, page_size: int, request: Request) -> Response: request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') return respond_rest(await ApiService.get_apis(page, page_size, request_id)) except HTTPException as e: # Surface 401/403 properly for tests that probe unauthorized access - return respond_rest(ResponseModel( - status_code=e.status_code, - response_headers={Headers.REQUEST_ID: request_id}, - error_code='API_AUTH', - error_message=e.detail - )) + return respond_rest( + ResponseModel( + status_code=e.status_code, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='API_AUTH', + error_message=e.detail, + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') -@api_router.get('', - description='Get all APIs (base path)', - response_model=List[ApiModelResponse] -) -async def get_all_apis_base(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE): + + +@api_router.get('', description='Get all APIs (base path)', response_model=list[ApiModelResponse]) +async def get_all_apis_base(page: int, page_size: int, request: Request) -> Response: """Convenience alias for GET /platform/api/all to support tests and clients that expect listing at the base collection path. """ @@ -302,21 +327,28 @@ async def get_all_apis_base(request: Request, page: int = Defaults.PAGE, page_si try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') return respond_rest(await ApiService.get_apis(page, page_size, request_id)) except HTTPException as e: - return respond_rest(ResponseModel( - status_code=e.status_code, - response_headers={Headers.REQUEST_ID: request_id}, - error_code='API_AUTH', - error_message=e.detail - )) + return respond_rest( + ResponseModel( + status_code=e.status_code, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='API_AUTH', + error_message=e.detail, + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) diff --git a/backend-services/routes/authorization_routes.py b/backend-services/routes/authorization_routes.py index 6faa9cf..63b8a59 100644 --- a/backend-services/routes/authorization_routes.py +++ b/backend-services/routes/authorization_routes.py @@ -4,23 +4,30 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from fastapi import APIRouter, Request, Depends, HTTPException, Response -from jose import JWTError -import uuid -import time import logging import os +import time +import uuid + +from fastapi import APIRouter, HTTPException, Request, Response +from jose import JWTError -from models.response_model import ResponseModel -from services.user_service import UserService -from utils.response_util import respond_rest -from utils.auth_util import auth_required, create_access_token -from utils.auth_blacklist import TimedHeap, jwt_blacklist, revoke_all_for_user, unrevoke_all_for_user, is_user_revoked, add_revoked_jti -from utils.role_util import platform_role_required_bool -from utils.role_util import is_admin_user -from models.update_user_model import UpdateUserModel from models.create_user_model import CreateUserModel +from models.response_model import ResponseModel +from models.update_user_model import UpdateUserModel +from services.user_service import UserService +from utils.auth_blacklist import ( + TimedHeap, + add_revoked_jti, + is_user_revoked, + jwt_blacklist, + revoke_all_for_user, + unrevoke_all_for_user, +) +from utils.auth_util import auth_required, create_access_token from utils.limit_throttle_util import limit_by_ip +from utils.response_util import respond_rest +from utils.role_util import is_admin_user, platform_role_required_bool authorization_router = APIRouter() @@ -35,23 +42,18 @@ Response: {} """ -@authorization_router.post('/authorization', + +@authorization_router.post( + '/authorization', description='Create authorization token', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'access_token': '******************' - } - } - } + 'content': {'application/json': {'example': {'access_token': '******************'}}}, } - } + }, ) - async def authorization(request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 @@ -65,53 +67,55 @@ async def authorization(request: Request): try: data = await request.json() except Exception: - return respond_rest(ResponseModel( - status_code=400, - response_headers={'request_id': request_id}, - error_code='AUTH004', - error_message='Invalid JSON payload' - )) + return respond_rest( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='AUTH004', + error_message='Invalid JSON payload', + ) + ) email = data.get('email') password = data.get('password') if not email or not password: - return respond_rest(ResponseModel( - status_code=400, - response_headers={ - 'request_id': request_id - }, - error_code='AUTH001', - error_message='Missing email or password' - )) + return respond_rest( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='AUTH001', + error_message='Missing email or password', + ) + ) user = await UserService.check_password_return_user(email, password) if not user: - return respond_rest(ResponseModel( - status_code=400, - response_headers={ - 'request_id': request_id - }, - error_code='AUTH002', - error_message='Invalid email or password' - )) + return respond_rest( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='AUTH002', + error_message='Invalid email or password', + ) + ) if not user['active']: - return respond_rest(ResponseModel( - status_code=400, - response_headers={ - 'request_id': request_id - }, - error_code='AUTH007', - error_message='User is not active' - )) + return respond_rest( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='AUTH007', + error_message='User is not active', + ) + ) access_token = create_access_token({'sub': user['username'], 'role': user['role']}, False) - logger.info(f"Login successful for user: {user['username']}") + logger.info(f'Login successful for user: {user["username"]}') - response = respond_rest(ResponseModel( - status_code=200, - response_headers={ - 'request_id': request_id - }, - response={'access_token': access_token} - )) + response = respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + response={'access_token': access_token}, + ) + ) if rate_limit_info: response.headers['X-RateLimit-Limit'] = str(rate_limit_info['limit']) @@ -121,6 +125,7 @@ async def authorization(request: Request): response.delete_cookie('access_token_cookie') import uuid as _uuid + csrf_token = str(_uuid.uuid4()) _secure_env = os.getenv('COOKIE_SECURE') @@ -131,7 +136,9 @@ async def authorization(request: Request): # Security warning: cookies should be secure in production if not _secure and os.getenv('ENV', '').lower() in ('production', 'prod'): - logger.warning(f'{request_id} | SECURITY WARNING: Secure cookies disabled in production environment') + logger.warning( + f'{request_id} | SECURITY WARNING: Secure cookies disabled in production environment' + ) _domain = os.getenv('COOKIE_DOMAIN', None) _samesite = (os.getenv('COOKIE_SAMESITE', 'Strict') or 'Strict').strip().lower() @@ -152,7 +159,7 @@ async def authorization(request: Request): samesite=_samesite, path='/', domain=safe_domain, - max_age=1800 + max_age=1800, ) response.set_cookie( @@ -162,7 +169,7 @@ async def authorization(request: Request): secure=_secure, samesite=_samesite, path='/', - max_age=1800 + max_age=1800, ) response.set_cookie( @@ -173,7 +180,7 @@ async def authorization(request: Request): samesite=_samesite, path='/', domain=safe_domain, - max_age=1800 + max_age=1800, ) response.set_cookie( @@ -183,44 +190,44 @@ async def authorization(request: Request): secure=_secure, samesite=_samesite, path='/', - max_age=1800 + max_age=1800, ) return response except HTTPException as e: if getattr(e, 'status_code', None) == 429: headers = getattr(e, 'headers', {}) or {} detail = e.detail if isinstance(e.detail, dict) else {} - return respond_rest(ResponseModel( - status_code=429, - response_headers={ - 'request_id': request_id, - **headers - }, - error_code=str(detail.get('error_code') or 'IP_RATE_LIMIT'), - error_message=str(detail.get('message') or 'Too many requests') - )) - return respond_rest(ResponseModel( - status_code=401, - response_headers={ - 'request_id': request_id - }, - error_code='AUTH003', - error_message='Unable to validate credentials' - )) + return respond_rest( + ResponseModel( + status_code=429, + response_headers={'request_id': request_id, **headers}, + error_code=str(detail.get('error_code') or 'IP_RATE_LIMIT'), + error_message=str(detail.get('message') or 'Too many requests'), + ) + ) + return respond_rest( + ResponseModel( + status_code=401, + response_headers={'request_id': request_id}, + error_code='AUTH003', + error_message='Unable to validate credentials', + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Register new user @@ -230,23 +237,18 @@ Response: {} """ -@authorization_router.post('/authorization/register', + +@authorization_router.post( + '/authorization/register', description='Register new user', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'User created successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'User created successfully'}}}, } - } + }, ) - async def register(request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 @@ -254,59 +256,68 @@ async def register(request: Request): # Rate limit registration to prevent abuse reg_limit = int(os.getenv('REGISTER_IP_RATE_LIMIT', '5')) reg_window = int(os.getenv('REGISTER_IP_RATE_WINDOW', '3600')) - rate_limit_info = await limit_by_ip(request, limit=reg_limit, window=reg_window) + await limit_by_ip(request, limit=reg_limit, window=reg_window) + + logger.info( + f'{request_id} | Register request from: {request.client.host}:{request.client.port}' + ) - logger.info(f'{request_id} | Register request from: {request.client.host}:{request.client.port}') - try: data = await request.json() except Exception: - return respond_rest(ResponseModel( - status_code=400, - response_headers={'request_id': request_id}, - error_code='AUTH004', - error_message='Invalid JSON payload' - )) - + return respond_rest( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='AUTH004', + error_message='Invalid JSON payload', + ) + ) + # Validate required fields if not data.get('email') or not data.get('password'): - return respond_rest(ResponseModel( - status_code=400, - response_headers={'request_id': request_id}, - error_code='AUTH001', - error_message='Missing email or password' - )) + return respond_rest( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='AUTH001', + error_message='Missing email or password', + ) + ) # Create user model # Default to 'user' role and active=True user_data = CreateUserModel( - username=data.get('email').split('@')[0], # Simple username derivation + username=data.get('email').split('@')[0], # Simple username derivation email=data.get('email'), password=data.get('password'), role='user', - active=True + active=True, ) # Check if user exists (UserService.create_user handles this but we want clean error) # Actually UserService.create_user will return error if exists. - + result = await UserService.create_user(user_data, request_id) - + # If successful, we could auto-login, but for now just return success return respond_rest(result) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -316,53 +327,66 @@ Response: {} """ -@authorization_router.post('/authorization/admin/revoke/{username}', - description='Revoke all active tokens for a user (admin)', - response_model=ResponseModel) +@authorization_router.post( + '/authorization/admin/revoke/{username}', + description='Revoke all active tokens for a user (admin)', + response_model=ResponseModel, +) async def admin_revoke_user_tokens(username: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) admin_user = payload.get('sub') - logger.info(f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(admin_user, 'manage_auth'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='AUTH900', - error_message='You do not have permission to manage auth' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='AUTH900', + error_message='You do not have permission to manage auth', + ) + ) try: if await is_admin_user(username) and not await is_admin_user(admin_user): - return respond_rest(ResponseModel( - status_code=404, - response_headers={'request_id': request_id}, - error_message='User not found' - )) + return respond_rest( + ResponseModel( + status_code=404, + response_headers={'request_id': request_id}, + error_message='User not found', + ) + ) except Exception as e: logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True) revoke_all_for_user(username) - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - message=f'All tokens revoked for {username}' - )) + return respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + message=f'All tokens revoked for {username}', + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -372,53 +396,66 @@ Response: {} """ -@authorization_router.post('/authorization/admin/unrevoke/{username}', - description='Clear token revocation for a user (admin)', - response_model=ResponseModel) +@authorization_router.post( + '/authorization/admin/unrevoke/{username}', + description='Clear token revocation for a user (admin)', + response_model=ResponseModel, +) async def admin_unrevoke_user_tokens(username: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) admin_user = payload.get('sub') - logger.info(f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(admin_user, 'manage_auth'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='AUTH900', - error_message='You do not have permission to manage auth' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='AUTH900', + error_message='You do not have permission to manage auth', + ) + ) try: if await is_admin_user(username) and not await is_admin_user(admin_user): - return respond_rest(ResponseModel( - status_code=404, - response_headers={'request_id': request_id}, - error_message='User not found' - )) + return respond_rest( + ResponseModel( + status_code=404, + response_headers={'request_id': request_id}, + error_message='User not found', + ) + ) except Exception as e: logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True) unrevoke_all_for_user(username) - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - message=f'Token revocation cleared for {username}' - )) + return respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + message=f'Token revocation cleared for {username}', + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -428,56 +465,69 @@ Response: {} """ -@authorization_router.post('/authorization/admin/disable/{username}', - description='Disable a user (admin)', - response_model=ResponseModel) +@authorization_router.post( + '/authorization/admin/disable/{username}', + description='Disable a user (admin)', + response_model=ResponseModel, +) async def admin_disable_user(username: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) admin_user = payload.get('sub') - logger.info(f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(admin_user, 'manage_auth'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='AUTH900', - error_message='You do not have permission to manage auth' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='AUTH900', + error_message='You do not have permission to manage auth', + ) + ) try: if await is_admin_user(username) and not await is_admin_user(admin_user): - return respond_rest(ResponseModel( - status_code=404, - response_headers={'request_id': request_id}, - error_message='User not found' - )) + return respond_rest( + ResponseModel( + status_code=404, + response_headers={'request_id': request_id}, + error_message='User not found', + ) + ) except Exception as e: logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True) await UserService.update_user(username, UpdateUserModel(active=False), request_id) revoke_all_for_user(username) - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - message=f'User {username} disabled and tokens revoked' - )) + return respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + message=f'User {username} disabled and tokens revoked', + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -487,54 +537,67 @@ Response: {} """ -@authorization_router.post('/authorization/admin/enable/{username}', - description='Enable a user (admin)', - response_model=ResponseModel) +@authorization_router.post( + '/authorization/admin/enable/{username}', + description='Enable a user (admin)', + response_model=ResponseModel, +) async def admin_enable_user(username: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) admin_user = payload.get('sub') - logger.info(f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(admin_user, 'manage_auth'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='AUTH900', - error_message='You do not have permission to manage auth' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='AUTH900', + error_message='You do not have permission to manage auth', + ) + ) try: if await is_admin_user(username) and not await is_admin_user(admin_user): - return respond_rest(ResponseModel( - status_code=404, - response_headers={'request_id': request_id}, - error_message='User not found' - )) + return respond_rest( + ResponseModel( + status_code=404, + response_headers={'request_id': request_id}, + error_message='User not found', + ) + ) except Exception as e: logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True) await UserService.update_user(username, UpdateUserModel(active=True), request_id) - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - message=f'User {username} enabled' - )) + return respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + message=f'User {username} enabled', + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -544,57 +607,65 @@ Response: {} """ -@authorization_router.get('/authorization/admin/status/{username}', - description='Get auth status for a user (admin)', - response_model=ResponseModel) +@authorization_router.get( + '/authorization/admin/status/{username}', + description='Get auth status for a user (admin)', + response_model=ResponseModel, +) async def admin_user_status(username: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) admin_user = payload.get('sub') - logger.info(f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(admin_user, 'manage_auth'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='AUTH900', - error_message='You do not have permission to manage auth' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='AUTH900', + error_message='You do not have permission to manage auth', + ) + ) try: if await is_admin_user(username) and not await is_admin_user(admin_user): - return respond_rest(ResponseModel( - status_code=404, - response_headers={'request_id': request_id}, - error_message='User not found' - )) + return respond_rest( + ResponseModel( + status_code=404, + response_headers={'request_id': request_id}, + error_message='User not found', + ) + ) except Exception as e: logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True) user = await UserService.get_user_by_username_helper(username) - status = { - 'active': bool(user.get('active', False)), - 'revoked': is_user_revoked(username) - } - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response=status - )) + status = {'active': bool(user.get('active', False)), 'revoked': is_user_revoked(username)} + return respond_rest( + ResponseModel( + status_code=200, response_headers={'request_id': request_id}, response=status + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Create authorization refresh token @@ -604,51 +675,49 @@ Response: {} """ -@authorization_router.post('/authorization/refresh', + +@authorization_router.post( + '/authorization/refresh', description='Create authorization refresh token', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'refresh_token': '******************' - } - } - } + 'content': {'application/json': {'example': {'refresh_token': '******************'}}}, } - } + }, ) - async def extended_authorization(request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') user = await UserService.get_user_by_username_helper(username) if not user['active']: - return respond_rest(ResponseModel( - status_code=400, - response_headers={ - 'request_id': request_id - }, - error_code='AUTH007', - error_message='User is not active' - )) + return respond_rest( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='AUTH007', + error_message='User is not active', + ) + ) refresh_token = create_access_token({'sub': username, 'role': user['role']}, True) - response = respond_rest(ResponseModel( - status_code=200, - response_headers={ - 'request_id': request_id - }, - response={'refresh_token': refresh_token} - )) + response = respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + response={'refresh_token': refresh_token}, + ) + ) import uuid as _uuid + csrf_token = str(_uuid.uuid4()) _secure_env = os.getenv('COOKIE_SECURE') @@ -659,7 +728,9 @@ async def extended_authorization(request: Request): # Security warning: cookies should be secure in production if not _secure and os.getenv('ENV', '').lower() in ('production', 'prod'): - logger.warning(f'{request_id} | SECURITY WARNING: Secure cookies disabled in production environment') + logger.warning( + f'{request_id} | SECURITY WARNING: Secure cookies disabled in production environment' + ) _domain = os.getenv('COOKIE_DOMAIN', None) _samesite = (os.getenv('COOKIE_SAMESITE', 'Strict') or 'Strict').strip().lower() @@ -680,7 +751,7 @@ async def extended_authorization(request: Request): samesite=_samesite, path='/', domain=safe_domain, - max_age=604800 + max_age=604800, ) response.set_cookie( @@ -690,7 +761,7 @@ async def extended_authorization(request: Request): secure=_secure, samesite=_samesite, path='/', - max_age=604800 + max_age=604800, ) response.set_cookie( @@ -701,7 +772,7 @@ async def extended_authorization(request: Request): samesite=_samesite, path='/', domain=safe_domain, - max_age=604800 + max_age=604800, ) response.set_cookie( @@ -711,42 +782,43 @@ async def extended_authorization(request: Request): secure=_secure, samesite=_samesite, path='/', - max_age=604800 + max_age=604800, ) return response - except HTTPException as e: - return respond_rest(ResponseModel( - status_code=401, - response_headers={ - 'request_id': request_id - }, - error_code='AUTH003', - error_message='Unable to validate credentials' - )) + except HTTPException: + return respond_rest( + ResponseModel( + status_code=401, + response_headers={'request_id': request_id}, + error_code='AUTH003', + error_message='Unable to validate credentials', + ) + ) except JWTError as e: logging.error(f'Token refresh failed: {str(e)}') - return respond_rest(ResponseModel( - status_code=401, - response_headers={ - 'request_id': request_id - }, - error_code='AUTH004', - error_message='Token refresh failed' - )) + return respond_rest( + ResponseModel( + status_code=401, + response_headers={'request_id': request_id}, + error_code='AUTH004', + error_message='Token refresh failed', + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Get authorization token status @@ -756,61 +828,59 @@ Response: {} """ -@authorization_router.get('/authorization/status', + +@authorization_router.get( + '/authorization/status', description='Get authorization token status', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'status': 'authorized' - } - } - } + 'content': {'application/json': {'example': {'status': 'authorized'}}}, } - } + }, ) - async def authorization_status(request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - return respond_rest(ResponseModel( - status_code=200, - response_headers={ - 'request_id': request_id - }, - message='Token is valid' - )) + return respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + message='Token is valid', + ) + ) except JWTError: - return respond_rest(ResponseModel( - status_code=401, - response_headers={ - 'request_id': request_id - }, - error_code='AUTH005', - error_message='Token is invalid' - )) + return respond_rest( + ResponseModel( + status_code=401, + response_headers={'request_id': request_id}, + error_code='AUTH005', + error_message='Token is invalid', + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Invalidate authorization token @@ -820,33 +890,33 @@ Response: {} """ -@authorization_router.post('/authorization/invalidate', + +@authorization_router.post( + '/authorization/invalidate', description='Invalidate authorization token', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Your token has been invalidated' - } - } - } + 'application/json': {'example': {'message': 'Your token has been invalidated'}} + }, } - } + }, ) - async def authorization_invalidate(response: Response, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') try: import time as _t + exp = payload.get('exp') ttl = None if isinstance(exp, (int, float)): @@ -857,29 +927,31 @@ async def authorization_invalidate(response: Response, request: Request): if username not in jwt_blacklist: jwt_blacklist[username] = TimedHeap() jwt_blacklist[username].push(payload.get('jti')) - response = respond_rest(ResponseModel( - status_code=200, - response_headers={ - 'request_id': request_id - }, - message='Your token has been invalidated' - )) + response = respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + message='Your token has been invalidated', + ) + ) _domain = os.getenv('COOKIE_DOMAIN', None) host = request.url.hostname or (request.client.host if request.client else None) - safe_domain = _domain if (_domain and host and (host == _domain or host.endswith(_domain))) else None + safe_domain = ( + _domain if (_domain and host and (host == _domain or host.endswith(_domain))) else None + ) response.delete_cookie('access_token_cookie', domain=safe_domain, path='/') return response except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/config_hot_reload_routes.py b/backend-services/routes/config_hot_reload_routes.py index a380734..5a1f183 100644 --- a/backend-services/routes/config_hot_reload_routes.py +++ b/backend-services/routes/config_hot_reload_routes.py @@ -4,37 +4,33 @@ Configuration Hot Reload Routes API endpoints for configuration management and hot reload. """ -from fastapi import APIRouter, Depends, HTTPException -from typing import Dict, Any import logging +from typing import Any +from fastapi import APIRouter, Depends, HTTPException + +from models.response_model import ResponseModel from utils.auth_util import auth_required from utils.hot_reload_config import hot_config -from models.response_model import ResponseModel logger = logging.getLogger('doorman.gateway') -config_hot_reload_router = APIRouter( - prefix='/config', - tags=['Configuration Hot Reload'] -) +config_hot_reload_router = APIRouter(prefix='/config', tags=['Configuration Hot Reload']) + @config_hot_reload_router.get( '/current', summary='Get Current Configuration', description='Retrieve current hot-reloadable configuration values', - response_model=Dict[str, Any], + response_model=dict[str, Any], ) -async def get_current_config( - payload: dict = Depends(auth_required) -): +async def get_current_config(payload: dict = Depends(auth_required)): """Get current configuration (admin only)""" try: accesses = payload.get('accesses', {}) if not accesses.get('manage_gateway'): raise HTTPException( - status_code=403, - detail='Insufficient permissions: manage_gateway required' + status_code=403, detail='Insufficient permissions: manage_gateway required' ) config = hot_config.dump() @@ -44,101 +40,126 @@ async def get_current_config( data={ 'config': config, 'source': 'Environment variables override config file values', - 'reload_command': 'kill -HUP $(cat doorman.pid)' + 'reload_command': 'kill -HUP $(cat doorman.pid)', }, error_code=None, - error_message=None + error_message=None, ).dict() except HTTPException: raise except Exception as e: logger.error(f'Failed to retrieve configuration: {e}', exc_info=True) - raise HTTPException( - status_code=500, - detail='Failed to retrieve configuration' - ) + raise HTTPException(status_code=500, detail='Failed to retrieve configuration') + @config_hot_reload_router.post( '/reload', summary='Trigger Configuration Reload', description='Manually trigger configuration reload (same as SIGHUP)', - response_model=Dict[str, Any], + response_model=dict[str, Any], ) -async def trigger_config_reload( - payload: dict = Depends(auth_required) -): +async def trigger_config_reload(payload: dict = Depends(auth_required)): """Trigger configuration reload (admin only)""" try: accesses = payload.get('accesses', {}) if not accesses.get('manage_gateway'): raise HTTPException( - status_code=403, - detail='Insufficient permissions: manage_gateway required' + status_code=403, detail='Insufficient permissions: manage_gateway required' ) hot_config.reload() return ResponseModel( status_code=200, - data={ - 'message': 'Configuration reloaded successfully', - 'config': hot_config.dump() - }, + data={'message': 'Configuration reloaded successfully', 'config': hot_config.dump()}, error_code=None, - error_message=None + error_message=None, ).dict() except HTTPException: raise except Exception as e: logger.error(f'Failed to reload configuration: {e}', exc_info=True) - raise HTTPException( - status_code=500, - detail='Failed to reload configuration' - ) + raise HTTPException(status_code=500, detail='Failed to reload configuration') + @config_hot_reload_router.get( '/reloadable-keys', summary='List Reloadable Configuration Keys', description='Get list of configuration keys that support hot reload', - response_model=Dict[str, Any], + response_model=dict[str, Any], ) -async def get_reloadable_keys( - payload: dict = Depends(auth_required) -): +async def get_reloadable_keys(payload: dict = Depends(auth_required)): """Get list of reloadable configuration keys""" try: reloadable_keys = [ - {'key': 'LOG_LEVEL', 'description': 'Log level (DEBUG, INFO, WARNING, ERROR)', 'example': 'INFO'}, + { + 'key': 'LOG_LEVEL', + 'description': 'Log level (DEBUG, INFO, WARNING, ERROR)', + 'example': 'INFO', + }, {'key': 'LOG_FORMAT', 'description': 'Log format (json, text)', 'example': 'json'}, {'key': 'LOG_FILE', 'description': 'Log file path', 'example': 'logs/doorman.log'}, - - {'key': 'GATEWAY_TIMEOUT', 'description': 'Gateway timeout in seconds', 'example': '30'}, - {'key': 'UPSTREAM_TIMEOUT', 'description': 'Upstream timeout in seconds', 'example': '30'}, - {'key': 'CONNECTION_TIMEOUT', 'description': 'Connection timeout in seconds', 'example': '10'}, - + { + 'key': 'GATEWAY_TIMEOUT', + 'description': 'Gateway timeout in seconds', + 'example': '30', + }, + { + 'key': 'UPSTREAM_TIMEOUT', + 'description': 'Upstream timeout in seconds', + 'example': '30', + }, + { + 'key': 'CONNECTION_TIMEOUT', + 'description': 'Connection timeout in seconds', + 'example': '10', + }, {'key': 'RATE_LIMIT_ENABLED', 'description': 'Enable rate limiting', 'example': 'true'}, {'key': 'RATE_LIMIT_REQUESTS', 'description': 'Requests per window', 'example': '100'}, {'key': 'RATE_LIMIT_WINDOW', 'description': 'Window size in seconds', 'example': '60'}, - {'key': 'CACHE_TTL', 'description': 'Cache TTL in seconds', 'example': '300'}, {'key': 'CACHE_MAX_SIZE', 'description': 'Maximum cache entries', 'example': '1000'}, - - {'key': 'CIRCUIT_BREAKER_ENABLED', 'description': 'Enable circuit breaker', 'example': 'true'}, - {'key': 'CIRCUIT_BREAKER_THRESHOLD', 'description': 'Failures before opening', 'example': '5'}, - {'key': 'CIRCUIT_BREAKER_TIMEOUT', 'description': 'Timeout before retry (seconds)', 'example': '60'}, - + { + 'key': 'CIRCUIT_BREAKER_ENABLED', + 'description': 'Enable circuit breaker', + 'example': 'true', + }, + { + 'key': 'CIRCUIT_BREAKER_THRESHOLD', + 'description': 'Failures before opening', + 'example': '5', + }, + { + 'key': 'CIRCUIT_BREAKER_TIMEOUT', + 'description': 'Timeout before retry (seconds)', + 'example': '60', + }, {'key': 'RETRY_ENABLED', 'description': 'Enable retry logic', 'example': 'true'}, {'key': 'RETRY_MAX_ATTEMPTS', 'description': 'Maximum retry attempts', 'example': '3'}, {'key': 'RETRY_BACKOFF', 'description': 'Backoff multiplier', 'example': '1.0'}, - - {'key': 'METRICS_ENABLED', 'description': 'Enable metrics collection', 'example': 'true'}, - {'key': 'METRICS_INTERVAL', 'description': 'Metrics interval (seconds)', 'example': '60'}, - - {'key': 'FEATURE_REQUEST_REPLAY', 'description': 'Enable request replay', 'example': 'false'}, + { + 'key': 'METRICS_ENABLED', + 'description': 'Enable metrics collection', + 'example': 'true', + }, + { + 'key': 'METRICS_INTERVAL', + 'description': 'Metrics interval (seconds)', + 'example': '60', + }, + { + 'key': 'FEATURE_REQUEST_REPLAY', + 'description': 'Enable request replay', + 'example': 'false', + }, {'key': 'FEATURE_AB_TESTING', 'description': 'Enable A/B testing', 'example': 'false'}, - {'key': 'FEATURE_COST_ANALYTICS', 'description': 'Enable cost analytics', 'example': 'false'}, + { + 'key': 'FEATURE_COST_ANALYTICS', + 'description': 'Enable cost analytics', + 'example': 'false', + }, ] return ResponseModel( @@ -150,16 +171,13 @@ async def get_reloadable_keys( 'Environment variables always override config file values', 'Changes take effect immediately after reload', 'Reload via: kill -HUP $(cat doorman.pid)', - 'Or use: POST /config/reload' - ] + 'Or use: POST /config/reload', + ], }, error_code=None, - error_message=None + error_message=None, ).dict() except Exception as e: logger.error(f'Failed to retrieve reloadable keys: {e}', exc_info=True) - raise HTTPException( - status_code=500, - detail='Failed to retrieve reloadable keys' - ) + raise HTTPException(status_code=500, detail='Failed to retrieve reloadable keys') diff --git a/backend-services/routes/config_routes.py b/backend-services/routes/config_routes.py index 76322f0..82046b5 100644 --- a/backend-services/routes/config_routes.py +++ b/backend-services/routes/config_routes.py @@ -2,37 +2,39 @@ Routes to export and import platform configuration (APIs, Endpoints, Roles, Groups, Routings). """ -from fastapi import APIRouter, Request -from typing import Any, Dict, List, Optional -import uuid -import time -import logging import copy +import logging +import time +import uuid +from typing import Any + +from fastapi import APIRouter, Request from models.response_model import ResponseModel -from utils.response_util import process_response -from utils.auth_util import auth_required -from utils.role_util import platform_role_required_bool -from utils.doorman_cache_util import doorman_cache from utils.audit_util import audit +from utils.auth_util import auth_required from utils.database import ( - api_collection, endpoint_collection, group_collection, role_collection, routing_collection, ) +from utils.doorman_cache_util import doorman_cache +from utils.response_util import process_response +from utils.role_util import platform_role_required_bool config_router = APIRouter() logger = logging.getLogger('doorman.gateway') -def _strip_id(doc: Dict[str, Any]) -> Dict[str, Any]: + +def _strip_id(doc: dict[str, Any]) -> dict[str, Any]: d = dict(doc) d.pop('_id', None) return d -def _export_all() -> Dict[str, Any]: + +def _export_all() -> dict[str, Any]: apis = [_strip_id(a) for a in api_collection.find().to_list(length=None)] endpoints = [_strip_id(e) for e in endpoint_collection.find().to_list(length=None)] roles = [_strip_id(r) for r in role_collection.find().to_list(length=None)] @@ -46,6 +48,7 @@ def _export_all() -> Dict[str, Any]: 'routings': routings, } + """ Endpoint @@ -55,11 +58,12 @@ Response: {} """ -@config_router.get('/config/export/all', + +@config_router.get( + '/config/export/all', description='Export all platform configuration (APIs, Endpoints, Roles, Groups, Routings)', response_model=ResponseModel, ) - async def export_all(request: Request): request_id = str(uuid.uuid4()) start = time.time() * 1000 @@ -67,15 +71,39 @@ async def export_all(request: Request): payload = await auth_required(request) username = payload.get('sub') if not await platform_role_required_bool(username, 'manage_gateway'): - return process_response(ResponseModel(status_code=403, error_code='CFG001', error_message='Insufficient permissions').dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, error_code='CFG001', error_message='Insufficient permissions' + ).dict(), + 'rest', + ) data = _export_all() - audit(request, actor=username, action='config.export_all', target='all', status='success', details={'counts': {k: len(v) for k,v in data.items()}}, request_id=request_id) - return process_response(ResponseModel(status_code=200, response_headers={'request_id': request_id}, response=data).dict(), 'rest') + audit( + request, + actor=username, + action='config.export_all', + target='all', + status='success', + details={'counts': {k: len(v) for k, v in data.items()}}, + request_id=request_id, + ) + return process_response( + ResponseModel( + status_code=200, response_headers={'request_id': request_id}, response=data + ).dict(), + 'rest', + ) except Exception as e: logger.error(f'{request_id} | export_all error: {e}') - return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, error_code='GTW999', error_message='An unexpected error occurred' + ).dict(), + 'rest', + ) finally: - logger.info(f'{request_id} | export_all took {time.time()*1000 - start:.2f}ms') + logger.info(f'{request_id} | export_all took {time.time() * 1000 - start:.2f}ms') + """ Endpoint @@ -86,37 +114,80 @@ Response: {} """ -@config_router.get('/config/export/apis', - description='Export APIs (optionally a single API with its endpoints)', - response_model=ResponseModel) -async def export_apis(request: Request, api_name: Optional[str] = None, api_version: Optional[str] = None): +@config_router.get( + '/config/export/apis', + description='Export APIs (optionally a single API with its endpoints)', + response_model=ResponseModel, +) +async def export_apis( + request: Request, api_name: str | None = None, api_version: str | None = None +): request_id = str(uuid.uuid4()) start = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') if not await platform_role_required_bool(username, 'manage_apis'): - return process_response(ResponseModel(status_code=403, error_code='CFG002', error_message='Insufficient permissions').dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, error_code='CFG002', error_message='Insufficient permissions' + ).dict(), + 'rest', + ) if api_name and api_version: api = api_collection.find_one({'api_name': api_name, 'api_version': api_version}) if not api: - return process_response(ResponseModel(status_code=404, error_code='CFG404', error_message='API not found').dict(), 'rest') - aid = api.get('api_id') - eps = endpoint_collection.find({'api_name': api_name, 'api_version': api_version}).to_list(length=None) - audit(request, actor=username, action='config.export_api', target=f'{api_name}/{api_version}', status='success', details={'endpoints': len(eps)}, request_id=request_id) - return process_response(ResponseModel(status_code=200, response={ - 'api': _strip_id(api), - 'endpoints': [_strip_id(e) for e in eps] - }).dict(), 'rest') + return process_response( + ResponseModel( + status_code=404, error_code='CFG404', error_message='API not found' + ).dict(), + 'rest', + ) + api.get('api_id') + eps = endpoint_collection.find( + {'api_name': api_name, 'api_version': api_version} + ).to_list(length=None) + audit( + request, + actor=username, + action='config.export_api', + target=f'{api_name}/{api_version}', + status='success', + details={'endpoints': len(eps)}, + request_id=request_id, + ) + return process_response( + ResponseModel( + status_code=200, + response={'api': _strip_id(api), 'endpoints': [_strip_id(e) for e in eps]}, + ).dict(), + 'rest', + ) apis = [_strip_id(a) for a in api_collection.find().to_list(length=None)] - audit(request, actor=username, action='config.export_apis', target='list', status='success', details={'count': len(apis)}, request_id=request_id) - return process_response(ResponseModel(status_code=200, response={'apis': apis}).dict(), 'rest') + audit( + request, + actor=username, + action='config.export_apis', + target='list', + status='success', + details={'count': len(apis)}, + request_id=request_id, + ) + return process_response( + ResponseModel(status_code=200, response={'apis': apis}).dict(), 'rest' + ) except Exception as e: logger.error(f'{request_id} | export_apis error: {e}') - return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, error_code='GTW999', error_message='An unexpected error occurred' + ).dict(), + 'rest', + ) finally: - logger.info(f'{request_id} | export_apis took {time.time()*1000 - start:.2f}ms') + logger.info(f'{request_id} | export_apis took {time.time() * 1000 - start:.2f}ms') + """ Endpoint @@ -127,27 +198,63 @@ Response: {} """ -@config_router.get('/config/export/roles', description='Export Roles', response_model=ResponseModel) -async def export_roles(request: Request, role_name: Optional[str] = None): +@config_router.get('/config/export/roles', description='Export Roles', response_model=ResponseModel) +async def export_roles(request: Request, role_name: str | None = None): request_id = str(uuid.uuid4()) try: payload = await auth_required(request) username = payload.get('sub') if not await platform_role_required_bool(username, 'manage_roles'): - return process_response(ResponseModel(status_code=403, error_code='CFG003', error_message='Insufficient permissions').dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, error_code='CFG003', error_message='Insufficient permissions' + ).dict(), + 'rest', + ) if role_name: role = role_collection.find_one({'role_name': role_name}) if not role: - return process_response(ResponseModel(status_code=404, error_code='CFG404', error_message='Role not found').dict(), 'rest') - audit(request, actor=username, action='config.export_role', target=role_name, status='success', details=None, request_id=request_id) - return process_response(ResponseModel(status_code=200, response={'role': _strip_id(role)}).dict(), 'rest') + return process_response( + ResponseModel( + status_code=404, error_code='CFG404', error_message='Role not found' + ).dict(), + 'rest', + ) + audit( + request, + actor=username, + action='config.export_role', + target=role_name, + status='success', + details=None, + request_id=request_id, + ) + return process_response( + ResponseModel(status_code=200, response={'role': _strip_id(role)}).dict(), 'rest' + ) roles = [_strip_id(r) for r in role_collection.find().to_list(length=None)] - audit(request, actor=username, action='config.export_roles', target='list', status='success', details={'count': len(roles)}, request_id=request_id) - return process_response(ResponseModel(status_code=200, response={'roles': roles}).dict(), 'rest') + audit( + request, + actor=username, + action='config.export_roles', + target='list', + status='success', + details={'count': len(roles)}, + request_id=request_id, + ) + return process_response( + ResponseModel(status_code=200, response={'roles': roles}).dict(), 'rest' + ) except Exception as e: logger.error(f'{request_id} | export_roles error: {e}') - return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, error_code='GTW999', error_message='An unexpected error occurred' + ).dict(), + 'rest', + ) + """ Endpoint @@ -158,27 +265,65 @@ Response: {} """ -@config_router.get('/config/export/groups', description='Export Groups', response_model=ResponseModel) -async def export_groups(request: Request, group_name: Optional[str] = None): +@config_router.get( + '/config/export/groups', description='Export Groups', response_model=ResponseModel +) +async def export_groups(request: Request, group_name: str | None = None): request_id = str(uuid.uuid4()) try: payload = await auth_required(request) username = payload.get('sub') if not await platform_role_required_bool(username, 'manage_groups'): - return process_response(ResponseModel(status_code=403, error_code='CFG004', error_message='Insufficient permissions').dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, error_code='CFG004', error_message='Insufficient permissions' + ).dict(), + 'rest', + ) if group_name: group = group_collection.find_one({'group_name': group_name}) if not group: - return process_response(ResponseModel(status_code=404, error_code='CFG404', error_message='Group not found').dict(), 'rest') - audit(request, actor=username, action='config.export_group', target=group_name, status='success', details=None, request_id=request_id) - return process_response(ResponseModel(status_code=200, response={'group': _strip_id(group)}).dict(), 'rest') + return process_response( + ResponseModel( + status_code=404, error_code='CFG404', error_message='Group not found' + ).dict(), + 'rest', + ) + audit( + request, + actor=username, + action='config.export_group', + target=group_name, + status='success', + details=None, + request_id=request_id, + ) + return process_response( + ResponseModel(status_code=200, response={'group': _strip_id(group)}).dict(), 'rest' + ) groups = [_strip_id(g) for g in group_collection.find().to_list(length=None)] - audit(request, actor=username, action='config.export_groups', target='list', status='success', details={'count': len(groups)}, request_id=request_id) - return process_response(ResponseModel(status_code=200, response={'groups': groups}).dict(), 'rest') + audit( + request, + actor=username, + action='config.export_groups', + target='list', + status='success', + details={'count': len(groups)}, + request_id=request_id, + ) + return process_response( + ResponseModel(status_code=200, response={'groups': groups}).dict(), 'rest' + ) except Exception as e: logger.error(f'{request_id} | export_groups error: {e}') - return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, error_code='GTW999', error_message='An unexpected error occurred' + ).dict(), + 'rest', + ) + """ Endpoint @@ -189,27 +334,66 @@ Response: {} """ -@config_router.get('/config/export/routings', description='Export Routings', response_model=ResponseModel) -async def export_routings(request: Request, client_key: Optional[str] = None): +@config_router.get( + '/config/export/routings', description='Export Routings', response_model=ResponseModel +) +async def export_routings(request: Request, client_key: str | None = None): request_id = str(uuid.uuid4()) try: payload = await auth_required(request) username = payload.get('sub') if not await platform_role_required_bool(username, 'manage_routings'): - return process_response(ResponseModel(status_code=403, error_code='CFG005', error_message='Insufficient permissions').dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, error_code='CFG005', error_message='Insufficient permissions' + ).dict(), + 'rest', + ) if client_key: routing = routing_collection.find_one({'client_key': client_key}) if not routing: - return process_response(ResponseModel(status_code=404, error_code='CFG404', error_message='Routing not found').dict(), 'rest') - audit(request, actor=username, action='config.export_routing', target=client_key, status='success', details=None, request_id=request_id) - return process_response(ResponseModel(status_code=200, response={'routing': _strip_id(routing)}).dict(), 'rest') + return process_response( + ResponseModel( + status_code=404, error_code='CFG404', error_message='Routing not found' + ).dict(), + 'rest', + ) + audit( + request, + actor=username, + action='config.export_routing', + target=client_key, + status='success', + details=None, + request_id=request_id, + ) + return process_response( + ResponseModel(status_code=200, response={'routing': _strip_id(routing)}).dict(), + 'rest', + ) routings = [_strip_id(r) for r in routing_collection.find().to_list(length=None)] - audit(request, actor=username, action='config.export_routings', target='list', status='success', details={'count': len(routings)}, request_id=request_id) - return process_response(ResponseModel(status_code=200, response={'routings': routings}).dict(), 'rest') + audit( + request, + actor=username, + action='config.export_routings', + target='list', + status='success', + details={'count': len(routings)}, + request_id=request_id, + ) + return process_response( + ResponseModel(status_code=200, response={'routings': routings}).dict(), 'rest' + ) except Exception as e: logger.error(f'{request_id} | export_routings error: {e}') - return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, error_code='GTW999', error_message='An unexpected error occurred' + ).dict(), + 'rest', + ) + """ Endpoint @@ -220,30 +404,47 @@ Response: {} """ -@config_router.get('/config/export/endpoints', - description='Export endpoints (optionally filter by api_name/api_version)', - response_model=ResponseModel) -async def export_endpoints(request: Request, api_name: Optional[str] = None, api_version: Optional[str] = None): +@config_router.get( + '/config/export/endpoints', + description='Export endpoints (optionally filter by api_name/api_version)', + response_model=ResponseModel, +) +async def export_endpoints( + request: Request, api_name: str | None = None, api_version: str | None = None +): request_id = str(uuid.uuid4()) try: payload = await auth_required(request) username = payload.get('sub') if not await platform_role_required_bool(username, 'manage_endpoints'): - return process_response(ResponseModel(status_code=403, error_code='CFG007', error_message='Insufficient permissions').dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, error_code='CFG007', error_message='Insufficient permissions' + ).dict(), + 'rest', + ) query = {} if api_name: query['api_name'] = api_name if api_version: query['api_version'] = api_version eps = [_strip_id(e) for e in endpoint_collection.find(query).to_list(length=None)] - return process_response(ResponseModel(status_code=200, response={'endpoints': eps}).dict(), 'rest') + return process_response( + ResponseModel(status_code=200, response={'endpoints': eps}).dict(), 'rest' + ) except Exception as e: logger.error(f'{request_id} | export_endpoints error: {e}') - return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, error_code='GTW999', error_message='An unexpected error occurred' + ).dict(), + 'rest', + ) -def _upsert_api(doc: Dict[str, Any]) -> None: + +def _upsert_api(doc: dict[str, Any]) -> None: api_name = doc.get('api_name') api_version = doc.get('api_version') if not api_name or not api_version: @@ -258,11 +459,14 @@ def _upsert_api(doc: Dict[str, Any]) -> None: to_set.setdefault('api_id', str(uuid.uuid4())) to_set.setdefault('api_path', f'/{api_name}/{api_version}') if existing: - api_collection.update_one({'api_name': api_name, 'api_version': api_version}, {'$set': to_set}) + api_collection.update_one( + {'api_name': api_name, 'api_version': api_version}, {'$set': to_set} + ) else: api_collection.insert_one(to_set) -def _upsert_endpoint(doc: Dict[str, Any]) -> None: + +def _upsert_endpoint(doc: dict[str, Any]) -> None: api_name = doc.get('api_name') api_version = doc.get('api_version') method = doc.get('endpoint_method') @@ -275,23 +479,29 @@ def _upsert_endpoint(doc: Dict[str, Any]) -> None: if api_doc: to_set['api_id'] = api_doc.get('api_id') to_set.setdefault('endpoint_id', str(uuid.uuid4())) - existing = endpoint_collection.find_one({ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': method, - 'endpoint_uri': uri, - }) - if existing: - endpoint_collection.update_one({ + existing = endpoint_collection.find_one( + { 'api_name': api_name, 'api_version': api_version, 'endpoint_method': method, 'endpoint_uri': uri, - }, {'$set': to_set}) + } + ) + if existing: + endpoint_collection.update_one( + { + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': method, + 'endpoint_uri': uri, + }, + {'$set': to_set}, + ) else: endpoint_collection.insert_one(to_set) -def _upsert_role(doc: Dict[str, Any]) -> None: + +def _upsert_role(doc: dict[str, Any]) -> None: name = doc.get('role_name') if not name: return @@ -302,7 +512,8 @@ def _upsert_role(doc: Dict[str, Any]) -> None: else: role_collection.insert_one(to_set) -def _upsert_group(doc: Dict[str, Any]) -> None: + +def _upsert_group(doc: dict[str, Any]) -> None: name = doc.get('group_name') if not name: return @@ -313,7 +524,8 @@ def _upsert_group(doc: Dict[str, Any]) -> None: else: group_collection.insert_one(to_set) -def _upsert_routing(doc: Dict[str, Any]) -> None: + +def _upsert_routing(doc: dict[str, Any]) -> None: key = doc.get('client_key') if not key: return @@ -324,6 +536,7 @@ def _upsert_routing(doc: Dict[str, Any]) -> None: else: routing_collection.insert_one(to_set) + """ Endpoint @@ -333,11 +546,13 @@ Response: {} """ -@config_router.post('/config/import', - description='Import platform configuration (any subset of apis, endpoints, roles, groups, routings)', - response_model=ResponseModel) -async def import_all(request: Request, body: Dict[str, Any]): +@config_router.post( + '/config/import', + description='Import platform configuration (any subset of apis, endpoints, roles, groups, routings)', + response_model=ResponseModel, +) +async def import_all(request: Request, body: dict[str, Any]): request_id = str(uuid.uuid4()) start = time.time() * 1000 try: @@ -345,27 +560,52 @@ async def import_all(request: Request, body: Dict[str, Any]): username = payload.get('sub') if not await platform_role_required_bool(username, 'manage_gateway'): - return process_response(ResponseModel(status_code=403, error_code='CFG006', error_message='Insufficient permissions').dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, error_code='CFG006', error_message='Insufficient permissions' + ).dict(), + 'rest', + ) counts = {'apis': 0, 'endpoints': 0, 'roles': 0, 'groups': 0, 'routings': 0} for api in body.get('apis', []) or []: - _upsert_api(api); counts['apis'] += 1 + _upsert_api(api) + counts['apis'] += 1 for ep in body.get('endpoints', []) or []: - _upsert_endpoint(ep); counts['endpoints'] += 1 + _upsert_endpoint(ep) + counts['endpoints'] += 1 for r in body.get('roles', []) or []: - _upsert_role(r); counts['roles'] += 1 + _upsert_role(r) + counts['roles'] += 1 for g in body.get('groups', []) or []: - _upsert_group(g); counts['groups'] += 1 + _upsert_group(g) + counts['groups'] += 1 for rt in body.get('routings', []) or []: - _upsert_routing(rt); counts['routings'] += 1 + _upsert_routing(rt) + counts['routings'] += 1 try: doorman_cache.clear_all_caches() except Exception: pass - audit(request, actor=username, action='config.import', target='bulk', status='success', details={'imported': counts}, request_id=request_id) - return process_response(ResponseModel(status_code=200, response={'imported': counts}).dict(), 'rest') + audit( + request, + actor=username, + action='config.import', + target='bulk', + status='success', + details={'imported': counts}, + request_id=request_id, + ) + return process_response( + ResponseModel(status_code=200, response={'imported': counts}).dict(), 'rest' + ) except Exception as e: logger.error(f'{request_id} | import_all error: {e}') - return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, error_code='GTW999', error_message='An unexpected error occurred' + ).dict(), + 'rest', + ) finally: - logger.info(f'{request_id} | import_all took {time.time()*1000 - start:.2f}ms') + logger.info(f'{request_id} | import_all took {time.time() * 1000 - start:.2f}ms') diff --git a/backend-services/routes/credit_routes.py b/backend-services/routes/credit_routes.py index 854612a..ae341b7 100644 --- a/backend-services/routes/credit_routes.py +++ b/backend-services/routes/credit_routes.py @@ -4,20 +4,20 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import List -from fastapi import APIRouter, Depends, Request -import uuid -import time import logging +import time +import uuid +from fastapi import APIRouter, Request + +from models.credit_model import CreditModel from models.response_model import ResponseModel from models.user_credits_model import UserCreditModel -from models.credit_model import CreditModel from services.credit_service import CreditService -from utils.auth_util import auth_required -from utils.response_util import respond_rest, process_response -from utils.role_util import platform_role_required_bool from utils.audit_util import audit +from utils.auth_util import auth_required +from utils.response_util import process_response, respond_rest +from utils.role_util import platform_role_required_bool credit_router = APIRouter() @@ -32,41 +32,46 @@ Response: {} """ -@credit_router.get('/defs', + +@credit_router.get( + '/defs', description='List credit definitions', response_model=ResponseModel, - responses={ - 200: {'description': 'Successful Response'} - } + responses={200: {'description': 'Successful Response'}}, ) - async def list_credit_definitions(request: Request, page: int = 1, page_size: int = 50): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_credits'): - return respond_rest(ResponseModel( - status_code=403, - error_code='CRD002', - error_message='Unable to retrieve credits' - )) + return respond_rest( + ResponseModel( + status_code=403, error_code='CRD002', error_message='Unable to retrieve credits' + ) + ) return respond_rest(await CreditService.list_credit_defs(page, page_size, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -76,38 +81,43 @@ Response: {} """ -@credit_router.get('/defs/{api_credit_group}', - description='Get a credit definition', - response_model=ResponseModel, -) +@credit_router.get( + '/defs/{api_credit_group}', description='Get a credit definition', response_model=ResponseModel +) async def get_credit_definition(api_credit_group: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_credits'): - return respond_rest(ResponseModel( - status_code=403, - error_code='CRD002', - error_message='Unable to retrieve credits' - )) + return respond_rest( + ResponseModel( + status_code=403, error_code='CRD002', error_message='Unable to retrieve credits' + ) + ) return respond_rest(await CreditService.get_credit_def(api_credit_group, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Create a credit definition @@ -117,7 +127,9 @@ Response: {} """ -@credit_router.post('', + +@credit_router.post( + '', description='Create a credit definition', response_model=ResponseModel, responses={ @@ -125,22 +137,21 @@ Response: 'description': 'Successful Response', 'content': { 'application/json': { - 'example': { - 'message': 'Credit definition created successfully' - } + 'example': {'message': 'Credit definition created successfully'} } - } + }, } - } + }, ) - async def create_credit(credit_data: CreditModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_credits'): return respond_rest( @@ -148,24 +159,35 @@ async def create_credit(credit_data: CreditModel, request: Request): status_code=403, error_code='CRD001', error_message='You do not have permission to manage credits', - )) + ) + ) result = await CreditService.create_credit(credit_data, request_id) - audit(request, actor=username, action='credit_def.create', target=credit_data.api_credit_group, status=result.get('status_code'), details=None, request_id=request_id) + audit( + request, + actor=username, + action='credit_def.create', + target=credit_data.api_credit_group, + status=result.get('status_code'), + details=None, + request_id=request_id, + ) return respond_rest(result) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Update a credit definition @@ -175,7 +197,9 @@ Response: {} """ -@credit_router.put('/{api_credit_group}', + +@credit_router.put( + '/{api_credit_group}', description='Update a credit definition', response_model=ResponseModel, responses={ @@ -183,22 +207,21 @@ Response: 'description': 'Successful Response', 'content': { 'application/json': { - 'example': { - 'message': 'Credit definition updated successfully' - } + 'example': {'message': 'Credit definition updated successfully'} } - } + }, } - } + }, ) - -async def update_credit(api_credit_group:str, credit_data: CreditModel, request: Request): +async def update_credit(api_credit_group: str, credit_data: CreditModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_credits'): return respond_rest( @@ -206,24 +229,35 @@ async def update_credit(api_credit_group:str, credit_data: CreditModel, request: status_code=403, error_code='CRD001', error_message='You do not have permission to manage credits', - )) + ) + ) result = await CreditService.update_credit(api_credit_group, credit_data, request_id) - audit(request, actor=username, action='credit_def.update', target=api_credit_group, status=result.get('status_code'), details=None, request_id=request_id) + audit( + request, + actor=username, + action='credit_def.update', + target=api_credit_group, + status=result.get('status_code'), + details=None, + request_id=request_id, + ) return respond_rest(result) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Delete a credit definition @@ -233,7 +267,9 @@ Response: {} """ -@credit_router.delete('/{api_credit_group}', + +@credit_router.delete( + '/{api_credit_group}', description='Delete a credit definition', response_model=ResponseModel, responses={ @@ -241,22 +277,21 @@ Response: 'description': 'Successful Response', 'content': { 'application/json': { - 'example': { - 'message': 'Credit definition deleted successfully' - } + 'example': {'message': 'Credit definition deleted successfully'} } - } + }, } - } + }, ) - -async def delete_credit(api_credit_group:str, request: Request): +async def delete_credit(api_credit_group: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_credits'): return respond_rest( @@ -264,24 +299,35 @@ async def delete_credit(api_credit_group:str, request: Request): status_code=403, error_code='CRD001', error_message='You do not have permission to manage credits', - )) + ) + ) result = await CreditService.delete_credit(api_credit_group, request_id) - audit(request, actor=username, action='credit_def.delete', target=api_credit_group, status=result.get('status_code'), details=None, request_id=request_id) + audit( + request, + actor=username, + action='credit_def.delete', + target=api_credit_group, + status=result.get('status_code'), + details=None, + request_id=request_id, + ) return respond_rest(result) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Add credits for a user @@ -291,30 +337,27 @@ Response: {} """ -@credit_router.post('/{username}', + +@credit_router.post( + '/{username}', description='Add credits for a user', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'Credits saved successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'Credits saved successfully'}}}, } - } + }, ) - async def add_user_credits(username: str, credit_data: UserCreditModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_credits'): return respond_rest( @@ -322,24 +365,35 @@ async def add_user_credits(username: str, credit_data: UserCreditModel, request: status_code=403, error_code='CRD001', error_message='You do not have permission to manage credits', - )) + ) + ) result = await CreditService.add_credits(username, credit_data, request_id) - audit(request, actor=username, action='user_credits.save', target=username, status=result.get('status_code'), details={'groups': list((credit_data.users_credits or {}).keys())}, request_id=request_id) + audit( + request, + actor=username, + action='user_credits.save', + target=username, + status=result.get('status_code'), + details={'groups': list((credit_data.users_credits or {}).keys())}, + request_id=request_id, + ) return respond_rest(result) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -349,18 +403,19 @@ Response: {} """ -@credit_router.get('/all', - description='Get all user credits', - response_model=List[UserCreditModel] -) -async def get_all_users_credits(request: Request, page: int = 1, page_size: int = 10, search: str = ''): +@credit_router.get('/all', description='Get all user credits', response_model=list[UserCreditModel]) +async def get_all_users_credits( + request: Request, page: int = 1, page_size: int = 10, search: str = '' +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_credits'): return respond_rest( @@ -368,22 +423,27 @@ async def get_all_users_credits(request: Request, page: int = 1, page_size: int status_code=403, error_code='CRD002', error_message='Unable to retrieve credits for all users', - )) - return respond_rest(await CreditService.get_all_credits(page, page_size, request_id, search=search)) + ) + ) + return respond_rest( + await CreditService.get_all_credits(page, page_size, request_id, search=search) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -393,38 +453,41 @@ Response: {} """ -@credit_router.get('/{username}', - description='Get credits for a user', - response_model=UserCreditModel -) +@credit_router.get( + '/{username}', description='Get credits for a user', response_model=UserCreditModel +) async def get_credits(username: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) - if not payload.get('sub') == username and not await platform_role_required_bool(payload.get('sub'), 'manage_credits'): + if not payload.get('sub') == username and not await platform_role_required_bool( + payload.get('sub'), 'manage_credits' + ): return respond_rest( ResponseModel( status_code=403, error_code='CRD003', error_message='Unable to retrieve credits for user', - )) + ) + ) return respond_rest(await CreditService.get_user_credits(username, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Rotate API key for a user @@ -434,30 +497,25 @@ Response: {} """ -@credit_router.post('/rotate-key', + +@credit_router.post( + '/rotate-key', description='Rotate API key for the authenticated user', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'api_key': '******************' - } - } - } + 'content': {'application/json': {'example': {'api_key': '******************'}}}, } - } + }, ) - async def rotate_key(request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - + # Get group from body try: body = await request.json() @@ -466,29 +524,33 @@ async def rotate_key(request: Request): group = None if not group: - return respond_rest(ResponseModel( - status_code=400, - error_code='CRD020', - error_message='api_credit_group is required' - )) + return respond_rest( + ResponseModel( + status_code=400, + error_code='CRD020', + error_message='api_credit_group is required', + ) + ) - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - + # No special role required, just authentication - + return respond_rest(await CreditService.rotate_api_key(username, group, request_id)) - + except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/dashboard_routes.py b/backend-services/routes/dashboard_routes.py index 2ec96d8..82b2228 100644 --- a/backend-services/routes/dashboard_routes.py +++ b/backend-services/routes/dashboard_routes.py @@ -4,18 +4,18 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from fastapi import APIRouter, Request -from typing import Dict, List -import uuid -import time import logging -from datetime import datetime, timedelta +import time +import uuid +from datetime import datetime + +from fastapi import APIRouter, Request from models.response_model import ResponseModel from utils.auth_util import auth_required -from utils.response_util import respond_rest -from utils.database import user_collection, api_collection, subscriptions_collection +from utils.database import api_collection, subscriptions_collection, user_collection from utils.metrics_util import metrics_store +from utils.response_util import respond_rest dashboard_router = APIRouter() logger = logging.getLogger('doorman.gateway') @@ -29,11 +29,8 @@ Response: {} """ -@dashboard_router.get('', - description='Get dashboard data', - response_model=ResponseModel -) +@dashboard_router.get('', description='Get dashboard data', response_model=ResponseModel) async def get_dashboard_data(request: Request): """Get dashboard statistics and data""" request_id = str(uuid.uuid4()) @@ -41,14 +38,16 @@ async def get_dashboard_data(request: Request): try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') total_users = user_collection.count_documents({'active': True}) total_apis = api_collection.count_documents({}) snap = metrics_store.snapshot('30d') - monthly_usage: Dict[str, int] = {} + monthly_usage: dict[str, int] = {} for pt in snap.get('series', []): try: ts = datetime.fromtimestamp(pt['timestamp']) @@ -61,15 +60,12 @@ async def get_dashboard_data(request: Request): for username, reqs in snap.get('top_users', [])[:5]: subs = subscriptions_collection.find_one({'username': username}) or {} subscribers = len(subs.get('apis', [])) if isinstance(subs.get('apis'), list) else 0 - active_users_list.append({ - 'username': username, - 'requests': f'{int(reqs):,}', - 'subscribers': subscribers - }) + active_users_list.append( + {'username': username, 'requests': f'{int(reqs):,}', 'subscribers': subscribers} + ) popular_apis = [] for api_key, reqs in snap.get('top_apis', [])[:10]: - try: name = api_key @@ -81,11 +77,9 @@ async def get_dashboard_data(request: Request): count += 1 except Exception: count = 0 - popular_apis.append({ - 'name': name, - 'requests': f'{int(reqs):,}', - 'subscribers': count - }) + popular_apis.append( + {'name': name, 'requests': f'{int(reqs):,}', 'subscribers': count} + ) except Exception: continue @@ -95,23 +89,27 @@ async def get_dashboard_data(request: Request): 'newApis': total_apis, 'monthlyUsage': monthly_usage, 'activeUsersList': active_users_list, - 'popularApis': popular_apis + 'popularApis': popular_apis, } - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response=dashboard_data - )) + return respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + response=dashboard_data, + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/demo_routes.py b/backend-services/routes/demo_routes.py index 0ee59f8..1ca80f5 100644 --- a/backend-services/routes/demo_routes.py +++ b/backend-services/routes/demo_routes.py @@ -3,17 +3,17 @@ Protected demo seeding routes for populating the running server with dummy data. Only available to users with 'manage_gateway' OR 'manage_credits'. """ -from fastapi import APIRouter, Request -from typing import Optional -import uuid -import time import logging +import time +import uuid + +from fastapi import APIRouter, Request from models.response_model import ResponseModel -from utils.response_util import respond_rest -from utils.role_util import platform_role_required_bool, is_admin_user from utils.auth_util import auth_required from utils.demo_seed_util import run_seed +from utils.response_util import respond_rest +from utils.role_util import is_admin_user demo_router = APIRouter() logger = logging.getLogger('doorman.gateway') @@ -27,38 +27,55 @@ Response: {} """ -@demo_router.post('/seed', - description='Seed the running server with demo data', - response_model=ResponseModel -) -async def demo_seed(request: Request, - users: int = 40, - apis: int = 15, - endpoints: int = 6, - groups: int = 8, - protos: int = 6, - logs: int = 1500, - seed: Optional[int] = None): +@demo_router.post( + '/seed', description='Seed the running server with demo data', response_model=ResponseModel +) +async def demo_seed( + request: Request, + users: int = 40, + apis: int = 15, + endpoints: int = 6, + groups: int = 8, + protos: int = 6, + logs: int = 1500, + seed: int | None = None, +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await is_admin_user(username): - return respond_rest(ResponseModel( - status_code=403, - error_code='DEMO001', - error_message='Permission denied to run seeder' - )) - res = run_seed(users=users, apis=apis, endpoints=endpoints, groups=groups, protos=protos, logs=logs, seed=seed) + return respond_rest( + ResponseModel( + status_code=403, + error_code='DEMO001', + error_message='Permission denied to run seeder', + ) + ) + res = run_seed( + users=users, + apis=apis, + endpoints=endpoints, + groups=groups, + protos=protos, + logs=logs, + seed=seed, + ) return respond_rest(ResponseModel(status_code=200, response=res, message='Seed completed')) except Exception as e: logger.error(f'{request_id} | Demo seed error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel(status_code=500, error_code='DEMO999', error_message='Failed to seed demo data')) + return respond_rest( + ResponseModel( + status_code=500, error_code='DEMO999', error_message='Failed to seed demo data' + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/endpoint_routes.py b/backend-services/routes/endpoint_routes.py index 4f22f6c..7cfe436 100644 --- a/backend-services/routes/endpoint_routes.py +++ b/backend-services/routes/endpoint_routes.py @@ -4,12 +4,13 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import List -from fastapi import APIRouter, Depends, Request, HTTPException -import uuid -import time import logging +import time +import uuid +from fastapi import APIRouter, HTTPException, Request + +from models.create_endpoint_model import CreateEndpointModel from models.create_endpoint_validation_model import CreateEndpointValidationModel from models.endpoint_model_response import EndpointModelResponse from models.endpoint_validation_model_response import EndpointValidationModelResponse @@ -18,9 +19,8 @@ from models.update_endpoint_model import UpdateEndpointModel from models.update_endpoint_validation_model import UpdateEndpointValidationModel from services.endpoint_service import EndpointService from utils.auth_util import auth_required -from models.create_endpoint_model import CreateEndpointModel -from utils.response_util import respond_rest, process_response -from utils.constants import Headers, Roles, ErrorCodes, Messages +from utils.constants import ErrorCodes, Headers, Messages, Roles +from utils.response_util import process_response, respond_rest from utils.role_util import platform_role_required_bool endpoint_router = APIRouter() @@ -36,57 +36,58 @@ Response: {} """ -@endpoint_router.post('', + +@endpoint_router.post( + '', description='Add endpoint', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Endpoint created successfully' - } - } - } + 'application/json': {'example': {'message': 'Endpoint created successfully'}} + }, } - } + }, ) - async def create_endpoint(endpoint_data: CreateEndpointModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_ENDPOINTS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='END010', - error_message='You do not have permission to create endpoints' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='END010', + error_message='You do not have permission to create endpoints', + ) + ) return respond_rest(await EndpointService.create_endpoint(endpoint_data, request_id)) except HTTPException as he: raise he except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Update endpoint @@ -96,57 +97,74 @@ Response: {} """ -@endpoint_router.put('/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}', + +@endpoint_router.put( + '/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}', description='Update endpoint', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Endpoint updated successfully' - } - } - } + 'application/json': {'example': {'message': 'Endpoint updated successfully'}} + }, } - } + }, ) - -async def update_endpoint(endpoint_method: str, api_name: str, api_version: str, endpoint_uri: str, endpoint_data: UpdateEndpointModel, request: Request): +async def update_endpoint( + endpoint_method: str, + api_name: str, + api_version: str, + endpoint_uri: str, + endpoint_data: UpdateEndpointModel, + request: Request, +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_ENDPOINTS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='END011', - error_message='You do not have permission to update endpoints' - )) - return respond_rest(await EndpointService.update_endpoint(endpoint_method, api_name, api_version, '/' + endpoint_uri, endpoint_data, request_id)) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='END011', + error_message='You do not have permission to update endpoints', + ) + ) + return respond_rest( + await EndpointService.update_endpoint( + endpoint_method, + api_name, + api_version, + '/' + endpoint_uri, + endpoint_data, + request_id, + ) + ) except HTTPException as he: raise he except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Delete endpoint @@ -156,57 +174,64 @@ Response: {} """ -@endpoint_router.delete('/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}', + +@endpoint_router.delete( + '/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}', description='Delete endpoint', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Endpoint deleted successfully' - } - } - } + 'application/json': {'example': {'message': 'Endpoint deleted successfully'}} + }, } - } + }, ) - -async def delete_endpoint(endpoint_method: str, api_name: str, api_version: str, endpoint_uri: str, request: Request): +async def delete_endpoint( + endpoint_method: str, api_name: str, api_version: str, endpoint_uri: str, request: Request +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_ENDPOINTS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='END012', - error_message='You do not have permission to delete endpoints' - )) - return respond_rest(await EndpointService.delete_endpoint(endpoint_method, api_name, api_version, '/' + endpoint_uri, request_id)) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='END012', + error_message='You do not have permission to delete endpoints', + ) + ) + return respond_rest( + await EndpointService.delete_endpoint( + endpoint_method, api_name, api_version, '/' + endpoint_uri, request_id + ) + ) except HTTPException as he: raise he except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -216,36 +241,47 @@ Response: {} """ -@endpoint_router.get('/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}', - description='Get endpoint by API name, API version and endpoint uri', - response_model=EndpointModelResponse -) -async def get_endpoint(endpoint_method: str, api_name: str, api_version: str, endpoint_uri: str, request: Request): +@endpoint_router.get( + '/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}', + description='Get endpoint by API name, API version and endpoint uri', + response_model=EndpointModelResponse, +) +async def get_endpoint( + endpoint_method: str, api_name: str, api_version: str, endpoint_uri: str, request: Request +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - return respond_rest(await EndpointService.get_endpoint(endpoint_method, api_name, api_version, '/' + endpoint_uri, request_id)) + return respond_rest( + await EndpointService.get_endpoint( + endpoint_method, api_name, api_version, '/' + endpoint_uri, request_id + ) + ) except HTTPException as he: raise he except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -255,36 +291,43 @@ Response: {} """ -@endpoint_router.get('/{api_name}/{api_version}', - description='Get all endpoints for an API', - response_model=List[EndpointModelResponse] -) +@endpoint_router.get( + '/{api_name}/{api_version}', + description='Get all endpoints for an API', + response_model=list[EndpointModelResponse], +) async def get_endpoints_by_name_version(api_name: str, api_version: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - return respond_rest(await EndpointService.get_endpoints_by_name_version(api_name, api_version, request_id)) + return respond_rest( + await EndpointService.get_endpoints_by_name_version(api_name, api_version, request_id) + ) except HTTPException as he: raise he except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Create a new endpoint validation @@ -294,7 +337,9 @@ Response: {} """ -@endpoint_router.post('/endpoint/validation', + +@endpoint_router.post( + '/endpoint/validation', description='Create a new endpoint validation', response_model=ResponseModel, responses={ @@ -302,49 +347,54 @@ Response: 'description': 'Successful Response', 'content': { 'application/json': { - 'example': { - 'message': 'Endpoint validation created successfully' - } + 'example': {'message': 'Endpoint validation created successfully'} } - } + }, } - } + }, ) - -async def create_endpoint_validation(endpoint_validation_data: CreateEndpointValidationModel, request: Request): +async def create_endpoint_validation( + endpoint_validation_data: CreateEndpointValidationModel, request: Request +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_ENDPOINTS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='END013', - error_message='You do not have permission to create endpoint validations' - )) - return respond_rest(await EndpointService.create_endpoint_validation(endpoint_validation_data, request_id)) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='END013', + error_message='You do not have permission to create endpoint validations', + ) + ) + return respond_rest( + await EndpointService.create_endpoint_validation(endpoint_validation_data, request_id) + ) except HTTPException as he: raise he except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -354,45 +404,56 @@ Response: {} """ -@endpoint_router.put('/endpoint/validation/{endpoint_id}', + +@endpoint_router.put( + '/endpoint/validation/{endpoint_id}', description='Update an endpoint validation by endpoint ID', - response_model=ResponseModel + response_model=ResponseModel, ) - -async def update_endpoint_validation(endpoint_id: str, endpoint_validation_data: UpdateEndpointValidationModel, request: Request): +async def update_endpoint_validation( + endpoint_id: str, endpoint_validation_data: UpdateEndpointValidationModel, request: Request +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_ENDPOINTS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='END014', - error_message='You do not have permission to update endpoint validations' - )) - return respond_rest(await EndpointService.update_endpoint_validation(endpoint_id, endpoint_validation_data, request_id)) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='END014', + error_message='You do not have permission to update endpoint validations', + ) + ) + return respond_rest( + await EndpointService.update_endpoint_validation( + endpoint_id, endpoint_validation_data, request_id + ) + ) except HTTPException as he: raise he except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -402,45 +463,52 @@ Response: {} """ -@endpoint_router.delete('/endpoint/validation/{endpoint_id}', - description='Delete an endpoint validation by endpoint ID', - response_model=ResponseModel -) +@endpoint_router.delete( + '/endpoint/validation/{endpoint_id}', + description='Delete an endpoint validation by endpoint ID', + response_model=ResponseModel, +) async def delete_endpoint_validation(endpoint_id: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_ENDPOINTS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='END015', - error_message='You do not have permission to delete endpoint validations' - )) - return respond_rest(await EndpointService.delete_endpoint_validation(endpoint_id, request_id)) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='END015', + error_message='You do not have permission to delete endpoint validations', + ) + ) + return respond_rest( + await EndpointService.delete_endpoint_validation(endpoint_id, request_id) + ) except HTTPException as he: raise he except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -450,36 +518,41 @@ Response: {} """ -@endpoint_router.get('/endpoint/validation/{endpoint_id}', - description='Get an endpoint validation by endpoint ID', - response_model=EndpointValidationModelResponse -) +@endpoint_router.get( + '/endpoint/validation/{endpoint_id}', + description='Get an endpoint validation by endpoint ID', + response_model=EndpointValidationModelResponse, +) async def get_endpoint_validation(endpoint_id: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') return respond_rest(await EndpointService.get_endpoint_validation(endpoint_id, request_id)) except HTTPException as he: raise he except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -489,13 +562,18 @@ Response: {} """ -@endpoint_router.post('/validation', - description='Create a new endpoint validation (alias)', - response_model=ResponseModel) -async def create_endpoint_validation_alias(endpoint_validation_data: CreateEndpointValidationModel, request: Request): +@endpoint_router.post( + '/validation', + description='Create a new endpoint validation (alias)', + response_model=ResponseModel, +) +async def create_endpoint_validation_alias( + endpoint_validation_data: CreateEndpointValidationModel, request: Request +): return await create_endpoint_validation(endpoint_validation_data, request) + """ Endpoint @@ -505,13 +583,18 @@ Response: {} """ -@endpoint_router.put('/validation/{endpoint_id}', - description='Update endpoint validation by endpoint ID (alias)', - response_model=ResponseModel) -async def update_endpoint_validation_alias(endpoint_id: str, endpoint_validation_data: UpdateEndpointValidationModel, request: Request): +@endpoint_router.put( + '/validation/{endpoint_id}', + description='Update endpoint validation by endpoint ID (alias)', + response_model=ResponseModel, +) +async def update_endpoint_validation_alias( + endpoint_id: str, endpoint_validation_data: UpdateEndpointValidationModel, request: Request +): return await update_endpoint_validation(endpoint_id, endpoint_validation_data, request) + """ Endpoint @@ -521,13 +604,16 @@ Response: {} """ -@endpoint_router.delete('/validation/{endpoint_id}', - description='Delete endpoint validation by endpoint ID (alias)', - response_model=ResponseModel) +@endpoint_router.delete( + '/validation/{endpoint_id}', + description='Delete endpoint validation by endpoint ID (alias)', + response_model=ResponseModel, +) async def delete_endpoint_validation_alias(endpoint_id: str, request: Request): return await delete_endpoint_validation(endpoint_id, request) + """ Endpoint @@ -537,9 +623,11 @@ Response: {} """ -@endpoint_router.get('/validation/{endpoint_id}', - description='Get endpoint validation by endpoint ID (alias)', - response_model=EndpointValidationModelResponse) +@endpoint_router.get( + '/validation/{endpoint_id}', + description='Get endpoint validation by endpoint ID (alias)', + response_model=EndpointValidationModelResponse, +) async def get_endpoint_validation_alias(endpoint_id: str, request: Request): return await get_endpoint_validation(endpoint_id, request) diff --git a/backend-services/routes/gateway_routes.py b/backend-services/routes/gateway_routes.py index 5d987f3..4840b2b 100644 --- a/backend-services/routes/gateway_routes.py +++ b/backend-services/routes/gateway_routes.py @@ -4,30 +4,36 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from fastapi import APIRouter, HTTPException, Request, Depends -import os -import uuid -import time -import logging import json +import logging import re +import time +import uuid from datetime import datetime +from fastapi import APIRouter, Depends, HTTPException, Request + from models.response_model import ResponseModel +from services.gateway_service import GatewayService from utils import api_util -from utils.doorman_cache_util import doorman_cache -from utils.limit_throttle_util import limit_and_throttle -from utils.bandwidth_util import enforce_pre_request_limit +from utils.audit_util import audit from utils.auth_util import auth_required +from utils.bandwidth_util import enforce_pre_request_limit +from utils.doorman_cache_util import doorman_cache from utils.group_util import group_required +from utils.health_check_util import ( + check_mongodb, + check_redis, + get_active_connections, + get_memory_usage, + get_uptime, +) +from utils.ip_policy_util import enforce_api_ip_policy +from utils.limit_throttle_util import limit_and_throttle from utils.response_util import process_response from utils.role_util import platform_role_required_bool from utils.subscription_util import subscription_required -from utils.health_check_util import check_mongodb, check_redis, get_memory_usage, get_active_connections, get_uptime -from services.gateway_service import GatewayService from utils.validation_util import validation_util -from utils.audit_util import audit -from utils.ip_policy_util import enforce_api_ip_policy gateway_router = APIRouter() @@ -42,10 +48,13 @@ Response: {} """ -@gateway_router.api_route('/status', methods=['GET'], - description='Gateway status (requires manage_gateway)', - response_model=ResponseModel) +@gateway_router.api_route( + '/status', + methods=['GET'], + description='Gateway status (requires manage_gateway)', + response_model=ResponseModel, +) async def status(request: Request): """Restricted status endpoint. @@ -57,53 +66,67 @@ async def status(request: Request): payload = await auth_required(request) username = payload.get('sub') if not await platform_role_required_bool(username, 'manage_gateway'): - return process_response(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='GTW013', - error_message='Forbidden' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='GTW013', + error_message='Forbidden', + ).dict(), + 'rest', + ) mongodb_status = await check_mongodb() redis_status = await check_redis() memory_usage = get_memory_usage() active_connections = get_active_connections() uptime = get_uptime() - return process_response(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response={ - 'status': 'online', - 'mongodb': mongodb_status, - 'redis': redis_status, - 'memory_usage': memory_usage, - 'active_connections': active_connections, - 'uptime': uptime - } - ).dict(), 'rest') - except Exception as e: - if hasattr(e, 'status_code') and getattr(e, 'status_code') == 401: - return process_response(ResponseModel( - status_code=401, + return process_response( + ResponseModel( + status_code=200, response_headers={'request_id': request_id}, - error_code='GTW401', - error_message='Unauthorized' - ).dict(), 'rest') + response={ + 'status': 'online', + 'mongodb': mongodb_status, + 'redis': redis_status, + 'memory_usage': memory_usage, + 'active_connections': active_connections, + 'uptime': uptime, + }, + ).dict(), + 'rest', + ) + except Exception as e: + if hasattr(e, 'status_code') and e.status_code == 401: + return process_response( + ResponseModel( + status_code=401, + response_headers={'request_id': request_id}, + error_code='GTW401', + error_message='Unauthorized', + ).dict(), + 'rest', + ) logger.error(f'{request_id} | Status check failed: {str(e)}') - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW006', - error_message='Internal server error' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW006', + error_message='Internal server error', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Status check time {end_time - start_time}ms') + @gateway_router.get('/health', description='Public health probe', include_in_schema=False) async def health(): return {'status': 'online'} + """ Clear all caches @@ -122,26 +145,20 @@ Response: {} """ -@gateway_router.api_route('/caches', methods=['DELETE', 'OPTIONS'], + +@gateway_router.api_route( + '/caches', + methods=['DELETE', 'OPTIONS'], description='Clear all caches', response_model=ResponseModel, - dependencies=[ - Depends(auth_required) - ], + dependencies=[Depends(auth_required)], responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'All caches cleared' - } - } - } + 'content': {'application/json': {'example': {'message': 'All caches cleared'}}}, } - } + }, ) - async def clear_all_caches(request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 @@ -149,35 +166,53 @@ async def clear_all_caches(request: Request): payload = await auth_required(request) username = payload.get('sub') if not await platform_role_required_bool(username, 'manage_gateway'): - return process_response(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='GTW008', - error_message='You do not have permission to clear caches' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='GTW008', + error_message='You do not have permission to clear caches', + ).dict(), + 'rest', + ) doorman_cache.clear_all_caches() try: from utils.limit_throttle_util import reset_counters as _reset_rate + _reset_rate() except Exception: pass - audit(request, actor=username, action='gateway.clear_caches', target='all', status='success', details=None) - return process_response(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - message='All caches cleared' - ).dict(), 'rest') - except Exception as e: - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + audit( + request, + actor=username, + action='gateway.clear_caches', + target='all', + status='success', + details=None, + ) + return process_response( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + message='All caches cleared', + ).dict(), + 'rest', + ) + except Exception: + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Clear caches took {end_time - start_time:.2f}ms') + """ Endpoint @@ -196,15 +231,22 @@ Response: {} """ -@gateway_router.api_route('/rest/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'], + +@gateway_router.api_route( + '/rest/{path:path}', + methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'], description='REST gateway endpoint', response_model=ResponseModel, - include_in_schema=False) + include_in_schema=False, +) async def gateway(request: Request, path: str): - request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4()) + request_id = ( + getattr(request.state, 'request_id', None) + or request.headers.get('X-Request-ID') + or str(uuid.uuid4()) + ) start_time = time.time() * 1000 try: - parts = [p for p in (path or '').split('/') if p] api_public = False api_auth_required = True @@ -216,25 +258,45 @@ async def gateway(request: Request, path: str): try: enforce_api_ip_policy(request, resolved_api) except HTTPException as e: - return process_response(ResponseModel(status_code=e.status_code, error_code=e.detail, error_message='IP restricted').dict(), 'rest') + return process_response( + ResponseModel( + status_code=e.status_code, + error_code=e.detail, + error_message='IP restricted', + ).dict(), + 'rest', + ) endpoint_uri = '/' + '/'.join(parts[2:]) if len(parts) > 2 else '/' try: endpoints = await api_util.get_api_endpoints(resolved_api.get('api_id')) import re as _re + regex_pattern = _re.compile(r'\{[^/]+\}') - method_to_match = 'GET' if str(request.method).upper() == 'HEAD' else request.method + method_to_match = ( + 'GET' if str(request.method).upper() == 'HEAD' else request.method + ) composite = method_to_match + endpoint_uri - if not any(_re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in (endpoints or [])): - return process_response(ResponseModel( - status_code=404, - response_headers={'request_id': request_id}, - error_code='GTW003', - error_message='Endpoint does not exist for the requested API' - ).dict(), 'rest') + if not any( + _re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) + for ep in (endpoints or []) + ): + return process_response( + ResponseModel( + status_code=404, + response_headers={'request_id': request_id}, + error_code='GTW003', + error_message='Endpoint does not exist for the requested API', + ).dict(), + 'rest', + ) except Exception: pass api_public = bool(resolved_api.get('api_public')) if resolved_api else False - api_auth_required = bool(resolved_api.get('api_auth_required')) if resolved_api and resolved_api.get('api_auth_required') is not None else True + api_auth_required = ( + bool(resolved_api.get('api_auth_required')) + if resolved_api and resolved_api.get('api_auth_required') is not None + else True + ) username = None if not api_public: if api_auth_required: @@ -245,59 +307,104 @@ async def gateway(request: Request, path: str): username = payload.get('sub') await enforce_pre_request_limit(request, username) else: - pass - logger.info(f"{request_id} | Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')[:-3]}ms") - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")[:-3]}ms' + ) + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - return process_response(await GatewayService.rest_gateway(username, request, request_id, start_time, path), 'rest') + return process_response( + await GatewayService.rest_gateway(username, request, request_id, start_time, path), + 'rest', + ) except HTTPException as e: - return process_response(ResponseModel( - status_code=e.status_code, - response_headers={ - 'request_id': request_id - }, - error_code=e.detail, - error_message=e.detail - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=e.status_code, + response_headers={'request_id': request_id}, + error_code=e.detail, + error_message=e.detail, + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') -@gateway_router.get('/rest/{path:path}', description='REST gateway endpoint (GET)', response_model=ResponseModel, operation_id='rest_get') + +@gateway_router.get( + '/rest/{path:path}', + description='REST gateway endpoint (GET)', + response_model=ResponseModel, + operation_id='rest_get', +) async def rest_get(request: Request, path: str): return await gateway(request, path) -@gateway_router.post('/rest/{path:path}', description='REST gateway endpoint (POST)', response_model=ResponseModel, operation_id='rest_post') + +@gateway_router.post( + '/rest/{path:path}', + description='REST gateway endpoint (POST)', + response_model=ResponseModel, + operation_id='rest_post', +) async def rest_post(request: Request, path: str): return await gateway(request, path) -@gateway_router.put('/rest/{path:path}', description='REST gateway endpoint (PUT)', response_model=ResponseModel, operation_id='rest_put') + +@gateway_router.put( + '/rest/{path:path}', + description='REST gateway endpoint (PUT)', + response_model=ResponseModel, + operation_id='rest_put', +) async def rest_put(request: Request, path: str): return await gateway(request, path) -@gateway_router.patch('/rest/{path:path}', description='REST gateway endpoint (PATCH)', response_model=ResponseModel, operation_id='rest_patch') + +@gateway_router.patch( + '/rest/{path:path}', + description='REST gateway endpoint (PATCH)', + response_model=ResponseModel, + operation_id='rest_patch', +) async def rest_patch(request: Request, path: str): return await gateway(request, path) -@gateway_router.delete('/rest/{path:path}', description='REST gateway endpoint (DELETE)', response_model=ResponseModel, operation_id='rest_delete') + +@gateway_router.delete( + '/rest/{path:path}', + description='REST gateway endpoint (DELETE)', + response_model=ResponseModel, + operation_id='rest_delete', +) async def rest_delete(request: Request, path: str): return await gateway(request, path) -@gateway_router.head('/rest/{path:path}', description='REST gateway endpoint (HEAD)', response_model=ResponseModel, operation_id='rest_head') + +@gateway_router.head( + '/rest/{path:path}', + description='REST gateway endpoint (HEAD)', + response_model=ResponseModel, + operation_id='rest_head', +) async def rest_head(request: Request, path: str): return await gateway(request, path) + """ Endpoint @@ -316,16 +423,22 @@ Response: {} """ -@gateway_router.api_route('/rest/{path:path}', methods=['OPTIONS'], - description='REST gateway CORS preflight', include_in_schema=False) +@gateway_router.api_route( + '/rest/{path:path}', + methods=['OPTIONS'], + description='REST gateway CORS preflight', + include_in_schema=False, +) async def rest_preflight(request: Request, path: str): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: + import os as _os + from utils import api_util as _api_util from utils.doorman_cache_util import doorman_cache as _cache - import os as _os + parts = [p for p in (path or '').split('/') if p] name_ver = '' if len(parts) >= 2 and parts[1].startswith('v') and parts[1][1:].isdigit(): @@ -334,15 +447,17 @@ async def rest_preflight(request: Request, path: str): api = await _api_util.get_api(api_key, name_ver) if not api: from fastapi.responses import Response as StarletteResponse + return StarletteResponse(status_code=204, headers={'request_id': request_id}) # Optionally enforce 405 for unregistered endpoints when requested try: - if _os.getenv('STRICT_OPTIONS_405', 'false').lower() in ('1','true','yes','on'): + if _os.getenv('STRICT_OPTIONS_405', 'false').lower() in ('1', 'true', 'yes', 'on'): endpoint_uri = '/' + '/'.join(parts[2:]) if len(parts) > 2 else '/' try: endpoints = await _api_util.get_api_endpoints(api.get('api_id')) import re as _re + regex_pattern = _re.compile(r'\{[^/]+\}') # For preflight, only care that the endpoint exists for any method exists = any( @@ -356,15 +471,22 @@ async def rest_preflight(request: Request, path: str): ) if not exists: from fastapi.responses import Response as StarletteResponse - return StarletteResponse(status_code=405, headers={'request_id': request_id}) + + return StarletteResponse( + status_code=405, headers={'request_id': request_id} + ) except Exception: pass except Exception: pass origin = request.headers.get('origin') or request.headers.get('Origin') - req_method = request.headers.get('access-control-request-method') or request.headers.get('Access-Control-Request-Method') - req_headers = request.headers.get('access-control-request-headers') or request.headers.get('Access-Control-Request-Headers') + req_method = request.headers.get('access-control-request-method') or request.headers.get( + 'Access-Control-Request-Method' + ) + req_headers = request.headers.get('access-control-request-headers') or request.headers.get( + 'Access-Control-Request-Headers' + ) ok, headers = GatewayService._compute_api_cors_headers(api, origin, req_method, req_headers) # Deterministic: always decide ACAO here from API config, regardless of computation above. # 1) Remove any existing ACAO/Vary from computed headers @@ -385,14 +507,17 @@ async def rest_preflight(request: Request, path: str): pass headers = {**(headers or {}), 'request_id': request_id} from fastapi.responses import Response as StarletteResponse + return StarletteResponse(status_code=204, headers=headers) except Exception: from fastapi.responses import Response as StarletteResponse + return StarletteResponse(status_code=204, headers={'request_id': request_id}) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -411,12 +536,19 @@ Response: {} """ -@gateway_router.api_route('/soap/{path:path}', methods=['POST'], - description='SOAP gateway endpoint', - response_model=ResponseModel) +@gateway_router.api_route( + '/soap/{path:path}', + methods=['POST'], + description='SOAP gateway endpoint', + response_model=ResponseModel, +) async def soap_gateway(request: Request, path: str): - request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4()) + request_id = ( + getattr(request.state, 'request_id', None) + or request.headers.get('X-Request-ID') + or str(uuid.uuid4()) + ) start_time = time.time() * 1000 try: parts = [p for p in (path or '').split('/') if p] @@ -426,12 +558,23 @@ async def soap_gateway(request: Request, path: str): api_key = doorman_cache.get_cache('api_id_cache', f'/{parts[0]}/{parts[1]}') api = await api_util.get_api(api_key, f'/{parts[0]}/{parts[1]}') api_public = bool(api.get('api_public')) if api else False - api_auth_required = bool(api.get('api_auth_required')) if api and api.get('api_auth_required') is not None else True + api_auth_required = ( + bool(api.get('api_auth_required')) + if api and api.get('api_auth_required') is not None + else True + ) if api: try: enforce_api_ip_policy(request, api) except HTTPException as e: - return process_response(ResponseModel(status_code=e.status_code, error_code=e.detail, error_message='IP restricted').dict(), 'soap') + return process_response( + ResponseModel( + status_code=e.status_code, + error_code=e.detail, + error_message='IP restricted', + ).dict(), + 'soap', + ) username = None if not api_public: if api_auth_required: @@ -443,33 +586,43 @@ async def soap_gateway(request: Request, path: str): await enforce_pre_request_limit(request, username) else: pass - logger.info(f"{request_id} | Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')[:-3]}ms") - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")[:-3]}ms' + ) + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - return process_response(await GatewayService.soap_gateway(username, request, request_id, start_time, path), 'soap') + return process_response( + await GatewayService.soap_gateway(username, request, request_id, start_time, path), + 'soap', + ) except HTTPException as e: - return process_response(ResponseModel( - status_code=e.status_code, - response_headers={ - 'request_id': request_id - }, - error_code=e.detail, - error_message=e.detail - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=e.status_code, + response_headers={'request_id': request_id}, + error_code=e.detail, + error_message=e.detail, + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'soap') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'soap', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -488,15 +641,20 @@ Response: {} """ -@gateway_router.api_route('/soap/{path:path}', methods=['OPTIONS'], - description='SOAP gateway CORS preflight', include_in_schema=False) +@gateway_router.api_route( + '/soap/{path:path}', + methods=['OPTIONS'], + description='SOAP gateway CORS preflight', + include_in_schema=False, +) async def soap_preflight(request: Request, path: str): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: from utils import api_util as _api_util from utils.doorman_cache_util import doorman_cache as _cache + parts = [p for p in (path or '').split('/') if p] name_ver = '' if len(parts) >= 2 and parts[1].startswith('v') and parts[1][1:].isdigit(): @@ -505,10 +663,15 @@ async def soap_preflight(request: Request, path: str): api = await _api_util.get_api(api_key, name_ver) if not api: from fastapi.responses import Response as StarletteResponse + return StarletteResponse(status_code=204, headers={'request_id': request_id}) origin = request.headers.get('origin') or request.headers.get('Origin') - req_method = request.headers.get('access-control-request-method') or request.headers.get('Access-Control-Request-Method') - req_headers = request.headers.get('access-control-request-headers') or request.headers.get('Access-Control-Request-Headers') + req_method = request.headers.get('access-control-request-method') or request.headers.get( + 'Access-Control-Request-Method' + ) + req_headers = request.headers.get('access-control-request-headers') or request.headers.get( + 'Access-Control-Request-Headers' + ) ok, headers = GatewayService._compute_api_cors_headers(api, origin, req_method, req_headers) if not ok and headers: try: @@ -518,14 +681,17 @@ async def soap_preflight(request: Request, path: str): pass headers = {**(headers or {}), 'request_id': request_id} from fastapi.responses import Response as StarletteResponse + return StarletteResponse(status_code=204, headers=headers) except Exception: from fastapi.responses import Response as StarletteResponse + return StarletteResponse(status_code=204, headers={'request_id': request_id}) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -544,27 +710,49 @@ Response: {} """ -@gateway_router.api_route('/graphql/{path:path}', methods=['POST'], - description='GraphQL gateway endpoint', - response_model=ResponseModel) +@gateway_router.api_route( + '/graphql/{path:path}', + methods=['POST'], + description='GraphQL gateway endpoint', + response_model=ResponseModel, +) async def graphql_gateway(request: Request, path: str): - request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4()) + request_id = ( + getattr(request.state, 'request_id', None) + or request.headers.get('X-Request-ID') + or str(uuid.uuid4()) + ) start_time = time.time() * 1000 try: if not request.headers.get('X-API-Version'): raise HTTPException(status_code=400, detail='X-API-Version header is required') - api_name = re.sub(r'^.*/', '',request.url.path) - api_key = doorman_cache.get_cache('api_id_cache', api_name + '/' + request.headers.get('X-API-Version', 'v0')) - api = await api_util.get_api(api_key, api_name + '/' + request.headers.get('X-API-Version', 'v0')) + api_name = re.sub(r'^.*/', '', request.url.path) + api_key = doorman_cache.get_cache( + 'api_id_cache', api_name + '/' + request.headers.get('X-API-Version', 'v0') + ) + api = await api_util.get_api( + api_key, api_name + '/' + request.headers.get('X-API-Version', 'v0') + ) if api: try: enforce_api_ip_policy(request, api) except HTTPException as e: - return process_response(ResponseModel(status_code=e.status_code, error_code=e.detail, error_message='IP restricted').dict(), 'graphql') + return process_response( + ResponseModel( + status_code=e.status_code, + error_code=e.detail, + error_message='IP restricted', + ).dict(), + 'graphql', + ) api_public = bool(api.get('api_public')) if api else False - api_auth_required = bool(api.get('api_auth_required')) if api and api.get('api_auth_required') is not None else True + api_auth_required = ( + bool(api.get('api_auth_required')) + if api and api.get('api_auth_required') is not None + else True + ) username = None if not api_public: if api_auth_required: @@ -576,8 +764,12 @@ async def graphql_gateway(request: Request, path: str): await enforce_pre_request_limit(request, username) else: pass - logger.info(f"{request_id} | Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')[:-3]}ms") - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")[:-3]}ms' + ) + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if api and api.get('validation_enabled'): body = await request.json() @@ -586,36 +778,45 @@ async def graphql_gateway(request: Request, path: str): try: await validation_util.validate_graphql_request(api.get('api_id'), query, variables) except Exception as e: - return process_response(ResponseModel( - status_code=400, - response_headers={'request_id': request_id}, - error_code='GTW011', - error_message=str(e) - ).dict(), 'graphql') - return process_response(await GatewayService.graphql_gateway(username, request, request_id, start_time, path), 'graphql') + return process_response( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='GTW011', + error_message=str(e), + ).dict(), + 'graphql', + ) + return process_response( + await GatewayService.graphql_gateway(username, request, request_id, start_time, path), + 'graphql', + ) except HTTPException as e: - return process_response(ResponseModel( - status_code=e.status_code, - response_headers={ - 'request_id': request_id - }, - error_code=e.detail, - error_message=e.detail - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=e.status_code, + response_headers={'request_id': request_id}, + error_code=e.detail, + error_message=e.detail, + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'graphql') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'graphql', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -634,15 +835,24 @@ Response: {} """ -@gateway_router.api_route('/graphql/{path:path}', methods=['OPTIONS'], - description='GraphQL gateway CORS preflight', include_in_schema=False) +@gateway_router.api_route( + '/graphql/{path:path}', + methods=['OPTIONS'], + description='GraphQL gateway CORS preflight', + include_in_schema=False, +) async def graphql_preflight(request: Request, path: str): - request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4()) + request_id = ( + getattr(request.state, 'request_id', None) + or request.headers.get('X-Request-ID') + or str(uuid.uuid4()) + ) start_time = time.time() * 1000 try: from utils import api_util as _api_util from utils.doorman_cache_util import doorman_cache as _cache + api_name = path.replace('graphql/', '') api_version = request.headers.get('X-API-Version', 'v1') api_path = f'/{api_name}/{api_version}' @@ -650,10 +860,15 @@ async def graphql_preflight(request: Request, path: str): api = await _api_util.get_api(api_key, f'{api_name}/{api_version}') if not api: from fastapi.responses import Response as StarletteResponse + return StarletteResponse(status_code=204, headers={'request_id': request_id}) origin = request.headers.get('origin') or request.headers.get('Origin') - req_method = request.headers.get('access-control-request-method') or request.headers.get('Access-Control-Request-Method') - req_headers = request.headers.get('access-control-request-headers') or request.headers.get('Access-Control-Request-Headers') + req_method = request.headers.get('access-control-request-method') or request.headers.get( + 'Access-Control-Request-Method' + ) + req_headers = request.headers.get('access-control-request-headers') or request.headers.get( + 'Access-Control-Request-Headers' + ) ok, headers = GatewayService._compute_api_cors_headers(api, origin, req_method, req_headers) if not ok and headers: try: @@ -663,14 +878,17 @@ async def graphql_preflight(request: Request, path: str): pass headers = {**(headers or {}), 'request_id': request_id} from fastapi.responses import Response as StarletteResponse + return StarletteResponse(status_code=204, headers=headers) except Exception: from fastapi.responses import Response as StarletteResponse + return StarletteResponse(status_code=204, headers={'request_id': request_id}) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -689,26 +907,48 @@ Response: {} """ -@gateway_router.api_route('/grpc/{path:path}', methods=['POST'], - description='gRPC gateway endpoint', - response_model=ResponseModel) +@gateway_router.api_route( + '/grpc/{path:path}', + methods=['POST'], + description='gRPC gateway endpoint', + response_model=ResponseModel, +) async def grpc_gateway(request: Request, path: str): - request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4()) + request_id = ( + getattr(request.state, 'request_id', None) + or request.headers.get('X-Request-ID') + or str(uuid.uuid4()) + ) start_time = time.time() * 1000 try: if not request.headers.get('X-API-Version'): raise HTTPException(status_code=400, detail='X-API-Version header is required') api_name = re.sub(r'^.*/', '', request.url.path) - api_key = doorman_cache.get_cache('api_id_cache', api_name + '/' + request.headers.get('X-API-Version', 'v0')) - api = await api_util.get_api(api_key, api_name + '/' + request.headers.get('X-API-Version', 'v0')) + api_key = doorman_cache.get_cache( + 'api_id_cache', api_name + '/' + request.headers.get('X-API-Version', 'v0') + ) + api = await api_util.get_api( + api_key, api_name + '/' + request.headers.get('X-API-Version', 'v0') + ) if api: try: enforce_api_ip_policy(request, api) except HTTPException as e: - return process_response(ResponseModel(status_code=e.status_code, error_code=e.detail, error_message='IP restricted').dict(), 'grpc') + return process_response( + ResponseModel( + status_code=e.status_code, + error_code=e.detail, + error_message='IP restricted', + ).dict(), + 'grpc', + ) api_public = bool(api.get('api_public')) if api else False - api_auth_required = bool(api.get('api_auth_required')) if api and api.get('api_auth_required') is not None else True + api_auth_required = ( + bool(api.get('api_auth_required')) + if api and api.get('api_auth_required') is not None + else True + ) username = None if not api_public: if api_auth_required: @@ -720,8 +960,12 @@ async def grpc_gateway(request: Request, path: str): await enforce_pre_request_limit(request, username) else: pass - logger.info(f"{request_id} | Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')[:-3]}ms") - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")[:-3]}ms' + ) + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if api and api.get('validation_enabled'): body = await request.json() @@ -729,40 +973,47 @@ async def grpc_gateway(request: Request, path: str): try: await validation_util.validate_grpc_request(api.get('api_id'), request_data) except Exception as e: - return process_response(ResponseModel( - status_code=400, - response_headers={'request_id': request_id}, - error_code='GTW011', - error_message=str(e) - ).dict(), 'grpc') - svc_resp = await GatewayService.grpc_gateway(username, request, request_id, start_time, path) + return process_response( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='GTW011', + error_message=str(e), + ).dict(), + 'grpc', + ) + svc_resp = await GatewayService.grpc_gateway( + username, request, request_id, start_time, path + ) if not isinstance(svc_resp, dict): svc_resp = ResponseModel( status_code=500, response_headers={'request_id': request_id}, error_code='GTW006', - error_message='Internal server error' + error_message='Internal server error', ).dict() return process_response(svc_resp, 'grpc') except HTTPException as e: - return process_response(ResponseModel( - status_code=e.status_code, - response_headers={ - 'request_id': request_id - }, - error_code=e.detail, - error_message=e.detail - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=e.status_code, + response_headers={'request_id': request_id}, + error_code=e.detail, + error_message=e.detail, + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'grpc') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'grpc', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/group_routes.py b/backend-services/routes/group_routes.py index 25ac0b6..dce49e3 100644 --- a/backend-services/routes/group_routes.py +++ b/backend-services/routes/group_routes.py @@ -4,20 +4,20 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import List -from fastapi import APIRouter, Depends, Request -import uuid -import time import logging +import time +import uuid +from fastapi import APIRouter, Request + +from models.create_group_model import CreateGroupModel from models.group_model_response import GroupModelResponse from models.response_model import ResponseModel from models.update_group_model import UpdateGroupModel from services.group_service import GroupService from utils.auth_util import auth_required -from models.create_group_model import CreateGroupModel -from utils.response_util import respond_rest, process_response -from utils.constants import Headers, Roles, ErrorCodes, Messages, Defaults +from utils.constants import Defaults, ErrorCodes, Headers, Messages, Roles +from utils.response_util import process_response, respond_rest from utils.role_util import platform_role_required_bool group_router = APIRouter() @@ -33,55 +33,53 @@ Response: {} """ -@group_router.post('', + +@group_router.post( + '', description='Add group', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'Group created successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'Group created successfully'}}}, } - } + }, ) - async def create_group(api_data: CreateGroupModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_GROUPS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='GRP008', - error_message='You do not have permission to create groups' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='GRP008', + error_message='You do not have permission to create groups', + ) + ) return respond_rest(await GroupService.create_group(api_data, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Update group @@ -91,55 +89,53 @@ Response: {} """ -@group_router.put('/{group_name}', + +@group_router.put( + '/{group_name}', description='Update group', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'Group updated successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'Group updated successfully'}}}, } - } + }, ) - async def update_group(group_name: str, api_data: UpdateGroupModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_GROUPS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='GRP009', - error_message='You do not have permission to update groups' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='GRP009', + error_message='You do not have permission to update groups', + ) + ) return respond_rest(await GroupService.update_group(group_name, api_data, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Delete group @@ -149,55 +145,53 @@ Response: {} """ -@group_router.delete('/{group_name}', + +@group_router.delete( + '/{group_name}', description='Delete group', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'Group deleted successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'Group deleted successfully'}}}, } - } + }, ) - async def delete_group(group_name: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_GROUPS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='GRP010', - error_message='You do not have permission to delete groups' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='GRP010', + error_message='You do not have permission to delete groups', + ) + ) return respond_rest(await GroupService.delete_group(group_name, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -207,34 +201,37 @@ Response: {} """ -@group_router.get('/all', - description='Get all groups', - response_model=List[GroupModelResponse] -) -async def get_groups(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE): +@group_router.get('/all', description='Get all groups', response_model=list[GroupModelResponse]) +async def get_groups( + request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') return respond_rest(await GroupService.get_groups(page, page_size, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -244,30 +241,30 @@ Response: {} """ -@group_router.get('/{group_name}', - description='Get group', - response_model=GroupModelResponse -) +@group_router.get('/{group_name}', description='Get group', response_model=GroupModelResponse) async def get_group(group_name: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') return respond_rest(await GroupService.get_group(group_name, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/logging_routes.py b/backend-services/routes/logging_routes.py index c73a985..9c65222 100644 --- a/backend-services/routes/logging_routes.py +++ b/backend-services/routes/logging_routes.py @@ -4,19 +4,19 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import List, Optional -from fastapi import APIRouter, Depends, Request, Query, HTTPException -from fastapi.responses import StreamingResponse -import uuid -import time -import logging import io +import logging +import time +import uuid + +from fastapi import APIRouter, HTTPException, Query, Request +from fastapi.responses import StreamingResponse from models.response_model import ResponseModel from services.logging_service import LoggingService from utils.auth_util import auth_required -from utils.response_util import respond_rest, process_response -from utils.constants import Headers, Roles, ErrorCodes, Messages +from utils.constants import ErrorCodes, Headers, Messages, Roles +from utils.response_util import process_response, respond_rest from utils.role_util import platform_role_required_bool logging_router = APIRouter() @@ -32,7 +32,9 @@ Response: {} """ -@logging_router.get('/logs', + +@logging_router.get( + '/logs', description='Get logs with filtering', response_model=ResponseModel, responses={ @@ -55,34 +57,33 @@ Response: 'response_time': '150.5', 'ip_address': '192.168.1.1', 'protocol': 'HTTP/1.1', - 'request_id': '123e4567-e89b-12d3-a456-426614174000' + 'request_id': '123e4567-e89b-12d3-a456-426614174000', } ], 'total': 100, - 'has_more': False + 'has_more': False, } } - } + }, } - } + }, ) - async def get_logs( request: Request, - start_date: Optional[str] = Query(None, description='Start date (YYYY-MM-DD)'), - end_date: Optional[str] = Query(None, description='End date (YYYY-MM-DD)'), - start_time: Optional[str] = Query(None, description='Start time (HH:MM)'), - end_time: Optional[str] = Query(None, description='End time (HH:MM)'), - user: Optional[str] = Query(None, description='Filter by user'), - endpoint: Optional[str] = Query(None, description='Filter by endpoint'), - request_id: Optional[str] = Query(None, description='Filter by request ID'), - method: Optional[str] = Query(None, description='Filter by HTTP method'), - ip_address: Optional[str] = Query(None, description='Filter by IP address'), - min_response_time: Optional[str] = Query(None, description='Minimum response time (ms)'), - max_response_time: Optional[str] = Query(None, description='Maximum response time (ms)'), - level: Optional[str] = Query(None, description='Filter by log level'), + start_date: str | None = Query(None, description='Start date (YYYY-MM-DD)'), + end_date: str | None = Query(None, description='End date (YYYY-MM-DD)'), + start_time: str | None = Query(None, description='Start time (HH:MM)'), + end_time: str | None = Query(None, description='End time (HH:MM)'), + user: str | None = Query(None, description='Filter by user'), + endpoint: str | None = Query(None, description='Filter by endpoint'), + request_id: str | None = Query(None, description='Filter by request ID'), + method: str | None = Query(None, description='Filter by HTTP method'), + ip_address: str | None = Query(None, description='Filter by IP address'), + min_response_time: str | None = Query(None, description='Minimum response time (ms)'), + max_response_time: str | None = Query(None, description='Maximum response time (ms)'), + level: str | None = Query(None, description='Filter by log level'), limit: int = Query(100, description='Number of logs to return', ge=1, le=1000), - offset: int = Query(0, description='Number of logs to skip', ge=0) + offset: int = Query(0, description='Number of logs to skip', ge=0), ): request_id_param = str(uuid.uuid4()) start_time_param = time.time() * 1000 @@ -90,18 +91,20 @@ async def get_logs( payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id_param} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id_param} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id_param} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.VIEW_LOGS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id_param - }, - error_code='LOG001', - error_message='You do not have permission to view logs' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id_param}, + error_code='LOG001', + error_message='You do not have permission to view logs', + ) + ) logging_service = LoggingService() result = await logging_service.get_logs( @@ -119,33 +122,32 @@ async def get_logs( level=level, limit=limit, offset=offset, - request_id_param=request_id_param + request_id_param=request_id_param, ) - return respond_rest(ResponseModel( - status_code=200, - response_headers={ - 'request_id': request_id_param - }, - response=result - )) + return respond_rest( + ResponseModel( + status_code=200, response_headers={'request_id': request_id_param}, response=result + ) + ) except HTTPException as e: raise e except Exception as e: logger.critical(f'{request_id_param} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id_param - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id_param}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: end_time_param = time.time() * 1000 logger.info(f'{request_id_param} | Total time: {str(end_time_param - start_time_param)}ms') + """ Endpoint @@ -155,60 +157,59 @@ Response: {} """ -@logging_router.get('/logs/files', - description='Get list of available log files', - response_model=ResponseModel -) +@logging_router.get( + '/logs/files', description='Get list of available log files', response_model=ResponseModel +) async def get_log_files(request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.VIEW_LOGS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='LOG005', - error_message='You do not have permission to view log files' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='LOG005', + error_message='You do not have permission to view log files', + ) + ) logging_service = LoggingService() log_files = logging_service.get_available_log_files() - return respond_rest(ResponseModel( - status_code=200, - response_headers={ - 'request_id': request_id - }, - response={ - 'log_files': log_files, - 'count': len(log_files) - } - )) + return respond_rest( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + response={'log_files': log_files, 'count': len(log_files)}, + ) + ) except HTTPException as e: raise e except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Get log statistics for dashboard @@ -218,7 +219,9 @@ Response: {} """ -@logging_router.get('/logs/statistics', + +@logging_router.get( + '/logs/statistics', description='Get log statistics for dashboard', response_model=ResponseModel, responses={ @@ -235,69 +238,70 @@ Response: 'avg_response_time': 150.5, 'top_apis': [ {'name': 'customer', 'count': 500}, - {'name': 'orders', 'count': 300} + {'name': 'orders', 'count': 300}, ], 'top_users': [ {'name': 'john_doe', 'count': 200}, - {'name': 'jane_smith', 'count': 150} + {'name': 'jane_smith', 'count': 150}, ], 'top_endpoints': [ {'name': '/api/customer/v1/users', 'count': 100}, - {'name': '/api/orders/v1/orders', 'count': 80} - ] + {'name': '/api/orders/v1/orders', 'count': 80}, + ], } } - } + }, } - } + }, ) - async def get_log_statistics(request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.VIEW_LOGS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='LOG002', - error_message='You do not have permission to view log statistics' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='LOG002', + error_message='You do not have permission to view log statistics', + ) + ) logging_service = LoggingService() statistics = await logging_service.get_log_statistics(request_id) - return respond_rest(ResponseModel( - status_code=200, - response_headers={ - 'request_id': request_id - }, - response=statistics - )) + return respond_rest( + ResponseModel( + status_code=200, response_headers={'request_id': request_id}, response=statistics + ) + ) except HTTPException as e: raise e except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Export logs in various formats @@ -307,7 +311,9 @@ Response: {} """ -@logging_router.get('/logs/export', + +@logging_router.get( + '/logs/export', description='Export logs in various formats', response_model=ResponseModel, responses={ @@ -317,42 +323,44 @@ Response: 'application/json': { 'example': { 'format': 'json', - 'data': '[{\"timestamp\": \"2024-01-01T12:00:00\", \"level\": \"INFO\"}]', - 'filename': 'logs_export_20240101_120000.json' + 'data': '[{"timestamp": "2024-01-01T12:00:00", "level": "INFO"}]', + 'filename': 'logs_export_20240101_120000.json', } } - } + }, } - } + }, ) - async def export_logs( request: Request, format: str = Query('json', description='Export format (json, csv)'), - start_date: Optional[str] = Query(None, description='Start date (YYYY-MM-DD)'), - end_date: Optional[str] = Query(None, description='End date (YYYY-MM-DD)'), - user: Optional[str] = Query(None, description='Filter by user'), - api: Optional[str] = Query(None, description='Filter by API'), - endpoint: Optional[str] = Query(None, description='Filter by endpoint'), - level: Optional[str] = Query(None, description='Filter by log level') + start_date: str | None = Query(None, description='Start date (YYYY-MM-DD)'), + end_date: str | None = Query(None, description='End date (YYYY-MM-DD)'), + user: str | None = Query(None, description='Filter by user'), + api: str | None = Query(None, description='Filter by API'), + endpoint: str | None = Query(None, description='Filter by endpoint'), + level: str | None = Query(None, description='Filter by log level'), ): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.EXPORT_LOGS): - return process_response(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='LOG003', - error_message='You do not have permission to export logs' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='LOG003', + error_message='You do not have permission to export logs', + ).dict(), + 'rest', + ) logging_service = LoggingService() @@ -371,31 +379,33 @@ async def export_logs( start_date=start_date, end_date=end_date, filters=filters, - request_id=request_id + request_id=request_id, ) - return respond_rest(ResponseModel( - status_code=200, - response_headers={ - Headers.REQUEST_ID: request_id - }, - response=export_result - )) + return respond_rest( + ResponseModel( + status_code=200, + response_headers={Headers.REQUEST_ID: request_id}, + response=export_result, + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Download logs as file @@ -405,47 +415,45 @@ Response: {} """ -@logging_router.get('/logs/download', + +@logging_router.get( + '/logs/download', description='Download logs as file', include_in_schema=False, responses={ - 200: { - 'description': 'File download', - 'content': { - 'application/json': {}, - 'text/csv': {} - } - } - } + 200: {'description': 'File download', 'content': {'application/json': {}, 'text/csv': {}}} + }, ) - async def download_logs( request: Request, format: str = Query('json', description='Export format (json, csv)'), - start_date: Optional[str] = Query(None, description='Start date (YYYY-MM-DD)'), - end_date: Optional[str] = Query(None, description='End date (YYYY-MM-DD)'), - user: Optional[str] = Query(None, description='Filter by user'), - api: Optional[str] = Query(None, description='Filter by API'), - endpoint: Optional[str] = Query(None, description='Filter by endpoint'), - level: Optional[str] = Query(None, description='Filter by log level') + start_date: str | None = Query(None, description='Start date (YYYY-MM-DD)'), + end_date: str | None = Query(None, description='End date (YYYY-MM-DD)'), + user: str | None = Query(None, description='Filter by user'), + api: str | None = Query(None, description='Filter by API'), + endpoint: str | None = Query(None, description='Filter by endpoint'), + level: str | None = Query(None, description='Filter by log level'), ): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.EXPORT_LOGS): - return process_response(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='LOG004', - error_message='You do not have permission to download logs' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='LOG004', + error_message='You do not have permission to download logs', + ).dict(), + 'rest', + ) logging_service = LoggingService() @@ -464,11 +472,11 @@ async def download_logs( start_date=start_date, end_date=end_date, filters=filters, - request_id=request_id + request_id=request_id, ) file_data = export_result['data'].encode('utf-8') - file_obj = io.BytesIO(file_data) + io.BytesIO(file_data) content_type = 'application/json' if format.lower() == 'json' else 'text/csv' @@ -476,21 +484,22 @@ async def download_logs( io.BytesIO(file_data), media_type=content_type, headers={ - 'Content-Disposition': f"attachment; filename={export_result['filename']}", - Headers.REQUEST_ID: request_id - } + 'Content-Disposition': f'attachment; filename={export_result["filename"]}', + Headers.REQUEST_ID: request_id, + }, ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/memory_routes.py b/backend-services/routes/memory_routes.py index 1fbedd5..01e46ab 100644 --- a/backend-services/routes/memory_routes.py +++ b/backend-services/routes/memory_routes.py @@ -2,26 +2,28 @@ Routes for dumping and restoring in-memory database state. """ -from fastapi import APIRouter, Request -from typing import Optional -from pydantic import BaseModel -import os -import uuid -import time import logging +import os +import time +import uuid + +from fastapi import APIRouter, Request +from pydantic import BaseModel -from utils.response_util import process_response from models.response_model import ResponseModel from utils.auth_util import auth_required -from utils.role_util import platform_role_required_bool from utils.database import database from utils.memory_dump_util import dump_memory_to_file, restore_memory_from_file +from utils.response_util import process_response +from utils.role_util import platform_role_required_bool memory_router = APIRouter() logger = logging.getLogger('doorman.gateway') + class DumpRequest(BaseModel): - path: Optional[str] = None + path: str | None = None + """ Endpoint @@ -32,68 +34,88 @@ Response: {} """ -@memory_router.post('/memory/dump', + +@memory_router.post( + '/memory/dump', description='Dump in-memory database to an encrypted file', response_model=ResponseModel, ) - -async def memory_dump(request: Request, body: Optional[DumpRequest] = None): +async def memory_dump(request: Request, body: DumpRequest | None = None): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_security'): - return process_response(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='SEC003', - error_message='You do not have permission to perform memory dump' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='SEC003', + error_message='You do not have permission to perform memory dump', + ).dict(), + 'rest', + ) if not database.memory_only: - return process_response(ResponseModel( - status_code=400, - response_headers={'request_id': request_id}, - error_code='MEM001', - error_message='Memory dump available only in memory-only mode' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='MEM001', + error_message='Memory dump available only in memory-only mode', + ).dict(), + 'rest', + ) if not os.getenv('MEM_ENCRYPTION_KEY'): - return process_response(ResponseModel( - status_code=400, - response_headers={'request_id': request_id}, - error_code='MEM002', - error_message='MEM_ENCRYPTION_KEY is not configured' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='MEM002', + error_message='MEM_ENCRYPTION_KEY is not configured', + ).dict(), + 'rest', + ) path = None if body and body.path: path = body.path dump_path = dump_memory_to_file(path) - return process_response(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - message='Memory dump created successfully', - response={'response': {'path': dump_path}} - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + message='Memory dump created successfully', + response={'response': {'path': dump_path}}, + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + class RestoreRequest(BaseModel): - path: Optional[str] = None + path: str | None = None + """ Endpoint @@ -104,69 +126,90 @@ Response: {} """ -@memory_router.post('/memory/restore', + +@memory_router.post( + '/memory/restore', description='Restore in-memory database from an encrypted file', response_model=ResponseModel, ) - -async def memory_restore(request: Request, body: Optional[RestoreRequest] = None): +async def memory_restore(request: Request, body: RestoreRequest | None = None): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_security'): - return process_response(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='SEC004', - error_message='You do not have permission to perform memory restore' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='SEC004', + error_message='You do not have permission to perform memory restore', + ).dict(), + 'rest', + ) if not database.memory_only: - return process_response(ResponseModel( - status_code=400, - response_headers={'request_id': request_id}, - error_code='MEM001', - error_message='Memory restore available only in memory-only mode' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='MEM001', + error_message='Memory restore available only in memory-only mode', + ).dict(), + 'rest', + ) if not os.getenv('MEM_ENCRYPTION_KEY'): - return process_response(ResponseModel( - status_code=400, - response_headers={'request_id': request_id}, - error_code='MEM002', - error_message='MEM_ENCRYPTION_KEY is not configured' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='MEM002', + error_message='MEM_ENCRYPTION_KEY is not configured', + ).dict(), + 'rest', + ) path = None if body and body.path: path = body.path info = restore_memory_from_file(path) - return process_response(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - message='Memory restore completed', - response={'response': info} - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + message='Memory restore completed', + response={'response': info}, + ).dict(), + 'rest', + ) except FileNotFoundError as e: - return process_response(ResponseModel( - status_code=404, - response_headers={'request_id': request_id}, - error_code='MEM003', - error_message=str(e) - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=404, + response_headers={'request_id': request_id}, + error_code='MEM003', + error_message=str(e), + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/monitor_routes.py b/backend-services/routes/monitor_routes.py index 5d01460..46f046e 100644 --- a/backend-services/routes/monitor_routes.py +++ b/backend-services/routes/monitor_routes.py @@ -2,28 +2,31 @@ Routes to expose gateway metrics to the web client. """ -from fastapi import APIRouter, Request -from pydantic import BaseModel -import uuid -import time -import logging -import io import csv +import io +import logging +import time +import uuid + +from fastapi import APIRouter, Request from fastapi.responses import Response as FastAPIResponse +from pydantic import BaseModel from models.response_model import ResponseModel -from utils.response_util import process_response -from utils.metrics_util import metrics_store from services.logging_service import LoggingService from utils.auth_util import auth_required -from utils.role_util import platform_role_required_bool -from utils.health_check_util import check_mongodb, check_redis -from utils.doorman_cache_util import doorman_cache from utils.database import database +from utils.doorman_cache_util import doorman_cache +from utils.health_check_util import check_mongodb, check_redis +from utils.metrics_util import metrics_store +from utils.response_util import process_response +from utils.role_util import platform_role_required_bool + class LivenessResponse(BaseModel): status: str + class ReadinessResponse(BaseModel): status: str mongodb: bool | None = None @@ -31,6 +34,7 @@ class ReadinessResponse(BaseModel): mode: str | None = None cache_backend: str | None = None + monitor_router = APIRouter() logger = logging.getLogger('doorman.gateway') @@ -43,26 +47,32 @@ Response: {} """ -@monitor_router.get('/monitor/metrics', - description='Get aggregated gateway metrics', - response_model=ResponseModel, -) -async def get_metrics(request: Request, range: str = '24h', group: str = 'minute', sort: str = 'asc'): +@monitor_router.get( + '/monitor/metrics', description='Get aggregated gateway metrics', response_model=ResponseModel +) +async def get_metrics( + request: Request, range: str = '24h', group: str = 'minute', sort: str = 'asc' +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_gateway'): - return process_response(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='MON001', - error_message='You do not have permission to view monitor metrics' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='MON001', + error_message='You do not have permission to view monitor metrics', + ).dict(), + 'rest', + ) grp = (group or 'minute').lower() if grp not in ('minute', 'day'): grp = 'minute' @@ -70,23 +80,28 @@ async def get_metrics(request: Request, range: str = '24h', group: str = 'minute if srt not in ('asc', 'desc'): srt = 'asc' snap = metrics_store.snapshot(range, group=grp, sort=srt) - return process_response(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response=snap - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=200, response_headers={'request_id': request_id}, response=snap + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -96,12 +111,16 @@ Response: {} """ -@monitor_router.get('/monitor/liveness', + +@monitor_router.get( + '/monitor/liveness', description='Kubernetes liveness probe endpoint (no auth)', - response_model=LivenessResponse) + response_model=LivenessResponse, +) async def liveness(request: Request): return {'status': 'alive'} + """ Endpoint @@ -111,9 +130,12 @@ Response: {} """ -@monitor_router.get('/monitor/readiness', + +@monitor_router.get( + '/monitor/readiness', description='Kubernetes readiness probe endpoint. Detailed status requires manage_gateway permission.', - response_model=ReadinessResponse) + response_model=ReadinessResponse, +) async def readiness(request: Request): """Readiness probe endpoint. @@ -128,7 +150,9 @@ async def readiness(request: Request): try: payload = await auth_required(request) username = payload.get('sub') - authorized = await platform_role_required_bool(username, 'manage_gateway') if username else False + authorized = ( + await platform_role_required_bool(username, 'manage_gateway') if username else False + ) except Exception: authorized = False @@ -150,6 +174,7 @@ async def readiness(request: Request): except Exception: return {'status': 'degraded'} + """ Endpoint @@ -159,10 +184,12 @@ Response: {} """ -@monitor_router.get('/monitor/report', - description='Generate a CSV report for a date range (requires manage_gateway)', - include_in_schema=False) +@monitor_router.get( + '/monitor/report', + description='Generate a CSV report for a date range (requires manage_gateway)', + include_in_schema=False, +) async def generate_report(request: Request, start: str, end: str): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 @@ -170,15 +197,19 @@ async def generate_report(request: Request, start: str, end: str): payload = await auth_required(request) username = payload.get('sub') if not await platform_role_required_bool(username, 'manage_gateway'): - return process_response(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='MON002', - error_message='You do not have permission to generate reports' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='MON002', + error_message='You do not have permission to generate reports', + ).dict(), + 'rest', + ) def _parse_ts(s: str) -> int: from datetime import datetime + fmt_variants = ['%Y-%m-%dT%H:%M', '%Y-%m-%d'] for fmt in fmt_variants: try: @@ -199,9 +230,11 @@ async def generate_report(request: Request, start: str, end: str): ls = LoggingService() import datetime as _dt + def _to_date_time(ts: int): dt = _dt.datetime.utcfromtimestamp(ts) return dt.strftime('%Y-%m-%d'), dt.strftime('%H:%M') + start_date, start_time_str = _to_date_time(start_ts) end_date, end_time_str = _to_date_time(end_ts) @@ -211,13 +244,17 @@ async def generate_report(request: Request, start: str, end: str): max_pages = 100 for _ in range(max_pages): batch = await ls.get_logs( - start_date=start_date, end_date=end_date, - start_time=start_time_str, end_time=end_time_str, - limit=page_limit, offset=offset, request_id_param=request_id + start_date=start_date, + end_date=end_date, + start_time=start_time_str, + end_time=end_time_str, + limit=page_limit, + offset=offset, + request_id_param=request_id, ) chunk = batch.get('logs', []) logs.extend(chunk) - total_batch = batch.get('total', 0) + batch.get('total', 0) offset += page_limit if not batch.get('has_more') or not chunk: break @@ -228,12 +265,13 @@ async def generate_report(request: Request, start: str, end: str): parts = ep.split('/') return f'rest:{parts[3]}' if len(parts) > 3 else 'rest:unknown' if ep.startswith('/api/graphql/'): - return f"graphql:{ep.split('/')[-1] or 'unknown'}" + return f'graphql:{ep.split("/")[-1] or "unknown"}' if ep.startswith('/api/soap/'): - return f"soap:{ep.split('/')[-1] or 'unknown'}" + return f'soap:{ep.split("/")[-1] or "unknown"}' return 'platform' except Exception: return 'unknown' + total = 0 errors = 0 total_ms = 0.0 @@ -244,17 +282,20 @@ async def generate_report(request: Request, start: str, end: str): for e in logs: ep = str(e.get('endpoint') or '') if not ep: - continue total += 1 status_code = None try: - status_code = int(e.get('status_code')) if e.get('status_code') is not None else None + status_code = ( + int(e.get('status_code')) if e.get('status_code') is not None else None + ) except Exception: status_code = None level = (e.get('level') or '').upper() - is_error = (status_code is not None and status_code >= 400) or (level not in ('INFO', 'DEBUG')) + is_error = (status_code is not None and status_code >= 400) or ( + level not in ('INFO', 'DEBUG') + ) if is_error: errors += 1 if status_code is not None: @@ -292,6 +333,7 @@ async def generate_report(request: Request, start: str, end: str): total_bytes_in = sum(getattr(b, 'bytes_in', 0) for b in sel) total_bytes_out = sum(getattr(b, 'bytes_out', 0) for b in sel) from collections import defaultdict + daily_bw = defaultdict(lambda: {'in': 0, 'out': 0}) for b in sel: day_ts = int((b.start_ts // 86400) * 86400) @@ -308,7 +350,9 @@ async def generate_report(request: Request, start: str, end: str): w.writerow(['total_requests', total]) w.writerow(['total_errors', errors]) w.writerow(['successes', max(total - errors, 0)]) - w.writerow(['success_rate', f'{(0 if total == 0 else (100.0 * (total - errors) / total)):.2f}%']) + w.writerow( + ['success_rate', f'{(0 if total == 0 else (100.0 * (total - errors) / total)):.2f}%'] + ) w.writerow(['avg_response_ms', f'{avg_ms:.2f}']) w.writerow([]) w.writerow(['Bandwidth Overview']) @@ -341,7 +385,6 @@ async def generate_report(request: Request, start: str, end: str): w.writerow(['Bandwidth (per day, UTC)']) w.writerow(['date', 'bytes_in', 'bytes_out', 'total']) for day_ts in sorted(daily_bw.keys()): - import datetime as _dt date_str = _dt.datetime.utcfromtimestamp(day_ts).strftime('%Y-%m-%d') bi = int(daily_bw[day_ts]['in']) bo = int(daily_bw[day_ts]['out']) @@ -349,24 +392,32 @@ async def generate_report(request: Request, start: str, end: str): csv_bytes = buf.getvalue().encode('utf-8') filename = f'doorman_report_{start}_to_{end}.csv' - return FastAPIResponse(content=csv_bytes, media_type='text/csv', headers={ - 'Content-Disposition': f'attachment; filename={filename}' - }) + return FastAPIResponse( + content=csv_bytes, + media_type='text/csv', + headers={'Content-Disposition': f'attachment; filename={filename}'}, + ) except ValueError as ve: - return process_response(ResponseModel( - status_code=400, - response_headers={'request_id': request_id}, - error_code='MON003', - error_message=str(ve) - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='MON003', + error_message=str(ve), + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error in report: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/proto_routes.py b/backend-services/routes/proto_routes.py index cfa1946..cc481be 100644 --- a/backend-services/routes/proto_routes.py +++ b/backend-services/routes/proto_routes.py @@ -4,21 +4,21 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from fastapi import APIRouter, Depends, Request, UploadFile, File, HTTPException -from pathlib import Path +import logging import os import re -import logging -import uuid -import time -from datetime import datetime -import sys import subprocess +import sys +import time +import uuid +from pathlib import Path + +from fastapi import APIRouter, File, HTTPException, Request, UploadFile from models.response_model import ResponseModel from utils.auth_util import auth_required +from utils.constants import Defaults, ErrorCodes, Headers, Messages, Roles from utils.response_util import process_response -from utils.constants import Headers, Defaults, Roles, ErrorCodes, Messages from utils.role_util import platform_role_required_bool proto_router = APIRouter() @@ -26,6 +26,7 @@ logger = logging.getLogger('doorman.gateway') PROJECT_ROOT = Path(__file__).parent.resolve() + def sanitize_filename(filename: str): """Sanitize and validate filename with comprehensive security checks""" if not filename: @@ -55,10 +56,13 @@ def sanitize_filename(filename: str): safe_pattern = re.compile(r'^[a-zA-Z0-9_\-\.]+$') if not safe_pattern.match(sanitized): - raise ValueError('Filename contains invalid characters (use only letters, numbers, underscore, dash, dot)') + raise ValueError( + 'Filename contains invalid characters (use only letters, numbers, underscore, dash, dot)' + ) return sanitized + def validate_path(base_path: Path, target_path: Path): try: base_path = Path(os.path.realpath(base_path)) @@ -71,6 +75,7 @@ def validate_path(base_path: Path, target_path: Path): logger.error(f'Path validation error: {str(e)}') return False + def validate_proto_content(content: bytes, max_size: int = 1024 * 1024) -> str: """Validate proto file content for security and correctness""" if len(content) > max_size: @@ -84,20 +89,21 @@ def validate_proto_content(content: bytes, max_size: int = 1024 * 1024) -> str: except UnicodeDecodeError: raise ValueError('Invalid proto file: not valid UTF-8') - if 'syntax' not in content_str and 'message' not in content_str and 'service' not in content_str: + if ( + 'syntax' not in content_str + and 'message' not in content_str + and 'service' not in content_str + ): raise ValueError('Invalid proto file: missing proto syntax (syntax/message/service)') - suspicious_patterns = [ - r'`', - r'\$\(', - r';\s*(?:rm|mv|cp|chmod|cat|wget|curl)', - ] + suspicious_patterns = [r'`', r'\$\(', r';\s*(?:rm|mv|cp|chmod|cat|wget|curl)'] for pattern in suspicious_patterns: if re.search(pattern, content_str): raise ValueError('Invalid proto file: suspicious content detected') return content_str + def get_safe_proto_path(api_name: str, api_version: str): try: safe_api_name = sanitize_filename(api_name) @@ -108,19 +114,16 @@ def get_safe_proto_path(api_name: str, api_version: str): proto_dir.mkdir(exist_ok=True) generated_dir.mkdir(exist_ok=True) proto_path = (proto_dir / f'{key}.proto').resolve() - if not validate_path(PROJECT_ROOT, proto_path) or not validate_path(PROJECT_ROOT, generated_dir): + if not validate_path(PROJECT_ROOT, proto_path) or not validate_path( + PROJECT_ROOT, generated_dir + ): raise ValueError('Invalid path detected') return proto_path, generated_dir except ValueError as e: - raise HTTPException( - status_code=400, - detail=f'Path validation error: {str(e)}' - ) + raise HTTPException(status_code=400, detail=f'Path validation error: {str(e)}') except Exception as e: - raise HTTPException( - status_code=500, - detail=f'Failed to create safe paths: {str(e)}' - ) + raise HTTPException(status_code=500, detail=f'Failed to create safe paths: {str(e)}') + """ Upload proto file @@ -131,36 +134,43 @@ Response: {} """ -@proto_router.post('/{api_name}/{api_version}', + +@proto_router.post( + '/{api_name}/{api_version}', description='Upload proto file', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Proto file uploaded successfully' - } - } - } + 'application/json': {'example': {'message': 'Proto file uploaded successfully'}} + }, } - }) - -async def upload_proto_file(api_name: str, api_version: str, file: UploadFile = File(...), request: Request = None): + }, +) +async def upload_proto_file( + api_name: str, api_version: str, file: UploadFile = File(...), request: Request = None +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: - max_size = int(os.getenv(Defaults.MAX_MULTIPART_SIZE_BYTES_ENV, Defaults.MAX_MULTIPART_SIZE_BYTES_DEFAULT)) + max_size = int( + os.getenv( + Defaults.MAX_MULTIPART_SIZE_BYTES_ENV, Defaults.MAX_MULTIPART_SIZE_BYTES_DEFAULT + ) + ) cl = request.headers.get('content-length') if request else None try: if cl and int(cl) > max_size: - return process_response(ResponseModel( - status_code=413, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.REQUEST_TOO_LARGE, - error_message=Messages.FILE_TOO_LARGE - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=413, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.REQUEST_TOO_LARGE, + error_message=Messages.FILE_TOO_LARGE, + ).dict(), + 'rest', + ) except Exception: pass payload = await auth_required(request) @@ -168,20 +178,26 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile = logger.info(f'{request_id} | Username: {username}') logger.info(f'{request_id} | Endpoint: POST /proto/{api_name}/{api_version}') if not await platform_role_required_bool(username, Roles.MANAGE_APIS): - return process_response(ResponseModel( - status_code=403, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.AUTH_REQUIRED, - error_message=Messages.PERMISSION_MANAGE_APIS - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.AUTH_REQUIRED, + error_message=Messages.PERMISSION_MANAGE_APIS, + ).dict(), + 'rest', + ) original_name = file.filename or '' if not original_name.lower().endswith('.proto'): - return process_response(ResponseModel( - status_code=400, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.REQUEST_FILE_TYPE, - error_message=Messages.ONLY_PROTO_ALLOWED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=400, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.REQUEST_FILE_TYPE, + error_message=Messages.ONLY_PROTO_ALLOWED, + ).dict(), + 'rest', + ) proto_path, generated_dir = get_safe_proto_path(api_name, api_version) content = await file.read() @@ -189,40 +205,58 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile = max_proto_size = int(os.getenv('MAX_PROTO_SIZE_BYTES', 1024 * 1024)) proto_content = validate_proto_content(content, max_size=max_proto_size) except ValueError as e: - return process_response(ResponseModel( - status_code=400, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.REQUEST_FILE_TYPE, - error_message=f'Invalid proto file: {str(e)}' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=400, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.REQUEST_FILE_TYPE, + error_message=f'Invalid proto file: {str(e)}', + ).dict(), + 'rest', + ) safe_api_name = sanitize_filename(api_name) safe_api_version = sanitize_filename(api_version) if 'package' in proto_content: - proto_content = re.sub(r'package\s+[^;]+;', f'package {safe_api_name}_{safe_api_version};', proto_content) + proto_content = re.sub( + r'package\s+[^;]+;', f'package {safe_api_name}_{safe_api_version};', proto_content + ) else: - proto_content = re.sub(r'syntax\s*=\s*"proto3";', f'syntax = "proto3";\n\npackage {safe_api_name}_{safe_api_version};', proto_content) + proto_content = re.sub( + r'syntax\s*=\s*"proto3";', + f'syntax = "proto3";\n\npackage {safe_api_name}_{safe_api_version};', + proto_content, + ) proto_path.write_text(proto_content) try: - subprocess.run([ - sys.executable, '-m', 'grpc_tools.protoc', - f'--proto_path={proto_path.parent}', - f'--python_out={generated_dir}', - f'--grpc_python_out={generated_dir}', - str(proto_path) - ], check=True) - logger.info(f"{request_id} | Proto compiled: src={proto_path} out={generated_dir}") + subprocess.run( + [ + sys.executable, + '-m', + 'grpc_tools.protoc', + f'--proto_path={proto_path.parent}', + f'--python_out={generated_dir}', + f'--grpc_python_out={generated_dir}', + str(proto_path), + ], + check=True, + ) + logger.info(f'{request_id} | Proto compiled: src={proto_path} out={generated_dir}') init_path = (generated_dir / '__init__.py').resolve() if not validate_path(generated_dir, init_path): raise ValueError('Invalid init path') if not init_path.exists(): init_path.write_text('"""Generated gRPC code."""\n') - pb2_grpc_file = (generated_dir / f'{safe_api_name}_{safe_api_version}_pb2_grpc.py').resolve() + pb2_grpc_file = ( + generated_dir / f'{safe_api_name}_{safe_api_version}_pb2_grpc.py' + ).resolve() if not validate_path(generated_dir, pb2_grpc_file): raise ValueError('Invalid grpc file path') if pb2_grpc_file.exists(): content = pb2_grpc_file.read_text() # Double-check sanitized values contain only safe characters before using in regex - if not re.match(r'^[a-zA-Z0-9_\-\.]+$', safe_api_name) or not re.match(r'^[a-zA-Z0-9_\-\.]+$', safe_api_version): + if not re.match(r'^[a-zA-Z0-9_\-\.]+$', safe_api_name) or not re.match( + r'^[a-zA-Z0-9_\-\.]+$', safe_api_version + ): raise ValueError('Invalid characters in sanitized API name or version') escaped_mod = re.escape(f'{safe_api_name}_{safe_api_version}_pb2') import_pattern = rf'^import {escaped_mod} as (.+)$' @@ -231,22 +265,34 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile = for i, line in enumerate(lines, 1): if 'import' in line and 'pb2' in line: logger.info(f'{request_id} | Line {i}: {repr(line)}') - new_content = re.sub(import_pattern, rf'from generated import {safe_api_name}_{safe_api_version}_pb2 as \1', content, flags=re.MULTILINE) + new_content = re.sub( + import_pattern, + rf'from generated import {safe_api_name}_{safe_api_version}_pb2 as \1', + content, + flags=re.MULTILINE, + ) if new_content != content: logger.info(f'{request_id} | Import fix applied successfully') pb2_grpc_file.write_text(new_content) - logger.info(f"{request_id} | Wrote fixed pb2_grpc at {pb2_grpc_file}") + logger.info(f'{request_id} | Wrote fixed pb2_grpc at {pb2_grpc_file}') pycache_dir = (generated_dir / '__pycache__').resolve() if not validate_path(generated_dir, pycache_dir): - logger.warning(f'{request_id} | Unsafe pycache path detected. Skipping cache cleanup.') + logger.warning( + f'{request_id} | Unsafe pycache path detected. Skipping cache cleanup.' + ) elif pycache_dir.exists(): - for pyc_file in pycache_dir.glob(f'{safe_api_name}_{safe_api_version}*.pyc'): + for pyc_file in pycache_dir.glob( + f'{safe_api_name}_{safe_api_version}*.pyc' + ): try: pyc_file.unlink() logger.info(f'{request_id} | Deleted cache file: {pyc_file.name}') except Exception as e: - logger.warning(f'{request_id} | Failed to delete cache file {pyc_file.name}: {e}') + logger.warning( + f'{request_id} | Failed to delete cache file {pyc_file.name}: {e}' + ) import sys as sys_import + pb2_module_name = f'{safe_api_name}_{safe_api_version}_pb2' pb2_grpc_module_name = f'{safe_api_name}_{safe_api_version}_pb2_grpc' if pb2_module_name in sys_import.modules: @@ -254,51 +300,78 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile = logger.info(f'{request_id} | Cleared {pb2_module_name} from sys.modules') if pb2_grpc_module_name in sys_import.modules: del sys_import.modules[pb2_grpc_module_name] - logger.info(f'{request_id} | Cleared {pb2_grpc_module_name} from sys.modules') + logger.info( + f'{request_id} | Cleared {pb2_grpc_module_name} from sys.modules' + ) else: - logger.warning(f'{request_id} | Import fix pattern did not match - no changes made') + logger.warning( + f'{request_id} | Import fix pattern did not match - no changes made' + ) try: # Reuse escaped_mod which was already validated above rel_pattern = rf'^from \\. import {escaped_mod} as (.+)$' content2 = pb2_grpc_file.read_text() - new2 = re.sub(rel_pattern, rf'from generated import {safe_api_name}_{safe_api_version}_pb2 as \\1', content2, flags=re.MULTILINE) + new2 = re.sub( + rel_pattern, + rf'from generated import {safe_api_name}_{safe_api_version}_pb2 as \\1', + content2, + flags=re.MULTILINE, + ) if new2 != content2: pb2_grpc_file.write_text(new2) - logger.info(f"{request_id} | Applied relative import rewrite for module {safe_api_name}_{safe_api_version}_pb2") + logger.info( + f'{request_id} | Applied relative import rewrite for module {safe_api_name}_{safe_api_version}_pb2' + ) except Exception as e: - logger.warning(f"{request_id} | Failed relative import rewrite: {e}") - return process_response(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - message='Proto file uploaded and gRPC code generated successfully' - ).dict(), 'rest') + logger.warning(f'{request_id} | Failed relative import rewrite: {e}') + return process_response( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + message='Proto file uploaded and gRPC code generated successfully', + ).dict(), + 'rest', + ) except subprocess.CalledProcessError as e: logger.error(f'{request_id} | Failed to generate gRPC code: {str(e)}') - return process_response(ResponseModel( + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.GRPC_GENERATION_FAILED, + error_message=f'{Messages.GRPC_GEN_FAILED}: {str(e)}', + ).dict(), + 'rest', + ) + except HTTPException as e: + logger.error(f'{request_id} | Path validation error: {str(e)}') + return process_response( + ResponseModel( + status_code=e.status_code, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.PATH_VALIDATION, + error_message=str(e.detail), + ).dict(), + 'rest', + ) + except Exception as e: + logger.error( + f'{request_id} | Error uploading proto file: {type(e).__name__}: {str(e)}', + exc_info=True, + ) + return process_response( + ResponseModel( status_code=500, response_headers={Headers.REQUEST_ID: request_id}, error_code=ErrorCodes.GRPC_GENERATION_FAILED, - error_message=f'{Messages.GRPC_GEN_FAILED}: {str(e)}' - ).dict(), 'rest') - except HTTPException as e: - logger.error(f'{request_id} | Path validation error: {str(e)}') - return process_response(ResponseModel( - status_code=e.status_code, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.PATH_VALIDATION, - error_message=str(e.detail) - ).dict(), 'rest') - except Exception as e: - logger.error(f'{request_id} | Error uploading proto file: {type(e).__name__}: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.GRPC_GENERATION_FAILED, - error_message=f'Failed to upload proto file: {str(e)}' - ).dict(), 'rest') + error_message=f'Failed to upload proto file: {str(e)}', + ).dict(), + 'rest', + ) finally: logger.info(f'{request_id} | Total time: {time.time() * 1000 - start_time}ms') + """ Get proto file @@ -308,23 +381,20 @@ Response: {} """ -@proto_router.get('/{api_name}/{api_version}', + +@proto_router.get( + '/{api_name}/{api_version}', description='Get proto file', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Proto file retrieved successfully' - } - } - } + 'application/json': {'example': {'message': 'Proto file retrieved successfully'}} + }, } - } + }, ) - async def get_proto_file(api_name: str, api_version: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 @@ -334,46 +404,62 @@ async def get_proto_file(api_name: str, api_version: str, request: Request): logger.info(f'{request_id} | Endpoint: {request.method} {request.url.path}') try: if not await platform_role_required_bool(username, Roles.MANAGE_APIS): - return process_response(ResponseModel( - status_code=403, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.AUTH_REQUIRED, - error_message=Messages.PERMISSION_MANAGE_APIS - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.AUTH_REQUIRED, + error_message=Messages.PERMISSION_MANAGE_APIS, + ).dict(), + 'rest', + ) proto_path, _ = get_safe_proto_path(api_name, api_version) if not proto_path.exists(): - return process_response(ResponseModel( - status_code=404, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.API_NOT_FOUND, - error_message=f'Proto file not found for API {api_name}/{api_version}' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=404, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.API_NOT_FOUND, + error_message=f'Proto file not found for API {api_name}/{api_version}', + ).dict(), + 'rest', + ) proto_content = proto_path.read_text() - return process_response(ResponseModel( - status_code=200, - response_headers={Headers.REQUEST_ID: request_id}, - message='Proto file retrieved successfully', - response={'content': proto_content} - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=200, + response_headers={Headers.REQUEST_ID: request_id}, + message='Proto file retrieved successfully', + response={'content': proto_content}, + ).dict(), + 'rest', + ) except HTTPException as e: logger.error(f'{request_id} | Path validation error: {str(e)}') - return process_response(ResponseModel( - status_code=e.status_code, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.PATH_VALIDATION, - error_message=str(e.detail) - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=e.status_code, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.PATH_VALIDATION, + error_message=str(e.detail), + ).dict(), + 'rest', + ) except Exception as e: logger.error(f'{request_id} | Failed to get proto file: {str(e)}') - return process_response(ResponseModel( - status_code=500, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.API_NOT_FOUND, - error_message=f'Failed to get proto file: {str(e)}' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.API_NOT_FOUND, + error_message=f'Failed to get proto file: {str(e)}', + ).dict(), + 'rest', + ) finally: logger.info(f'{request_id} | Total time: {time.time() * 1000 - start_time}ms') + """ Update proto file @@ -383,46 +469,53 @@ Response: {} """ -@proto_router.put('/{api_name}/{api_version}', + +@proto_router.put( + '/{api_name}/{api_version}', description='Update proto file', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Proto file updated successfully' - } - } - } + 'application/json': {'example': {'message': 'Proto file updated successfully'}} + }, } - } + }, ) - -async def update_proto_file(api_name: str, api_version: str, request: Request, proto_file: UploadFile = File(...)): +async def update_proto_file( + api_name: str, api_version: str, request: Request, proto_file: UploadFile = File(...) +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_APIS): - return process_response(ResponseModel( - status_code=403, - response_headers={Headers.REQUEST_ID: request_id}, - error_code='API008', - error_message='You do not have permission to update proto files' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='API008', + error_message='You do not have permission to update proto files', + ).dict(), + 'rest', + ) original_name = proto_file.filename or '' if not original_name.lower().endswith('.proto'): - return process_response(ResponseModel( - status_code=400, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.REQUEST_FILE_TYPE, - error_message=Messages.ONLY_PROTO_ALLOWED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=400, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.REQUEST_FILE_TYPE, + error_message=Messages.ONLY_PROTO_ALLOWED, + ).dict(), + 'rest', + ) proto_path, generated_dir = get_safe_proto_path(api_name, api_version) content = await proto_file.read() @@ -430,55 +523,76 @@ async def update_proto_file(api_name: str, api_version: str, request: Request, p max_proto_size = int(os.getenv('MAX_PROTO_SIZE_BYTES', 1024 * 1024)) proto_content = validate_proto_content(content, max_size=max_proto_size) except ValueError as e: - return process_response(ResponseModel( - status_code=400, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.REQUEST_FILE_TYPE, - error_message=f'Invalid proto file: {str(e)}' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=400, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.REQUEST_FILE_TYPE, + error_message=f'Invalid proto file: {str(e)}', + ).dict(), + 'rest', + ) proto_path.write_text(proto_content) try: - subprocess.run([ - sys.executable, '-m', 'grpc_tools.protoc', - f'--proto_path={proto_path.parent}', - f'--python_out={generated_dir}', - f'--grpc_python_out={generated_dir}', - str(proto_path) - ], check=True) + subprocess.run( + [ + sys.executable, + '-m', + 'grpc_tools.protoc', + f'--proto_path={proto_path.parent}', + f'--python_out={generated_dir}', + f'--grpc_python_out={generated_dir}', + str(proto_path), + ], + check=True, + ) except subprocess.CalledProcessError as e: logger.error(f'{request_id} | Failed to generate gRPC code: {str(e)}') - return process_response(ResponseModel( - status_code=500, + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='API009', + error_message='Failed to generate gRPC code from proto file', + ).dict(), + 'rest', + ) + return process_response( + ResponseModel( + status_code=200, response_headers={Headers.REQUEST_ID: request_id}, - error_code='API009', - error_message='Failed to generate gRPC code from proto file' - ).dict(), 'rest') - return process_response(ResponseModel( - status_code=200, - response_headers={Headers.REQUEST_ID: request_id}, - message='Proto file updated successfully' - ).dict(), 'rest') + message='Proto file updated successfully', + ).dict(), + 'rest', + ) except HTTPException as e: logger.error(f'{request_id} | Path validation error: {str(e)}') - return process_response(ResponseModel( - status_code=e.status_code, - response_headers={'request_id': request_id}, - error_code='GTW013', - error_message=str(e.detail) - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=e.status_code, + response_headers={'request_id': request_id}, + error_code='GTW013', + error_message=str(e.detail), + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Delete proto file @@ -488,38 +602,40 @@ Response: {} """ -@proto_router.delete('/{api_name}/{api_version}', + +@proto_router.delete( + '/{api_name}/{api_version}', description='Delete proto file', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Proto file deleted successfully' - } - } - } + 'application/json': {'example': {'message': 'Proto file deleted successfully'}} + }, } - } + }, ) - async def delete_proto_file(api_name: str, api_version: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_APIS): - return process_response(ResponseModel( - status_code=403, - response_headers={Headers.REQUEST_ID: request_id}, - error_code='API008', - error_message='You do not have permission to delete proto files' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='API008', + error_message='You do not have permission to delete proto files', + ).dict(), + 'rest', + ) proto_path, generated_dir = get_safe_proto_path(api_name, api_version) safe_api_name = sanitize_filename(api_name) safe_api_version = sanitize_filename(api_version) @@ -529,36 +645,52 @@ async def delete_proto_file(api_name: str, api_version: str, request: Request): raise ValueError('Unsafe proto file path detected') proto_path.unlink() logger.info(f'{request_id} | Deleted proto file: {proto_path}') - generated_files = [f'{key}_pb2.py', f'{key}_pb2.pyc', f'{key}_pb2_grpc.py', f'{key}_pb2_grpc.pyc'] + generated_files = [ + f'{key}_pb2.py', + f'{key}_pb2.pyc', + f'{key}_pb2_grpc.py', + f'{key}_pb2_grpc.pyc', + ] for file in generated_files: file_path = (generated_dir / file).resolve() if not validate_path(generated_dir, file_path): - logger.warning(f'{request_id} | Unsafe file path detected: {file_path}. Skipping deletion.') + logger.warning( + f'{request_id} | Unsafe file path detected: {file_path}. Skipping deletion.' + ) continue if file_path.exists(): file_path.unlink() logger.info(f'{request_id} | Deleted generated file: {file_path}') - return process_response(ResponseModel( - status_code=200, - response_headers={Headers.REQUEST_ID: request_id}, - message='Proto file and generated files deleted successfully' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=200, + response_headers={Headers.REQUEST_ID: request_id}, + message='Proto file and generated files deleted successfully', + ).dict(), + 'rest', + ) except ValueError as e: logger.error(f'{request_id} | Path validation error: {str(e)}') - return process_response(ResponseModel( - status_code=400, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.PATH_VALIDATION, - error_message=str(e) - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=400, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.PATH_VALIDATION, + error_message=str(e), + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={Headers.REQUEST_ID: request_id}, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/quota_routes.py b/backend-services/routes/quota_routes.py index f552cad..3cad36c 100644 --- a/backend-services/routes/quota_routes.py +++ b/backend-services/routes/quota_routes.py @@ -6,15 +6,15 @@ User-facing endpoints for checking current usage and limits. """ import logging -from typing import Optional from datetime import datetime -from fastapi import APIRouter, HTTPException, Depends, status + +from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel -from models.rate_limit_models import QuotaType, TierLimits +from models.rate_limit_models import QuotaType from services.tier_service import TierService, get_tier_service -from utils.quota_tracker import QuotaTracker, get_quota_tracker from utils.database_async import async_database +from utils.quota_tracker import QuotaTracker, get_quota_tracker logger = logging.getLogger(__name__) @@ -25,8 +25,10 @@ quota_router = APIRouter() # RESPONSE MODELS # ============================================================================ + class QuotaStatusResponse(BaseModel): """Response model for quota status""" + quota_type: str current_usage: int limit: int @@ -43,16 +45,18 @@ class QuotaStatusResponse(BaseModel): class TierInfoResponse(BaseModel): """Response model for tier information""" + tier_id: str tier_name: str display_name: str limits: dict - price_monthly: Optional[float] + price_monthly: float | None features: list class QuotaDashboardResponse(BaseModel): """Complete quota dashboard response""" + user_id: str tier_info: TierInfoResponse quotas: list @@ -63,6 +67,7 @@ class QuotaDashboardResponse(BaseModel): # DEPENDENCY INJECTION # ============================================================================ + async def get_quota_tracker_dep() -> QuotaTracker: """Dependency to get quota tracker""" return get_quota_tracker() @@ -77,27 +82,28 @@ async def get_tier_service_dep() -> TierService: def get_current_user_id() -> str: """ Get current user ID from request context - + In production, this should extract from JWT token or session. For now, returns a placeholder. """ # TODO: Extract from auth middleware - return "current_user" + return 'current_user' # ============================================================================ # QUOTA STATUS ENDPOINTS # ============================================================================ -@quota_router.get("/status", response_model=QuotaDashboardResponse) + +@quota_router.get('/status', response_model=QuotaDashboardResponse) async def get_quota_status( user_id: str = Depends(get_current_user_id), quota_tracker: QuotaTracker = Depends(get_quota_tracker_dep), - tier_service: TierService = Depends(get_tier_service_dep) + tier_service: TierService = Depends(get_tier_service_dep), ): """ Get complete quota status for current user - + Returns: - Current tier information - All quota usages @@ -106,82 +112,78 @@ async def get_quota_status( try: # Get user's tier tier = await tier_service.get_user_tier(user_id) - + if not tier: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No tier assigned to user" + status_code=status.HTTP_404_NOT_FOUND, detail='No tier assigned to user' ) - + # Get user's effective limits (including overrides) limits = await tier_service.get_user_limits(user_id) - + if not limits: limits = tier.limits - + # Check all quotas quotas = [] - + # Monthly request quota if limits.monthly_request_quota: result = quota_tracker.check_quota( - user_id, - QuotaType.REQUESTS, - limits.monthly_request_quota, - 'month' + user_id, QuotaType.REQUESTS, limits.monthly_request_quota, 'month' ) - quotas.append(QuotaStatusResponse( - quota_type='monthly_requests', - current_usage=result.current_usage, - limit=result.limit, - remaining=result.remaining, - percentage_used=result.percentage_used, - reset_at=result.reset_at.isoformat(), - is_warning=result.is_warning, - is_critical=result.is_critical, - is_exhausted=result.is_exhausted - )) - + quotas.append( + QuotaStatusResponse( + quota_type='monthly_requests', + current_usage=result.current_usage, + limit=result.limit, + remaining=result.remaining, + percentage_used=result.percentage_used, + reset_at=result.reset_at.isoformat(), + is_warning=result.is_warning, + is_critical=result.is_critical, + is_exhausted=result.is_exhausted, + ) + ) + # Daily request quota if limits.daily_request_quota: result = quota_tracker.check_quota( - user_id, - QuotaType.REQUESTS, - limits.daily_request_quota, - 'day' + user_id, QuotaType.REQUESTS, limits.daily_request_quota, 'day' ) - quotas.append(QuotaStatusResponse( - quota_type='daily_requests', - current_usage=result.current_usage, - limit=result.limit, - remaining=result.remaining, - percentage_used=result.percentage_used, - reset_at=result.reset_at.isoformat(), - is_warning=result.is_warning, - is_critical=result.is_critical, - is_exhausted=result.is_exhausted - )) - + quotas.append( + QuotaStatusResponse( + quota_type='daily_requests', + current_usage=result.current_usage, + limit=result.limit, + remaining=result.remaining, + percentage_used=result.percentage_used, + reset_at=result.reset_at.isoformat(), + is_warning=result.is_warning, + is_critical=result.is_critical, + is_exhausted=result.is_exhausted, + ) + ) + # Monthly bandwidth quota if limits.monthly_bandwidth_quota: result = quota_tracker.check_quota( - user_id, - QuotaType.BANDWIDTH, - limits.monthly_bandwidth_quota, - 'month' + user_id, QuotaType.BANDWIDTH, limits.monthly_bandwidth_quota, 'month' ) - quotas.append(QuotaStatusResponse( - quota_type='monthly_bandwidth', - current_usage=result.current_usage, - limit=result.limit, - remaining=result.remaining, - percentage_used=result.percentage_used, - reset_at=result.reset_at.isoformat(), - is_warning=result.is_warning, - is_critical=result.is_critical, - is_exhausted=result.is_exhausted - )) - + quotas.append( + QuotaStatusResponse( + quota_type='monthly_bandwidth', + current_usage=result.current_usage, + limit=result.limit, + remaining=result.remaining, + percentage_used=result.percentage_used, + reset_at=result.reset_at.isoformat(), + is_warning=result.is_warning, + is_critical=result.is_critical, + is_exhausted=result.is_exhausted, + ) + ) + # Build tier info tier_info = TierInfoResponse( tier_id=tier.tier_id, @@ -189,85 +191,79 @@ async def get_quota_status( display_name=tier.display_name, limits=limits.to_dict(), price_monthly=tier.price_monthly, - features=tier.features + features=tier.features, ) - + # Build usage summary total_usage = sum(q.current_usage for q in quotas if 'requests' in q.quota_type) total_limit = sum(q.limit for q in quotas if 'requests' in q.quota_type) - + usage_summary = { 'total_requests_used': total_usage, 'total_requests_limit': total_limit, 'has_warnings': any(q.is_warning for q in quotas), 'has_critical': any(q.is_critical for q in quotas), - 'has_exhausted': any(q.is_exhausted for q in quotas) + 'has_exhausted': any(q.is_exhausted for q in quotas), } - + return QuotaDashboardResponse( - user_id=user_id, - tier_info=tier_info, - quotas=quotas, - usage_summary=usage_summary + user_id=user_id, tier_info=tier_info, quotas=quotas, usage_summary=usage_summary ) - + except HTTPException: raise except Exception as e: - logger.error(f"Error getting quota status: {e}") + logger.error(f'Error getting quota status: {e}') raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get quota status" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get quota status' ) -@quota_router.get("/status/{quota_type}") +@quota_router.get('/status/{quota_type}') async def get_specific_quota_status( quota_type: str, user_id: str = Depends(get_current_user_id), quota_tracker: QuotaTracker = Depends(get_quota_tracker_dep), - tier_service: TierService = Depends(get_tier_service_dep) + tier_service: TierService = Depends(get_tier_service_dep), ): """ Get status for a specific quota type - + Args: quota_type: Type of quota (monthly_requests, daily_requests, monthly_bandwidth) """ try: # Get user's limits limits = await tier_service.get_user_limits(user_id) - + if not limits: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No limits found for user" + status_code=status.HTTP_404_NOT_FOUND, detail='No limits found for user' ) - + # Map quota type to limit and period quota_mapping = { 'monthly_requests': (QuotaType.REQUESTS, limits.monthly_request_quota, 'month'), 'daily_requests': (QuotaType.REQUESTS, limits.daily_request_quota, 'day'), - 'monthly_bandwidth': (QuotaType.BANDWIDTH, limits.monthly_bandwidth_quota, 'month') + 'monthly_bandwidth': (QuotaType.BANDWIDTH, limits.monthly_bandwidth_quota, 'month'), } - + if quota_type not in quota_mapping: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid quota type: {quota_type}" + status_code=status.HTTP_400_BAD_REQUEST, detail=f'Invalid quota type: {quota_type}' ) - + q_type, limit, period = quota_mapping[quota_type] - + if not limit: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Quota {quota_type} not configured for user" + detail=f'Quota {quota_type} not configured for user', ) - + # Check quota result = quota_tracker.check_quota(user_id, q_type, limit, period) - + return QuotaStatusResponse( quota_type=quota_type, current_usage=result.current_usage, @@ -277,180 +273,159 @@ async def get_specific_quota_status( reset_at=result.reset_at.isoformat(), is_warning=result.is_warning, is_critical=result.is_critical, - is_exhausted=result.is_exhausted + is_exhausted=result.is_exhausted, ) - + except HTTPException: raise except Exception as e: - logger.error(f"Error getting quota status: {e}") + logger.error(f'Error getting quota status: {e}') raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get quota status" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get quota status' ) -@quota_router.get("/usage/history") +@quota_router.get('/usage/history') async def get_usage_history( user_id: str = Depends(get_current_user_id), - quota_tracker: QuotaTracker = Depends(get_quota_tracker_dep) + quota_tracker: QuotaTracker = Depends(get_quota_tracker_dep), ): """ Get historical usage data - + Returns usage history for the past 6 months. """ try: # Get history from quota tracker - history = quota_tracker.get_quota_history( - user_id, - QuotaType.REQUESTS, - months=6 - ) - + history = quota_tracker.get_quota_history(user_id, QuotaType.REQUESTS, months=6) + return { 'user_id': user_id, 'history': history, - 'note': 'Historical tracking not yet fully implemented' + 'note': 'Historical tracking not yet fully implemented', } - + except Exception as e: - logger.error(f"Error getting usage history: {e}") + logger.error(f'Error getting usage history: {e}') raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get usage history" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get usage history' ) -@quota_router.post("/usage/export") +@quota_router.post('/usage/export') async def export_usage_data( format: str = 'json', user_id: str = Depends(get_current_user_id), quota_tracker: QuotaTracker = Depends(get_quota_tracker_dep), - tier_service: TierService = Depends(get_tier_service_dep) + tier_service: TierService = Depends(get_tier_service_dep), ): """ Export usage data in JSON or CSV format - + Args: format: Export format ('json' or 'csv') """ try: # Get current quota status limits = await tier_service.get_user_limits(user_id) - + if not limits: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No limits found for user" + status_code=status.HTTP_404_NOT_FOUND, detail='No limits found for user' ) - + # Collect all quota data - export_data = { - 'user_id': user_id, - 'export_date': datetime.now().isoformat(), - 'quotas': [] - } - + export_data = {'user_id': user_id, 'export_date': datetime.now().isoformat(), 'quotas': []} + # Add monthly requests if limits.monthly_request_quota: result = quota_tracker.check_quota( - user_id, - QuotaType.REQUESTS, - limits.monthly_request_quota, - 'month' + user_id, QuotaType.REQUESTS, limits.monthly_request_quota, 'month' ) - export_data['quotas'].append({ - 'type': 'monthly_requests', - 'current_usage': result.current_usage, - 'limit': result.limit, - 'remaining': result.remaining, - 'percentage_used': result.percentage_used, - 'reset_at': result.reset_at.isoformat() - }) - + export_data['quotas'].append( + { + 'type': 'monthly_requests', + 'current_usage': result.current_usage, + 'limit': result.limit, + 'remaining': result.remaining, + 'percentage_used': result.percentage_used, + 'reset_at': result.reset_at.isoformat(), + } + ) + # Add daily requests if limits.daily_request_quota: result = quota_tracker.check_quota( - user_id, - QuotaType.REQUESTS, - limits.daily_request_quota, - 'day' + user_id, QuotaType.REQUESTS, limits.daily_request_quota, 'day' ) - export_data['quotas'].append({ - 'type': 'daily_requests', - 'current_usage': result.current_usage, - 'limit': result.limit, - 'remaining': result.remaining, - 'percentage_used': result.percentage_used, - 'reset_at': result.reset_at.isoformat() - }) - + export_data['quotas'].append( + { + 'type': 'daily_requests', + 'current_usage': result.current_usage, + 'limit': result.limit, + 'remaining': result.remaining, + 'percentage_used': result.percentage_used, + 'reset_at': result.reset_at.isoformat(), + } + ) + if format == 'csv': # Convert to CSV format csv_lines = ['Type,Current Usage,Limit,Remaining,Percentage Used,Reset At'] for quota in export_data['quotas']: csv_lines.append( - f"{quota['type']},{quota['current_usage']},{quota['limit']}," - f"{quota['remaining']},{quota['percentage_used']:.2f},{quota['reset_at']}" + f'{quota["type"]},{quota["current_usage"]},{quota["limit"]},' + f'{quota["remaining"]},{quota["percentage_used"]:.2f},{quota["reset_at"]}' ) - - return { - 'format': 'csv', - 'data': '\n'.join(csv_lines) - } + + return {'format': 'csv', 'data': '\n'.join(csv_lines)} else: # Return JSON - return { - 'format': 'json', - 'data': export_data - } - + return {'format': 'json', 'data': export_data} + except HTTPException: raise except Exception as e: - logger.error(f"Error exporting usage data: {e}") + logger.error(f'Error exporting usage data: {e}') raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to export usage data" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to export usage data' ) -@quota_router.get("/tier/info") +@quota_router.get('/tier/info') async def get_tier_info( user_id: str = Depends(get_current_user_id), - tier_service: TierService = Depends(get_tier_service_dep) + tier_service: TierService = Depends(get_tier_service_dep), ): """ Get current tier information for user - + Returns tier details, benefits, and upgrade options. """ try: # Get user's tier tier = await tier_service.get_user_tier(user_id) - + if not tier: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No tier assigned to user" + status_code=status.HTTP_404_NOT_FOUND, detail='No tier assigned to user' ) - + # Get all available tiers for upgrade options all_tiers = await tier_service.list_tiers(enabled_only=True) - + # Filter tiers higher than current (for upgrades) upgrade_options = [ { 'tier_id': t.tier_id, 'display_name': t.display_name, 'price_monthly': t.price_monthly, - 'features': t.features + 'features': t.features, } for t in all_tiers if t.tier_id != tier.tier_id ] - + return { 'current_tier': TierInfoResponse( tier_id=tier.tier_id, @@ -458,45 +433,43 @@ async def get_tier_info( display_name=tier.display_name, limits=tier.limits.to_dict(), price_monthly=tier.price_monthly, - features=tier.features + features=tier.features, ), - 'upgrade_options': upgrade_options + 'upgrade_options': upgrade_options, } - + except HTTPException: raise except Exception as e: - logger.error(f"Error getting tier info: {e}") + logger.error(f'Error getting tier info: {e}') raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get tier info" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get tier info' ) -@quota_router.get("/burst/status") +@quota_router.get('/burst/status') async def get_burst_status( user_id: str = Depends(get_current_user_id), - tier_service: TierService = Depends(get_tier_service_dep) + tier_service: TierService = Depends(get_tier_service_dep), ): """ Get burst usage status for current user - + Returns burst token consumption across different time windows. """ try: # Get user's tier and limits tier = await tier_service.get_user_tier(user_id) - + if not tier: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No tier assigned to user" + status_code=status.HTTP_404_NOT_FOUND, detail='No tier assigned to user' ) - + limits = await tier_service.get_user_limits(user_id) if not limits: limits = tier.limits - + # TODO: Get actual burst usage from Redis # For now, return structure with placeholder data burst_status = { @@ -504,23 +477,22 @@ async def get_burst_status( 'burst_limits': { 'per_minute': limits.burst_per_minute, 'per_hour': limits.burst_per_hour, - 'per_second': limits.burst_per_second + 'per_second': limits.burst_per_second, }, 'burst_usage': { 'per_minute': 0, # TODO: Get from Redis 'per_hour': 0, - 'per_second': 0 + 'per_second': 0, }, - 'note': 'Burst tracking requires rate limiter integration' + 'note': 'Burst tracking requires rate limiter integration', } - + return burst_status - + except HTTPException: raise except Exception as e: - logger.error(f"Error getting burst status: {e}") + logger.error(f'Error getting burst status: {e}') raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get burst status" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get burst status' ) diff --git a/backend-services/routes/rate_limit_rule_routes.py b/backend-services/routes/rate_limit_rule_routes.py index 914fad6..b61f7a7 100644 --- a/backend-services/routes/rate_limit_rule_routes.py +++ b/backend-services/routes/rate_limit_rule_routes.py @@ -5,8 +5,8 @@ FastAPI routes for managing rate limit rules. """ import logging -from typing import List, Optional -from fastapi import APIRouter, HTTPException, Depends, status, Query + +from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel, Field from models.rate_limit_models import RateLimitRule, RuleType, TimeWindow @@ -22,58 +22,69 @@ rate_limit_rule_router = APIRouter() # REQUEST/RESPONSE MODELS # ============================================================================ + class RuleCreateRequest(BaseModel): """Request model for creating a rule""" - rule_id: str = Field(..., description="Unique rule identifier") - rule_type: str = Field(..., description="Rule type (per_user, per_api, per_endpoint, per_ip, global)") - time_window: str = Field(..., description="Time window (second, minute, hour, day, month)") - limit: int = Field(..., gt=0, description="Maximum requests allowed") - target_identifier: Optional[str] = Field(None, description="Target (user ID, API name, endpoint, IP)") - burst_allowance: int = Field(0, ge=0, description="Additional burst requests") - priority: int = Field(0, description="Rule priority (higher = checked first)") - enabled: bool = Field(True, description="Whether rule is enabled") - description: Optional[str] = Field(None, description="Rule description") + + rule_id: str = Field(..., description='Unique rule identifier') + rule_type: str = Field( + ..., description='Rule type (per_user, per_api, per_endpoint, per_ip, global)' + ) + time_window: str = Field(..., description='Time window (second, minute, hour, day, month)') + limit: int = Field(..., gt=0, description='Maximum requests allowed') + target_identifier: str | None = Field( + None, description='Target (user ID, API name, endpoint, IP)' + ) + burst_allowance: int = Field(0, ge=0, description='Additional burst requests') + priority: int = Field(0, description='Rule priority (higher = checked first)') + enabled: bool = Field(True, description='Whether rule is enabled') + description: str | None = Field(None, description='Rule description') class RuleUpdateRequest(BaseModel): """Request model for updating a rule""" - limit: Optional[int] = Field(None, gt=0) - target_identifier: Optional[str] = None - burst_allowance: Optional[int] = Field(None, ge=0) - priority: Optional[int] = None - enabled: Optional[bool] = None - description: Optional[str] = None + + limit: int | None = Field(None, gt=0) + target_identifier: str | None = None + burst_allowance: int | None = Field(None, ge=0) + priority: int | None = None + enabled: bool | None = None + description: str | None = None class BulkRuleRequest(BaseModel): """Request model for bulk operations""" - rule_ids: List[str] + + rule_ids: list[str] class RuleDuplicateRequest(BaseModel): """Request model for duplicating a rule""" + new_rule_id: str class RuleResponse(BaseModel): """Response model for rule""" + rule_id: str rule_type: str time_window: str limit: int - target_identifier: Optional[str] + target_identifier: str | None burst_allowance: int priority: int enabled: bool - description: Optional[str] - created_at: Optional[str] - updated_at: Optional[str] + description: str | None + created_at: str | None + updated_at: str | None # ============================================================================ # DEPENDENCY INJECTION # ============================================================================ + async def get_rule_service_dep() -> RateLimitRuleService: """Dependency to get rule service""" return get_rate_limit_rule_service(async_database.db) @@ -83,10 +94,10 @@ async def get_rule_service_dep() -> RateLimitRuleService: # RULE CRUD ENDPOINTS # ============================================================================ -@rate_limit_rule_router.post("/", response_model=RuleResponse, status_code=status.HTTP_201_CREATED) + +@rate_limit_rule_router.post('/', response_model=RuleResponse, status_code=status.HTTP_201_CREATED) async def create_rule( - request: RuleCreateRequest, - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) + request: RuleCreateRequest, rule_service: RateLimitRuleService = Depends(get_rule_service_dep) ): """Create a new rate limit rule""" try: @@ -99,281 +110,309 @@ async def create_rule( burst_allowance=request.burst_allowance, priority=request.priority, enabled=request.enabled, - description=request.description + description=request.description, ) - + # Validate rule errors = rule_service.validate_rule(rule) if errors: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail={"errors": errors}) - + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail={'errors': errors}) + created_rule = await rule_service.create_rule(rule) - + return RuleResponse(**created_rule.to_dict()) - + except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except Exception as e: - logger.error(f"Error creating rule: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create rule") + logger.error(f'Error creating rule: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to create rule' + ) -@rate_limit_rule_router.get("/", response_model=List[RuleResponse]) +@rate_limit_rule_router.get('/', response_model=list[RuleResponse]) async def list_rules( - rule_type: Optional[str] = Query(None, description="Filter by rule type"), - enabled_only: bool = Query(False, description="Only return enabled rules"), + rule_type: str | None = Query(None, description='Filter by rule type'), + enabled_only: bool = Query(False, description='Only return enabled rules'), skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=1000), - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) + rule_service: RateLimitRuleService = Depends(get_rule_service_dep), ): """List all rate limit rules""" try: rule_type_enum = RuleType(rule_type) if rule_type else None rules = await rule_service.list_rules( - rule_type=rule_type_enum, - enabled_only=enabled_only, - skip=skip, - limit=limit + rule_type=rule_type_enum, enabled_only=enabled_only, skip=skip, limit=limit ) - + return [RuleResponse(**rule.to_dict()) for rule in rules] - + except Exception as e: - logger.error(f"Error listing rules: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to list rules") + logger.error(f'Error listing rules: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to list rules' + ) -@rate_limit_rule_router.get("/search") +@rate_limit_rule_router.get('/search') async def search_rules( - q: str = Query(..., description="Search term"), - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) + q: str = Query(..., description='Search term'), + rule_service: RateLimitRuleService = Depends(get_rule_service_dep), ): """Search rules by ID, description, or target""" try: rules = await rule_service.search_rules(q) return [RuleResponse(**rule.to_dict()) for rule in rules] - + except Exception as e: - logger.error(f"Error searching rules: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to search rules") + logger.error(f'Error searching rules: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to search rules' + ) -@rate_limit_rule_router.get("/{rule_id}", response_model=RuleResponse) +@rate_limit_rule_router.get('/{rule_id}', response_model=RuleResponse) async def get_rule( - rule_id: str, - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) + rule_id: str, rule_service: RateLimitRuleService = Depends(get_rule_service_dep) ): """Get a specific rule by ID""" try: rule = await rule_service.get_rule(rule_id) - + if not rule: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Rule {rule_id} not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f'Rule {rule_id} not found' + ) + return RuleResponse(**rule.to_dict()) - + except HTTPException: raise except Exception as e: - logger.error(f"Error getting rule: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get rule") + logger.error(f'Error getting rule: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get rule' + ) -@rate_limit_rule_router.put("/{rule_id}", response_model=RuleResponse) +@rate_limit_rule_router.put('/{rule_id}', response_model=RuleResponse) async def update_rule( rule_id: str, request: RuleUpdateRequest, - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) + rule_service: RateLimitRuleService = Depends(get_rule_service_dep), ): """Update a rate limit rule""" try: updates = {k: v for k, v in request.dict().items() if v is not None} - + updated_rule = await rule_service.update_rule(rule_id, updates) - + if not updated_rule: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Rule {rule_id} not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f'Rule {rule_id} not found' + ) + return RuleResponse(**updated_rule.to_dict()) - + except HTTPException: raise except Exception as e: - logger.error(f"Error updating rule: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update rule") + logger.error(f'Error updating rule: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to update rule' + ) -@rate_limit_rule_router.delete("/{rule_id}", status_code=status.HTTP_204_NO_CONTENT) +@rate_limit_rule_router.delete('/{rule_id}', status_code=status.HTTP_204_NO_CONTENT) async def delete_rule( - rule_id: str, - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) + rule_id: str, rule_service: RateLimitRuleService = Depends(get_rule_service_dep) ): """Delete a rate limit rule""" try: deleted = await rule_service.delete_rule(rule_id) - + if not deleted: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Rule {rule_id} not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f'Rule {rule_id} not found' + ) + except HTTPException: raise except Exception as e: - logger.error(f"Error deleting rule: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to delete rule") + logger.error(f'Error deleting rule: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to delete rule' + ) -@rate_limit_rule_router.post("/{rule_id}/enable", response_model=RuleResponse) +@rate_limit_rule_router.post('/{rule_id}/enable', response_model=RuleResponse) async def enable_rule( - rule_id: str, - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) + rule_id: str, rule_service: RateLimitRuleService = Depends(get_rule_service_dep) ): """Enable a rule""" try: rule = await rule_service.enable_rule(rule_id) - + if not rule: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Rule {rule_id} not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f'Rule {rule_id} not found' + ) + return RuleResponse(**rule.to_dict()) - + except HTTPException: raise except Exception as e: - logger.error(f"Error enabling rule: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to enable rule") + logger.error(f'Error enabling rule: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to enable rule' + ) -@rate_limit_rule_router.post("/{rule_id}/disable", response_model=RuleResponse) +@rate_limit_rule_router.post('/{rule_id}/disable', response_model=RuleResponse) async def disable_rule( - rule_id: str, - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) + rule_id: str, rule_service: RateLimitRuleService = Depends(get_rule_service_dep) ): """Disable a rule""" try: rule = await rule_service.disable_rule(rule_id) - + if not rule: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Rule {rule_id} not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f'Rule {rule_id} not found' + ) + return RuleResponse(**rule.to_dict()) - + except HTTPException: raise except Exception as e: - logger.error(f"Error disabling rule: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to disable rule") + logger.error(f'Error disabling rule: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to disable rule' + ) # ============================================================================ # BULK OPERATIONS # ============================================================================ -@rate_limit_rule_router.post("/bulk/delete") + +@rate_limit_rule_router.post('/bulk/delete') async def bulk_delete_rules( - request: BulkRuleRequest, - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) + request: BulkRuleRequest, rule_service: RateLimitRuleService = Depends(get_rule_service_dep) ): """Delete multiple rules at once""" try: count = await rule_service.bulk_delete_rules(request.rule_ids) - return {"deleted_count": count} - + return {'deleted_count': count} + except Exception as e: - logger.error(f"Error bulk deleting rules: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to delete rules") + logger.error(f'Error bulk deleting rules: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to delete rules' + ) -@rate_limit_rule_router.post("/bulk/enable") +@rate_limit_rule_router.post('/bulk/enable') async def bulk_enable_rules( - request: BulkRuleRequest, - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) + request: BulkRuleRequest, rule_service: RateLimitRuleService = Depends(get_rule_service_dep) ): """Enable multiple rules at once""" try: count = await rule_service.bulk_enable_rules(request.rule_ids) - return {"enabled_count": count} - + return {'enabled_count': count} + except Exception as e: - logger.error(f"Error bulk enabling rules: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to enable rules") + logger.error(f'Error bulk enabling rules: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to enable rules' + ) -@rate_limit_rule_router.post("/bulk/disable") +@rate_limit_rule_router.post('/bulk/disable') async def bulk_disable_rules( - request: BulkRuleRequest, - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) + request: BulkRuleRequest, rule_service: RateLimitRuleService = Depends(get_rule_service_dep) ): """Disable multiple rules at once""" try: count = await rule_service.bulk_disable_rules(request.rule_ids) - return {"disabled_count": count} - + return {'disabled_count': count} + except Exception as e: - logger.error(f"Error bulk disabling rules: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to disable rules") + logger.error(f'Error bulk disabling rules: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to disable rules' + ) # ============================================================================ # RULE DUPLICATION # ============================================================================ -@rate_limit_rule_router.post("/{rule_id}/duplicate", response_model=RuleResponse, status_code=status.HTTP_201_CREATED) + +@rate_limit_rule_router.post( + '/{rule_id}/duplicate', response_model=RuleResponse, status_code=status.HTTP_201_CREATED +) async def duplicate_rule( rule_id: str, request: RuleDuplicateRequest, - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) + rule_service: RateLimitRuleService = Depends(get_rule_service_dep), ): """Duplicate an existing rule""" try: new_rule = await rule_service.duplicate_rule(rule_id, request.new_rule_id) return RuleResponse(**new_rule.to_dict()) - + except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except Exception as e: - logger.error(f"Error duplicating rule: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to duplicate rule") + logger.error(f'Error duplicating rule: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to duplicate rule' + ) # ============================================================================ # STATISTICS # ============================================================================ -@rate_limit_rule_router.get("/statistics/summary") -async def get_rule_statistics( - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) -): + +@rate_limit_rule_router.get('/statistics/summary') +async def get_rule_statistics(rule_service: RateLimitRuleService = Depends(get_rule_service_dep)): """Get statistics about rate limit rules""" try: stats = await rule_service.get_rule_statistics() return stats - + except Exception as e: - logger.error(f"Error getting rule statistics: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get statistics") + logger.error(f'Error getting rule statistics: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get statistics' + ) # ============================================================================ # RATE LIMIT STATUS ENDPOINT (User-facing) # ============================================================================ -@rate_limit_rule_router.get("/status") -async def get_rate_limit_status( - rule_service: RateLimitRuleService = Depends(get_rule_service_dep) -): + +@rate_limit_rule_router.get('/status') +async def get_rate_limit_status(rule_service: RateLimitRuleService = Depends(get_rule_service_dep)): """ Get current rate limit status for the authenticated user - + Returns applicable rate limit rules and current usage. This is a user-facing endpoint showing their current limits. """ try: # TODO: Get user_id from auth middleware - user_id = "current_user" - + user_id = 'current_user' + # Get applicable rules for user rules = await rule_service.get_applicable_rules(user_id=user_id) - + # Format response status_info = { 'user_id': user_id, @@ -384,19 +423,18 @@ async def get_rate_limit_status( 'time_window': rule.time_window.value, 'limit': rule.limit, 'burst_allowance': rule.burst_allowance, - 'description': rule.description + 'description': rule.description, } for rule in rules ], - 'note': 'Use /platform/quota/status for detailed usage information' + 'note': 'Use /platform/quota/status for detailed usage information', } - + return status_info - + except Exception as e: - logger.error(f"Error getting rate limit status: {e}") + logger.error(f'Error getting rate limit status: {e}') raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get rate limit status" + detail='Failed to get rate limit status', ) - diff --git a/backend-services/routes/role_routes.py b/backend-services/routes/role_routes.py index d633101..51ca0b5 100644 --- a/backend-services/routes/role_routes.py +++ b/backend-services/routes/role_routes.py @@ -4,21 +4,21 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import List -from fastapi import APIRouter, Depends, Request -import uuid -import time import logging +import time +import uuid +from fastapi import APIRouter, Request + +from models.create_role_model import CreateRoleModel from models.response_model import ResponseModel from models.role_model_response import RoleModelResponse from models.update_role_model import UpdateRoleModel from services.role_service import RoleService from utils.auth_util import auth_required -from models.create_role_model import CreateRoleModel +from utils.constants import Defaults, ErrorCodes, Headers, Messages, Roles from utils.response_util import respond_rest -from utils.constants import Headers, Roles, ErrorCodes, Messages, Defaults -from utils.role_util import platform_role_required_bool, is_admin_role, is_admin_user +from utils.role_util import is_admin_role, is_admin_user, platform_role_required_bool role_router = APIRouter() @@ -33,68 +33,70 @@ Response: {} """ -@role_router.post('', + +@role_router.post( + '', description='Add role', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'Role created successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'Role created successfully'}}}, } - } + }, ) - async def create_role(api_data: CreateRoleModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_ROLES): logger.error(f'{request_id} | User does not have permission to create roles') - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='ROLE009', - error_message='You do not have permission to create roles' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='ROLE009', + error_message='You do not have permission to create roles', + ) + ) try: - incoming_is_admin = bool(getattr(api_data, 'platform_admin', False)) or api_data.role_name.strip().lower() in ('admin', 'platform admin') + incoming_is_admin = bool( + getattr(api_data, 'platform_admin', False) + ) or api_data.role_name.strip().lower() in ('admin', 'platform admin') except Exception: incoming_is_admin = False if incoming_is_admin: if not await is_admin_user(username): - return respond_rest(ResponseModel( - status_code=403, - response_headers={Headers.REQUEST_ID: request_id}, - error_code='ROLE013', - error_message='Only admin may create the admin role' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='ROLE013', + error_message='Only admin may create the admin role', + ) + ) return respond_rest(await RoleService.create_role(api_data, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Update role @@ -104,76 +106,80 @@ Response: {} """ -@role_router.put('/{role_name}', + +@role_router.put( + '/{role_name}', description='Update role', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'Role updated successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'Role updated successfully'}}}, } - } + }, ) - async def update_role(role_name: str, api_data: UpdateRoleModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: # DEBUG: Log the incoming data logger.info(f'{request_id} | DEBUG: Received model data: {api_data.dict()}') - + payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_ROLES): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='ROLE010', - error_message='You do not have permission to update roles' - )) - target_is_admin = await is_admin_role(role_name) - if target_is_admin and not await is_admin_user(username): - return respond_rest(ResponseModel( - status_code=403, - response_headers={Headers.REQUEST_ID: request_id}, - error_code='ROLE014', - error_message='Only admin may modify the admin role' - )) - try: - if getattr(api_data, 'platform_admin', None) is not None and not await is_admin_user(username): - return respond_rest(ResponseModel( + return respond_rest( + ResponseModel( status_code=403, response_headers={Headers.REQUEST_ID: request_id}, - error_code='ROLE015', - error_message='Only admin may change admin designation' - )) + error_code='ROLE010', + error_message='You do not have permission to update roles', + ) + ) + target_is_admin = await is_admin_role(role_name) + if target_is_admin and not await is_admin_user(username): + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='ROLE014', + error_message='Only admin may modify the admin role', + ) + ) + try: + if getattr(api_data, 'platform_admin', None) is not None and not await is_admin_user( + username + ): + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='ROLE015', + error_message='Only admin may change admin designation', + ) + ) except Exception: pass return respond_rest(await RoleService.update_role(role_name, api_data, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Delete role @@ -183,63 +189,63 @@ Response: {} """ -@role_router.delete('/{role_name}', + +@role_router.delete( + '/{role_name}', description='Delete role', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'Role deleted successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'Role deleted successfully'}}}, } - } + }, ) - async def delete_role(role_name: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_ROLES): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='ROLE011', - error_message='You do not have permission to delete roles' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='ROLE011', + error_message='You do not have permission to delete roles', + ) + ) target_is_admin = await is_admin_role(role_name) if target_is_admin and not await is_admin_user(username): - return respond_rest(ResponseModel( - status_code=403, - response_headers={Headers.REQUEST_ID: request_id}, - error_code='ROLE016', - error_message='Only admin may delete the admin role' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='ROLE016', + error_message='Only admin may delete the admin role', + ) + ) return respond_rest(await RoleService.delete_role(role_name, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -249,18 +255,19 @@ Response: {} """ -@role_router.get('/all', - description='Get all roles', - response_model=List[RoleModelResponse] -) -async def get_roles(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE): +@role_router.get('/all', description='Get all roles', response_model=list[RoleModelResponse]) +async def get_roles( + request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') data = await RoleService.get_roles(page, page_size, request_id) try: @@ -283,18 +290,19 @@ async def get_roles(request: Request, page: int = Defaults.PAGE, page_size: int return respond_rest(data) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -304,25 +312,26 @@ Response: {} """ -@role_router.get('/{role_name}', - description='Get role', - response_model=RoleModelResponse -) +@role_router.get('/{role_name}', description='Get role', response_model=RoleModelResponse) async def get_role(role_name: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if await is_admin_role(role_name) and not await is_admin_user(username): - return respond_rest(ResponseModel( - status_code=404, - response_headers={Headers.REQUEST_ID: request_id}, - error_message='Role not found' - )) + return respond_rest( + ResponseModel( + status_code=404, + response_headers={Headers.REQUEST_ID: request_id}, + error_message='Role not found', + ) + ) data = await RoleService.get_role(role_name, request_id) try: if data.get('status_code') == 200 and not await is_admin_user(username): @@ -337,22 +346,25 @@ async def get_role(role_name: str, request: Request): return respond_rest(data) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') -@role_router.get('', - description='Get all roles (base path)', - response_model=List[RoleModelResponse] + + +@role_router.get( + '', description='Get all roles (base path)', response_model=list[RoleModelResponse] ) -async def get_roles_base(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE): +async def get_roles_base( + request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE +): """Convenience alias for GET /platform/role/all to support clients/tests that expect listing at the base collection path. """ diff --git a/backend-services/routes/routing_routes.py b/backend-services/routes/routing_routes.py index 940b23e..59ea70e 100644 --- a/backend-services/routes/routing_routes.py +++ b/backend-services/routes/routing_routes.py @@ -4,11 +4,11 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import List -from fastapi import APIRouter, Depends, Request -import uuid -import time import logging +import time +import uuid + +from fastapi import APIRouter, Request from models.create_routing_model import CreateRoutingModel from models.response_model import ResponseModel @@ -16,7 +16,7 @@ from models.routing_model_response import RoutingModelResponse from models.update_routing_model import UpdateRoutingModel from services.routing_service import RoutingService from utils.auth_util import auth_required -from utils.response_util import respond_rest, process_response +from utils.response_util import process_response, respond_rest from utils.role_util import platform_role_required_bool routing_router = APIRouter() @@ -32,55 +32,56 @@ Response: {} """ -@routing_router.post('', + +@routing_router.post( + '', description='Add routing', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Routing created successfully' - } - } - } + 'application/json': {'example': {'message': 'Routing created successfully'}} + }, } - } + }, ) - async def create_routing(api_data: CreateRoutingModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_routings'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - 'request_id': request_id - }, - error_code='RTG009', - error_message='You do not have permission to create routings' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='RTG009', + error_message='You do not have permission to create routings', + ) + ) return respond_rest(await RoutingService.create_routing(api_data, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Update routing @@ -90,55 +91,56 @@ Response: {} """ -@routing_router.put('/{client_key}', + +@routing_router.put( + '/{client_key}', description='Update routing', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Routing updated successfully' - } - } - } + 'application/json': {'example': {'message': 'Routing updated successfully'}} + }, } - } + }, ) - async def update_routing(client_key: str, api_data: UpdateRoutingModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_routings'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - 'request_id': request_id - }, - error_code='RTG010', - error_message='You do not have permission to update routings' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='RTG010', + error_message='You do not have permission to update routings', + ) + ) return respond_rest(await RoutingService.update_routing(client_key, api_data, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Delete routing @@ -148,55 +150,56 @@ Response: {} """ -@routing_router.delete('/{client_key}', + +@routing_router.delete( + '/{client_key}', description='Delete routing', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Routing deleted successfully' - } - } - } + 'application/json': {'example': {'message': 'Routing deleted successfully'}} + }, } - } + }, ) - async def delete_routing(client_key: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_routings'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - 'request_id': request_id - }, - error_code='RTG011', - error_message='You do not have permission to delete routings' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='RTG011', + error_message='You do not have permission to delete routings', + ) + ) return respond_rest(await RoutingService.delete_routing(client_key, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -206,43 +209,46 @@ Response: {} """ -@routing_router.get('/all', - description='Get all routings', - response_model=List[RoutingModelResponse] -) +@routing_router.get( + '/all', description='Get all routings', response_model=list[RoutingModelResponse] +) async def get_routings(request: Request, page: int = 1, page_size: int = 10): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_routings'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - 'request_id': request_id - }, - error_code='RTG012', - error_message='You do not have permission to get routings' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='RTG012', + error_message='You do not have permission to get routings', + ) + ) return respond_rest(await RoutingService.get_routings(page, page_size, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -252,39 +258,39 @@ Response: {} """ -@routing_router.get('/{client_key}', - description='Get routing', - response_model=RoutingModelResponse -) +@routing_router.get('/{client_key}', description='Get routing', response_model=RoutingModelResponse) async def get_routing(client_key: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_routings'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - 'request_id': request_id - }, - error_code='RTG013', - error_message='You do not have permission to get routings' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='RTG013', + error_message='You do not have permission to get routings', + ) + ) return respond_rest(await RoutingService.get_routing(client_key, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/security_routes.py b/backend-services/routes/security_routes.py index c8e2c26..7026fee 100644 --- a/backend-services/routes/security_routes.py +++ b/backend-services/routes/security_routes.py @@ -2,22 +2,22 @@ Routes for managing security settings. """ -from fastapi import APIRouter, Request -from typing import Optional -import os -import sys -import subprocess -import uuid -import time import logging +import os +import subprocess +import sys +import time +import uuid + +from fastapi import APIRouter, Request from models.response_model import ResponseModel from models.security_settings_model import SecuritySettingsModel -from utils.response_util import process_response +from utils.audit_util import audit from utils.auth_util import auth_required +from utils.response_util import process_response from utils.role_util import platform_role_required_bool from utils.security_settings_util import load_settings, save_settings -from utils.audit_util import audit security_router = APIRouter() logger = logging.getLogger('doorman.gateway') @@ -31,26 +31,34 @@ Response: {} """ -@security_router.get('/security/settings', - description='Get security settings', - response_model=ResponseModel, -) +@security_router.get( + '/security/settings', description='Get security settings', response_model=ResponseModel +) async def get_security_settings(request: Request): - request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4()) + request_id = ( + getattr(request.state, 'request_id', None) + or request.headers.get('X-Request-ID') + or str(uuid.uuid4()) + ) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_security'): - return process_response(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='SEC001', - error_message='You do not have permission to view security settings' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='SEC001', + error_message='You do not have permission to view security settings', + ).dict(), + 'rest', + ) settings = await load_settings() try: client_ip = request.client.host if request.client else None @@ -63,44 +71,57 @@ async def get_security_settings(request: Request): settings_with_mode = dict(settings) try: from utils.database import database + settings_with_mode['memory_only'] = bool(database.memory_only) except Exception: settings_with_mode['memory_only'] = False try: import os + env_val = os.getenv('LOCAL_HOST_IP_BYPASS') locked = isinstance(env_val, str) and env_val.strip() != '' if locked: - settings_with_mode['allow_localhost_bypass'] = (env_val.lower() == 'true') + settings_with_mode['allow_localhost_bypass'] = env_val.lower() == 'true' settings_with_mode['allow_localhost_bypass_locked'] = locked except Exception: settings_with_mode['allow_localhost_bypass_locked'] = False try: warnings = [] - if settings_with_mode.get('trust_x_forwarded_for') and not (settings_with_mode.get('xff_trusted_proxies') or []): - warnings.append('Trust X-Forwarded-For is enabled, but no trusted proxies are configured. Set xff_trusted_proxies to avoid header spoofing.') + if settings_with_mode.get('trust_x_forwarded_for') and not ( + settings_with_mode.get('xff_trusted_proxies') or [] + ): + warnings.append( + 'Trust X-Forwarded-For is enabled, but no trusted proxies are configured. Set xff_trusted_proxies to avoid header spoofing.' + ) settings_with_mode['security_warnings'] = warnings except Exception: pass settings_with_mode['client_ip'] = client_ip settings_with_mode['client_ip_xff'] = client_ip_xff - return process_response(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response=settings_with_mode - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + response=settings_with_mode, + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -110,57 +131,85 @@ Response: {} """ -@security_router.put('/security/settings', - description='Update security settings', - response_model=ResponseModel, + +@security_router.put( + '/security/settings', description='Update security settings', response_model=ResponseModel ) -async def update_security_settings(request: Request, body: Optional[SecuritySettingsModel] = None): - request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4()) +async def update_security_settings(request: Request, body: SecuritySettingsModel | None = None): + request_id = ( + getattr(request.state, 'request_id', None) + or request.headers.get('X-Request-ID') + or str(uuid.uuid4()) + ) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_security'): - return process_response(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='SEC002', - error_message='You do not have permission to update security settings' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='SEC002', + error_message='You do not have permission to update security settings', + ).dict(), + 'rest', + ) payload_dict = {} try: - payload_dict = (body.dict(exclude_none=True) if body is not None else {}) + payload_dict = body.dict(exclude_none=True) if body is not None else {} except Exception: payload_dict = {} new_settings = await save_settings(payload_dict) - audit(request, actor=username, action='security.update_settings', target='security_settings', status='success', details={k: new_settings.get(k) for k in ('enable_auto_save','auto_save_frequency_seconds','dump_path')}, request_id=request_id) + audit( + request, + actor=username, + action='security.update_settings', + target='security_settings', + status='success', + details={ + k: new_settings.get(k) + for k in ('enable_auto_save', 'auto_save_frequency_seconds', 'dump_path') + }, + request_id=request_id, + ) settings_with_mode = dict(new_settings) try: from utils.database import database + settings_with_mode['memory_only'] = bool(database.memory_only) except Exception: settings_with_mode['memory_only'] = False - return process_response(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - message='Security settings updated', - response=settings_with_mode - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + message='Security settings updated', + response=settings_with_mode, + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -170,69 +219,103 @@ Response: {} """ -@security_router.post('/security/restart', + +@security_router.post( + '/security/restart', description='Schedule a safe gateway restart (PID-based)', response_model=ResponseModel, ) - async def restart_gateway(request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_security'): - return process_response(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='SEC003', - error_message='You do not have permission to restart the gateway' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='SEC003', + error_message='You do not have permission to restart the gateway', + ).dict(), + 'rest', + ) pid_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'doorman.pid') pid_file = os.path.abspath(pid_file) if not os.path.exists(pid_file): - return process_response(ResponseModel( - status_code=409, - response_headers={'request_id': request_id}, - error_code='SEC004', - error_message="Restart not supported: no PID file found (run using 'doorman start' or contact your admin)" - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=409, + response_headers={'request_id': request_id}, + error_code='SEC004', + error_message="Restart not supported: no PID file found (run using 'doorman start' or contact your admin)", + ).dict(), + 'rest', + ) - doorman_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'doorman.py')) + doorman_path = os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'doorman.py') + ) try: if os.name == 'nt': - subprocess.Popen([sys.executable, doorman_path, 'restart'], - creationflags=subprocess.CREATE_NEW_PROCESS_GROUP | subprocess.DETACHED_PROCESS, - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.Popen( + [sys.executable, doorman_path, 'restart'], + creationflags=subprocess.CREATE_NEW_PROCESS_GROUP | subprocess.DETACHED_PROCESS, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) else: - subprocess.Popen([sys.executable, doorman_path, 'restart'], - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, - preexec_fn=os.setsid) + subprocess.Popen( + [sys.executable, doorman_path, 'restart'], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + preexec_fn=os.setsid, + ) except Exception as e: logger.error(f'{request_id} | Failed to spawn restarter: {e}') - return process_response(ResponseModel( - status_code=500, + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='SEC005', + error_message='Failed to schedule restart', + ).dict(), + 'rest', + ) + audit( + request, + actor=username, + action='security.restart', + target='gateway', + status='scheduled', + details=None, + request_id=request_id, + ) + return process_response( + ResponseModel( + status_code=202, response_headers={'request_id': request_id}, - error_code='SEC005', - error_message='Failed to schedule restart' - ).dict(), 'rest') - audit(request, actor=username, action='security.restart', target='gateway', status='scheduled', details=None, request_id=request_id) - return process_response(ResponseModel( - status_code=202, - response_headers={'request_id': request_id}, - message='Restart scheduled' - ).dict(), 'rest') + message='Restart scheduled', + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/subscription_routes.py b/backend-services/routes/subscription_routes.py index 1f8398e..2d65699 100644 --- a/backend-services/routes/subscription_routes.py +++ b/backend-services/routes/subscription_routes.py @@ -4,20 +4,21 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from fastapi import APIRouter, HTTPException, Depends, Request -import uuid -import time import logging +import time +import uuid + +from fastapi import APIRouter, HTTPException, Request from models.response_model import ResponseModel -from services.subscription_service import SubscriptionService -from utils.auth_util import auth_required from models.subscribe_model import SubscribeModel -from utils.group_util import group_required -from utils.role_util import platform_role_required_bool -from utils.database import api_collection -from utils.response_util import respond_rest +from services.subscription_service import SubscriptionService from utils.audit_util import audit +from utils.auth_util import auth_required +from utils.database import api_collection +from utils.group_util import group_required +from utils.response_util import respond_rest +from utils.role_util import platform_role_required_bool subscription_router = APIRouter() @@ -32,72 +33,84 @@ Response: {} """ -@subscription_router.post('/subscribe', + +@subscription_router.post( + '/subscribe', description='Subscribe to API', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Subscription created successfully' - } - } - } + 'application/json': {'example': {'message': 'Subscription created successfully'}} + }, } - } + }, ) - async def subscribe_api(api_data: SubscribeModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - if not await group_required(request, api_data.api_name + '/' + api_data.api_version, api_data.username): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - 'request_id': request_id - }, - error_code='SUB007', - error_message='You do not have the correct group access' - )) + if not await group_required( + request, api_data.api_name + '/' + api_data.api_version, api_data.username + ): + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='SUB007', + error_message='You do not have the correct group access', + ) + ) target_user = api_data.username or username - logger.info(f'{request_id} | Actor: {username} | Action: subscribe | Target: {target_user} | API: {api_data.api_name}/{api_data.api_version}') + logger.info( + f'{request_id} | Actor: {username} | Action: subscribe | Target: {target_user} | API: {api_data.api_name}/{api_data.api_version}' + ) result = await SubscriptionService.subscribe(api_data, request_id) actor_user = username target_user = api_data.username or username - audit(request, actor=actor_user, action='subscription.subscribe', target=f'{target_user}:{api_data.api_name}/{api_data.api_version}', status=result.get('status_code'), details=None, request_id=request_id) + audit( + request, + actor=actor_user, + action='subscription.subscribe', + target=f'{target_user}:{api_data.api_name}/{api_data.api_version}', + status=result.get('status_code'), + details=None, + request_id=request_id, + ) return respond_rest(result) except HTTPException as e: - return respond_rest(ResponseModel( - status_code=e.status_code, - response_headers={ - 'request_id': request_id - }, - error_code='GEN001', - error_message=e.detail - )) + return respond_rest( + ResponseModel( + status_code=e.status_code, + response_headers={'request_id': request_id}, + error_code='GEN001', + error_message=e.detail, + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Unsubscribe from API @@ -107,72 +120,84 @@ Response: {} """ -@subscription_router.post('/unsubscribe', + +@subscription_router.post( + '/unsubscribe', description='Unsubscribe from API', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Subscription deleted successfully' - } - } - } + 'application/json': {'example': {'message': 'Subscription deleted successfully'}} + }, } - } + }, ) - async def unsubscribe_api(api_data: SubscribeModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - if not await group_required(request, api_data.api_name + '/' + api_data.api_version, api_data.username): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - 'request_id': request_id - }, - error_code='SUB008', - error_message='You do not have the correct group access' - )) + if not await group_required( + request, api_data.api_name + '/' + api_data.api_version, api_data.username + ): + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='SUB008', + error_message='You do not have the correct group access', + ) + ) target_user = api_data.username or username - logger.info(f'{request_id} | Actor: {username} | Action: unsubscribe | Target: {target_user} | API: {api_data.api_name}/{api_data.api_version}') + logger.info( + f'{request_id} | Actor: {username} | Action: unsubscribe | Target: {target_user} | API: {api_data.api_name}/{api_data.api_version}' + ) result = await SubscriptionService.unsubscribe(api_data, request_id) actor_user = username target_user = api_data.username or username - audit(request, actor=actor_user, action='subscription.unsubscribe', target=f'{target_user}:{api_data.api_name}/{api_data.api_version}', status=result.get('status_code'), details=None, request_id=request_id) + audit( + request, + actor=actor_user, + action='subscription.unsubscribe', + target=f'{target_user}:{api_data.api_name}/{api_data.api_version}', + status=result.get('status_code'), + details=None, + request_id=request_id, + ) return respond_rest(result) except HTTPException as e: - return respond_rest(ResponseModel( - status_code=e.status_code, - response_headers={ - 'request_id': request_id - }, - error_code='GEN002', - error_message=e.detail - )) + return respond_rest( + ResponseModel( + status_code=e.status_code, + response_headers={'request_id': request_id}, + error_code='GEN002', + error_message=e.detail, + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Get current user's subscriptions @@ -182,49 +207,44 @@ Response: {} """ -@subscription_router.get('/subscriptions', + +@subscription_router.get( + '/subscriptions', description="Get current user's subscriptions", response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'apis': [ - 'customer/v1', - 'orders/v1' - ] - } - } - } + 'content': {'application/json': {'example': {'apis': ['customer/v1', 'orders/v1']}}}, } - } + }, ) - async def subscriptions_for_current_user(request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') return respond_rest(await SubscriptionService.get_user_subscriptions(username, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Get user's subscriptions @@ -234,49 +254,44 @@ Response: {} """ -@subscription_router.get('/subscriptions/{user_id}', + +@subscription_router.get( + '/subscriptions/{user_id}', description="Get user's subscriptions", response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'apis': [ - 'customer/v1', - 'orders/v1' - ] - } - } - } + 'content': {'application/json': {'example': {'apis': ['customer/v1', 'orders/v1']}}}, } - } + }, ) - async def subscriptions_for_user_by_id(user_id: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') return respond_rest(await SubscriptionService.get_user_subscriptions(user_id, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - 'request_id': request_id - }, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -286,17 +301,21 @@ Response: {} """ -@subscription_router.get('/available-apis/{username}', - description='List available APIs for subscription based on permissions and groups', - response_model=ResponseModel) +@subscription_router.get( + '/available-apis/{username}', + description='List available APIs for subscription based on permissions and groups', + response_model=ResponseModel, +) async def available_apis(username: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) actor = payload.get('sub') - logger.info(f'{request_id} | Username: {actor} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {actor} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') accesses = (payload or {}).get('accesses') or {} @@ -307,19 +326,30 @@ async def available_apis(username: str, request: Request): cursor = api_collection.find().sort('api_name', 1) apis = cursor.to_list(length=None) for a in apis: - if a.get('_id'): del a['_id'] + if a.get('_id'): + del a['_id'] if can_manage: - data = [{ 'api_name': a.get('api_name'), 'api_version': a.get('api_version'), 'api_description': a.get('api_description') } for a in apis] + data = [ + { + 'api_name': a.get('api_name'), + 'api_version': a.get('api_version'), + 'api_description': a.get('api_description'), + } + for a in apis + ] return respond_rest(ResponseModel(status_code=200, response={'apis': data})) if username != actor: - return respond_rest(ResponseModel( - status_code=403, - error_code='SUB009', - error_message='You do not have permission to view available APIs for this user' - )) + return respond_rest( + ResponseModel( + status_code=403, + error_code='SUB009', + error_message='You do not have permission to view available APIs for this user', + ) + ) try: from services.user_service import UserService + user = await UserService.get_user_by_username_helper(actor) user_groups = set(user.get('groups') or []) except Exception: @@ -328,16 +358,24 @@ async def available_apis(username: str, request: Request): for a in apis: api_groups = set(a.get('api_allowed_groups') or []) if user_groups.intersection(api_groups): - allowed.append({ 'api_name': a.get('api_name'), 'api_version': a.get('api_version'), 'api_description': a.get('api_description') }) + allowed.append( + { + 'api_name': a.get('api_name'), + 'api_version': a.get('api_version'), + 'api_description': a.get('api_description'), + } + ) return respond_rest(ResponseModel(status_code=200, response={'apis': allowed})) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='GTW999', - error_message='An unexpected error occurred' - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='GTW999', + error_message='An unexpected error occurred', + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/tier_routes.py b/backend-services/routes/tier_routes.py index b9cbafe..dcfb843 100644 --- a/backend-services/routes/tier_routes.py +++ b/backend-services/routes/tier_routes.py @@ -5,20 +5,20 @@ FastAPI routes for managing tiers, plans, and user assignments. """ import logging -from typing import List, Optional -from datetime import datetime -from fastapi import APIRouter, HTTPException, Depends, status, Query, Request -from pydantic import BaseModel, Field -import uuid import time +import uuid +from datetime import datetime -from models.rate_limit_models import Tier, TierLimits, TierName, UserTierAssignment +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status +from pydantic import BaseModel, Field + +from models.rate_limit_models import Tier, TierLimits, TierName from models.response_model import ResponseModel from services.tier_service import TierService, get_tier_service -from utils.database_async import async_database from utils.auth_util import auth_required -from utils.role_util import platform_role_required_bool +from utils.database_async import async_database from utils.response_util import respond_rest +from utils.role_util import platform_role_required_bool logger = logging.getLogger(__name__) @@ -29,69 +29,76 @@ tier_router = APIRouter() # REQUEST/RESPONSE MODELS # ============================================================================ + class TierLimitsRequest(BaseModel): """Request model for tier limits""" - requests_per_second: Optional[int] = None - requests_per_minute: Optional[int] = None - requests_per_hour: Optional[int] = None - requests_per_day: Optional[int] = None - requests_per_month: Optional[int] = None + + requests_per_second: int | None = None + requests_per_minute: int | None = None + requests_per_hour: int | None = None + requests_per_day: int | None = None + requests_per_month: int | None = None burst_per_second: int = 0 burst_per_minute: int = 0 burst_per_hour: int = 0 - monthly_request_quota: Optional[int] = None - daily_request_quota: Optional[int] = None - monthly_bandwidth_quota: Optional[int] = None + monthly_request_quota: int | None = None + daily_request_quota: int | None = None + monthly_bandwidth_quota: int | None = None enable_throttling: bool = False max_queue_time_ms: int = 5000 class TierCreateRequest(BaseModel): """Request model for creating a tier""" - tier_id: str = Field(..., description="Unique tier identifier") - name: str = Field(..., description="Tier name (free, pro, enterprise, custom)") - display_name: str = Field(..., description="Display name for tier") - description: Optional[str] = None + + tier_id: str = Field(..., description='Unique tier identifier') + name: str = Field(..., description='Tier name (free, pro, enterprise, custom)') + display_name: str = Field(..., description='Display name for tier') + description: str | None = None limits: TierLimitsRequest - price_monthly: Optional[float] = None - price_yearly: Optional[float] = None - features: List[str] = [] + price_monthly: float | None = None + price_yearly: float | None = None + features: list[str] = [] is_default: bool = False enabled: bool = True class TierUpdateRequest(BaseModel): """Request model for updating a tier""" - display_name: Optional[str] = None - description: Optional[str] = None - limits: Optional[TierLimitsRequest] = None - price_monthly: Optional[float] = None - price_yearly: Optional[float] = None - features: Optional[List[str]] = None - is_default: Optional[bool] = None - enabled: Optional[bool] = None + + display_name: str | None = None + description: str | None = None + limits: TierLimitsRequest | None = None + price_monthly: float | None = None + price_yearly: float | None = None + features: list[str] | None = None + is_default: bool | None = None + enabled: bool | None = None class UserAssignmentRequest(BaseModel): """Request model for assigning user to tier""" + user_id: str tier_id: str - effective_from: Optional[datetime] = None - effective_until: Optional[datetime] = None - override_limits: Optional[TierLimitsRequest] = None - notes: Optional[str] = None + effective_from: datetime | None = None + effective_until: datetime | None = None + override_limits: TierLimitsRequest | None = None + notes: str | None = None class TierUpgradeRequest(BaseModel): """Request model for tier upgrade""" + user_id: str new_tier_id: str immediate: bool = True - scheduled_date: Optional[datetime] = None + scheduled_date: datetime | None = None class TierDowngradeRequest(BaseModel): """Request model for tier downgrade""" + user_id: str new_tier_id: str grace_period_days: int = 0 @@ -99,6 +106,7 @@ class TierDowngradeRequest(BaseModel): class TemporaryUpgradeRequest(BaseModel): """Request model for temporary tier upgrade""" + user_id: str temp_tier_id: str duration_days: int @@ -106,24 +114,26 @@ class TemporaryUpgradeRequest(BaseModel): class TierResponse(BaseModel): """Response model for tier""" + tier_id: str name: str display_name: str - description: Optional[str] + description: str | None limits: dict - price_monthly: Optional[float] - price_yearly: Optional[float] - features: List[str] + price_monthly: float | None + price_yearly: float | None + features: list[str] is_default: bool enabled: bool - created_at: Optional[str] - updated_at: Optional[str] + created_at: str | None + updated_at: str | None # ============================================================================ # DEPENDENCY INJECTION # ============================================================================ + async def get_tier_service_dep() -> TierService: """Dependency to get tier service""" return get_tier_service(async_database.db) @@ -133,14 +143,14 @@ async def get_tier_service_dep() -> TierService: # TIER CRUD ENDPOINTS # ============================================================================ -@tier_router.post("/", response_model=TierResponse, status_code=status.HTTP_201_CREATED) + +@tier_router.post('/', response_model=TierResponse, status_code=status.HTTP_201_CREATED) async def create_tier( - request: TierCreateRequest, - tier_service: TierService = Depends(get_tier_service_dep) + request: TierCreateRequest, tier_service: TierService = Depends(get_tier_service_dep) ): """ Create a new tier - + Requires admin permissions. """ try: @@ -155,128 +165,134 @@ async def create_tier( price_yearly=request.price_yearly, features=request.features, is_default=request.is_default, - enabled=request.enabled + enabled=request.enabled, ) - + created_tier = await tier_service.create_tier(tier) - + return TierResponse( **created_tier.to_dict(), created_at=created_tier.created_at.isoformat() if created_tier.created_at else None, - updated_at=created_tier.updated_at.isoformat() if created_tier.updated_at else None + updated_at=created_tier.updated_at.isoformat() if created_tier.updated_at else None, ) - + except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except Exception as e: - logger.error(f"Error creating tier: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create tier") + logger.error(f'Error creating tier: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to create tier' + ) -@tier_router.get("/") +@tier_router.get('/') async def list_tiers( request: Request, - enabled_only: bool = Query(False, description="Only return enabled tiers"), - search: Optional[str] = Query(None, description="Search tiers by name or description"), + enabled_only: bool = Query(False, description='Only return enabled tiers'), + search: str | None = Query(None, description='Search tiers by name or description'), skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=1000), - tier_service: TierService = Depends(get_tier_service_dep) + tier_service: TierService = Depends(get_tier_service_dep), ): """ List all tiers - + Can filter by enabled status and paginate results. """ request_id = str(uuid.uuid4()) start_time = time.time() - + try: payload = await auth_required(request) username = payload.get('sub') - - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') - logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - - if not await platform_role_required_bool(username, 'manage_tiers'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='TIER001', - error_message='You do not have permission to manage tiers' - )) - - tiers = await tier_service.list_tiers( - enabled_only=enabled_only, - search_term=search, - skip=skip, - limit=limit + + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' ) - + logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') + + if not await platform_role_required_bool(username, 'manage_tiers'): + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='TIER001', + error_message='You do not have permission to manage tiers', + ) + ) + + tiers = await tier_service.list_tiers( + enabled_only=enabled_only, search_term=search, skip=skip, limit=limit + ) + tier_list = [ TierResponse( **tier.to_dict(), created_at=tier.created_at.isoformat() if tier.created_at else None, - updated_at=tier.updated_at.isoformat() if tier.updated_at else None + updated_at=tier.updated_at.isoformat() if tier.updated_at else None, ).dict() for tier in tiers ] - - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response=tier_list - )) - + + return respond_rest( + ResponseModel( + status_code=200, response_headers={'request_id': request_id}, response=tier_list + ) + ) + except Exception as e: logger.error(f'{request_id} | Error listing tiers: {e}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='TIER999', - error_message='Failed to list tiers' - )) - + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='TIER999', + error_message='Failed to list tiers', + ) + ) + finally: end_time = time.time() logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms') -@tier_router.get("/{tier_id}", response_model=TierResponse) -async def get_tier( - tier_id: str, - tier_service: TierService = Depends(get_tier_service_dep) -): +@tier_router.get('/{tier_id}', response_model=TierResponse) +async def get_tier(tier_id: str, tier_service: TierService = Depends(get_tier_service_dep)): """ Get a specific tier by ID """ try: tier = await tier_service.get_tier(tier_id) - + if not tier: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Tier {tier_id} not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f'Tier {tier_id} not found' + ) + return TierResponse( **tier.to_dict(), created_at=tier.created_at.isoformat() if tier.created_at else None, - updated_at=tier.updated_at.isoformat() if tier.updated_at else None + updated_at=tier.updated_at.isoformat() if tier.updated_at else None, ) - + except HTTPException: raise except Exception as e: - logger.error(f"Error getting tier: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get tier") + logger.error(f'Error getting tier: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get tier' + ) -@tier_router.put("/{tier_id}", response_model=TierResponse) +@tier_router.put('/{tier_id}', response_model=TierResponse) async def update_tier( tier_id: str, request: TierUpdateRequest, - tier_service: TierService = Depends(get_tier_service_dep) + tier_service: TierService = Depends(get_tier_service_dep), ): """ Update a tier - + Requires admin permissions. """ try: @@ -298,63 +314,68 @@ async def update_tier( updates['is_default'] = request.is_default if request.enabled is not None: updates['enabled'] = request.enabled - + updated_tier = await tier_service.update_tier(tier_id, updates) - + if not updated_tier: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Tier {tier_id} not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f'Tier {tier_id} not found' + ) + return TierResponse( **updated_tier.to_dict(), created_at=updated_tier.created_at.isoformat() if updated_tier.created_at else None, - updated_at=updated_tier.updated_at.isoformat() if updated_tier.updated_at else None + updated_at=updated_tier.updated_at.isoformat() if updated_tier.updated_at else None, ) - + except HTTPException: raise except Exception as e: - logger.error(f"Error updating tier: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update tier") + logger.error(f'Error updating tier: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to update tier' + ) -@tier_router.delete("/{tier_id}", status_code=status.HTTP_204_NO_CONTENT) -async def delete_tier( - tier_id: str, - tier_service: TierService = Depends(get_tier_service_dep) -): +@tier_router.delete('/{tier_id}', status_code=status.HTTP_204_NO_CONTENT) +async def delete_tier(tier_id: str, tier_service: TierService = Depends(get_tier_service_dep)): """ Delete a tier - + Requires admin permissions. Cannot delete tier if users are assigned to it. """ try: deleted = await tier_service.delete_tier(tier_id) - + if not deleted: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Tier {tier_id} not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f'Tier {tier_id} not found' + ) + except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except HTTPException: raise except Exception as e: - logger.error(f"Error deleting tier: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to delete tier") + logger.error(f'Error deleting tier: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to delete tier' + ) # ============================================================================ # USER ASSIGNMENT ENDPOINTS # ============================================================================ -@tier_router.post("/assignments", status_code=status.HTTP_201_CREATED) + +@tier_router.post('/assignments', status_code=status.HTTP_201_CREATED) async def assign_user_to_tier( - request: UserAssignmentRequest, - tier_service: TierService = Depends(get_tier_service_dep) + request: UserAssignmentRequest, tier_service: TierService = Depends(get_tier_service_dep) ): """ Assign a user to a tier - + Requires admin permissions. """ try: @@ -362,134 +383,147 @@ async def assign_user_to_tier( override_limits = None if request.override_limits: override_limits = TierLimits(**request.override_limits.dict()) - + assignment = await tier_service.assign_user_to_tier( user_id=request.user_id, tier_id=request.tier_id, effective_from=request.effective_from, effective_until=request.effective_until, override_limits=override_limits, - notes=request.notes + notes=request.notes, ) - + return assignment.to_dict() - + except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except Exception as e: - logger.error(f"Error assigning user to tier: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to assign user") + logger.error(f'Error assigning user to tier: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to assign user' + ) -@tier_router.get("/assignments/{user_id}") +@tier_router.get('/assignments/{user_id}') async def get_user_assignment( - user_id: str, - tier_service: TierService = Depends(get_tier_service_dep) + user_id: str, tier_service: TierService = Depends(get_tier_service_dep) ): """ Get a user's tier assignment """ try: assignment = await tier_service.get_user_assignment(user_id) - + if not assignment: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"No assignment found for user {user_id}") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f'No assignment found for user {user_id}', + ) + return assignment.to_dict() - + except HTTPException: raise except Exception as e: - logger.error(f"Error getting user assignment: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get assignment") + logger.error(f'Error getting user assignment: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get assignment' + ) -@tier_router.get("/assignments/{user_id}/tier", response_model=TierResponse) -async def get_user_tier( - user_id: str, - tier_service: TierService = Depends(get_tier_service_dep) -): +@tier_router.get('/assignments/{user_id}/tier', response_model=TierResponse) +async def get_user_tier(user_id: str, tier_service: TierService = Depends(get_tier_service_dep)): """ Get the effective tier for a user - + Considers effective dates and returns current tier. """ try: tier = await tier_service.get_user_tier(user_id) - + if not tier: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"No tier found for user {user_id}") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f'No tier found for user {user_id}' + ) + return TierResponse( **tier.to_dict(), created_at=tier.created_at.isoformat() if tier.created_at else None, - updated_at=tier.updated_at.isoformat() if tier.updated_at else None + updated_at=tier.updated_at.isoformat() if tier.updated_at else None, ) - + except HTTPException: raise except Exception as e: - logger.error(f"Error getting user tier: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get user tier") + logger.error(f'Error getting user tier: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get user tier' + ) -@tier_router.delete("/assignments/{user_id}", status_code=status.HTTP_204_NO_CONTENT) +@tier_router.delete('/assignments/{user_id}', status_code=status.HTTP_204_NO_CONTENT) async def remove_user_assignment( - user_id: str, - tier_service: TierService = Depends(get_tier_service_dep) + user_id: str, tier_service: TierService = Depends(get_tier_service_dep) ): """ Remove a user's tier assignment - + Requires admin permissions. """ try: removed = await tier_service.remove_user_assignment(user_id) - + if not removed: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"No assignment found for user {user_id}") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f'No assignment found for user {user_id}', + ) + except HTTPException: raise except Exception as e: - logger.error(f"Error removing user assignment: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to remove assignment") + logger.error(f'Error removing user assignment: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to remove assignment' + ) -@tier_router.get("/{tier_id}/users") +@tier_router.get('/{tier_id}/users') async def list_users_in_tier( tier_id: str, skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=1000), - tier_service: TierService = Depends(get_tier_service_dep) + tier_service: TierService = Depends(get_tier_service_dep), ): """ List all users assigned to a tier - + Requires admin permissions. """ try: assignments = await tier_service.list_users_in_tier(tier_id, skip=skip, limit=limit) - + return [assignment.to_dict() for assignment in assignments] - + except Exception as e: - logger.error(f"Error listing users in tier: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to list users") + logger.error(f'Error listing users in tier: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to list users' + ) # ============================================================================ # TIER UPGRADE/DOWNGRADE ENDPOINTS # ============================================================================ -@tier_router.post("/upgrade") + +@tier_router.post('/upgrade') async def upgrade_user_tier( - request: TierUpgradeRequest, - tier_service: TierService = Depends(get_tier_service_dep) + request: TierUpgradeRequest, tier_service: TierService = Depends(get_tier_service_dep) ): """ Upgrade a user to a higher tier - + Requires admin permissions. """ try: @@ -497,78 +531,83 @@ async def upgrade_user_tier( user_id=request.user_id, new_tier_id=request.new_tier_id, immediate=request.immediate, - scheduled_date=request.scheduled_date + scheduled_date=request.scheduled_date, ) - + return assignment.to_dict() - + except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except Exception as e: - logger.error(f"Error upgrading user tier: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to upgrade tier") + logger.error(f'Error upgrading user tier: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to upgrade tier' + ) -@tier_router.post("/downgrade") +@tier_router.post('/downgrade') async def downgrade_user_tier( - request: TierDowngradeRequest, - tier_service: TierService = Depends(get_tier_service_dep) + request: TierDowngradeRequest, tier_service: TierService = Depends(get_tier_service_dep) ): """ Downgrade a user to a lower tier - + Requires admin permissions. """ try: assignment = await tier_service.downgrade_user_tier( user_id=request.user_id, new_tier_id=request.new_tier_id, - grace_period_days=request.grace_period_days + grace_period_days=request.grace_period_days, ) - + return assignment.to_dict() - + except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except Exception as e: - logger.error(f"Error downgrading user tier: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to downgrade tier") + logger.error(f'Error downgrading user tier: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to downgrade tier' + ) -@tier_router.post("/temporary-upgrade") +@tier_router.post('/temporary-upgrade') async def temporary_tier_upgrade( - request: TemporaryUpgradeRequest, - tier_service: TierService = Depends(get_tier_service_dep) + request: TemporaryUpgradeRequest, tier_service: TierService = Depends(get_tier_service_dep) ): """ Temporarily upgrade a user to a higher tier - + Requires admin permissions. """ try: assignment = await tier_service.temporary_tier_upgrade( user_id=request.user_id, temp_tier_id=request.temp_tier_id, - duration_days=request.duration_days + duration_days=request.duration_days, ) - + return assignment.to_dict() - + except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except Exception as e: - logger.error(f"Error creating temporary upgrade: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create temporary upgrade") + logger.error(f'Error creating temporary upgrade: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Failed to create temporary upgrade', + ) # ============================================================================ # TIER COMPARISON & ANALYTICS ENDPOINTS # ============================================================================ -@tier_router.post("/compare") + +@tier_router.post('/compare') async def compare_tiers( - tier_ids: List[str], - tier_service: TierService = Depends(get_tier_service_dep) + tier_ids: list[str], tier_service: TierService = Depends(get_tier_service_dep) ): """ Compare multiple tiers side-by-side @@ -576,108 +615,119 @@ async def compare_tiers( try: comparison = await tier_service.compare_tiers(tier_ids) return comparison - + except Exception as e: - logger.error(f"Error comparing tiers: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to compare tiers") + logger.error(f'Error comparing tiers: {e}') + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to compare tiers' + ) -@tier_router.get("/statistics/all") +@tier_router.get('/statistics/all') async def get_all_tier_statistics( - request: Request, - tier_service: TierService = Depends(get_tier_service_dep) + request: Request, tier_service: TierService = Depends(get_tier_service_dep) ): """ Get statistics for all tiers - + Requires admin permissions. """ request_id = str(uuid.uuid4()) start_time = time.time() - + try: payload = await auth_required(request) username = payload.get('sub') - - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - + if not await platform_role_required_bool(username, 'manage_tiers'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='TIER001', - error_message='You do not have permission to manage tiers' - )) - + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='TIER001', + error_message='You do not have permission to manage tiers', + ) + ) + stats = await tier_service.get_all_tier_statistics() - - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response=stats - )) - + + return respond_rest( + ResponseModel( + status_code=200, response_headers={'request_id': request_id}, response=stats + ) + ) + except Exception as e: logger.error(f'{request_id} | Error getting all tier statistics: {e}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='TIER999', - error_message='Failed to get statistics' - )) - + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='TIER999', + error_message='Failed to get statistics', + ) + ) + finally: end_time = time.time() logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms') -@tier_router.get("/{tier_id}/statistics") +@tier_router.get('/{tier_id}/statistics') async def get_tier_statistics( - request: Request, - tier_id: str, - tier_service: TierService = Depends(get_tier_service_dep) + request: Request, tier_id: str, tier_service: TierService = Depends(get_tier_service_dep) ): """ Get statistics for a tier - + Requires admin permissions. """ request_id = str(uuid.uuid4()) start_time = time.time() - + try: payload = await auth_required(request) username = payload.get('sub') - - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - + if not await platform_role_required_bool(username, 'manage_tiers'): - return respond_rest(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='TIER001', - error_message='You do not have permission to manage tiers' - )) - + return respond_rest( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='TIER001', + error_message='You do not have permission to manage tiers', + ) + ) + stats = await tier_service.get_tier_statistics(tier_id) - - return respond_rest(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response=stats - )) - + + return respond_rest( + ResponseModel( + status_code=200, response_headers={'request_id': request_id}, response=stats + ) + ) + except Exception as e: logger.error(f'{request_id} | Error getting tier statistics: {e}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='TIER999', - error_message='Failed to get statistics' - )) - + return respond_rest( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='TIER999', + error_message='Failed to get statistics', + ) + ) + finally: end_time = time.time() logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms') diff --git a/backend-services/routes/tools_routes.py b/backend-services/routes/tools_routes.py index 5979308..67e5652 100644 --- a/backend-services/routes/tools_routes.py +++ b/backend-services/routes/tools_routes.py @@ -2,30 +2,38 @@ Tools and diagnostics routes (e.g., CORS checker). """ +import logging +import os +import time +import uuid +from typing import Any + from fastapi import APIRouter, Request from pydantic import BaseModel, Field -from typing import List, Optional, Dict, Any -import os -import uuid -import time -import logging from models.response_model import ResponseModel -from utils.response_util import process_response -from utils.auth_util import auth_required -from utils.role_util import platform_role_required_bool from utils import chaos_util +from utils.auth_util import auth_required +from utils.response_util import process_response +from utils.role_util import platform_role_required_bool tools_router = APIRouter() logger = logging.getLogger('doorman.gateway') + class CorsCheckRequest(BaseModel): origin: str = Field(..., description='Origin to evaluate, e.g., https://localhost:3000') method: str = Field(..., description='Intended request method, e.g., GET/POST/PUT') - request_headers: Optional[List[str]] = Field(default=None, description='Requested headers from Access-Control-Request-Headers') - with_credentials: Optional[bool] = Field(default=None, description='Whether credentials will be sent; defaults to ALLOW_CREDENTIALS env if omitted') + request_headers: list[str] | None = Field( + default=None, description='Requested headers from Access-Control-Request-Headers' + ) + with_credentials: bool | None = Field( + default=None, + description='Whether credentials will be sent; defaults to ALLOW_CREDENTIALS env if omitted', + ) -def _compute_cors_config() -> Dict[str, Any]: + +def _compute_cors_config() -> dict[str, Any]: origins_env = os.getenv('ALLOWED_ORIGINS', 'http://localhost:3000') if not (origins_env or '').strip(): @@ -67,6 +75,7 @@ def _compute_cors_config() -> Dict[str, Any]: 'cors_strict': cors_strict, } + """ Endpoint @@ -76,67 +85,93 @@ Response: {} """ -@tools_router.post('/cors/check', description='Simulate CORS preflight/actual decisions against current gateway config', response_model=ResponseModel) +@tools_router.post( + '/cors/check', + description='Simulate CORS preflight/actual decisions against current gateway config', + response_model=ResponseModel, +) async def cors_check(request: Request, body: CorsCheckRequest): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, 'manage_security'): - return process_response(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='TLS001', - error_message='You do not have permission to use tools' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='TLS001', + error_message='You do not have permission to use tools', + ).dict(), + 'rest', + ) cfg = _compute_cors_config() origin = (body.origin or '').strip() method = (body.method or '').strip().upper() requested_headers = [h.strip() for h in (body.request_headers or []) if h.strip()] - with_credentials = cfg['credentials'] if body.with_credentials is None else bool(body.with_credentials) + with_credentials = ( + cfg['credentials'] if body.with_credentials is None else bool(body.with_credentials) + ) - origin_allowed = origin in cfg['safe_origins'] or (not cfg['cors_strict'] and '*' in cfg['origins']) + origin_allowed = origin in cfg['safe_origins'] or ( + not cfg['cors_strict'] and '*' in cfg['origins'] + ) method_allowed = method in cfg['methods'] allowed_headers_lower = {h.lower() for h in cfg['headers']} requested_lower = [h.lower() for h in requested_headers] - headers_not_allowed = [h for h in requested_headers if h.lower() not in allowed_headers_lower] + headers_not_allowed = [ + h for h in requested_headers if h.lower() not in allowed_headers_lower + ] headers_allowed = len(headers_not_allowed) == 0 preflight_allowed = origin_allowed and method_allowed and headers_allowed preflight_headers = { - 'Access-Control-Allow-Origin': origin if origin_allowed else None, 'Access-Control-Allow-Methods': ', '.join(cfg['methods']), - 'Access-Control-Allow-Headers': ', '.join(cfg['headers']) if requested_headers else ', '.join(cfg['headers']), - 'Access-Control-Allow-Credentials': 'true' if with_credentials and cfg['credentials'] else 'false', + 'Access-Control-Allow-Headers': ', '.join(cfg['headers']) + if requested_headers + else ', '.join(cfg['headers']), + 'Access-Control-Allow-Credentials': 'true' + if with_credentials and cfg['credentials'] + else 'false', 'Vary': 'Origin', } actual_allowed = origin_allowed actual_headers = { 'Access-Control-Allow-Origin': origin if origin_allowed else None, - 'Access-Control-Allow-Credentials': 'true' if with_credentials and cfg['credentials'] else 'false', + 'Access-Control-Allow-Credentials': 'true' + if with_credentials and cfg['credentials'] + else 'false', 'Vary': 'Origin', } - notes: List[str] = [] + notes: list[str] = [] if cfg['credentials'] and ('*' in cfg['origins']) and not cfg['cors_strict']: - notes.append('Wildcard origins with credentials can be rejected by browsers; prefer explicit origins or set CORS_STRICT=true.') + notes.append( + 'Wildcard origins with credentials can be rejected by browsers; prefer explicit origins or set CORS_STRICT=true.' + ) if any(h == '*' for h in os.getenv('ALLOW_HEADERS', '*').split(',')): - notes.append("ALLOW_HEADERS='*' replaced with a conservative default set to satisfy credentialed requests.") + notes.append( + "ALLOW_HEADERS='*' replaced with a conservative default set to satisfy credentialed requests." + ) if not origin_allowed: notes.append('Origin is not allowed based on current configuration.') if not method_allowed: notes.append('Requested method is not in ALLOW_METHODS.') if not headers_allowed and headers_not_allowed: - notes.append(f"Some requested headers are not allowed: {', '.join(headers_not_allowed)}") + notes.append( + f'Some requested headers are not allowed: {", ".join(headers_not_allowed)}' + ) response_payload = { 'config': { @@ -162,37 +197,48 @@ async def cors_check(request: Request, body: CorsCheckRequest): 'not_allowed_headers': headers_not_allowed, 'response_headers': preflight_headers, }, - 'actual': { - 'allowed': actual_allowed, - 'response_headers': actual_headers, - }, + 'actual': {'allowed': actual_allowed, 'response_headers': actual_headers}, 'notes': notes, } - return process_response(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response=response_payload - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + response=response_payload, + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='TLS999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='TLS999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + class ChaosToggleRequest(BaseModel): backend: str = Field(..., description='Backend to toggle (redis|mongo)') enabled: bool = Field(..., description='Enable or disable outage simulation') - duration_ms: Optional[int] = Field(default=None, description='Optional duration for outage before auto-disable') + duration_ms: int | None = Field( + default=None, description='Optional duration for outage before auto-disable' + ) -@tools_router.post('/chaos/toggle', description='Toggle simulated backend outages (redis|mongo)', response_model=ResponseModel) + +@tools_router.post( + '/chaos/toggle', + description='Toggle simulated backend outages (redis|mongo)', + response_model=ResponseModel, +) async def chaos_toggle(request: Request, body: ChaosToggleRequest): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 @@ -200,42 +246,57 @@ async def chaos_toggle(request: Request, body: ChaosToggleRequest): payload = await auth_required(request) username = payload.get('sub') if not await platform_role_required_bool(username, 'manage_gateway'): - return process_response(ResponseModel( - status_code=403, - response_headers={'request_id': request_id}, - error_code='TLS001', - error_message='You do not have permission to use tools' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='TLS001', + error_message='You do not have permission to use tools', + ).dict(), + 'rest', + ) backend = (body.backend or '').strip().lower() if backend not in ('redis', 'mongo'): - return process_response(ResponseModel( - status_code=400, - response_headers={'request_id': request_id}, - error_code='TLS002', - error_message='backend must be redis or mongo' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=400, + response_headers={'request_id': request_id}, + error_code='TLS002', + error_message='backend must be redis or mongo', + ).dict(), + 'rest', + ) if body.duration_ms and int(body.duration_ms) > 0: chaos_util.enable_for(backend, int(body.duration_ms)) else: chaos_util.enable(backend, bool(body.enabled)) - return process_response(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response={'backend': backend, 'enabled': chaos_util.should_fail(backend)} - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=200, + response_headers={'request_id': request_id}, + response={'backend': backend, 'enabled': chaos_util.should_fail(backend)}, + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='TLS999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='TLS999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') -@tools_router.get('/chaos/stats', description='Get chaos simulation stats', response_model=ResponseModel) + +@tools_router.get( + '/chaos/stats', description='Get chaos simulation stats', response_model=ResponseModel +) async def chaos_stats(request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 @@ -243,25 +304,34 @@ async def chaos_stats(request: Request): payload = await auth_required(request) username = payload.get('sub') if not await platform_role_required_bool(username, 'manage_gateway'): - return process_response(ResponseModel( - status_code=403, + return process_response( + ResponseModel( + status_code=403, + response_headers={'request_id': request_id}, + error_code='TLS001', + error_message='You do not have permission to use tools', + ).dict(), + 'rest', + ) + return process_response( + ResponseModel( + status_code=200, response_headers={'request_id': request_id}, - error_code='TLS001', - error_message='You do not have permission to use tools' - ).dict(), 'rest') - return process_response(ResponseModel( - status_code=200, - response_headers={'request_id': request_id}, - response=chaos_util.stats() - ).dict(), 'rest') + response=chaos_util.stats(), + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={'request_id': request_id}, - error_code='TLS999', - error_message='An unexpected error occurred' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={'request_id': request_id}, + error_code='TLS999', + error_message='An unexpected error occurred', + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') diff --git a/backend-services/routes/user_routes.py b/backend-services/routes/user_routes.py index 8dd41cb..3d1e100 100644 --- a/backend-services/routes/user_routes.py +++ b/backend-services/routes/user_routes.py @@ -4,41 +4,43 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import List -import os -from fastapi import APIRouter, Request, HTTPException -import uuid -import time import logging +import os +import time +import uuid +from fastapi import APIRouter, HTTPException, Request + +from models.create_user_model import CreateUserModel from models.response_model import ResponseModel +from models.update_password_model import UpdatePasswordModel +from models.update_user_model import UpdateUserModel from models.user_model_response import UserModelResponse from services.user_service import UserService from utils.auth_util import auth_required -from utils.response_util import respond_rest, process_response -from utils.role_util import platform_role_required_bool, is_admin_user, is_admin_role -from utils.constants import ErrorCodes, Messages, Defaults, Roles, Headers -from utils.database import role_collection -from models.create_user_model import CreateUserModel -from models.update_user_model import UpdateUserModel -from models.update_password_model import UpdatePasswordModel +from utils.constants import Defaults, ErrorCodes, Headers, Messages, Roles +from utils.response_util import process_response, respond_rest +from utils.role_util import is_admin_role, is_admin_user, platform_role_required_bool user_router = APIRouter() logger = logging.getLogger('doorman.gateway') + async def _safe_is_admin_user(username: str) -> bool: try: return await is_admin_user(username) except Exception: return False + async def _safe_is_admin_role(role: str) -> bool: try: return await is_admin_role(role) except Exception: return False + """ Add user @@ -48,61 +50,62 @@ Response: {} """ -@user_router.post('', + +@user_router.post( + '', description='Add user', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'User created successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'User created successfully'}}}, } - } + }, ) - async def create_user(user_data: CreateUserModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') if not await platform_role_required_bool(username, Roles.MANAGE_USERS): return respond_rest( ResponseModel( status_code=403, error_code='USR006', - error_message='Can only update your own information' - )) + error_message='Can only update your own information', + ) + ) if user_data.role and await _safe_is_admin_role(user_data.role): if not await _safe_is_admin_user(username): return respond_rest( ResponseModel( status_code=403, error_code='USR015', - error_message='Only admin may create users with the admin role' - )) + error_message='Only admin may create users with the admin role', + ) + ) return respond_rest(await UserService.create_user(user_data, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Update user @@ -112,41 +115,47 @@ Response: {} """ -@user_router.put('/{username}', + +@user_router.put( + '/{username}', description='Update user', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'User updated successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'User updated successfully'}}}, } - } + }, ) - async def update_user(username: str, api_data: UpdateUserModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) auth_username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') # Block modifications to bootstrap admin user, except for limited operational fields if username == 'admin': allowed_keys = { - 'bandwidth_limit_bytes', 'bandwidth_limit_window', - 'rate_limit_duration', 'rate_limit_duration_type', 'rate_limit_enabled', - 'throttle_duration', 'throttle_duration_type', 'throttle_wait_duration', - 'throttle_wait_duration_type', 'throttle_queue_limit', 'throttle_enabled', + 'bandwidth_limit_bytes', + 'bandwidth_limit_window', + 'rate_limit_duration', + 'rate_limit_duration_type', + 'rate_limit_enabled', + 'throttle_duration', + 'throttle_duration_type', + 'throttle_wait_duration', + 'throttle_wait_duration_type', + 'throttle_queue_limit', + 'throttle_enabled', } try: - incoming = {k for k, v in (api_data.dict(exclude_unset=True) or {}).items() if v is not None} + incoming = { + k for k, v in (api_data.dict(exclude_unset=True) or {}).items() if v is not None + } except Exception: incoming = set() if not incoming.issubset(allowed_keys): @@ -154,22 +163,27 @@ async def update_user(username: str, api_data: UpdateUserModel, request: Request ResponseModel( status_code=403, error_code='USR020', - error_message='Super admin user cannot be modified' - )) - if not auth_username == username and not await platform_role_required_bool(auth_username, Roles.MANAGE_USERS): + error_message='Super admin user cannot be modified', + ) + ) + if not auth_username == username and not await platform_role_required_bool( + auth_username, Roles.MANAGE_USERS + ): return respond_rest( ResponseModel( status_code=403, error_code='USR006', - error_message='Can only update your own information' - )) + error_message='Can only update your own information', + ) + ) if await _safe_is_admin_user(username) and not await _safe_is_admin_user(auth_username): return respond_rest( ResponseModel( status_code=403, error_code='USR012', - error_message='Only admin may modify admin users' - )) + error_message='Only admin may modify admin users', + ) + ) new_role = api_data.role if new_role is not None: target_is_admin = await _safe_is_admin_user(username) @@ -179,23 +193,26 @@ async def update_user(username: str, api_data: UpdateUserModel, request: Request ResponseModel( status_code=403, error_code='USR013', - error_message='Only admin may change admin role assignments' - )) + error_message='Only admin may change admin role assignments', + ) + ) return respond_rest(await UserService.update_user(username, api_data, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Delete user @@ -205,30 +222,27 @@ Response: {} """ -@user_router.delete('/{username}', + +@user_router.delete( + '/{username}', description='Delete user', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', - 'content': { - 'application/json': { - 'example': { - 'message': 'User deleted successfully' - } - } - } + 'content': {'application/json': {'example': {'message': 'User deleted successfully'}}}, } - } + }, ) - async def delete_user(username: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) auth_username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') # Block any deletion of bootstrap admin user if username == 'admin': @@ -236,37 +250,44 @@ async def delete_user(username: str, request: Request): ResponseModel( status_code=403, error_code='USR021', - error_message='Super admin user cannot be deleted' - )) - if not auth_username == username and not await platform_role_required_bool(auth_username, Roles.MANAGE_USERS): + error_message='Super admin user cannot be deleted', + ) + ) + if not auth_username == username and not await platform_role_required_bool( + auth_username, Roles.MANAGE_USERS + ): return respond_rest( ResponseModel( status_code=403, error_code='USR007', - error_message='Can only delete your own account' - )) + error_message='Can only delete your own account', + ) + ) if await _safe_is_admin_user(username) and not await _safe_is_admin_user(auth_username): return respond_rest( ResponseModel( status_code=403, error_code='USR014', - error_message='Only admin may delete admin users' - )) + error_message='Only admin may delete admin users', + ) + ) return respond_rest(await UserService.delete_user(username, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Update user password @@ -276,65 +297,68 @@ Response: {} """ -@user_router.put('/{username}/update-password', + +@user_router.put( + '/{username}/update-password', description='Update user password', response_model=ResponseModel, responses={ 200: { 'description': 'Successful Response', 'content': { - 'application/json': { - 'example': { - 'message': 'Password updated successfully' - } - } - } + 'application/json': {'example': {'message': 'Password updated successfully'}} + }, } - } + }, ) - async def update_user_password(username: str, api_data: UpdatePasswordModel, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) auth_username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') # Block any password changes to bootstrap admin user if username == 'admin': - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='USR022', - error_message='Super admin password cannot be changed via UI' - )) - if not auth_username == username and not await platform_role_required_bool(auth_username, Roles.MANAGE_USERS): - return respond_rest(ResponseModel( - status_code=403, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code='USR006', - error_message='Can only update your own password' - )) + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='USR022', + error_message='Super admin password cannot be changed via UI', + ) + ) + if not auth_username == username and not await platform_role_required_bool( + auth_username, Roles.MANAGE_USERS + ): + return respond_rest( + ResponseModel( + status_code=403, + response_headers={Headers.REQUEST_ID: request_id}, + error_code='USR006', + error_message='Can only update your own password', + ) + ) return respond_rest(await UserService.update_password(username, api_data, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -344,43 +368,43 @@ Response: {} """ -@user_router.get('/me', - description='Get user by username', - response_model=UserModelResponse - ) +@user_router.get('/me', description='Get user by username', response_model=UserModelResponse) async def get_user_by_username(request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) auth_username = payload.get('sub') - logger.info(f'{request_id} | Username: {auth_username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {auth_username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') return respond_rest(await UserService.get_user_by_username(auth_username, request_id)) except HTTPException as e: - return respond_rest(ResponseModel( - status_code=e.status_code, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.HTTP_EXCEPTION, - error_message=e.detail - )) + return respond_rest( + ResponseModel( + status_code=e.status_code, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.HTTP_EXCEPTION, + error_message=e.detail, + ) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return respond_rest(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return respond_rest( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -390,18 +414,19 @@ Response: {} """ -@user_router.get('/all', - description='Get all users', - response_model=List[UserModelResponse] -) -async def get_all_users(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE): +@user_router.get('/all', description='Get all users', response_model=list[UserModelResponse]) +async def get_all_users( + request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE +): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') data = await UserService.get_all_users(page, page_size, request_id) if data.get('status_code') == 200 and isinstance(data.get('response'), dict): @@ -412,35 +437,40 @@ async def get_all_users(request: Request, page: int = Defaults.PAGE, page_size: if u.get('username') == 'admin': continue # Hide other admin role users from non-admin users - if not await _safe_is_admin_user(username) and await _safe_is_admin_role(u.get('role')): + if not await _safe_is_admin_user(username) and await _safe_is_admin_role( + u.get('role') + ): continue filtered.append(u) data = dict(data) data['response'] = {'users': filtered} return process_response(data, 'rest') except HTTPException as e: - return process_response(ResponseModel( - status_code=e.status_code, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.HTTP_EXCEPTION, - error_message=e.detail - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=e.status_code, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.HTTP_EXCEPTION, + error_message=e.detail, + ).dict(), + 'rest', + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -450,55 +480,72 @@ Response: {} """ -@user_router.get('/{username}', - description='Get user by username', - response_model=UserModelResponse -) +@user_router.get( + '/{username}', description='Get user by username', response_model=UserModelResponse +) async def get_user_by_username(username: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) auth_username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') # Block access to bootstrap admin user for ALL users (including admin), # except when STRICT_RESPONSE_ENVELOPE=true (envelope-shape tests) - if username == 'admin' and not (os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true'): - return process_response(ResponseModel( - status_code=404, - response_headers={Headers.REQUEST_ID: request_id}, - error_message='User not found' - ).dict(), 'rest') - if not auth_username == username and not await platform_role_required_bool(auth_username, 'manage_users'): + if username == 'admin' and not ( + os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true' + ): + return process_response( + ResponseModel( + status_code=404, + response_headers={Headers.REQUEST_ID: request_id}, + error_message='User not found', + ).dict(), + 'rest', + ) + if not auth_username == username and not await platform_role_required_bool( + auth_username, 'manage_users' + ): return process_response( ResponseModel( status_code=403, error_code='USR008', error_message='Unable to retrieve information for user', - ).dict(), 'rest') + ).dict(), + 'rest', + ) if not await _safe_is_admin_user(auth_username) and await _safe_is_admin_user(username): - return process_response(ResponseModel( - status_code=404, - response_headers={Headers.REQUEST_ID: request_id}, - error_message='User not found' - ).dict(), 'rest') - return process_response(await UserService.get_user_by_username(username, request_id), 'rest') + return process_response( + ResponseModel( + status_code=404, + response_headers={Headers.REQUEST_ID: request_id}, + error_message='User not found', + ).dict(), + 'rest', + ) + return process_response( + await UserService.get_user_by_username(username, request_id), 'rest' + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') + """ Endpoint @@ -508,55 +555,66 @@ Response: {} """ -@user_router.get('/email/{email}', - description='Get user by email', - response_model=List[UserModelResponse] -) +@user_router.get( + '/email/{email}', description='Get user by email', response_model=list[UserModelResponse] +) async def get_user_by_email(email: str, request: Request): request_id = str(uuid.uuid4()) start_time = time.time() * 1000 try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') data = await UserService.get_user_by_email(username, email, request_id) if data.get('status_code') == 200 and isinstance(data.get('response'), dict): u = data.get('response') # Block access to bootstrap admin user for ALL users if u.get('username') == 'admin': - return process_response(ResponseModel( - status_code=404, - response_headers={Headers.REQUEST_ID: request_id}, - error_message='User not found' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=404, + response_headers={Headers.REQUEST_ID: request_id}, + error_message='User not found', + ).dict(), + 'rest', + ) # Block access to other admin users for non-admin users if not await _safe_is_admin_user(username) and await _safe_is_admin_role(u.get('role')): - return process_response(ResponseModel( - status_code=404, - response_headers={Headers.REQUEST_ID: request_id}, - error_message='User not found' - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=404, + response_headers={Headers.REQUEST_ID: request_id}, + error_message='User not found', + ).dict(), + 'rest', + ) return process_response(data, 'rest') except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - ).dict(), 'rest') + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ).dict(), + 'rest', + ) finally: end_time = time.time() * 1000 logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms') -@user_router.get('', - description='Get all users (base path)', - response_model=List[UserModelResponse] + + +@user_router.get( + '', description='Get all users (base path)', response_model=list[UserModelResponse] ) -async def get_all_users_base(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE): +async def get_all_users_base( + request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE +): """Convenience alias for GET /platform/user/all to support clients and tests that expect listing at the base collection path. """ diff --git a/backend-services/routes/vault_routes.py b/backend-services/routes/vault_routes.py index 9a0afac..b1fcd86 100644 --- a/backend-services/routes/vault_routes.py +++ b/backend-services/routes/vault_routes.py @@ -4,20 +4,19 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import List -import uuid -import time import logging -from fastapi import APIRouter, Request, HTTPException +import time +import uuid + +from fastapi import APIRouter, Request -from models.response_model import ResponseModel from models.create_vault_entry_model import CreateVaultEntryModel +from models.response_model import ResponseModel from models.update_vault_entry_model import UpdateVaultEntryModel -from models.vault_entry_model_response import VaultEntryModelResponse from services.vault_service import VaultService from utils.auth_util import auth_required -from utils.response_util import respond_rest, process_response -from utils.constants import ErrorCodes, Messages, Headers +from utils.constants import ErrorCodes, Headers, Messages +from utils.response_util import process_response, respond_rest vault_router = APIRouter() @@ -35,12 +34,12 @@ logger = logging.getLogger('doorman.gateway') 'application/json': { 'example': { 'message': 'Vault entry created successfully', - 'data': {'key_name': 'api_key_production'} + 'data': {'key_name': 'api_key_production'}, } } - } + }, } - } + }, ) async def create_vault_entry(entry_data: CreateVaultEntryModel, request: Request): """ @@ -52,20 +51,22 @@ async def create_vault_entry(entry_data: CreateVaultEntryModel, request: Request try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - + return respond_rest(await VaultService.create_vault_entry(username, entry_data, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: elapsed = time.time() * 1000 - start_time logger.info(f'{request_id} | Total time: {elapsed:.2f}ms') @@ -88,16 +89,16 @@ async def create_vault_entry(entry_data: CreateVaultEntryModel, request: Request 'username': 'john_doe', 'description': 'Production API key', 'created_at': '2024-11-22T10:15:30Z', - 'updated_at': '2024-11-22T10:15:30Z' + 'updated_at': '2024-11-22T10:15:30Z', } ], - 'count': 1 + 'count': 1, } } } - } + }, } - } + }, ) async def list_vault_entries(request: Request): """ @@ -109,20 +110,22 @@ async def list_vault_entries(request: Request): try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - + return respond_rest(await VaultService.list_vault_entries(username, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: elapsed = time.time() * 1000 - start_time logger.info(f'{request_id} | Total time: {elapsed:.2f}ms') @@ -143,13 +146,13 @@ async def list_vault_entries(request: Request): 'username': 'john_doe', 'description': 'Production API key', 'created_at': '2024-11-22T10:15:30Z', - 'updated_at': '2024-11-22T10:15:30Z' + 'updated_at': '2024-11-22T10:15:30Z', } } } - } + }, } - } + }, ) async def get_vault_entry(key_name: str, request: Request): """ @@ -161,20 +164,22 @@ async def get_vault_entry(key_name: str, request: Request): try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - + return respond_rest(await VaultService.get_vault_entry(username, key_name, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: elapsed = time.time() * 1000 - start_time logger.info(f'{request_id} | Total time: {elapsed:.2f}ms') @@ -188,14 +193,10 @@ async def get_vault_entry(key_name: str, request: Request): 200: { 'description': 'Vault entry updated successfully', 'content': { - 'application/json': { - 'example': { - 'message': 'Vault entry updated successfully' - } - } - } + 'application/json': {'example': {'message': 'Vault entry updated successfully'}} + }, } - } + }, ) async def update_vault_entry(key_name: str, update_data: UpdateVaultEntryModel, request: Request): """ @@ -207,20 +208,24 @@ async def update_vault_entry(key_name: str, update_data: UpdateVaultEntryModel, try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - - return respond_rest(await VaultService.update_vault_entry(username, key_name, update_data, request_id)) + + return respond_rest( + await VaultService.update_vault_entry(username, key_name, update_data, request_id) + ) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: elapsed = time.time() * 1000 - start_time logger.info(f'{request_id} | Total time: {elapsed:.2f}ms') @@ -234,14 +239,10 @@ async def update_vault_entry(key_name: str, update_data: UpdateVaultEntryModel, 200: { 'description': 'Vault entry deleted successfully', 'content': { - 'application/json': { - 'example': { - 'message': 'Vault entry deleted successfully' - } - } - } + 'application/json': {'example': {'message': 'Vault entry deleted successfully'}} + }, } - } + }, ) async def delete_vault_entry(key_name: str, request: Request): """ @@ -253,20 +254,22 @@ async def delete_vault_entry(key_name: str, request: Request): try: payload = await auth_required(request) username = payload.get('sub') - logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}') + logger.info( + f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}' + ) logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}') - + return respond_rest(await VaultService.delete_vault_entry(username, key_name, request_id)) except Exception as e: logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True) - return process_response(ResponseModel( - status_code=500, - response_headers={ - Headers.REQUEST_ID: request_id - }, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED - )) + return process_response( + ResponseModel( + status_code=500, + response_headers={Headers.REQUEST_ID: request_id}, + error_code=ErrorCodes.UNEXPECTED, + error_message=Messages.UNEXPECTED, + ) + ) finally: elapsed = time.time() * 1000 - start_time logger.info(f'{request_id} | Total time: {elapsed:.2f}ms') diff --git a/backend-services/services/api_service.py b/backend-services/services/api_service.py index 1bbf4c2..d7e47c4 100644 --- a/backend-services/services/api_service.py +++ b/backend-services/services/api_service.py @@ -4,23 +4,22 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -import uuid import logging +import uuid +from models.create_api_model import CreateApiModel from models.response_model import ResponseModel from models.update_api_model import UpdateApiModel +from utils.async_db import db_delete_one, db_find_one, db_insert_one, db_update_one +from utils.constants import ErrorCodes, Messages from utils.database_async import api_collection -from utils.async_db import db_find_one, db_insert_one, db_update_one, db_delete_one -from utils.cache_manager_util import cache_manager from utils.doorman_cache_util import doorman_cache -from models.create_api_model import CreateApiModel from utils.paging_util import validate_page_params -from utils.constants import ErrorCodes, Messages, Defaults logger = logging.getLogger('doorman.gateway') -class ApiService: +class ApiService: @staticmethod async def create_api(data: CreateApiModel, request_id): """ @@ -33,16 +32,17 @@ class ApiService: return ResponseModel( status_code=400, error_code='API013', - error_message='Public API cannot have credits enabled' + error_message='Public API cannot have credits enabled', ).dict() except Exception: pass cache_key = f'{data.api_name}/{data.api_version}' existing = doorman_cache.get_cache('api_cache', cache_key) if not existing: - existing = await db_find_one(api_collection, {'api_name': data.api_name, 'api_version': data.api_version}) + existing = await db_find_one( + api_collection, {'api_name': data.api_name, 'api_version': data.api_version} + ) if existing: - try: if existing.get('_id'): existing = {k: v for k, v in existing.items() if k != '_id'} @@ -50,7 +50,9 @@ class ApiService: if not existing.get('api_id'): existing['api_id'] = str(uuid.uuid4()) if not existing.get('api_path'): - existing['api_path'] = f"/{existing.get('api_name')}/{existing.get('api_version')}" + existing['api_path'] = ( + f'/{existing.get("api_name")}/{existing.get("api_version")}' + ) doorman_cache.set_cache('api_cache', cache_key, existing) doorman_cache.set_cache('api_id_cache', existing['api_path'], existing['api_id']) except Exception: @@ -58,11 +60,9 @@ class ApiService: logger.info(request_id + ' | API already exists; returning success') return ResponseModel( status_code=200, - response_headers={ - 'request_id': request_id - }, - message='API already exists' - ).dict() + response_headers={'request_id': request_id}, + message='API already exists', + ).dict() data.api_path = f'/{data.api_name}/{data.api_version}' data.api_id = str(uuid.uuid4()) api_dict = data.dict() @@ -70,10 +70,8 @@ class ApiService: if not insert_result.acknowledged: logger.error(request_id + ' | API creation failed with code API002') return ResponseModel( - status_code=400, - error_code='API002', - error_message='Unable to insert endpoint' - ).dict() + status_code=400, error_code='API002', error_message='Unable to insert endpoint' + ).dict() api_dict['_id'] = str(insert_result.inserted_id) # Cache by both api_id and canonical path for consistent lookups doorman_cache.set_cache('api_cache', data.api_id, api_dict) @@ -89,12 +87,10 @@ class ApiService: api_copy = {'api_name': data.api_name, 'api_version': data.api_version} return ResponseModel( status_code=201, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, response={'api': api_copy}, - message='API created successfully' - ).dict() + message='API created successfully', + ).dict() @staticmethod async def update_api(api_name, api_version, data: UpdateApiModel, request_id): @@ -102,36 +98,49 @@ class ApiService: Update an API on the platform. """ logger.info(request_id + ' | Updating API: ' + api_name + ' ' + api_version) - if data.api_name and data.api_name != api_name or data.api_version and data.api_version != api_version or data.api_path and data.api_path != f'/{api_name}/{api_version}': + if ( + data.api_name + and data.api_name != api_name + or data.api_version + and data.api_version != api_version + or data.api_path + and data.api_path != f'/{api_name}/{api_version}' + ): logger.error(request_id + ' | API update failed with code API005') return ResponseModel( status_code=400, error_code='API005', - error_message='API name and version cannot be updated' - ).dict() + error_message='API name and version cannot be updated', + ).dict() api = doorman_cache.get_cache('api_cache', f'{api_name}/{api_version}') if not api: - api = await db_find_one(api_collection, {'api_name': api_name, 'api_version': api_version}) + api = await db_find_one( + api_collection, {'api_name': api_name, 'api_version': api_version} + ) if not api: logger.error(request_id + ' | API update failed with code API003') return ResponseModel( status_code=400, error_code='API003', - error_message='API does not exist for the requested name and version' - ).dict() + error_message='API does not exist for the requested name and version', + ).dict() else: - doorman_cache.delete_cache('api_cache', doorman_cache.get_cache('api_id_cache', f'/{api_name}/{api_version}')) + doorman_cache.delete_cache( + 'api_cache', doorman_cache.get_cache('api_id_cache', f'/{api_name}/{api_version}') + ) doorman_cache.delete_cache('api_id_cache', f'/{api_name}/{api_version}') not_null_data = {k: v for k, v in data.dict().items() if v is not None} try: desired_public = bool(not_null_data.get('api_public', api.get('api_public'))) - desired_credits = bool(not_null_data.get('api_credits_enabled', api.get('api_credits_enabled'))) + desired_credits = bool( + not_null_data.get('api_credits_enabled', api.get('api_credits_enabled')) + ) if desired_public and desired_credits: return ResponseModel( status_code=400, error_code='API013', - error_message='Public API cannot have credits enabled' + error_message='Public API cannot have credits enabled', ).dict() except Exception: pass @@ -140,7 +149,7 @@ class ApiService: update_result = await db_update_one( api_collection, {'api_name': api_name, 'api_version': api_version}, - {'$set': not_null_data} + {'$set': not_null_data}, ) if update_result.modified_count > 0: cache_key = f'{api_name}/{api_version}' @@ -149,28 +158,23 @@ class ApiService: if not update_result.acknowledged or update_result.modified_count == 0: logger.error(request_id + ' | API update failed with code API002') return ResponseModel( - status_code=400, - error_code='API002', - error_message='Unable to update api' - ).dict() + status_code=400, error_code='API002', error_message='Unable to update api' + ).dict() except Exception as e: cache_key = f'{api_name}/{api_version}' doorman_cache.delete_cache('api_cache', cache_key) doorman_cache.delete_cache('api_id_cache', f'/{api_name}/{api_version}') - logger.error(request_id + ' | API update failed with exception: ' + str(e), exc_info=True) + logger.error( + request_id + ' | API update failed with exception: ' + str(e), exc_info=True + ) raise logger.info(request_id + ' | API updated successful') - return ResponseModel( - status_code=200, - message='API updated successfully' - ).dict() + return ResponseModel(status_code=200, message='API updated successfully').dict() else: logger.error(request_id + ' | API update failed with code API006') return ResponseModel( - status_code=400, - error_code='API006', - error_message='No data to update' - ).dict() + status_code=400, error_code='API006', error_message='No data to update' + ).dict() @staticmethod async def delete_api(api_name, api_version, request_id): @@ -180,32 +184,34 @@ class ApiService: logger.info(request_id + ' | Deleting API: ' + api_name + ' ' + api_version) api = doorman_cache.get_cache('api_cache', f'{api_name}/{api_version}') if not api: - api = await db_find_one(api_collection, {'api_name': api_name, 'api_version': api_version}) + api = await db_find_one( + api_collection, {'api_name': api_name, 'api_version': api_version} + ) if not api: logger.error(request_id + ' | API deletion failed with code API003') return ResponseModel( status_code=400, error_code='API003', - error_message='API does not exist for the requested name and version' - ).dict() - delete_result = await db_delete_one(api_collection, {'api_name': api_name, 'api_version': api_version}) + error_message='API does not exist for the requested name and version', + ).dict() + delete_result = await db_delete_one( + api_collection, {'api_name': api_name, 'api_version': api_version} + ) if not delete_result.acknowledged: logger.error(request_id + ' | API deletion failed with code API002') return ResponseModel( - status_code=400, - error_code='API002', - error_message='Unable to delete endpoint' - ).dict() - doorman_cache.delete_cache('api_cache', doorman_cache.get_cache('api_id_cache', f'/{api_name}/{api_version}')) + status_code=400, error_code='API002', error_message='Unable to delete endpoint' + ).dict() + doorman_cache.delete_cache( + 'api_cache', doorman_cache.get_cache('api_id_cache', f'/{api_name}/{api_version}') + ) doorman_cache.delete_cache('api_id_cache', f'/{api_name}/{api_version}') logger.info(request_id + ' | API deletion successful') return ResponseModel( status_code=200, - response_headers={ - 'request_id': request_id - }, - message='API deleted successfully' - ).dict() + response_headers={'request_id': request_id}, + message='API deleted successfully', + ).dict() @staticmethod async def get_api_by_name_version(api_name, api_version, request_id): @@ -215,52 +221,52 @@ class ApiService: logger.info(request_id + ' | Getting API: ' + api_name + ' ' + api_version) api = doorman_cache.get_cache('api_cache', f'{api_name}/{api_version}') if not api: - api = api_collection.find_one({'api_name': api_name, 'api_version': api_version}) + api = await db_find_one( + api_collection, {'api_name': api_name, 'api_version': api_version} + ) if not api: logger.error(request_id + ' | API retrieval failed with code API003') return ResponseModel( status_code=400, error_code='API003', - error_message='API does not exist for the requested name and version' - ).dict() - if api.get('_id'): del api['_id'] + error_message='API does not exist for the requested name and version', + ).dict() + if api.get('_id'): + del api['_id'] doorman_cache.set_cache('api_cache', f'{api_name}/{api_version}', api) if '_id' in api: del api['_id'] logger.info(request_id + ' | API retrieval successful') - return ResponseModel( - status_code=200, - response=api - ).dict() + return ResponseModel(status_code=200, response=api).dict() @staticmethod async def get_apis(page, page_size, request_id): """ Get all APIs that a user has access to with pagination. """ - logger.info(request_id + ' | Getting APIs: Page=' + str(page) + ' Page Size=' + str(page_size)) + logger.info( + request_id + ' | Getting APIs: Page=' + str(page) + ' Page Size=' + str(page_size) + ) try: page, page_size = validate_page_params(page, page_size) except Exception as e: return ResponseModel( status_code=400, error_code=ErrorCodes.PAGE_SIZE, - error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING) + error_message=( + Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING + ), ).dict() skip = (page - 1) * page_size cursor = api_collection.find().sort('api_name', 1).skip(skip).limit(page_size) apis = cursor.to_list(length=None) for api in apis: - if api.get('_id'): del api['_id'] + if api.get('_id'): + del api['_id'] logger.info(request_id + ' | APIs retrieval successful') - meta = { - 'total': len(apis), - 'page': page, - 'page_size': page_size, - } + meta = {'total': len(apis), 'page': page, 'page_size': page_size} # Add a message to keep payload sizes above compression overhead for tests message = 'API list retrieved successfully. This message also helps ensure the response has a reasonable size for compression benchmarks.' return ResponseModel( - status_code=200, - response={'apis': apis, 'meta': meta, 'message': message} - ).dict() + status_code=200, response={'apis': apis, 'meta': meta, 'message': message} + ).dict() diff --git a/backend-services/services/credit_service.py b/backend-services/services/credit_service.py index 6e8f84a..463bda0 100644 --- a/backend-services/services/credit_service.py +++ b/backend-services/services/credit_service.py @@ -4,39 +4,37 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from pymongo.errors import PyMongoError import logging -from typing import Optional import secrets -from models.response_model import ResponseModel +from pymongo.errors import PyMongoError + from models.credit_model import CreditModel +from models.response_model import ResponseModel from models.user_credits_model import UserCreditModel -from utils.database_async import credit_def_collection, user_credit_collection -from utils.async_db import db_find_one, db_insert_one, db_update_one, db_delete_one, db_find_list -from utils.encryption_util import encrypt_value, decrypt_value -from utils.doorman_cache_util import doorman_cache -from utils.paging_util import validate_page_params +from utils.async_db import db_delete_one, db_find_list, db_find_one, db_insert_one, db_update_one from utils.constants import ErrorCodes, Messages +from utils.database_async import credit_def_collection, user_credit_collection +from utils.doorman_cache_util import doorman_cache +from utils.encryption_util import decrypt_value, encrypt_value +from utils.paging_util import validate_page_params logger = logging.getLogger('doorman.gateway') -class CreditService: +class CreditService: @staticmethod - def _validate_credit_data(data: CreditModel) -> Optional[ResponseModel]: + def _validate_credit_data(data: CreditModel) -> ResponseModel | None: """Validate credit definition data before creation or update.""" if not data.api_credit_group: return ResponseModel( - status_code=400, - error_code='CRD009', - error_message='Credit group name is required' + status_code=400, error_code='CRD009', error_message='Credit group name is required' ) if not data.api_key or not data.api_key_header: return ResponseModel( status_code=400, error_code='CRD010', - error_message='API key and header are required' + error_message='API key and header are required', ) return None @@ -46,15 +44,21 @@ class CreditService: logger.info(request_id + ' | Creating credit definition') validation_error = CreditService._validate_credit_data(data) if validation_error: - logger.error(request_id + f' | Credit creation failed with code {validation_error.error_code}') + logger.error( + request_id + f' | Credit creation failed with code {validation_error.error_code}' + ) return validation_error.dict() try: - if doorman_cache.get_cache('credit_def_cache', data.api_credit_group) or await db_find_one(credit_def_collection, {'api_credit_group': data.api_credit_group}): + if doorman_cache.get_cache( + 'credit_def_cache', data.api_credit_group + ) or await db_find_one( + credit_def_collection, {'api_credit_group': data.api_credit_group} + ): logger.error(request_id + ' | Credit creation failed with code CRD001') return ResponseModel( status_code=400, error_code='CRD001', - error_message='Credit group already exists' + error_message='Credit group already exists', ).dict() credit_data = data.dict() if credit_data.get('api_key') is not None: @@ -67,7 +71,7 @@ class CreditService: return ResponseModel( status_code=400, error_code='CRD002', - error_message='Unable to insert credit definition' + error_message='Unable to insert credit definition', ).dict() credit_data['_id'] = str(insert_result.inserted_id) doorman_cache.set_cache('credit_def_cache', data.api_credit_group, credit_data) @@ -75,14 +79,14 @@ class CreditService: return ResponseModel( status_code=201, response_headers={'request_id': request_id}, - message='Credit definition created successfully' + message='Credit definition created successfully', ).dict() except PyMongoError as e: logger.error(request_id + f' | Credit creation failed with database error: {str(e)}') return ResponseModel( status_code=500, error_code='CRD011', - error_message='Database error occurred while creating credit definition' + error_message='Database error occurred while creating credit definition', ).dict() @staticmethod @@ -91,7 +95,9 @@ class CreditService: logger.info(request_id + ' | Updating credit definition') validation_error = CreditService._validate_credit_data(data) if validation_error: - logger.error(request_id + f' | Credit update failed with code {validation_error.error_code}') + logger.error( + request_id + f' | Credit update failed with code {validation_error.error_code}' + ) return validation_error.dict() try: if data.api_credit_group and data.api_credit_group != api_credit_group: @@ -99,17 +105,19 @@ class CreditService: return ResponseModel( status_code=400, error_code='CRD003', - error_message='Credit group name cannot be updated' + error_message='Credit group name cannot be updated', ).dict() doc = doorman_cache.get_cache('credit_def_cache', api_credit_group) if not doc: - doc = await db_find_one(credit_def_collection, {'api_credit_group': api_credit_group}) + doc = await db_find_one( + credit_def_collection, {'api_credit_group': api_credit_group} + ) if not doc: logger.error(request_id + ' | Credit update failed with code CRD004') return ResponseModel( status_code=400, error_code='CRD004', - error_message='Credit definition does not exist for the requested group' + error_message='Credit definition does not exist for the requested group', ).dict() else: doorman_cache.delete_cache('credit_def_cache', api_credit_group) @@ -119,22 +127,34 @@ class CreditService: if 'api_key_new' in not_null: not_null['api_key_new'] = encrypt_value(not_null['api_key_new']) if not_null: - update_result = await db_update_one(credit_def_collection, {'api_credit_group': api_credit_group}, {'$set': not_null}) + update_result = await db_update_one( + credit_def_collection, + {'api_credit_group': api_credit_group}, + {'$set': not_null}, + ) if not update_result.acknowledged or update_result.modified_count == 0: logger.error(request_id + ' | Credit update failed with code CRD005') return ResponseModel( status_code=400, error_code='CRD005', - error_message='Unable to update credit definition' + error_message='Unable to update credit definition', ).dict() logger.info(request_id + ' | Credit update successful') - return ResponseModel(status_code=200, message='Credit definition updated successfully').dict() + return ResponseModel( + status_code=200, message='Credit definition updated successfully' + ).dict() else: logger.error(request_id + ' | Credit update failed with code CRD006') - return ResponseModel(status_code=400, error_code='CRD006', error_message='No data to update').dict() + return ResponseModel( + status_code=400, error_code='CRD006', error_message='No data to update' + ).dict() except PyMongoError as e: logger.error(request_id + f' | Credit update failed with database error: {str(e)}') - return ResponseModel(status_code=500, error_code='CRD012', error_message='Database error occurred while updating credit definition').dict() + return ResponseModel( + status_code=500, + error_code='CRD012', + error_message='Database error occurred while updating credit definition', + ).dict() @staticmethod async def delete_credit(api_credit_group: str, request_id): @@ -143,21 +163,39 @@ class CreditService: try: doc = doorman_cache.get_cache('credit_def_cache', api_credit_group) if not doc: - doc = await db_find_one(credit_def_collection, {'api_credit_group': api_credit_group}) + doc = await db_find_one( + credit_def_collection, {'api_credit_group': api_credit_group} + ) if not doc: logger.error(request_id + ' | Credit deletion failed with code CRD007') - return ResponseModel(status_code=400, error_code='CRD007', error_message='Credit definition does not exist for the requested group').dict() + return ResponseModel( + status_code=400, + error_code='CRD007', + error_message='Credit definition does not exist for the requested group', + ).dict() else: doorman_cache.delete_cache('credit_def_cache', api_credit_group) - delete_result = await db_delete_one(credit_def_collection, {'api_credit_group': api_credit_group}) + delete_result = await db_delete_one( + credit_def_collection, {'api_credit_group': api_credit_group} + ) if not delete_result.acknowledged or delete_result.deleted_count == 0: logger.error(request_id + ' | Credit deletion failed with code CRD008') - return ResponseModel(status_code=400, error_code='CRD008', error_message='Unable to delete credit definition').dict() + return ResponseModel( + status_code=400, + error_code='CRD008', + error_message='Unable to delete credit definition', + ).dict() logger.info(request_id + ' | Credit deletion successful') - return ResponseModel(status_code=200, message='Credit definition deleted successfully').dict() + return ResponseModel( + status_code=200, message='Credit definition deleted successfully' + ).dict() except PyMongoError as e: logger.error(request_id + f' | Credit deletion failed with database error: {str(e)}') - return ResponseModel(status_code=500, error_code='CRD013', error_message='Database error occurred while deleting credit definition').dict() + return ResponseModel( + status_code=500, + error_code='CRD013', + error_message='Database error occurred while deleting credit definition', + ).dict() @staticmethod async def list_credit_defs(page: int, page_size: int, request_id): @@ -170,7 +208,11 @@ class CreditService: return ResponseModel( status_code=400, error_code=ErrorCodes.PAGE_SIZE, - error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING) + error_message=( + Messages.PAGE_TOO_LARGE + if 'page_size' in str(e) + else Messages.INVALID_PAGING + ), ).dict() all_defs = await db_find_list(credit_def_collection, {}) all_defs.sort(key=lambda d: d.get('api_credit_group')) @@ -180,25 +222,35 @@ class CreditService: for doc in all_defs[start:end]: if doc.get('_id'): del doc['_id'] - items.append({ - 'api_credit_group': doc.get('api_credit_group'), - 'api_key_header': doc.get('api_key_header'), - 'api_key_present': bool(doc.get('api_key')), - 'credit_tiers': doc.get('credit_tiers', []), - }) + items.append( + { + 'api_credit_group': doc.get('api_credit_group'), + 'api_key_header': doc.get('api_key_header'), + 'api_key_present': bool(doc.get('api_key')), + 'credit_tiers': doc.get('credit_tiers', []), + } + ) return ResponseModel(status_code=200, response={'items': items}).dict() except PyMongoError as e: logger.error(request_id + f' | Credit list failed with database error: {str(e)}') - return ResponseModel(status_code=500, error_code='CRD020', error_message='Database error occurred while listing credit definitions').dict() + return ResponseModel( + status_code=500, + error_code='CRD020', + error_message='Database error occurred while listing credit definitions', + ).dict() @staticmethod async def get_credit_def(api_credit_group: str, request_id): """Get a single credit definition (masked).""" logger.info(request_id + ' | Getting credit definition') try: - doc = credit_def_collection.find_one({'api_credit_group': api_credit_group}) + doc = await db_find_one(credit_def_collection, {'api_credit_group': api_credit_group}) if not doc: - return ResponseModel(status_code=404, error_code='CRD021', error_message='Credit definition not found').dict() + return ResponseModel( + status_code=404, + error_code='CRD021', + error_message='Credit definition not found', + ).dict() if doc.get('_id'): del doc['_id'] masked = { @@ -210,7 +262,11 @@ class CreditService: return ResponseModel(status_code=200, response=masked).dict() except PyMongoError as e: logger.error(request_id + f' | Credit fetch failed with database error: {str(e)}') - return ResponseModel(status_code=500, error_code='CRD022', error_message='Database error occurred while retrieving credit definition').dict() + return ResponseModel( + status_code=500, + error_code='CRD022', + error_message='Database error occurred while retrieving credit definition', + ).dict() @staticmethod async def add_credits(username: str, data: UserCreditModel, request_id): @@ -218,7 +274,11 @@ class CreditService: logger.info(request_id + f' | Adding credits for user: {username}') try: if data.username and data.username != username: - return ResponseModel(status_code=400, error_code='CRD014', error_message='Username in body does not match path').dict() + return ResponseModel( + status_code=400, + error_code='CRD014', + error_message='Username in body does not match path', + ).dict() doc = await db_find_one(user_credit_collection, {'username': username}) users_credits = data.users_credits or {} secured = {} @@ -229,13 +289,21 @@ class CreditService: secured[group] = info payload = {'username': username, 'users_credits': secured} if doc: - await db_update_one(user_credit_collection, {'username': username}, {'$set': {'users_credits': secured}}) + await db_update_one( + user_credit_collection, + {'username': username}, + {'$set': {'users_credits': secured}}, + ) else: await db_insert_one(user_credit_collection, payload) return ResponseModel(status_code=200, message='Credits saved successfully').dict() except PyMongoError as e: logger.error(request_id + f' | Add credits failed with database error: {str(e)}') - return ResponseModel(status_code=500, error_code='CRD015', error_message='Database error occurred while saving user credits').dict() + return ResponseModel( + status_code=500, + error_code='CRD015', + error_message='Database error occurred while saving user credits', + ).dict() @staticmethod async def get_all_credits(page: int, page_size: int, request_id, search: str = ''): @@ -247,7 +315,11 @@ class CreditService: return ResponseModel( status_code=400, error_code=ErrorCodes.PAGE_SIZE, - error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING) + error_message=( + Messages.PAGE_TOO_LARGE + if 'page_size' in str(e) + else Messages.INVALID_PAGING + ), ).dict() cursor = user_credit_collection.find().sort('username', 1) @@ -271,7 +343,7 @@ class CreditService: if it.get('_id'): del it['_id'] uc = it.get('users_credits') or {} - for g, info in uc.items(): + for _g, info in uc.items(): if isinstance(info, dict) and 'user_api_key' in info: dec = decrypt_value(info.get('user_api_key')) if dec is not None: @@ -279,7 +351,11 @@ class CreditService: return ResponseModel(status_code=200, response={'user_credits': items}).dict() except PyMongoError as e: logger.error(request_id + f' | Get all credits failed with database error: {str(e)}') - return ResponseModel(status_code=500, error_code='CRD016', error_message='Database error occurred while retrieving credits').dict() + return ResponseModel( + status_code=500, + error_code='CRD016', + error_message='Database error occurred while retrieving credits', + ).dict() @staticmethod async def get_user_credits(username: str, request_id): @@ -287,11 +363,13 @@ class CreditService: try: doc = await db_find_one(user_credit_collection, {'username': username}) if not doc: - return ResponseModel(status_code=404, error_code='CRD017', error_message='User credits not found').dict() + return ResponseModel( + status_code=404, error_code='CRD017', error_message='User credits not found' + ).dict() if doc.get('_id'): del doc['_id'] uc = doc.get('users_credits') or {} - for g, info in uc.items(): + for _g, info in uc.items(): if isinstance(info, dict) and 'user_api_key' in info: dec = decrypt_value(info.get('user_api_key')) if dec is not None: @@ -299,7 +377,11 @@ class CreditService: return ResponseModel(status_code=200, response=doc).dict() except PyMongoError as e: logger.error(request_id + f' | Get user credits failed with database error: {str(e)}') - return ResponseModel(status_code=500, error_code='CRD018', error_message='Database error occurred while retrieving user credits').dict() + return ResponseModel( + status_code=500, + error_code='CRD018', + error_message='Database error occurred while retrieving user credits', + ).dict() @staticmethod async def rotate_api_key(username: str, group: str, request_id): @@ -312,31 +394,45 @@ class CreditService: # But maybe we can create the credit entry if it's missing. # Let's error if not found for now, or create empty. doc = {'username': username, 'users_credits': {}} - + users_credits = doc.get('users_credits') or {} group_credits = users_credits.get(group) or {} - + # Generate new key new_key = secrets.token_urlsafe(32) encrypted_key = encrypt_value(new_key) - + # Update group credits # Preserve other fields in group_credits (like available_credits) if isinstance(group_credits, dict): group_credits['user_api_key'] = encrypted_key else: # Should be a dict, but if it was somehow not, reset it - group_credits = {'user_api_key': encrypted_key, 'available_credits': 0, 'tier_name': 'default'} + group_credits = { + 'user_api_key': encrypted_key, + 'available_credits': 0, + 'tier_name': 'default', + } users_credits[group] = group_credits - + if doc.get('_id'): - await db_update_one(user_credit_collection, {'username': username}, {'$set': {'users_credits': users_credits}}) + await db_update_one( + user_credit_collection, + {'username': username}, + {'$set': {'users_credits': users_credits}}, + ) else: - await db_insert_one(user_credit_collection, {'username': username, 'users_credits': users_credits}) - + await db_insert_one( + user_credit_collection, {'username': username, 'users_credits': users_credits} + ) + return ResponseModel(status_code=200, response={'api_key': new_key}).dict() - + except PyMongoError as e: logger.error(request_id + f' | Rotate key failed with database error: {str(e)}') - return ResponseModel(status_code=500, error_code='CRD019', error_message='Database error occurred while rotating API key').dict() + return ResponseModel( + status_code=500, + error_code='CRD019', + error_message='Database error occurred while rotating API key', + ).dict() diff --git a/backend-services/services/endpoint_service.py b/backend-services/services/endpoint_service.py index 2b2672e..0f0b746 100644 --- a/backend-services/services/endpoint_service.py +++ b/backend-services/services/endpoint_service.py @@ -4,61 +4,75 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -import uuid import logging -import os import string as _string +import uuid from pathlib import Path +from models.create_endpoint_model import CreateEndpointModel from models.create_endpoint_validation_model import CreateEndpointValidationModel from models.response_model import ResponseModel from models.update_endpoint_model import UpdateEndpointModel from models.update_endpoint_validation_model import UpdateEndpointValidationModel -from utils.database import endpoint_collection, api_collection, endpoint_validation_collection -from utils.cache_manager_util import cache_manager +from utils.database import api_collection, endpoint_collection, endpoint_validation_collection from utils.doorman_cache_util import doorman_cache -from models.create_endpoint_model import CreateEndpointModel logger = logging.getLogger('doorman.gateway') -class EndpointService: +class EndpointService: @staticmethod async def create_endpoint(data: CreateEndpointModel, request_id): """ Create an endpoint for an API. """ - logger.info(request_id + ' | Creating endpoint: ' + data.api_name + ' ' + data.api_version + ' ' + data.endpoint_uri) - cache_key = f'/{data.endpoint_method}/{data.api_name}/{data.api_version}/{data.endpoint_uri}'.replace('//', '/') - if doorman_cache.get_cache('endpoint_cache', cache_key) or endpoint_collection.find_one({ - 'endpoint_method': data.endpoint_method, - 'api_name': data.api_name, - 'api_version': data.api_version, - 'endpoint_uri': data.endpoint_uri - }): + logger.info( + request_id + + ' | Creating endpoint: ' + + data.api_name + + ' ' + + data.api_version + + ' ' + + data.endpoint_uri + ) + cache_key = f'/{data.endpoint_method}/{data.api_name}/{data.api_version}/{data.endpoint_uri}'.replace( + '//', '/' + ) + if doorman_cache.get_cache('endpoint_cache', cache_key) or endpoint_collection.find_one( + { + 'endpoint_method': data.endpoint_method, + 'api_name': data.api_name, + 'api_version': data.api_version, + 'endpoint_uri': data.endpoint_uri, + } + ): logger.error(request_id + ' | Endpoint creation failed with code END001') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='END001', - error_message='Endpoint already exists for the requested API name, version and URI' + error_message='Endpoint already exists for the requested API name, version and URI', ).dict() # Resolve API ID from cache using canonical key "/{name}/{version}" - data.api_id = doorman_cache.get_cache('api_id_cache', f'/{data.api_name}/{data.api_version}') + data.api_id = doorman_cache.get_cache( + 'api_id_cache', f'/{data.api_name}/{data.api_version}' + ) if not data.api_id: - api = api_collection.find_one({'api_name': data.api_name, 'api_version': data.api_version}) + api = api_collection.find_one( + {'api_name': data.api_name, 'api_version': data.api_version} + ) if not api: logger.error(request_id + ' | Endpoint creation failed with code END002') return ResponseModel( status_code=400, error_code='END002', - error_message='API does not exist for the requested name and version' + error_message='API does not exist for the requested name and version', ).dict() data.api_id = api.get('api_id') # Ensure cache uses the same canonical key with leading slash - doorman_cache.set_cache('api_id_cache', f'/{data.api_name}/{data.api_version}', data.api_id) + doorman_cache.set_cache( + 'api_id_cache', f'/{data.api_name}/{data.api_version}', data.api_id + ) data.endpoint_id = str(uuid.uuid4()) endpoint_dict = data.dict() insert_result = endpoint_collection.insert_one(endpoint_dict) @@ -66,21 +80,25 @@ class EndpointService: logger.error(request_id + ' | Endpoint creation failed with code END003') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='END003', - error_message='Unable to insert endpoint' + error_message='Unable to insert endpoint', ).dict() endpoint_dict['_id'] = str(insert_result.inserted_id) doorman_cache.set_cache('endpoint_cache', cache_key, endpoint_dict) api_endpoints = doorman_cache.get_cache('api_endpoint_cache', data.api_id) or list() - api_endpoints.append(endpoint_dict.get('endpoint_method') + endpoint_dict.get('endpoint_uri')) + api_endpoints.append( + endpoint_dict.get('endpoint_method') + endpoint_dict.get('endpoint_uri') + ) doorman_cache.set_cache('api_endpoint_cache', data.api_id, api_endpoints) logger.info(request_id + ' | Endpoint creation successful') try: - if data.endpoint_method.upper() == 'POST' and str(data.endpoint_uri).strip().lower() == '/grpc': + if ( + data.endpoint_method.upper() == 'POST' + and str(data.endpoint_uri).strip().lower() == '/grpc' + ): from grpc_tools import protoc as _protoc + api_name = data.api_name api_version = data.api_version module_base = f'{api_name}_{api_version}'.replace('-', '_') @@ -118,11 +136,19 @@ class EndpointService: 'message DeleteReply { bool ok = 1; }\n' ) proto_path.write_text(proto_content, encoding='utf-8') - code = _protoc.main([ - 'protoc', f'--proto_path={str(proto_dir)}', f'--python_out={str(generated_dir)}', f'--grpc_python_out={str(generated_dir)}', str(proto_path) - ]) + code = _protoc.main( + [ + 'protoc', + f'--proto_path={str(proto_dir)}', + f'--python_out={str(generated_dir)}', + f'--grpc_python_out={str(generated_dir)}', + str(proto_path), + ] + ) if code != 0: - logger.warning(f'{request_id} | Pre-gen gRPC stubs returned {code} for {module_base}') + logger.warning( + f'{request_id} | Pre-gen gRPC stubs returned {code} for {module_base}' + ) try: init_path = generated_dir / '__init__.py' if not init_path.exists(): @@ -133,76 +159,81 @@ class EndpointService: logger.debug(f'{request_id} | Skipping pre-gen gRPC stubs: {_e}') return ResponseModel( status_code=201, - response_headers={ - 'request_id': request_id - }, - message='Endpoint created successfully' + response_headers={'request_id': request_id}, + message='Endpoint created successfully', ).dict() @staticmethod - async def update_endpoint(endpoint_method, api_name, api_version, endpoint_uri, data: UpdateEndpointModel, request_id): - logger.info(request_id + ' | Updating endpoint: ' + api_name + ' ' + api_version + ' ' + endpoint_uri) + async def update_endpoint( + endpoint_method, api_name, api_version, endpoint_uri, data: UpdateEndpointModel, request_id + ): + logger.info( + request_id + + ' | Updating endpoint: ' + + api_name + + ' ' + + api_version + + ' ' + + endpoint_uri + ) cache_key = f'/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}'.replace('//', '/') endpoint = doorman_cache.get_cache('endpoint_cache', cache_key) if not endpoint: - endpoint = endpoint_collection.find_one({ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_uri': endpoint_uri, - 'endpoint_method': endpoint_method - }) + endpoint = endpoint_collection.find_one( + { + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_uri': endpoint_uri, + 'endpoint_method': endpoint_method, + } + ) logger.error(request_id + ' | Endpoint update failed with code END008') if not endpoint: return ResponseModel( status_code=400, error_code='END008', - error_message='Endpoint does not exist for the requested API name, version and URI' + error_message='Endpoint does not exist for the requested API name, version and URI', ).dict() else: doorman_cache.delete_cache('endpoint_cache', cache_key) - if (data.endpoint_method and data.endpoint_method != endpoint.get('endpoint_method')) or (data.api_name and data.api_name != endpoint.get('api_name')) or (data.api_version and data.api_version != endpoint.get('api_version')) or (data.endpoint_uri and data.endpoint_uri != endpoint.get('endpoint_uri')): + if ( + (data.endpoint_method and data.endpoint_method != endpoint.get('endpoint_method')) + or (data.api_name and data.api_name != endpoint.get('api_name')) + or (data.api_version and data.api_version != endpoint.get('api_version')) + or (data.endpoint_uri and data.endpoint_uri != endpoint.get('endpoint_uri')) + ): logger.error(request_id + ' | Endpoint update failed with code END006') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='END006', - error_message='API method, name, version and URI cannot be updated' + error_message='API method, name, version and URI cannot be updated', ).dict() not_null_data = {k: v for k, v in data.dict().items() if v is not None} if not_null_data: - update_result = endpoint_collection.update_one({ + update_result = endpoint_collection.update_one( + { 'api_name': api_name, 'api_version': api_version, 'endpoint_uri': endpoint_uri, - 'endpoint_method': endpoint_method + 'endpoint_method': endpoint_method, }, - { - '$set': not_null_data - } + {'$set': not_null_data}, ) if not update_result.acknowledged or update_result.modified_count == 0: logger.error(request_id + ' | Endpoint update failed with code END003') return ResponseModel( - status_code=400, - error_code='END003', - error_message='Unable to update endpoint' + status_code=400, error_code='END003', error_message='Unable to update endpoint' ).dict() logger.info(request_id + ' | Endpoint update successful') - return ResponseModel( - status_code=200, - message='Endpoint updated successfully' - ).dict() + return ResponseModel(status_code=200, message='Endpoint updated successfully').dict() else: logger.error(request_id + ' | Endpoint update failed with code END007') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='END007', - error_message='No data to update' + error_message='No data to update', ).dict() @staticmethod @@ -210,38 +241,42 @@ class EndpointService: """ Delete an endpoint for an API. """ - logger.info(request_id + ' | Deleting: ' + api_name + ' ' + api_version + ' ' + endpoint_uri) + logger.info( + request_id + ' | Deleting: ' + api_name + ' ' + api_version + ' ' + endpoint_uri + ) cache_key = f'/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}'.replace('//', '/') endpoint = doorman_cache.get_cache('endpoint_cache', cache_key) if not endpoint: - endpoint = endpoint_collection.find_one({ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_uri': endpoint_uri, - 'endpoint_method': endpoint_method - }) + endpoint = endpoint_collection.find_one( + { + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_uri': endpoint_uri, + 'endpoint_method': endpoint_method, + } + ) if not endpoint: logger.error(request_id + ' | Endpoint deletion failed with code END004') return ResponseModel( status_code=400, error_code='END004', - error_message='Endpoint does not exist for the requested API name, version and URI' + error_message='Endpoint does not exist for the requested API name, version and URI', ).dict() - delete_result = endpoint_collection.delete_one({ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_uri': endpoint_uri, - 'endpoint_method': endpoint_method - }) + delete_result = endpoint_collection.delete_one( + { + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_uri': endpoint_uri, + 'endpoint_method': endpoint_method, + } + ) if not delete_result.acknowledged: logger.error(request_id + ' | Endpoint deletion failed with code END009') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='END009', - error_message='Unable to delete endpoint' + error_message='Unable to delete endpoint', ).dict() doorman_cache.delete_cache('endpoint_cache', cache_key) try: @@ -253,10 +288,8 @@ class EndpointService: logger.info(request_id + ' | Endpoint deletion successful') return ResponseModel( status_code=200, - response_headers={ - 'request_id': request_id - }, - message='Endpoint deleted successfully' + response_headers={'request_id': request_id}, + message='Endpoint deleted successfully', ).dict() @staticmethod @@ -265,30 +298,34 @@ class EndpointService: Get an endpoint by API name, version and URI. """ logger.info(request_id + ' | Getting: ' + api_name + ' ' + api_version + ' ' + endpoint_uri) - endpoint = doorman_cache.get_cache('endpoint_cache', f'{api_name}/{api_version}/{endpoint_uri}') + endpoint = doorman_cache.get_cache( + 'endpoint_cache', f'{api_name}/{api_version}/{endpoint_uri}' + ) if not endpoint: - endpoint = endpoint_collection.find_one({ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_uri': endpoint_uri, - 'endpoint_method': endpoint_method - }) + endpoint = endpoint_collection.find_one( + { + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_uri': endpoint_uri, + 'endpoint_method': endpoint_method, + } + ) if not endpoint: logger.error(request_id + ' | Endpoint retrieval failed with code END004') return ResponseModel( status_code=400, error_code='END004', - error_message='Endpoint does not exist for the requested API name, version and URI' + error_message='Endpoint does not exist for the requested API name, version and URI', ).dict() - if endpoint.get('_id'): del endpoint['_id'] - doorman_cache.set_cache('endpoint_cache', f'{api_name}/{api_version}/{endpoint_uri}', endpoint) + if endpoint.get('_id'): + del endpoint['_id'] + doorman_cache.set_cache( + 'endpoint_cache', f'{api_name}/{api_version}/{endpoint_uri}', endpoint + ) if '_id' in endpoint: del endpoint['_id'] logger.info(request_id + ' | Endpoint retrieval successful') - return ResponseModel( - status_code=200, - response=endpoint - ).dict() + return ResponseModel(status_code=200, response=endpoint).dict() @staticmethod async def get_endpoints_by_name_version(api_name, api_version, request_id): @@ -296,32 +333,25 @@ class EndpointService: Get all endpoints by API name and version. """ logger.info(request_id + ' | Getting: ' + api_name + ' ' + api_version) - cursor = endpoint_collection.find({ - 'api_name': api_name, - 'api_version': api_version - }) + cursor = endpoint_collection.find({'api_name': api_name, 'api_version': api_version}) try: endpoints = list(cursor) except Exception: endpoints = await cursor.to_list(length=None) for endpoint in endpoints: - if '_id' in endpoint: del endpoint['_id'] + if '_id' in endpoint: + del endpoint['_id'] if not endpoints: logger.error(request_id + ' | Endpoint retrieval failed with code END005') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='END005', - error_message='No endpoints found for the requested API name and version' + error_message='No endpoints found for the requested API name and version', ).dict() logger.info(request_id + ' | Endpoint retrieval successful') - return ResponseModel( - status_code=200, - response={'endpoints': endpoints} - ).dict() + return ResponseModel(status_code=200, response={'endpoints': endpoints}).dict() @staticmethod async def create_endpoint_validation(data: CreateEndpointValidationModel, request_id): @@ -332,32 +362,24 @@ class EndpointService: if not data.endpoint_id: logger.error(request_id + ' | Endpoint ID is required') return ResponseModel( - status_code=400, - error_code='END013', - error_message='Endpoint ID is required' + status_code=400, error_code='END013', error_message='Endpoint ID is required' ).dict() if not data.validation_schema: logger.error(request_id + ' | Validation schema is required') return ResponseModel( - status_code=400, - error_code='END014', - error_message='Validation schema is required' + status_code=400, error_code='END014', error_message='Validation schema is required' ).dict() if doorman_cache.get_cache('endpoint_validation_cache', data.endpoint_id): logger.error(request_id + ' | Endpoint validation already exists') return ResponseModel( status_code=400, error_code='END017', - error_message='Endpoint validation already exists' + error_message='Endpoint validation already exists', ).dict() - if not endpoint_collection.find_one({ - 'endpoint_id': data.endpoint_id - }): + if not endpoint_collection.find_one({'endpoint_id': data.endpoint_id}): logger.error(request_id + ' | Endpoint does not exist') return ResponseModel( - status_code=400, - error_code='END015', - error_message='Endpoint does not exist' + status_code=400, error_code='END015', error_message='Endpoint does not exist' ).dict() validation_dict = data.dict() insert_result = endpoint_validation_collection.insert_one(validation_dict) @@ -366,13 +388,12 @@ class EndpointService: return ResponseModel( status_code=400, error_code='END016', - error_message='Unable to create endpoint validation' + error_message='Unable to create endpoint validation', ).dict() logger.info(request_id + ' | Endpoint validation created successfully') doorman_cache.set_cache('endpoint_validation_cache', f'{data.endpoint_id}', validation_dict) return ResponseModel( - status_code=201, - message='Endpoint validation created successfully' + status_code=201, message='Endpoint validation created successfully' ).dict() @staticmethod @@ -383,21 +404,18 @@ class EndpointService: logger.info(request_id + ' | Getting endpoint validation: ' + endpoint_id) validation = doorman_cache.get_cache('endpoint_validation_cache', endpoint_id) if not validation: - validation = endpoint_validation_collection.find_one({ - 'endpoint_id': endpoint_id - }) + validation = endpoint_validation_collection.find_one({'endpoint_id': endpoint_id}) if not validation: - logger.error(request_id + ' | Endpoint validation retrieval failed with code END018') + logger.error( + request_id + ' | Endpoint validation retrieval failed with code END018' + ) return ResponseModel( status_code=400, error_code='END018', - error_message='Endpoint validation does not exist' + error_message='Endpoint validation does not exist', ).dict() logger.info(request_id + ' | Endpoint validation retrieval successful') - return ResponseModel( - status_code=200, - response=validation - ).dict() + return ResponseModel(status_code=200, response=validation).dict() @staticmethod async def delete_endpoint_validation(endpoint_id, request_id): @@ -405,24 +423,23 @@ class EndpointService: Delete an endpoint validation by endpoint ID. """ logger.info(request_id + ' | Deleting endpoint validation: ' + endpoint_id) - delete_result = endpoint_validation_collection.delete_one({ - 'endpoint_id': endpoint_id - }) + delete_result = endpoint_validation_collection.delete_one({'endpoint_id': endpoint_id}) if not delete_result.acknowledged: logger.error(request_id + ' | Endpoint validation deletion failed with code END019') return ResponseModel( status_code=400, error_code='END019', - error_message='Unable to delete endpoint validation' + error_message='Unable to delete endpoint validation', ).dict() logger.info(request_id + ' | Endpoint validation deletion successful') return ResponseModel( - status_code=200, - message='Endpoint validation deleted successfully' + status_code=200, message='Endpoint validation deleted successfully' ).dict() @staticmethod - async def update_endpoint_validation(endpoint_id, data: UpdateEndpointValidationModel, request_id): + async def update_endpoint_validation( + endpoint_id, data: UpdateEndpointValidationModel, request_id + ): """ Update an endpoint validation by endpoint ID. """ @@ -430,39 +447,32 @@ class EndpointService: if not data.validation_enabled: logger.error(request_id + ' | Validation enabled is required') return ResponseModel( - status_code=400, - error_code='END020', - error_message='Validation enabled is required' + status_code=400, error_code='END020', error_message='Validation enabled is required' ).dict() if not data.validation_schema: logger.error(request_id + ' | Validation schema is required') return ResponseModel( - status_code=400, - error_code='END021', - error_message='Validation schema is required' + status_code=400, error_code='END021', error_message='Validation schema is required' ).dict() - if not endpoint_collection.find_one({ - 'endpoint_id': endpoint_id - }): + if not endpoint_collection.find_one({'endpoint_id': endpoint_id}): logger.error(request_id + ' | Endpoint does not exist') return ResponseModel( - status_code=400, - error_code='END022', - error_message='Endpoint does not exist' + status_code=400, error_code='END022', error_message='Endpoint does not exist' ).dict() - update_result = endpoint_validation_collection.update_one({ - 'endpoint_id': endpoint_id - }, { - '$set': { - 'validation_enabled': data.validation_enabled, - 'validation_schema': data.validation_schema - } - }) + update_result = endpoint_validation_collection.update_one( + {'endpoint_id': endpoint_id}, + { + '$set': { + 'validation_enabled': data.validation_enabled, + 'validation_schema': data.validation_schema, + } + }, + ) if not update_result.acknowledged: logger.error(request_id + ' | Endpoint validation update failed with code END023') return ResponseModel( status_code=400, error_code='END023', - error_message='Unable to update endpoint validation' + error_message='Unable to update endpoint validation', ).dict() logger.info(request_id + ' | Endpoint validation updated successfully') diff --git a/backend-services/services/gateway_service.py b/backend-services/services/gateway_service.py index 3fdad6f..b4687d6 100644 --- a/backend-services/services/gateway_service.py +++ b/backend-services/services/gateway_service.py @@ -4,56 +4,59 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ +import asyncio +import importlib +import json +import logging import os import random -import json -import sys -import xml.etree.ElementTree as ET -import logging import re -import time -import httpx -from typing import Dict -import grpc -import asyncio -from google.protobuf.json_format import MessageToDict -import importlib import string +import sys +import time +import xml.etree.ElementTree as ET from pathlib import Path +import grpc +import httpx +from google.protobuf.json_format import MessageToDict + try: from gql import Client as _GqlClient + def gql(q): return q except Exception: + class _GqlClient: def __init__(self, *args, **kwargs): pass + def gql(q): return q + Client = _GqlClient from models.response_model import ResponseModel -from utils import api_util, routing_util -from utils import credit_util -from utils.gateway_utils import get_headers +from utils import api_util, credit_util, routing_util from utils.doorman_cache_util import doorman_cache +from utils.gateway_utils import get_headers +from utils.http_client import CircuitOpenError, request_with_resilience from utils.validation_util import validation_util -from utils.http_client import request_with_resilience, CircuitOpenError logging.getLogger('gql').setLevel(logging.WARNING) logger = logging.getLogger('doorman.gateway') -class GatewayService: +class GatewayService: timeout = httpx.Timeout( - connect=float(os.getenv('HTTP_CONNECT_TIMEOUT', 5.0)), - read=float(os.getenv('HTTP_READ_TIMEOUT', 30.0)), - write=float(os.getenv('HTTP_WRITE_TIMEOUT', 30.0)), - pool=float(os.getenv('HTTP_TIMEOUT', 30.0)) - ) + connect=float(os.getenv('HTTP_CONNECT_TIMEOUT', 5.0)), + read=float(os.getenv('HTTP_READ_TIMEOUT', 30.0)), + write=float(os.getenv('HTTP_WRITE_TIMEOUT', 30.0)), + pool=float(os.getenv('HTTP_TIMEOUT', 30.0)), + ) _http_client: httpx.AsyncClient | None = None @staticmethod @@ -77,7 +80,9 @@ class GatewayService: expiry = float(os.getenv('HTTP_KEEPALIVE_EXPIRY', 30.0)) except Exception: expiry = 30.0 - return httpx.Limits(max_connections=max_conns, max_keepalive_connections=max_keep, keepalive_expiry=expiry) + return httpx.Limits( + max_connections=max_conns, max_keepalive_connections=max_keep, keepalive_expiry=expiry + ) @classmethod def get_http_client(cls) -> httpx.AsyncClient: @@ -91,13 +96,13 @@ class GatewayService: cls._http_client = httpx.AsyncClient( timeout=cls.timeout, limits=cls._build_limits(), - http2=(os.getenv('HTTP_ENABLE_HTTP2', 'false').lower() == 'true') + http2=(os.getenv('HTTP_ENABLE_HTTP2', 'false').lower() == 'true'), ) return cls._http_client return httpx.AsyncClient( timeout=cls.timeout, limits=cls._build_limits(), - http2=(os.getenv('HTTP_ENABLE_HTTP2', 'false').lower() == 'true') + http2=(os.getenv('HTTP_ENABLE_HTTP2', 'false').lower() == 'true'), ) @classmethod @@ -111,16 +116,18 @@ class GatewayService: cls._http_client = None def error_response(request_id, code, message, status=404): - logger.error(f'{request_id} | REST gateway failed with code {code}') - return ResponseModel( - status_code=status, - response_headers={'request_id': request_id}, - error_code=code, - error_message=message - ).dict() + logger.error(f'{request_id} | REST gateway failed with code {code}') + return ResponseModel( + status_code=status, + response_headers={'request_id': request_id}, + error_code=code, + error_message=message, + ).dict() @staticmethod - def _compute_api_cors_headers(api: dict, origin: str | None, req_method: str | None, req_headers: str | None): + def _compute_api_cors_headers( + api: dict, origin: str | None, req_method: str | None, req_headers: str | None + ): try: origin = (origin or '').strip() req_method = (req_method or '').strip().upper() @@ -128,13 +135,21 @@ class GatewayService: _ao = api.get('api_cors_allow_origins', None) # None => default '*', empty list => disallow all - allow_origins = (['*'] if _ao is None else list(_ao)) + allow_origins = ['*'] if _ao is None else list(_ao) _am = api.get('api_cors_allow_methods', None) - allow_methods = [m.strip().upper() for m in (_am if _am is not None else ['GET','POST','PUT','DELETE','PATCH','HEAD','OPTIONS']) if m] + allow_methods = [ + m.strip().upper() + for m in ( + _am + if _am is not None + else ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS'] + ) + if m + ] if 'OPTIONS' not in allow_methods: allow_methods.append('OPTIONS') _ah = api.get('api_cors_allow_headers', None) - allow_headers = (_ah if _ah is not None else ['*']) + allow_headers = _ah if _ah is not None else ['*'] allow_credentials = bool(api.get('api_cors_allow_credentials')) expose_headers = api.get('api_cors_expose_headers') or [] @@ -158,7 +173,10 @@ class GatewayService: o_scheme, o_host = 'https', origin if o_scheme != scheme: continue - if o_host.endswith(host_suffix) and o_host.count('.') >= host_suffix.count('.') + 1: + if ( + o_host.endswith(host_suffix) + and o_host.count('.') >= host_suffix.count('.') + 1 + ): origin_allowed = True break except Exception: @@ -205,7 +223,10 @@ class GatewayService: ctype = ctype_raw.split(';', 1)[0].strip().lower() body = getattr(response, 'content', b'') - if ctype in ('application/json', 'application/graphql+json') or 'application/graphql' in ctype: + if ( + ctype in ('application/json', 'application/graphql+json') + or 'application/graphql' in ctype + ): return json.loads(body) if ctype in ('application/xml', 'text/xml'): @@ -278,10 +299,7 @@ class GatewayService: if not await credit_util.deduct_credit(api.get('api_credit_group'), username): logger.warning(f'{request_id} | Credit deduction failed for user {username}') return GatewayService.error_response( - request_id, - 'GTW008', - 'User does not have any credits', - status=401 + request_id, 'GTW008', 'User does not have any credits', status=401 ) return None @@ -315,8 +333,7 @@ class GatewayService: if username and not bool(api.get('api_public')): user_specific_api_key = await credit_util.get_user_api_key( - api.get('api_credit_group'), - username + api.get('api_credit_group'), username ) if user_specific_api_key: headers[header_name] = user_specific_api_key @@ -349,8 +366,7 @@ class GatewayService: continue sanitized_key = ''.join( - c if c.isalnum() or c in ('-', '_', '.') else '-' - for c in key + c if c.isalnum() or c in ('-', '_', '.') else '-' for c in key ) if not sanitized_key: @@ -369,7 +385,7 @@ class GatewayService: return metadata_list - _IDENT_ALLOWED = set(string.ascii_letters + string.digits + "_") + _IDENT_ALLOWED = set(string.ascii_letters + string.digits + '_') _PROJECT_ROOT = Path(__file__).resolve().parent.parent @staticmethod @@ -389,7 +405,7 @@ class GatewayService: name = name.strip() if not name or len(name) > max_len: return False - if name[0] not in string.ascii_letters + "_": + if name[0] not in string.ascii_letters + '_': return False for ch in name: if ch not in GatewayService._IDENT_ALLOWED: @@ -410,9 +426,9 @@ class GatewayService: if not pkg: return None pkg = str(pkg).strip() - if "/" in pkg or "\\" in pkg or ".." in pkg: + if '/' in pkg or '\\' in pkg or '..' in pkg: return None - parts = pkg.split(".") if "." in pkg else [pkg] + parts = pkg.split('.') if '.' in pkg else [pkg] if any(not GatewayService._is_valid_identifier(p) for p in parts if p is not None): return None return pkg @@ -424,19 +440,24 @@ class GatewayService: return None try: method_fq = str(method_fq).strip() - if "." not in method_fq: + if '.' not in method_fq: return None - service, method = method_fq.split(".", 1) + service, method = method_fq.split('.', 1) service = service.strip() method = method.strip() - if not (GatewayService._is_valid_identifier(service) and GatewayService._is_valid_identifier(method)): + if not ( + GatewayService._is_valid_identifier(service) + and GatewayService._is_valid_identifier(method) + ): return None return service, method except Exception: return None @staticmethod - async def rest_gateway(username, request, request_id, start_time, path, url=None, method=None, retry=0): + async def rest_gateway( + username, request, request_id, start_time, path, url=None, method=None, retry=0 + ): """ External gateway. """ @@ -444,7 +465,6 @@ class GatewayService: current_time = backend_end_time = None try: if not url and not method: - parts = [p for p in (path or '').split('/') if p] api_name_version = '' endpoint_uri = '' @@ -454,22 +474,38 @@ class GatewayService: api_key = doorman_cache.get_cache('api_id_cache', api_name_version) api = await api_util.get_api(api_key, api_name_version) if not api: - return GatewayService.error_response(request_id, 'GTW001', 'API does not exist for the requested name and version') + return GatewayService.error_response( + request_id, + 'GTW001', + 'API does not exist for the requested name and version', + ) if api.get('active') is False: - return GatewayService.error_response(request_id, 'GTW012', 'API is disabled', status=403) + return GatewayService.error_response( + request_id, 'GTW012', 'API is disabled', status=403 + ) endpoints = await api_util.get_api_endpoints(api.get('api_id')) if not endpoints: - return GatewayService.error_response(request_id, 'GTW002', 'No endpoints found for the requested API') + return GatewayService.error_response( + request_id, 'GTW002', 'No endpoints found for the requested API' + ) regex_pattern = re.compile(r'\{[^/]+\}') match_method = 'GET' if str(request.method).upper() == 'HEAD' else request.method composite = match_method + '/' + endpoint_uri - if not any(re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in endpoints): + if not any( + re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in endpoints + ): logger.error(f'{endpoints} | REST gateway failed with code GTW003') - return GatewayService.error_response(request_id, 'GTW003', 'Endpoint does not exist for the requested API') + return GatewayService.error_response( + request_id, 'GTW003', 'Endpoint does not exist for the requested API' + ) client_key = request.headers.get('client-key') - server = await routing_util.pick_upstream_server(api, request.method, endpoint_uri, client_key) + server = await routing_util.pick_upstream_server( + api, request.method, endpoint_uri, client_key + ) if not server: - return GatewayService.error_response(request_id, 'GTW001', 'No upstream servers configured') + return GatewayService.error_response( + request_id, 'GTW001', 'No upstream servers configured' + ) logger.info(f'{request_id} | REST gateway to: {server}') url = server.rstrip('/') + '/' + endpoint_uri.lstrip('/') method = request.method.upper() @@ -477,7 +513,9 @@ class GatewayService: if api.get('api_credits_enabled') and username and not bool(api.get('api_public')): if not await credit_util.deduct_credit(api.get('api_credit_group'), username): - return GatewayService.error_response(request_id, 'GTW008', 'User does not have any credits', status=401) + return GatewayService.error_response( + request_id, 'GTW008', 'User does not have any credits', status=401 + ) else: try: parts = [p for p in (path or '').split('/') if p] @@ -501,12 +539,16 @@ class GatewayService: headers['X-User-Email'] = str(username) headers['X-Doorman-User'] = str(username) if api and api.get('api_credits_enabled'): - ai_token_headers = await credit_util.get_credit_api_header(api.get('api_credit_group')) + ai_token_headers = await credit_util.get_credit_api_header( + api.get('api_credit_group') + ) if ai_token_headers: headers[ai_token_headers[0]] = ai_token_headers[1] if username and not bool(api.get('api_public')): - user_specific_api_key = await credit_util.get_user_api_key(api.get('api_credit_group'), username) + user_specific_api_key = await credit_util.get_user_api_key( + api.get('api_credit_group'), username + ) if user_specific_api_key: headers[ai_token_headers[0]] = user_specific_api_key content_type = request.headers.get('Content-Type', '').upper() @@ -516,11 +558,17 @@ class GatewayService: swap_from = api.get('api_authorization_field_swap') source_val = None if swap_from: - for key_variant in (swap_from, str(swap_from).lower(), str(swap_from).title()): + for key_variant in ( + swap_from, + str(swap_from).lower(), + str(swap_from).title(), + ): if key_variant in headers: source_val = headers.get(key_variant) break - orig_auth = request.headers.get('Authorization') or request.headers.get('authorization') + orig_auth = request.headers.get('Authorization') or request.headers.get( + 'authorization' + ) if source_val is not None and str(source_val).strip() != '': headers['Authorization'] = source_val elif orig_auth is not None and str(orig_auth).strip() != '': @@ -530,7 +578,11 @@ class GatewayService: try: lookup_method = 'GET' if str(method).upper() == 'HEAD' else method - endpoint_doc = await api_util.get_endpoint(api, lookup_method, '/' + endpoint_uri.lstrip('/')) if api else None + endpoint_doc = ( + await api_util.get_endpoint(api, lookup_method, '/' + endpoint_uri.lstrip('/')) + if api + else None + ) endpoint_id = endpoint_doc.get('endpoint_id') if endpoint_doc else None if endpoint_id: if 'JSON' in content_type: @@ -546,24 +598,36 @@ class GatewayService: try: if method == 'GET': http_response = await request_with_resilience( - client, 'GET', url, + client, + 'GET', + url, api_key=api.get('api_path') if api else (api_name_version or '/api/rest'), - headers=headers, params=query_params, + headers=headers, + params=query_params, retries=retry, api_config=api, ) elif method == 'HEAD': http_response = await request_with_resilience( - client, 'HEAD', url, + client, + 'HEAD', + url, api_key=api.get('api_path') if api else (api_name_version or '/api/rest'), - headers=headers, params=query_params, + headers=headers, + params=query_params, retries=retry, api_config=api, ) elif method in ('POST', 'PUT', 'DELETE', 'PATCH'): - cl_header = request.headers.get('content-length') or request.headers.get('Content-Length') + cl_header = request.headers.get('content-length') or request.headers.get( + 'Content-Length' + ) try: - content_length = int(cl_header) if cl_header is not None and str(cl_header).strip() != '' else 0 + content_length = ( + int(cl_header) + if cl_header is not None and str(cl_header).strip() != '' + else 0 + ) except Exception: content_length = 0 @@ -571,31 +635,50 @@ class GatewayService: if 'JSON' in content_type: body = await request.json() http_response = await request_with_resilience( - client, method, url, - api_key=api.get('api_path') if api else (api_name_version or '/api/rest'), - headers=headers, params=query_params, json=body, + client, + method, + url, + api_key=api.get('api_path') + if api + else (api_name_version or '/api/rest'), + headers=headers, + params=query_params, + json=body, retries=retry, api_config=api, ) else: body = await request.body() http_response = await request_with_resilience( - client, method, url, - api_key=api.get('api_path') if api else (api_name_version or '/api/rest'), - headers=headers, params=query_params, content=body, + client, + method, + url, + api_key=api.get('api_path') + if api + else (api_name_version or '/api/rest'), + headers=headers, + params=query_params, + content=body, retries=retry, api_config=api, ) else: http_response = await request_with_resilience( - client, method, url, - api_key=api.get('api_path') if api else (api_name_version or '/api/rest'), - headers=headers, params=query_params, + client, + method, + url, + api_key=api.get('api_path') + if api + else (api_name_version or '/api/rest'), + headers=headers, + params=query_params, retries=retry, api_config=api, ) else: - return GatewayService.error_response(request_id, 'GTW004', 'Method not supported', status=405) + return GatewayService.error_response( + request_id, 'GTW004', 'Method not supported', status=405 + ) finally: if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'true').lower() == 'false': try: @@ -615,13 +698,15 @@ class GatewayService: status_code=500, response_headers={'request_id': request_id}, error_code='GTW006', - error_message='Malformed JSON from upstream' + error_message='Malformed JSON from upstream', ).dict() else: response_content = http_response.text backend_end_time = time.time() * 1000 if http_response.status_code == 404: - return GatewayService.error_response(request_id, 'GTW005', 'Endpoint does not exist in backend service') + return GatewayService.error_response( + request_id, 'GTW005', 'Endpoint does not exist in backend service' + ) logger.info(f'{request_id} | REST gateway status code: {http_response.status_code}') response_headers = {'request_id': request_id} allowed_lower = {h.lower() for h in (allowed_headers or [])} @@ -645,29 +730,33 @@ class GatewayService: return ResponseModel( status_code=http_response.status_code, response_headers=response_headers, - response=response_content + response=response_content, ).dict() except CircuitOpenError: return ResponseModel( status_code=503, response_headers={'request_id': request_id}, error_code='GTW999', - error_message='Upstream circuit open' + error_message='Upstream circuit open', ).dict() except httpx.TimeoutException: try: - metrics_store.record_upstream_timeout('rest:' + (api.get('api_path') if api else (api_name_version or '/api/rest'))) + metrics_store.record_upstream_timeout( + 'rest:' + (api.get('api_path') if api else (api_name_version or '/api/rest')) + ) except Exception: pass return ResponseModel( status_code=504, response_headers={'request_id': request_id}, error_code='GTW010', - error_message='Gateway timeout' + error_message='Gateway timeout', ).dict() except Exception: logger.error(f'{request_id} | REST gateway failed with code GTW006', exc_info=True) - return GatewayService.error_response(request_id, 'GTW006', 'Internal server error', status=500) + return GatewayService.error_response( + request_id, 'GTW006', 'Internal server error', status=500 + ) finally: if current_time: logger.info(f'{request_id} | Gateway time {current_time - start_time}ms') @@ -683,7 +772,6 @@ class GatewayService: current_time = backend_end_time = None try: if not url: - parts = [p for p in (path or '').split('/') if p] api_name_version = '' endpoint_uri = '' @@ -693,27 +781,45 @@ class GatewayService: api_key = doorman_cache.get_cache('api_id_cache', api_name_version) api = await api_util.get_api(api_key, api_name_version) if not api: - return GatewayService.error_response(request_id, 'GTW001', 'API does not exist for the requested name and version') + return GatewayService.error_response( + request_id, + 'GTW001', + 'API does not exist for the requested name and version', + ) if api.get('active') is False: - return GatewayService.error_response(request_id, 'GTW012', 'API is disabled', status=403) + return GatewayService.error_response( + request_id, 'GTW012', 'API is disabled', status=403 + ) endpoints = await api_util.get_api_endpoints(api.get('api_id')) logger.info(f'{request_id} | SOAP gateway endpoints: {endpoints}') if not endpoints: - return GatewayService.error_response(request_id, 'GTW002', 'No endpoints found for the requested API') + return GatewayService.error_response( + request_id, 'GTW002', 'No endpoints found for the requested API' + ) regex_pattern = re.compile(r'\{[^/]+\}') composite = 'POST/' + endpoint_uri - if not any(re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in endpoints): - return GatewayService.error_response(request_id, 'GTW003', 'Endpoint does not exist for the requested API') + if not any( + re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in endpoints + ): + return GatewayService.error_response( + request_id, 'GTW003', 'Endpoint does not exist for the requested API' + ) client_key = request.headers.get('client-key') - server = await routing_util.pick_upstream_server(api, 'POST', endpoint_uri, client_key) + server = await routing_util.pick_upstream_server( + api, 'POST', endpoint_uri, client_key + ) if not server: - return GatewayService.error_response(request_id, 'GTW001', 'No upstream servers configured') + return GatewayService.error_response( + request_id, 'GTW001', 'No upstream servers configured' + ) url = server.rstrip('/') + '/' + endpoint_uri.lstrip('/') logger.info(f'{request_id} | SOAP gateway to: {url}') retry = api.get('api_allowed_retry_count') or 0 if api.get('api_credits_enabled') and username and not bool(api.get('api_public')): if not await credit_util.deduct_credit(api.get('api_credit_group'), username): - return GatewayService.error_response(request_id, 'GTW008', 'User does not have any credits', status=401) + return GatewayService.error_response( + request_id, 'GTW008', 'User does not have any credits', status=401 + ) else: try: parts = [p for p in (path or '').split('/') if p] @@ -748,11 +854,17 @@ class GatewayService: swap_from = api.get('api_authorization_field_swap') source_val = None if swap_from: - for key_variant in (swap_from, str(swap_from).lower(), str(swap_from).title()): + for key_variant in ( + swap_from, + str(swap_from).lower(), + str(swap_from).title(), + ): if key_variant in headers: source_val = headers.get(key_variant) break - orig_auth = request.headers.get('Authorization') or request.headers.get('authorization') + orig_auth = request.headers.get('Authorization') or request.headers.get( + 'authorization' + ) if source_val is not None and str(source_val).strip() != '': headers['Authorization'] = source_val elif orig_auth is not None and str(orig_auth).strip() != '': @@ -761,7 +873,11 @@ class GatewayService: pass try: - endpoint_doc = await api_util.get_endpoint(api, 'POST', '/' + endpoint_uri.lstrip('/')) if api else None + endpoint_doc = ( + await api_util.get_endpoint(api, 'POST', '/' + endpoint_uri.lstrip('/')) + if api + else None + ) endpoint_id = endpoint_doc.get('endpoint_id') if endpoint_doc else None if endpoint_id: await validation_util.validate_soap_request(endpoint_id, envelope) @@ -771,9 +887,13 @@ class GatewayService: client = GatewayService.get_http_client() try: http_response = await request_with_resilience( - client, 'POST', url, + client, + 'POST', + url, api_key=api.get('api_path') if api else (api_name_version or '/api/soap'), - headers=headers, params=query_params, content=envelope, + headers=headers, + params=query_params, + content=envelope, retries=retry, api_config=api, ) @@ -787,7 +907,9 @@ class GatewayService: logger.info(f'{request_id} | SOAP gateway response: {response_content}') backend_end_time = time.time() * 1000 if http_response.status_code == 404: - return GatewayService.error_response(request_id, 'GTW005', 'Endpoint does not exist in backend service') + return GatewayService.error_response( + request_id, 'GTW005', 'Endpoint does not exist in backend service' + ) logger.info(f'{request_id} | SOAP gateway status code: {http_response.status_code}') response_headers = {'request_id': request_id} allowed_lower = {h.lower() for h in (allowed_headers or [])} @@ -811,29 +933,33 @@ class GatewayService: return ResponseModel( status_code=http_response.status_code, response_headers=response_headers, - response=response_content + response=response_content, ).dict() except CircuitOpenError: return ResponseModel( status_code=503, response_headers={'request_id': request_id}, error_code='GTW999', - error_message='Upstream circuit open' + error_message='Upstream circuit open', ).dict() except httpx.TimeoutException: try: - metrics_store.record_upstream_timeout('soap:' + (api.get('api_path') if api else '/api/soap')) + metrics_store.record_upstream_timeout( + 'soap:' + (api.get('api_path') if api else '/api/soap') + ) except Exception: pass return ResponseModel( status_code=504, response_headers={'request_id': request_id}, error_code='GTW010', - error_message='Gateway timeout' + error_message='Gateway timeout', ).dict() except Exception: logger.error(f'{request_id} | SOAP gateway failed with code GTW006') - return GatewayService.error_response(request_id, 'GTW006', 'Internal server error', status=500) + return GatewayService.error_response( + request_id, 'GTW006', 'Internal server error', status=500 + ) finally: if current_time: logger.info(f'{request_id} | Gateway time {current_time - start_time}ms') @@ -853,15 +979,21 @@ class GatewayService: if not api: api = await api_util.get_api(None, api_path) if not api: - logger.error(f'{request_id} | API not found: {api_path}') - return GatewayService.error_response(request_id, 'GTW001', f'API does not exist: {api_path}') + logger.error(f'{request_id} | API not found: {api_path}') + return GatewayService.error_response( + request_id, 'GTW001', f'API does not exist: {api_path}' + ) if api.get('active') is False: - return GatewayService.error_response(request_id, 'GTW012', 'API is disabled', status=403) + return GatewayService.error_response( + request_id, 'GTW012', 'API is disabled', status=403 + ) doorman_cache.set_cache('api_cache', api_path, api) retry = api.get('api_allowed_retry_count') or 0 if api.get('api_credits_enabled') and username and not bool(api.get('api_public')): if not await credit_util.deduct_credit(api.get('api_credit_group'), username): - return GatewayService.error_response(request_id, 'GTW008', 'User does not have any credits', status=401) + return GatewayService.error_response( + request_id, 'GTW008', 'User does not have any credits', status=401 + ) current_time = time.time() * 1000 allowed_headers = api.get('api_allowed_headers') or [] headers = await get_headers(request, allowed_headers) @@ -869,11 +1001,15 @@ class GatewayService: headers['Content-Type'] = 'application/json' headers['Accept'] = 'application/json' if api.get('api_credits_enabled'): - ai_token_headers = await credit_util.get_credit_api_header(api.get('api_credit_group')) + ai_token_headers = await credit_util.get_credit_api_header( + api.get('api_credit_group') + ) if ai_token_headers: headers[ai_token_headers[0]] = ai_token_headers[1] if username and not bool(api.get('api_public')): - user_specific_api_key = await credit_util.get_user_api_key(api.get('api_credit_group'), username) + user_specific_api_key = await credit_util.get_user_api_key( + api.get('api_credit_group'), username + ) if user_specific_api_key: headers[ai_token_headers[0]] = user_specific_api_key if api.get('api_authorization_field_swap'): @@ -881,11 +1017,17 @@ class GatewayService: swap_from = api.get('api_authorization_field_swap') source_val = None if swap_from: - for key_variant in (swap_from, str(swap_from).lower(), str(swap_from).title()): + for key_variant in ( + swap_from, + str(swap_from).lower(), + str(swap_from).title(), + ): if key_variant in headers: source_val = headers.get(key_variant) break - orig_auth = request.headers.get('Authorization') or request.headers.get('authorization') + orig_auth = request.headers.get('Authorization') or request.headers.get( + 'authorization' + ) if source_val is not None and str(source_val).strip() != '': headers['Authorization'] = source_val elif orig_auth is not None and str(orig_auth).strip() != '': @@ -910,19 +1052,27 @@ class GatewayService: async with Client(transport=None, fetch_schema_from_transport=False) as session: result = await session.execute(gql(query), variable_values=variables) except Exception as _e: - logger.debug(f'{request_id} | GraphQL Client execution failed; falling back to HTTP: {_e}') + logger.debug( + f'{request_id} | GraphQL Client execution failed; falling back to HTTP: {_e}' + ) if result is None: client_key = request.headers.get('client-key') - server = await routing_util.pick_upstream_server(api, 'POST', '/graphql', client_key) + server = await routing_util.pick_upstream_server( + api, 'POST', '/graphql', client_key + ) if not server: logger.error(f'{request_id} | No upstream servers configured for {api_path}') - return GatewayService.error_response(request_id, 'GTW001', 'No upstream servers configured') + return GatewayService.error_response( + request_id, 'GTW001', 'No upstream servers configured' + ) url = server.rstrip('/') + '/graphql' client = GatewayService.get_http_client() try: http_resp = await request_with_resilience( - client, 'POST', url, + client, + 'POST', + url, api_key=api_path, headers=headers, json={'query': query, 'variables': variables}, @@ -931,24 +1081,29 @@ class GatewayService: ) except AttributeError: http_resp = await client.post( - url, - json={'query': query, 'variables': variables}, - headers=headers, + url, json={'query': query, 'variables': variables}, headers=headers ) try: data = http_resp.json() except Exception as je: data = { - 'errors': [{ - 'message': f'Invalid JSON from upstream: {str(je)}', - 'extensions': {'code': 'BAD_RESPONSE'} - }]} + 'errors': [ + { + 'message': f'Invalid JSON from upstream: {str(je)}', + 'extensions': {'code': 'BAD_RESPONSE'}, + } + ] + } status = getattr(http_resp, 'status_code', 200) if status != 200 and 'errors' not in data: - data = {'errors': [{ - 'message': data.get('message') or f'HTTP {status}', - 'extensions': {'code': f'HTTP_{status}'} - }]} + data = { + 'errors': [ + { + 'message': data.get('message') or f'HTTP {status}', + 'extensions': {'code': f'HTTP_{status}'}, + } + ] + } result = data backend_end_time = time.time() * 1000 @@ -972,24 +1127,28 @@ class GatewayService: response_headers['X-Backend-Time'] = str(int(backend_end_time - current_time)) except Exception: pass - return ResponseModel(status_code=200, response_headers=response_headers, response=result).dict() + return ResponseModel( + status_code=200, response_headers=response_headers, response=result + ).dict() except CircuitOpenError: return ResponseModel( status_code=503, response_headers={'request_id': request_id}, error_code='GTW999', - error_message='Upstream circuit open' + error_message='Upstream circuit open', ).dict() except httpx.TimeoutException: try: - metrics_store.record_upstream_timeout('graphql:' + (api.get('api_path') if api else '/api/graphql')) + metrics_store.record_upstream_timeout( + 'graphql:' + (api.get('api_path') if api else '/api/graphql') + ) except Exception: pass return ResponseModel( status_code=504, response_headers={'request_id': request_id}, error_code='GTW010', - error_message='Gateway timeout' + error_message='Gateway timeout', ).dict() except Exception as e: logger.error(f'{request_id} | GraphQL gateway failed with code GTW006: {str(e)}') @@ -1002,7 +1161,9 @@ class GatewayService: logger.info(f'{request_id} | Backend time {backend_end_time - current_time}ms') @staticmethod - async def grpc_gateway(username, request, request_id, start_time, path, api_name=None, url=None, retry=0): + async def grpc_gateway( + username, request, request_id, start_time, path, api_name=None, url=None, retry=0 + ): logger.info(f'{request_id} | gRPC gateway processing request') current_time = backend_end_time = None try: @@ -1011,7 +1172,9 @@ class GatewayService: path_parts = path.strip('/').split('/') if len(path_parts) < 1: logger.error(f'{request_id} | Invalid API path format: {path}') - return GatewayService.error_response(request_id, 'GTW001', 'Invalid API path format', status=404) + return GatewayService.error_response( + request_id, 'GTW001', 'Invalid API path format', status=404 + ) api_name = path_parts[-1] api_version = request.headers.get('X-API-Version', 'v1') api_path = f'{api_name}/{api_version}' @@ -1021,19 +1184,33 @@ class GatewayService: body = await request.json() if not isinstance(body, dict): logger.error(f'{request_id} | Invalid request body format') - return GatewayService.error_response(request_id, 'GTW011', 'Invalid request body format', status=400) + return GatewayService.error_response( + request_id, 'GTW011', 'Invalid request body format', status=400 + ) except json.JSONDecodeError: logger.error(f'{request_id} | Invalid JSON in request body') - return GatewayService.error_response(request_id, 'GTW011', 'Invalid JSON in request body', status=400) + return GatewayService.error_response( + request_id, 'GTW011', 'Invalid JSON in request body', status=400 + ) parsed = GatewayService._parse_and_validate_method(body.get('method')) if not parsed: - return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.', status=400) + return GatewayService.error_response( + request_id, + 'GTW011', + 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.', + status=400, + ) _service_name_preview, _method_name_preview = parsed pkg_override_raw = (body.get('package') or '').strip() if pkg_override_raw: if GatewayService._validate_package_name(pkg_override_raw) is None: - return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC package. Use letters, digits, underscore only.', status=400) + return GatewayService.error_response( + request_id, + 'GTW011', + 'Invalid gRPC package. Use letters, digits, underscore only.', + status=400, + ) api = doorman_cache.get_cache('api_cache', api_path) if not api: @@ -1043,25 +1220,35 @@ class GatewayService: endpoint_doc = await api_util.get_endpoint(api, 'POST', '/grpc') endpoint_id = endpoint_doc.get('endpoint_id') if endpoint_doc else None if endpoint_id: - await validation_util.validate_grpc_request(endpoint_id, body.get('message')) + await validation_util.validate_grpc_request( + endpoint_id, body.get('message') + ) except Exception as e: - return GatewayService.error_response(request_id, 'GTW011', str(e), status=400) + return GatewayService.error_response( + request_id, 'GTW011', str(e), status=400 + ) api_pkg_raw = None try: api_pkg_raw = (api.get('api_grpc_package') or '').strip() if api else None except Exception: api_pkg_raw = None pkg_override = (body.get('package') or '').strip() or None - api_pkg = GatewayService._validate_package_name(api_pkg_raw) if api_pkg_raw else None - pkg_override_valid = GatewayService._validate_package_name(pkg_override) if pkg_override else None + api_pkg = ( + GatewayService._validate_package_name(api_pkg_raw) if api_pkg_raw else None + ) + pkg_override_valid = ( + GatewayService._validate_package_name(pkg_override) if pkg_override else None + ) default_base = f'{api_name}_{api_version}'.replace('-', '_') if not GatewayService._is_valid_identifier(default_base): - default_base = ''.join(ch if ch in GatewayService._IDENT_ALLOWED else '_' for ch in default_base) - module_base = (api_pkg or pkg_override_valid or default_base) + default_base = ''.join( + ch if ch in GatewayService._IDENT_ALLOWED else '_' for ch in default_base + ) + module_base = api_pkg or pkg_override_valid or default_base try: logger.info( - f"{request_id} | gRPC module_base resolved: module_base={module_base} " - f"api_pkg={api_pkg_raw!r} pkg_override={pkg_override_raw!r} default_base={default_base}" + f'{request_id} | gRPC module_base resolved: module_base={module_base} ' + f'api_pkg={api_pkg_raw!r} pkg_override={pkg_override_raw!r} default_base={default_base}' ) except Exception: pass @@ -1084,7 +1271,7 @@ class GatewayService: request_id, 'GTW013', 'gRPC service not allowed', status=403 ) if allowed_methods and isinstance(allowed_methods, list): - method_fq = f"{service_name}.{method_name}" + method_fq = f'{service_name}.{method_name}' if method_fq not in allowed_methods: return GatewayService.error_response( request_id, 'GTW013', 'gRPC method not allowed', status=403 @@ -1100,9 +1287,16 @@ class GatewayService: proto_path = (proto_dir / proto_rel.with_suffix('.proto')).resolve() # Validate resolved path stays within project bounds if not GatewayService._validate_under_base(project_root, proto_path): - return GatewayService.error_response(request_id, 'GTW012', 'Invalid path for proto resolution', status=400) + return GatewayService.error_response( + request_id, 'GTW012', 'Invalid path for proto resolution', status=400 + ) if not GatewayService._validate_under_base(proto_dir, proto_path): - return GatewayService.error_response(request_id, 'GTW012', 'Proto path must be within proto directory', status=400) + return GatewayService.error_response( + request_id, + 'GTW012', + 'Proto path must be within proto directory', + status=400, + ) generated_dir = project_root / 'generated' gen_dir_str = str(generated_dir) @@ -1112,7 +1306,9 @@ class GatewayService: if gen_dir_str not in sys.path: sys.path.insert(0, gen_dir_str) try: - logger.info(f"{request_id} | sys.path updated for gRPC import. project_root={proj_root_str}, generated_dir={gen_dir_str}") + logger.info( + f'{request_id} | sys.path updated for gRPC import. project_root={proj_root_str}, generated_dir={gen_dir_str}' + ) except Exception: pass @@ -1129,11 +1325,17 @@ class GatewayService: gen_pb2_grpc_name = f'generated.{module_base}_pb2_grpc' pb2 = importlib.import_module(gen_pb2_name) pb2_grpc = importlib.import_module(gen_pb2_grpc_name) - logger.info(f"{request_id} | Successfully imported gRPC modules: {pb2.__name__} and {pb2_grpc.__name__}") + logger.info( + f'{request_id} | Successfully imported gRPC modules: {pb2.__name__} and {pb2_grpc.__name__}' + ) except ModuleNotFoundError as mnf_exc: - logger.warning(f"{request_id} | gRPC modules not found, will attempt proto generation: {str(mnf_exc)}") + logger.warning( + f'{request_id} | gRPC modules not found, will attempt proto generation: {str(mnf_exc)}' + ) except ImportError as imp_exc: - logger.error(f"{request_id} | ImportError loading gRPC modules (likely broken import in generated file): {str(imp_exc)}") + logger.error( + f'{request_id} | ImportError loading gRPC modules (likely broken import in generated file): {str(imp_exc)}' + ) mod_pb2 = f'{module_base}_pb2' mod_pb2_grpc = f'{module_base}_pb2_grpc' if mod_pb2 in sys.modules: @@ -1144,15 +1346,17 @@ class GatewayService: request_id, 'GTW012', f'Failed to import gRPC modules. Proto files may need regeneration. Error: {str(imp_exc)[:100]}', - status=404 + status=404, ) except Exception as import_exc: - logger.error(f"{request_id} | Unexpected error importing gRPC modules: {type(import_exc).__name__}: {str(import_exc)}") + logger.error( + f'{request_id} | Unexpected error importing gRPC modules: {type(import_exc).__name__}: {str(import_exc)}' + ) return GatewayService.error_response( request_id, 'GTW012', f'Unexpected error importing gRPC modules: {type(import_exc).__name__}', - status=500 + status=500, ) if pb2 is None or pb2_grpc is None: @@ -1161,7 +1365,9 @@ class GatewayService: try: pb2_check_path = (generated_dir / (module_base + '_pb2.py')).resolve() if GatewayService._validate_under_base(generated_dir, pb2_check_path): - logger.info(f"{request_id} | gRPC generated check: proto_path={proto_path} exists={proto_path.exists()} generated_dir={generated_dir} pb2={module_base}_pb2.py={pb2_check_path.exists()}") + logger.info( + f'{request_id} | gRPC generated check: proto_path={proto_path} exists={proto_path.exists()} generated_dir={generated_dir} pb2={module_base}_pb2.py={pb2_check_path.exists()}' + ) except Exception: pass method_fq = body.get('method', '') @@ -1196,44 +1402,71 @@ class GatewayService: generated_dir.mkdir(exist_ok=True) try: from grpc_tools import protoc as _protoc - code = _protoc.main([ - 'protoc', f'--proto_path={str(proto_dir)}', f'--python_out={str(generated_dir)}', f'--grpc_python_out={str(generated_dir)}', str(proto_path) - ]) + + code = _protoc.main( + [ + 'protoc', + f'--proto_path={str(proto_dir)}', + f'--python_out={str(generated_dir)}', + f'--grpc_python_out={str(generated_dir)}', + str(proto_path), + ] + ) if code != 0: raise RuntimeError(f'protoc returned {code}') init_path = generated_dir / '__init__.py' if not init_path.exists(): - init_path.write_text('"""Generated gRPC code."""\n', encoding='utf-8') + init_path.write_text( + '"""Generated gRPC code."""\n', encoding='utf-8' + ) except Exception as ge: logger.error(f'{request_id} | On-demand proto generation failed: {ge}') if os.getenv('DOORMAN_TEST_MODE', '').lower() == 'true': pb2 = type('PB2', (), {}) pb2_grpc = type('SVC', (), {}) else: - return GatewayService.error_response(request_id, 'GTW012', f'Proto file not found for API: {api_path}', status=404) + return GatewayService.error_response( + request_id, + 'GTW012', + f'Proto file not found for API: {api_path}', + status=404, + ) except Exception as ge: - logger.error(f'{request_id} | Proto file not found and generation skipped: {ge}') + logger.error( + f'{request_id} | Proto file not found and generation skipped: {ge}' + ) if os.getenv('DOORMAN_TEST_MODE', '').lower() != 'true': - return GatewayService.error_response(request_id, 'GTW012', f'Proto file not found for API: {api_path}', status=404) + return GatewayService.error_response( + request_id, + 'GTW012', + f'Proto file not found for API: {api_path}', + status=404, + ) api = doorman_cache.get_cache('api_cache', api_path) if not api: api = await api_util.get_api(None, api_path) if not api: logger.error(f'{request_id} | API not found: {api_path}') - return GatewayService.error_response(request_id, 'GTW001', f'API does not exist: {api_path}', status=404) + return GatewayService.error_response( + request_id, 'GTW001', f'API does not exist: {api_path}', status=404 + ) doorman_cache.set_cache('api_cache', api_path, api) client_key = request.headers.get('client-key') server = await routing_util.pick_upstream_server(api, 'POST', '/grpc', client_key) if not server: logger.error(f'{request_id} | No upstream servers configured for {api_path}') - return GatewayService.error_response(request_id, 'GTW001', 'No upstream servers configured', status=404) + return GatewayService.error_response( + request_id, 'GTW001', 'No upstream servers configured', status=404 + ) url = server.rstrip('/') if url.startswith('grpc://'): url = url[7:] retry = api.get('api_allowed_retry_count') or 0 if api.get('api_credits_enabled') and username and not bool(api.get('api_public')): if not await credit_util.deduct_credit(api.get('api_credit_group'), username): - return GatewayService.error_response(request_id, 'GTW008', 'User does not have any credits', status=401) + return GatewayService.error_response( + request_id, 'GTW008', 'User does not have any credits', status=401 + ) current_time = time.time() * 1000 try: if not url: @@ -1245,24 +1478,41 @@ class GatewayService: api_name = path_parts[-1] if path_parts else None if api_name: api_path = f'{api_name}/{api_version}' - api = doorman_cache.get_cache('api_cache', api_path) or await api_util.get_api(None, api_path) + api = doorman_cache.get_cache( + 'api_cache', api_path + ) or await api_util.get_api(None, api_path) try: - api_pkg_raw = (api.get('api_grpc_package') or '').strip() if api else None + api_pkg_raw = ( + (api.get('api_grpc_package') or '').strip() if api else None + ) except Exception: api_pkg_raw = None pkg_override = (body.get('package') or '').strip() or None - api_pkg = GatewayService._validate_package_name(api_pkg_raw) if api_pkg_raw else None - pkg_override_valid = GatewayService._validate_package_name(pkg_override) if pkg_override else None + api_pkg = ( + GatewayService._validate_package_name(api_pkg_raw) + if api_pkg_raw + else None + ) + pkg_override_valid = ( + GatewayService._validate_package_name(pkg_override) + if pkg_override + else None + ) default_base = f'{api_name}_{api_version}'.replace('-', '_') if not GatewayService._is_valid_identifier(default_base): - default_base = ''.join(ch if ch in GatewayService._IDENT_ALLOWED else '_' for ch in default_base) - module_base = (api_pkg or pkg_override_valid or default_base) + default_base = ''.join( + ch if ch in GatewayService._IDENT_ALLOWED else '_' + for ch in default_base + ) + module_base = api_pkg or pkg_override_valid or default_base try: allowed_pkgs = api.get('api_grpc_allowed_packages') if api else None allowed_svcs = api.get('api_grpc_allowed_services') if api else None allowed_methods = api.get('api_grpc_allowed_methods') if api else None - prev_parsed = GatewayService._parse_and_validate_method(body.get('method')) + prev_parsed = GatewayService._parse_and_validate_method( + body.get('method') + ) if prev_parsed: svc_name, mth_name = prev_parsed else: @@ -1278,8 +1528,13 @@ class GatewayService: return GatewayService.error_response( request_id, 'GTW013', 'gRPC service not allowed', status=403 ) - if svc_name and mth_name and allowed_methods and isinstance(allowed_methods, list): - if f"{svc_name}.{mth_name}" not in allowed_methods: + if ( + svc_name + and mth_name + and allowed_methods + and isinstance(allowed_methods, list) + ): + if f'{svc_name}.{mth_name}' not in allowed_methods: return GatewayService.error_response( request_id, 'GTW013', 'gRPC method not allowed', status=403 ) @@ -1296,35 +1551,69 @@ class GatewayService: body = await request.json() if not isinstance(body, dict): logger.error(f'{request_id} | Invalid request body format') - return GatewayService.error_response(request_id, 'GTW011', 'Invalid request body format', status=400) + return GatewayService.error_response( + request_id, 'GTW011', 'Invalid request body format', status=400 + ) except json.JSONDecodeError: logger.error(f'{request_id} | Invalid JSON in request body') - return GatewayService.error_response(request_id, 'GTW011', 'Invalid JSON in request body', status=400) + return GatewayService.error_response( + request_id, 'GTW011', 'Invalid JSON in request body', status=400 + ) if 'method' not in body: logger.error(f'{request_id} | Missing method in request body') - return GatewayService.error_response(request_id, 'GTW011', 'Missing method in request body', status=400) + return GatewayService.error_response( + request_id, 'GTW011', 'Missing method in request body', status=400 + ) if 'message' not in body: logger.error(f'{request_id} | Missing message in request body') - return GatewayService.error_response(request_id, 'GTW011', 'Missing message in request body', status=400) + return GatewayService.error_response( + request_id, 'GTW011', 'Missing message in request body', status=400 + ) parsed_method = GatewayService._parse_and_validate_method(body.get('method')) if not parsed_method: - return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.', status=400) + return GatewayService.error_response( + request_id, + 'GTW011', + 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.', + status=400, + ) pkg_override = (body.get('package') or '').strip() or None if pkg_override and GatewayService._validate_package_name(pkg_override) is None: - return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC package. Use letters, digits, underscore only.', status=400) + return GatewayService.error_response( + request_id, + 'GTW011', + 'Invalid gRPC package. Use letters, digits, underscore only.', + status=400, + ) try: svc_name, mth_name = parsed_method allowed_pkgs = api.get('api_grpc_allowed_packages') if api else None allowed_svcs = api.get('api_grpc_allowed_services') if api else None allowed_methods = api.get('api_grpc_allowed_methods') if api else None - if allowed_pkgs and isinstance(allowed_pkgs, list) and module_base not in allowed_pkgs: - return GatewayService.error_response(request_id, 'GTW013', 'gRPC package not allowed', status=403) + if ( + allowed_pkgs + and isinstance(allowed_pkgs, list) + and module_base not in allowed_pkgs + ): + return GatewayService.error_response( + request_id, 'GTW013', 'gRPC package not allowed', status=403 + ) if allowed_svcs and isinstance(allowed_svcs, list) and svc_name not in allowed_svcs: - return GatewayService.error_response(request_id, 'GTW013', 'gRPC service not allowed', status=403) - if allowed_methods and isinstance(allowed_methods, list) and f"{svc_name}.{mth_name}" not in allowed_methods: - return GatewayService.error_response(request_id, 'GTW013', 'gRPC method not allowed', status=403) + return GatewayService.error_response( + request_id, 'GTW013', 'gRPC service not allowed', status=403 + ) + if ( + allowed_methods + and isinstance(allowed_methods, list) + and f'{svc_name}.{mth_name}' not in allowed_methods + ): + return GatewayService.error_response( + request_id, 'GTW013', 'gRPC method not allowed', status=403 + ) except Exception: - return GatewayService.error_response(request_id, 'GTW013', 'gRPC target not allowed', status=403) + return GatewayService.error_response( + request_id, 'GTW013', 'gRPC target not allowed', status=403 + ) proto_rel = Path(module_base.replace('.', '/')) proto_filename = f'{proto_rel.name}.proto' @@ -1335,11 +1624,11 @@ class GatewayService: await validation_util.validate_grpc_request(endpoint_id, body.get('message')) except Exception as e: return GatewayService.error_response(request_id, 'GTW011', str(e), status=400) - proto_path = (GatewayService._PROJECT_ROOT / 'proto' / proto_rel.with_suffix('.proto')) + proto_path = GatewayService._PROJECT_ROOT / 'proto' / proto_rel.with_suffix('.proto') use_imported = False try: if 'pb2' in locals() and 'pb2_grpc' in locals(): - use_imported = (pb2 is not None and pb2_grpc is not None) + use_imported = pb2 is not None and pb2_grpc is not None except Exception: use_imported = False module_name = module_base @@ -1351,7 +1640,9 @@ class GatewayService: if gen_dir_str not in sys.path: sys.path.insert(0, gen_dir_str) try: - logger.info(f"{request_id} | sys.path prepared for import: project_root={proj_root_str}, generated_dir={gen_dir_str}") + logger.info( + f'{request_id} | sys.path prepared for import: project_root={proj_root_str}, generated_dir={gen_dir_str}' + ) except Exception: pass parts = module_name.split('.') if '.' in module_name else [module_name] @@ -1361,7 +1652,7 @@ class GatewayService: if use_imported: pb2_module = pb2 service_module = pb2_grpc - logger.info(f"{request_id} | Using imported gRPC modules for {module_name}") + logger.info(f'{request_id} | Using imported gRPC modules for {module_name}') else: if not proto_path.exists(): if os.getenv('DOORMAN_TEST_MODE', '').lower() == 'true': @@ -1371,89 +1662,159 @@ class GatewayService: use_imported = True except Exception: logger.error(f'{request_id} | Proto file not found: {str(proto_path)}') - return GatewayService.error_response(request_id, 'GTW012', f'Proto file not found for API: {api_path}', status=404) + return GatewayService.error_response( + request_id, + 'GTW012', + f'Proto file not found for API: {api_path}', + status=404, + ) if not use_imported: - pb2_path = (package_dir / f"{parts[-1]}_pb2.py").resolve() - pb2_grpc_path = (package_dir / f"{parts[-1]}_pb2_grpc.py").resolve() + pb2_path = (package_dir / f'{parts[-1]}_pb2.py').resolve() + pb2_grpc_path = (package_dir / f'{parts[-1]}_pb2_grpc.py').resolve() # Validate paths before checking existence - if not GatewayService._validate_under_base(generated_dir, pb2_path) or not GatewayService._validate_under_base(generated_dir, pb2_grpc_path): - logger.error(f"{request_id} | Invalid path for generated modules: pb2={pb2_path} pb2_grpc={pb2_grpc_path}") - return GatewayService.error_response(request_id, 'GTW012', 'Invalid generated module path', status=400) + if not GatewayService._validate_under_base( + generated_dir, pb2_path + ) or not GatewayService._validate_under_base(generated_dir, pb2_grpc_path): + logger.error( + f'{request_id} | Invalid path for generated modules: pb2={pb2_path} pb2_grpc={pb2_grpc_path}' + ) + return GatewayService.error_response( + request_id, 'GTW012', 'Invalid generated module path', status=400 + ) if not (pb2_path.is_file() and pb2_grpc_path.is_file()): - logger.error(f"{request_id} | Generated modules not found for '{module_name}' pb2={pb2_path} exists={pb2_path.is_file()} pb2_grpc={pb2_grpc_path} exists={pb2_grpc_path.is_file()}") + logger.error( + f"{request_id} | Generated modules not found for '{module_name}' pb2={pb2_path} exists={pb2_path.is_file()} pb2_grpc={pb2_grpc_path} exists={pb2_grpc_path.is_file()}" + ) if isinstance(url, str) and url.startswith(('http://', 'https://')): try: client = GatewayService.get_http_client() http_url = url.rstrip('/') + '/grpc' - http_response = await client.post(http_url, json=body, headers=headers) + http_response = await client.post( + http_url, json=body, headers=headers + ) finally: - if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'false').lower() != 'true': + if ( + os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'false').lower() + != 'true' + ): try: await client.aclose() except Exception: pass if http_response.status_code == 404: - return GatewayService.error_response(request_id, 'GTW005', 'Endpoint does not exist in backend service') + return GatewayService.error_response( + request_id, + 'GTW005', + 'Endpoint does not exist in backend service', + ) response_headers = {'request_id': request_id} try: if current_time and start_time: - response_headers['X-Gateway-Time'] = str(int(current_time - start_time)) + response_headers['X-Gateway-Time'] = str( + int(current_time - start_time) + ) except Exception: pass return ResponseModel( status_code=http_response.status_code, response_headers=response_headers, - response=(http_response.json() if http_response.headers.get('Content-Type','').startswith('application/json') else http_response.text) + response=( + http_response.json() + if http_response.headers.get('Content-Type', '').startswith( + 'application/json' + ) + else http_response.text + ), ).dict() - return GatewayService.error_response(request_id, 'GTW012', f'Generated gRPC modules not found for package: {module_name}', status=404) + return GatewayService.error_response( + request_id, + 'GTW012', + f'Generated gRPC modules not found for package: {module_name}', + status=404, + ) if not use_imported: try: if GatewayService._validate_package_name(module_name) is None: - return GatewayService.error_response(request_id, 'GTW012', 'Invalid gRPC module name', status=400) + return GatewayService.error_response( + request_id, 'GTW012', 'Invalid gRPC module name', status=400 + ) import_name_pb2 = f'{module_name}_pb2' import_name_grpc = f'{module_name}_pb2_grpc' - logger.info(f"{request_id} | Importing generated modules: {import_name_pb2} and {import_name_grpc}") + logger.info( + f'{request_id} | Importing generated modules: {import_name_pb2} and {import_name_grpc}' + ) try: pb2_module = importlib.import_module(import_name_pb2) service_module = importlib.import_module(import_name_grpc) except ModuleNotFoundError: alt_pb2 = f'generated.{module_name}_pb2' alt_grpc = f'generated.{module_name}_pb2_grpc' - logger.info(f"{request_id} | Retrying import via generated package: {alt_pb2} and {alt_grpc}") + logger.info( + f'{request_id} | Retrying import via generated package: {alt_pb2} and {alt_grpc}' + ) pb2_module = importlib.import_module(alt_pb2) service_module = importlib.import_module(alt_grpc) except ImportError as e: - logger.error(f'{request_id} | Failed to import gRPC module: {str(e)}', exc_info=True) + logger.error( + f'{request_id} | Failed to import gRPC module: {str(e)}', exc_info=True + ) if isinstance(url, str) and url.startswith(('http://', 'https://')): try: client = GatewayService.get_http_client() http_url = url.rstrip('/') + '/grpc' - http_response = await client.post(http_url, json=body, headers=headers) + http_response = await client.post( + http_url, json=body, headers=headers + ) finally: - if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'false').lower() != 'true': + if ( + os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'false').lower() + != 'true' + ): try: await client.aclose() except Exception: pass if http_response.status_code == 404: - return GatewayService.error_response(request_id, 'GTW005', 'Endpoint does not exist in backend service') + return GatewayService.error_response( + request_id, + 'GTW005', + 'Endpoint does not exist in backend service', + ) response_headers = {'request_id': request_id} try: if current_time and start_time: - response_headers['X-Gateway-Time'] = str(int(current_time - start_time)) + response_headers['X-Gateway-Time'] = str( + int(current_time - start_time) + ) except Exception: pass return ResponseModel( status_code=http_response.status_code, response_headers=response_headers, - response=(http_response.json() if http_response.headers.get('Content-Type','').startswith('application/json') else http_response.text) + response=( + http_response.json() + if http_response.headers.get('Content-Type', '').startswith( + 'application/json' + ) + else http_response.text + ), ).dict() - return GatewayService.error_response(request_id, 'GTW012', f'Failed to import gRPC module: {str(e)}', status=404) + return GatewayService.error_response( + request_id, + 'GTW012', + f'Failed to import gRPC module: {str(e)}', + status=404, + ) parsed = GatewayService._parse_and_validate_method(body.get('method')) if not parsed: - return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.', status=400) + return GatewayService.error_response( + request_id, + 'GTW011', + 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.', + status=400, + ) service_name, method_name = parsed - if isinstance(url, str) and url.startswith(("http://", "https://")): + if isinstance(url, str) and url.startswith(('http://', 'https://')): try: client = GatewayService.get_http_client() http_url = url.rstrip('/') + '/grpc' @@ -1465,7 +1826,9 @@ class GatewayService: except Exception: pass if http_response.status_code == 404: - return GatewayService.error_response(request_id, 'GTW005', 'Endpoint does not exist in backend service') + return GatewayService.error_response( + request_id, 'GTW005', 'Endpoint does not exist in backend service' + ) response_headers = {'request_id': request_id} try: if current_time and start_time: @@ -1475,10 +1838,16 @@ class GatewayService: return ResponseModel( status_code=http_response.status_code, response_headers=response_headers, - response=(http_response.json() if http_response.headers.get('Content-Type','').startswith('application/json') else http_response.text) + response=( + http_response.json() + if http_response.headers.get('Content-Type', '').startswith( + 'application/json' + ) + else http_response.text + ), ).dict() - logger.info(f"{request_id} | Connecting to gRPC upstream: {url}") + logger.info(f'{request_id} | Connecting to gRPC upstream: {url}') channel = grpc.aio.insecure_channel(url) try: await asyncio.wait_for(channel.channel_ready(), timeout=2.0) @@ -1488,48 +1857,60 @@ class GatewayService: reply_class_name = f'{method_name}Reply' try: - logger.info(f"{request_id} | Resolving message types: {request_class_name} and {reply_class_name} from pb2_module={getattr(pb2_module, '__name__', 'unknown')}") + logger.info( + f'{request_id} | Resolving message types: {request_class_name} and {reply_class_name} from pb2_module={getattr(pb2_module, "__name__", "unknown")}' + ) if pb2_module is None: - logger.error(f'{request_id} | pb2_module is None - cannot resolve message types') + logger.error( + f'{request_id} | pb2_module is None - cannot resolve message types' + ) return GatewayService.error_response( request_id, 'GTW012', 'Internal error: protobuf module not loaded', - status=500 + status=500, ) try: request_class = getattr(pb2_module, request_class_name) reply_class = getattr(pb2_module, reply_class_name) except AttributeError as attr_err: - logger.error(f'{request_id} | Message types not found in pb2_module: {str(attr_err)}') + logger.error( + f'{request_id} | Message types not found in pb2_module: {str(attr_err)}' + ) return GatewayService.error_response( request_id, 'GTW006', f'Message types {request_class_name}/{reply_class_name} not found in protobuf module', - status=500 + status=500, ) try: request_message = request_class() - logger.info(f'{request_id} | Successfully created request message of type {request_class_name}') + logger.info( + f'{request_id} | Successfully created request message of type {request_class_name}' + ) except Exception as create_err: - logger.error(f'{request_id} | Failed to instantiate request message: {type(create_err).__name__}: {str(create_err)}') + logger.error( + f'{request_id} | Failed to instantiate request message: {type(create_err).__name__}: {str(create_err)}' + ) return GatewayService.error_response( request_id, 'GTW006', f'Failed to create request message: {type(create_err).__name__}', - status=500 + status=500, ) except Exception as e: - logger.error(f'{request_id} | Unexpected error in message type resolution: {type(e).__name__}: {str(e)}') + logger.error( + f'{request_id} | Unexpected error in message type resolution: {type(e).__name__}: {str(e)}' + ) return GatewayService.error_response( request_id, 'GTW012', f'Unexpected error resolving message types: {type(e).__name__}', - status=500 + status=500, ) for key, value in body['message'].items(): try: @@ -1553,12 +1934,16 @@ class GatewayService: except Exception: base_ms, max_ms = 100, 1000 - stream_mode = str((body.get('stream') or body.get('streaming') or '')).lower() + stream_mode = str(body.get('stream') or body.get('streaming') or '').lower() idempotent_override = body.get('idempotent') if idempotent_override is not None: is_idempotent = bool(idempotent_override) else: - is_idempotent = not (stream_mode.startswith('client') or stream_mode.startswith('bidi') or stream_mode.startswith('bi')) + is_idempotent = not ( + stream_mode.startswith('client') + or stream_mode.startswith('bidi') + or stream_mode.startswith('bi') + ) retryable = { grpc.StatusCode.UNAVAILABLE, @@ -1572,19 +1957,26 @@ class GatewayService: final_code_name = 'OK' got_response = False try: - logger.info(f"{request_id} | gRPC entering attempts={attempts} stream_mode={stream_mode or 'unary'} method={service_name}.{method_name}") + logger.info( + f'{request_id} | gRPC entering attempts={attempts} stream_mode={stream_mode or "unary"} method={service_name}.{method_name}' + ) except Exception: pass for attempt in range(attempts): try: full_method = f'/{module_base}.{service_name}/{method_name}' try: - logger.info(f"{request_id} | gRPC attempt={attempt+1}/{attempts} calling {full_method}") + logger.info( + f'{request_id} | gRPC attempt={attempt + 1}/{attempts} calling {full_method}' + ) except Exception: pass req_ser = getattr(request_message, 'SerializeToString', None) if not callable(req_ser): - req_ser = (lambda _m: b'') + + def req_ser(_m): + return b'' + metadata_list = GatewayService._sanitize_grpc_metadata(headers or {}) if stream_mode.startswith('server'): call = channel.unary_stream( @@ -1604,7 +1996,15 @@ class GatewayService: items.append(d) if len(items) >= max_items: break - response = type('R', (), {'DESCRIPTOR': type('D', (), {'fields': []})(), 'ok': True, '_items': items})() + response = type( + 'R', + (), + { + 'DESCRIPTOR': type('D', (), {'fields': []})(), + 'ok': True, + '_items': items, + }, + )() got_response = True elif stream_mode.startswith('client'): try: @@ -1615,6 +2015,7 @@ class GatewayService: ) except AttributeError: stream = None + async def _gen_client(): msgs = body.get('messages') or [] if not msgs: @@ -1634,6 +2035,7 @@ class GatewayService: yield msg except Exception: yield request_message + if stream is not None: try: response = await stream(_gen_client(), metadata=metadata_list) @@ -1660,6 +2062,7 @@ class GatewayService: ) except AttributeError: bidi = None + async def _gen_bidi(): msgs = body.get('messages') or [] if not msgs: @@ -1679,6 +2082,7 @@ class GatewayService: yield msg except Exception: yield request_message + items = [] max_items = int(body.get('max_items') or 50) if bidi is not None: @@ -1704,7 +2108,15 @@ class GatewayService: items.append(d) if len(items) >= max_items: break - response = type('R', (), {'DESCRIPTOR': type('D', (), {'fields': []})(), 'ok': True, '_items': items})() + response = type( + 'R', + (), + { + 'DESCRIPTOR': type('D', (), {'fields': []})(), + 'ok': True, + '_items': items, + }, + )() got_response = True else: unary = channel.unary_unary( @@ -1719,7 +2131,9 @@ class GatewayService: got_response = True last_exc = None try: - logger.info(f"{request_id} | gRPC unary success; stream_mode={stream_mode or 'unary'}") + logger.info( + f'{request_id} | gRPC unary success; stream_mode={stream_mode or "unary"}' + ) except Exception: pass break @@ -1728,13 +2142,13 @@ class GatewayService: try: code = getattr(e2, 'code', lambda: None)() cname = str(getattr(code, 'name', '') or 'UNKNOWN') - logger.info(f"{request_id} | gRPC primary call raised: {cname}") + logger.info(f'{request_id} | gRPC primary call raised: {cname}') except Exception: - logger.info(f"{request_id} | gRPC primary call raised non-grpc exception") + logger.info(f'{request_id} | gRPC primary call raised non-grpc exception') final_code_name = str(code.name) if getattr(code, 'name', None) else 'ERROR' if attempt < attempts - 1 and is_idempotent and code in retryable: retries_made += 1 - delay = min(max_ms, base_ms * (2 ** attempt)) / 1000.0 + delay = min(max_ms, base_ms * (2**attempt)) / 1000.0 jitter_factor = 1.0 + (random.random() * jitter - (jitter / 2.0)) await asyncio.sleep(max(0.01, delay * jitter_factor)) continue @@ -1742,7 +2156,10 @@ class GatewayService: alt_method = f'/{service_name}/{method_name}' req_ser = getattr(request_message, 'SerializeToString', None) if not callable(req_ser): - req_ser = (lambda _m: b'') + + def req_ser(_m): + return b'' + if stream_mode.startswith('server'): call2 = channel.unary_stream( alt_method, @@ -1761,7 +2178,15 @@ class GatewayService: items.append(d) if len(items) >= max_items: break - response = type('R', (), {'DESCRIPTOR': type('D', (), {'fields': []})(), 'ok': True, '_items': items})() + response = type( + 'R', + (), + { + 'DESCRIPTOR': type('D', (), {'fields': []})(), + 'ok': True, + '_items': items, + }, + )() got_response = True elif stream_mode.startswith('client'): try: @@ -1772,6 +2197,7 @@ class GatewayService: ) except AttributeError: stream2 = None + async def _gen_client_alt(): msgs = body.get('messages') or [] if not msgs: @@ -1791,9 +2217,12 @@ class GatewayService: yield msg except Exception: yield request_message + if stream2 is not None: try: - response = await stream2(_gen_client_alt(), metadata=metadata_list) + response = await stream2( + _gen_client_alt(), metadata=metadata_list + ) except TypeError: response = await stream2(_gen_client_alt()) got_response = True @@ -1817,6 +2246,7 @@ class GatewayService: ) except AttributeError: bidi2 = None + async def _gen_bidi_alt(): msgs = body.get('messages') or [] if not msgs: @@ -1836,6 +2266,7 @@ class GatewayService: yield msg except Exception: yield request_message + items = [] max_items = int(body.get('max_items') or 50) if bidi2 is not None: @@ -1861,7 +2292,15 @@ class GatewayService: items.append(d) if len(items) >= max_items: break - response = type('R', (), {'DESCRIPTOR': type('D', (), {'fields': []})(), 'ok': True, '_items': items})() + response = type( + 'R', + (), + { + 'DESCRIPTOR': type('D', (), {'fields': []})(), + 'ok': True, + '_items': items, + }, + )() got_response = True else: unary2 = channel.unary_unary( @@ -1881,13 +2320,15 @@ class GatewayService: try: code3 = getattr(e3, 'code', lambda: None)() cname3 = str(getattr(code3, 'name', '') or 'UNKNOWN') - logger.info(f"{request_id} | gRPC alt call raised: {cname3}") + logger.info(f'{request_id} | gRPC alt call raised: {cname3}') except Exception: - logger.info(f"{request_id} | gRPC alt call raised non-grpc exception") - final_code_name = str(code3.name) if getattr(code3, 'name', None) else 'ERROR' + logger.info(f'{request_id} | gRPC alt call raised non-grpc exception') + final_code_name = ( + str(code3.name) if getattr(code3, 'name', None) else 'ERROR' + ) if attempt < attempts - 1 and is_idempotent and code3 in retryable: retries_made += 1 - delay = min(max_ms, base_ms * (2 ** attempt)) / 1000.0 + delay = min(max_ms, base_ms * (2**attempt)) / 1000.0 jitter_factor = 1.0 + (random.random() * jitter - (jitter / 2.0)) await asyncio.sleep(max(0.01, delay * jitter_factor)) continue @@ -1900,11 +2341,13 @@ class GatewayService: code_obj = getattr(last_exc, 'code', lambda: None)() if code_obj and hasattr(code_obj, 'name'): code_name = str(code_obj.name).upper() - logger.info(f"{request_id} | gRPC call failed with status: {code_name}") + logger.info(f'{request_id} | gRPC call failed with status: {code_name}') else: - logger.warning(f"{request_id} | gRPC exception has no valid status code") + logger.warning(f'{request_id} | gRPC exception has no valid status code') except Exception as code_extract_err: - logger.warning(f"{request_id} | Failed to extract gRPC status code: {str(code_extract_err)}") + logger.warning( + f'{request_id} | Failed to extract gRPC status code: {str(code_extract_err)}' + ) status_map = { 'OK': 200, @@ -1941,33 +2384,41 @@ class GatewayService: details = f'gRPC error: {code_name}' logger.error( - f"{request_id} | gRPC call failed after {retries_made} retries. " - f"Status: {code_name}, HTTP: {http_status}, Details: {details[:100]}" + f'{request_id} | gRPC call failed after {retries_made} retries. ' + f'Status: {code_name}, HTTP: {http_status}, Details: {details[:100]}' ) response_headers = { 'request_id': request_id, 'X-Retry-Count': str(retries_made), 'X-GRPC-Status': code_name, - 'X-GRPC-Code': str(code_obj.value[0]) if code_obj and hasattr(code_obj, 'value') else 'unknown' + 'X-GRPC-Code': str(code_obj.value[0]) + if code_obj and hasattr(code_obj, 'value') + else 'unknown', } return ResponseModel( status_code=http_status, response_headers=response_headers, error_code='GTW006', - error_message=str(details)[:255] + error_message=str(details)[:255], ).dict() if not got_response and last_exc is None: try: - logger.error(f"{request_id} | gRPC loop ended with no response and no exception; returning 500 UNKNOWN") + logger.error( + f'{request_id} | gRPC loop ended with no response and no exception; returning 500 UNKNOWN' + ) except Exception: pass return ResponseModel( status_code=500, - response_headers={'request_id': request_id, 'X-Retry-Count': str(retries_made), 'X-Retry-Final': 'UNKNOWN'}, + response_headers={ + 'request_id': request_id, + 'X-Retry-Count': str(retries_made), + 'X-Retry-Final': 'UNKNOWN', + }, error_code='GTW006', - error_message='gRPC call failed' + error_message='gRPC call failed', ).dict() response_dict = {} @@ -1981,7 +2432,11 @@ class GatewayService: else: response_dict[field.name] = value backend_end_time = time.time() * 1000 - response_headers = {'request_id': request_id, 'X-Retry-Count': str(retries_made), 'X-Retry-Final': final_code_name} + response_headers = { + 'request_id': request_id, + 'X-Retry-Count': str(retries_made), + 'X-Retry-Final': final_code_name, + } try: if current_time and start_time: response_headers['X-Gateway-Time'] = str(int(current_time - start_time)) @@ -1990,20 +2445,20 @@ class GatewayService: except Exception: pass try: - logger.info(f"{request_id} | gRPC return 200 with items={bool(response_dict.get('items'))}") + logger.info( + f'{request_id} | gRPC return 200 with items={bool(response_dict.get("items"))}' + ) except Exception: pass return ResponseModel( - status_code=200, - response_headers=response_headers, - response=response_dict + status_code=200, response_headers=response_headers, response=response_dict ).dict() except httpx.TimeoutException: return ResponseModel( status_code=504, response_headers={'request_id': request_id}, error_code='GTW010', - error_message='Gateway timeout' + error_message='Gateway timeout', ).dict() except Exception as e: code_name = 'UNKNOWN' @@ -2057,10 +2512,10 @@ class GatewayService: response_headers={ 'request_id': request_id, 'X-Error-Type': type(e).__name__, - 'X-GRPC-Status': code_name + 'X-GRPC-Status': code_name, }, error_code='GTW006', - error_message=details[:255] + error_message=details[:255], ).dict() finally: if current_time: @@ -2068,7 +2523,9 @@ class GatewayService: if backend_end_time and current_time: logger.info(f'{request_id} | Backend time {backend_end_time - current_time}ms') - async def _make_graphql_request(self, url: str, query: str, headers: Dict[str, str] = None) -> Dict: + async def _make_graphql_request( + self, url: str, query: str, headers: dict[str, str] = None + ) -> dict: try: if headers is None: headers = {} @@ -2079,13 +2536,22 @@ class GatewayService: if 'errors' in data: return data if r.status_code != 200: - return {'errors': [{'message': f'HTTP {r.status_code}: {data.get("message", "Unknown error")}', 'extensions': {'code': 'HTTP_ERROR'}}]} + return { + 'errors': [ + { + 'message': f'HTTP {r.status_code}: {data.get("message", "Unknown error")}', + 'extensions': {'code': 'HTTP_ERROR'}, + } + ] + } return data except Exception as e: logger.error(f'Error making GraphQL request: {str(e)}') return { - 'errors': [{ - 'message': f'Error making GraphQL request: {str(e)}', - 'extensions': {'code': 'REQUEST_ERROR'} - }] + 'errors': [ + { + 'message': f'Error making GraphQL request: {str(e)}', + 'extensions': {'code': 'REQUEST_ERROR'}, + } + ] } diff --git a/backend-services/services/group_service.py b/backend-services/services/group_service.py index f33c2e8..ce108f9 100644 --- a/backend-services/services/group_service.py +++ b/backend-services/services/group_service.py @@ -4,22 +4,22 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from pymongo.errors import DuplicateKeyError import logging +from pymongo.errors import DuplicateKeyError + +from models.create_group_model import CreateGroupModel from models.response_model import ResponseModel from models.update_group_model import UpdateGroupModel -from utils.database import group_collection -from utils.cache_manager_util import cache_manager -from utils.doorman_cache_util import doorman_cache -from models.create_group_model import CreateGroupModel -from utils.paging_util import validate_page_params from utils.constants import ErrorCodes, Messages +from utils.database import group_collection +from utils.doorman_cache_util import doorman_cache +from utils.paging_util import validate_page_params logger = logging.getLogger('doorman.gateway') -class GroupService: +class GroupService: @staticmethod async def create_group(data: CreateGroupModel, request_id): """ @@ -29,11 +29,9 @@ class GroupService: if doorman_cache.get_cache('group_cache', data.group_name): return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='GRP001', - error_message='Group already exists' + error_message='Group already exists', ).dict() group_dict = data.dict() try: @@ -41,26 +39,19 @@ class GroupService: if not insert_result.acknowledged: logger.error(request_id + ' | Group creation failed with code GRP002') return ResponseModel( - status_code=400, - error_code='GRP002', - error_message='Unable to insert group' + status_code=400, error_code='GRP002', error_message='Unable to insert group' ).dict() group_dict['_id'] = str(insert_result.inserted_id) doorman_cache.set_cache('group_cache', data.group_name, group_dict) logger.info(request_id + ' | Group creation successful') - return ResponseModel( - status_code=201, - message='Group created successfully' - ).dict() - except DuplicateKeyError as e: + return ResponseModel(status_code=201, message='Group created successfully').dict() + except DuplicateKeyError: logger.error(request_id + ' | Group creation failed with code GRP001') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='GRP001', - error_message='Group already exists' + error_message='Group already exists', ).dict() @staticmethod @@ -72,53 +63,39 @@ class GroupService: if data.group_name and data.group_name != group_name: return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='GRP004', - error_message='Group name cannot be updated' + error_message='Group name cannot be updated', ).dict() group = doorman_cache.get_cache('group_cache', group_name) if not group: - group = group_collection.find_one({ - 'group_name': group_name - }) + group = group_collection.find_one({'group_name': group_name}) if not group: logger.error(request_id + ' | Group update failed with code GRP003') return ResponseModel( - status_code=400, - error_code='GRP003', - error_message='Group does not exist' + status_code=400, error_code='GRP003', error_message='Group does not exist' ).dict() else: doorman_cache.delete_cache('group_cache', group_name) not_null_data = {k: v for k, v in data.dict().items() if v is not None} if not_null_data: update_result = group_collection.update_one( - {'group_name': group_name}, - {'$set': not_null_data} + {'group_name': group_name}, {'$set': not_null_data} ) if not update_result.acknowledged or update_result.modified_count == 0: logger.error(request_id + ' | Group update failed with code GRP002') return ResponseModel( - status_code=400, - error_code='GRP005', - error_message='Unable to update group' + status_code=400, error_code='GRP005', error_message='Unable to update group' ).dict() logger.info(request_id + ' | Group updated successful') - return ResponseModel( - status_code=200, - message='Group updated successfully' - ).dict() + return ResponseModel(status_code=200, message='Group updated successfully').dict() else: logger.error(request_id + ' | Group update failed with code GRP006') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='GRP006', - error_message='No data to update' + error_message='No data to update', ).dict() @staticmethod @@ -129,35 +106,27 @@ class GroupService: logger.info(request_id + ' | Deleting group: ' + group_name) group = doorman_cache.get_cache('group_cache', group_name) if not group: - group = group_collection.find_one({ - 'group_name': group_name - }) + group = group_collection.find_one({'group_name': group_name}) if not group: logger.error(request_id + ' | Group deletion failed with code GRP003') return ResponseModel( - status_code=400, - error_code='GRP003', - error_message='Group does not exist' + status_code=400, error_code='GRP003', error_message='Group does not exist' ).dict() delete_result = group_collection.delete_one({'group_name': group_name}) if not delete_result.acknowledged: logger.error(request_id + ' | Group deletion failed with code GRP002') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='GRP007', - error_message='Unable to delete group' + error_message='Unable to delete group', ).dict() doorman_cache.delete_cache('group_cache', group_name) logger.info(request_id + ' | Group deletion successful') return ResponseModel( status_code=200, - response_headers={ - 'request_id': request_id - }, - message='Group deleted successfully' + response_headers={'request_id': request_id}, + message='Group deleted successfully', ).dict() @staticmethod @@ -165,7 +134,9 @@ class GroupService: """ Check if a group exists. """ - if doorman_cache.get_cache('group_cache', data.get('group_name')) or group_collection.find_one({'group_name': data.get('group_name')}): + if doorman_cache.get_cache( + 'group_cache', data.get('group_name') + ) or group_collection.find_one({'group_name': data.get('group_name')}): return True return False @@ -174,25 +145,27 @@ class GroupService: """ Get all groups. """ - logger.info(request_id + ' | Getting groups: Page=' + str(page) + ' Page Size=' + str(page_size)) + logger.info( + request_id + ' | Getting groups: Page=' + str(page) + ' Page Size=' + str(page_size) + ) try: page, page_size = validate_page_params(page, page_size) except Exception as e: return ResponseModel( status_code=400, error_code=ErrorCodes.PAGE_SIZE, - error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING) + error_message=( + Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING + ), ).dict() skip = (page - 1) * page_size cursor = group_collection.find().sort('group_name', 1).skip(skip).limit(page_size) groups = cursor.to_list(length=None) for group in groups: - if group.get('_id'): del group['_id'] + if group.get('_id'): + del group['_id'] logger.info(request_id + ' | Groups retrieval successful') - return ResponseModel( - status_code=200, - response={'groups': groups} - ).dict() + return ResponseModel(status_code=200, response={'groups': groups}).dict() @staticmethod async def get_group(group_name, request_id): @@ -206,15 +179,12 @@ class GroupService: if not group: logger.error(request_id + ' | Group retrieval failed with code GRP003') return ResponseModel( - status_code=404, - error_code='GRP003', - error_message='Group does not exist' + status_code=404, error_code='GRP003', error_message='Group does not exist' ).dict() - if group.get('_id'): del group['_id'] + if group.get('_id'): + del group['_id'] doorman_cache.set_cache('group_cache', group_name, group) - if group.get('_id'): del group['_id'] + if group.get('_id'): + del group['_id'] logger.info(request_id + ' | Group retrieval successful') - return ResponseModel( - status_code=200, - response=group - ).dict() + return ResponseModel(status_code=200, response=group).dict() diff --git a/backend-services/services/logging_service.py b/backend-services/services/logging_service.py index f1cea77..6df754a 100644 --- a/backend-services/services/logging_service.py +++ b/backend-services/services/logging_service.py @@ -4,17 +4,19 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ +import glob +import json import logging import os -import json -import glob -from datetime import datetime, timedelta -from typing import List, Dict, Any, Optional -from fastapi import HTTPException import re +from datetime import datetime +from typing import Any + +from fastapi import HTTPException logger = logging.getLogger('doorman.logging') + class LoggingService: def __init__(self): env_dir = os.getenv('LOGS_DIR') @@ -24,43 +26,40 @@ class LoggingService: base_dir = os.path.dirname(os.path.abspath(__file__)) backend_root = os.path.normpath(os.path.join(base_dir, '..')) candidate = os.path.join(backend_root, 'platform-logs') - self.log_directory = candidate if os.path.isdir(candidate) else os.path.join(backend_root, 'logs') + self.log_directory = ( + candidate if os.path.isdir(candidate) else os.path.join(backend_root, 'logs') + ) self.log_file_patterns = ['doorman.log*', 'doorman-trail.log*'] self.max_logs_per_request = 1000 async def get_logs( self, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - user: Optional[str] = None, - endpoint: Optional[str] = None, - request_id: Optional[str] = None, - method: Optional[str] = None, - ip_address: Optional[str] = None, - min_response_time: Optional[str] = None, - max_response_time: Optional[str] = None, - level: Optional[str] = None, + start_date: str | None = None, + end_date: str | None = None, + start_time: str | None = None, + end_time: str | None = None, + user: str | None = None, + endpoint: str | None = None, + request_id: str | None = None, + method: str | None = None, + ip_address: str | None = None, + min_response_time: str | None = None, + max_response_time: str | None = None, + level: str | None = None, limit: int = 100, offset: int = 0, - request_id_param: str = None - ) -> Dict[str, Any]: + request_id_param: str = None, + ) -> dict[str, Any]: """ Retrieve and filter logs based on various criteria """ try: - log_files: list[str] = [] for pat in self.log_file_patterns: log_files.extend(glob.glob(os.path.join(self.log_directory, pat))) if not log_files: - return { - 'logs': [], - 'total': 0, - 'has_more': False - } + return {'logs': [], 'total': 0, 'has_more': False} logs = [] total_count = 0 @@ -72,23 +71,26 @@ class LoggingService: continue try: - with open(log_file, 'r', encoding='utf-8') as file: + with open(log_file, encoding='utf-8') as file: for line in file: log_entry = self._parse_log_line(line) - if log_entry and self._matches_filters(log_entry, { - 'start_date': start_date, - 'end_date': end_date, - 'start_time': start_time, - 'end_time': end_time, - 'user': user, - 'endpoint': endpoint, - 'request_id': request_id, - 'method': method, - 'ip_address': ip_address, - 'min_response_time': min_response_time, - 'max_response_time': max_response_time, - 'level': level - }): + if log_entry and self._matches_filters( + log_entry, + { + 'start_date': start_date, + 'end_date': end_date, + 'start_time': start_time, + 'end_time': end_time, + 'user': user, + 'endpoint': endpoint, + 'request_id': request_id, + 'method': method, + 'ip_address': ip_address, + 'min_response_time': min_response_time, + 'max_response_time': max_response_time, + 'level': level, + }, + ): total_count += 1 if len(logs) < limit and total_count > offset: logs.append(log_entry) @@ -103,17 +105,13 @@ class LoggingService: logger.warning(f'Error reading log file {log_file}: {str(e)}') continue - return { - 'logs': logs, - 'total': total_count, - 'has_more': total_count > offset + limit - } + return {'logs': logs, 'total': total_count, 'has_more': total_count > offset + limit} except Exception as e: logger.error(f'Error retrieving logs: {str(e)}', exc_info=True) raise HTTPException(status_code=500, detail='Failed to retrieve logs') - def get_available_log_files(self) -> List[str]: + def get_available_log_files(self) -> list[str]: """ Get list of available log files for debugging """ @@ -123,12 +121,11 @@ class LoggingService: log_files.sort(key=lambda x: os.path.getmtime(x), reverse=True) return log_files - async def get_log_statistics(self, request_id: str = None) -> Dict[str, Any]: + async def get_log_statistics(self, request_id: str = None) -> dict[str, Any]: """ Get log statistics for dashboard """ try: - log_files: list[str] = [] for pat in self.log_file_patterns: log_files.extend(glob.glob(os.path.join(self.log_directory, pat))) @@ -143,7 +140,7 @@ class LoggingService: 'avg_response_time': 0, 'top_apis': [], 'top_users': [], - 'top_endpoints': [] + 'top_endpoints': [], } stats = { @@ -155,7 +152,7 @@ class LoggingService: 'response_times': [], 'apis': {}, 'users': {}, - 'endpoints': {} + 'endpoints': {}, } for log_file in log_files: @@ -163,7 +160,7 @@ class LoggingService: continue try: - with open(log_file, 'r', encoding='utf-8') as file: + with open(log_file, encoding='utf-8') as file: for line in file: log_entry = self._parse_log_line(line) if log_entry: @@ -181,28 +178,42 @@ class LoggingService: if log_entry.get('response_time'): try: - stats['response_times'].append(float(log_entry['response_time'])) + stats['response_times'].append( + float(log_entry['response_time']) + ) except (ValueError, TypeError): pass if log_entry.get('api'): - stats['apis'][log_entry['api']] = stats['apis'].get(log_entry['api'], 0) + 1 + stats['apis'][log_entry['api']] = ( + stats['apis'].get(log_entry['api'], 0) + 1 + ) if log_entry.get('user'): - stats['users'][log_entry['user']] = stats['users'].get(log_entry['user'], 0) + 1 + stats['users'][log_entry['user']] = ( + stats['users'].get(log_entry['user'], 0) + 1 + ) if log_entry.get('endpoint'): - stats['endpoints'][log_entry['endpoint']] = stats['endpoints'].get(log_entry['endpoint'], 0) + 1 + stats['endpoints'][log_entry['endpoint']] = ( + stats['endpoints'].get(log_entry['endpoint'], 0) + 1 + ) except Exception as e: logger.warning(f'Error reading log file {log_file} for statistics: {str(e)}') continue - avg_response_time = sum(stats['response_times']) / len(stats['response_times']) if stats['response_times'] else 0 + avg_response_time = ( + sum(stats['response_times']) / len(stats['response_times']) + if stats['response_times'] + else 0 + ) top_apis = sorted(stats['apis'].items(), key=lambda x: x[1], reverse=True)[:10] top_users = sorted(stats['users'].items(), key=lambda x: x[1], reverse=True)[:10] - top_endpoints = sorted(stats['endpoints'].items(), key=lambda x: x[1], reverse=True)[:10] + top_endpoints = sorted(stats['endpoints'].items(), key=lambda x: x[1], reverse=True)[ + :10 + ] return { 'total_logs': stats['total_logs'], @@ -213,7 +224,9 @@ class LoggingService: 'avg_response_time': round(avg_response_time, 2), 'top_apis': [{'name': api, 'count': count} for api, count in top_apis], 'top_users': [{'name': user, 'count': count} for user, count in top_users], - 'top_endpoints': [{'name': endpoint, 'count': count} for endpoint, count in top_endpoints] + 'top_endpoints': [ + {'name': endpoint, 'count': count} for endpoint, count in top_endpoints + ], } except Exception as e: @@ -223,21 +236,17 @@ class LoggingService: async def export_logs( self, format: str = 'json', - start_date: Optional[str] = None, - end_date: Optional[str] = None, - filters: Dict[str, Any] = None, - request_id: str = None - ) -> Dict[str, Any]: + start_date: str | None = None, + end_date: str | None = None, + filters: dict[str, Any] = None, + request_id: str = None, + ) -> dict[str, Any]: """ Export logs in various formats """ try: - logs_data = await self.get_logs( - start_date=start_date, - end_date=end_date, - limit=10000, - **filters if filters else {} + start_date=start_date, end_date=end_date, limit=10000, **filters if filters else {} ) logs = logs_data['logs'] @@ -246,24 +255,24 @@ class LoggingService: return { 'format': 'json', 'data': json.dumps(logs, indent=2, default=str), - 'filename': f"logs_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + 'filename': f'logs_export_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json', } elif format.lower() == 'csv': if not logs: return { 'format': 'csv', 'data': 'timestamp,level,message,source,user,api,endpoint,method,status_code,response_time,ip_address,protocol,request_id,group,role\n', - 'filename': f"logs_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" + 'filename': f'logs_export_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv', } csv_data = 'timestamp,level,message,source,user,api,endpoint,method,status_code,response_time,ip_address,protocol,request_id,group,role\n' for log in logs: - csv_data += f"{log.get('timestamp', '')},{log.get('level', '')},{log.get('message', '').replace(',', ';')},{log.get('source', '')},{log.get('user', '')},{log.get('api', '')},{log.get('endpoint', '')},{log.get('method', '')},{log.get('status_code', '')},{log.get('response_time', '')},{log.get('ip_address', '')},{log.get('protocol', '')},{log.get('request_id', '')},{log.get('group', '')},{log.get('role', '')}\n" + csv_data += f'{log.get("timestamp", "")},{log.get("level", "")},{log.get("message", "").replace(",", ";")},{log.get("source", "")},{log.get("user", "")},{log.get("api", "")},{log.get("endpoint", "")},{log.get("method", "")},{log.get("status_code", "")},{log.get("response_time", "")},{log.get("ip_address", "")},{log.get("protocol", "")},{log.get("request_id", "")},{log.get("group", "")},{log.get("role", "")}\n' return { 'format': 'csv', 'data': csv_data, - 'filename': f"logs_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" + 'filename': f'logs_export_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv', } else: raise HTTPException(status_code=400, detail='Unsupported export format') @@ -272,7 +281,7 @@ class LoggingService: logger.error(f'Error exporting logs: {str(e)}') raise HTTPException(status_code=500, detail='Failed to export logs') - def _parse_log_line(self, line: str) -> Optional[Dict[str, Any]]: + def _parse_log_line(self, line: str) -> dict[str, Any] | None: """ Parse a log line and extract structured data Format: timestamp - logger_name - level - request_id | message @@ -292,7 +301,9 @@ class LoggingService: level = rec.get('level', '') structured = self._extract_structured_data(message) return { - 'timestamp': timestamp if isinstance(timestamp, str) else timestamp.isoformat(), + 'timestamp': timestamp + if isinstance(timestamp, str) + else timestamp.isoformat(), 'level': level, 'message': message, 'source': name, @@ -330,7 +341,7 @@ class LoggingService: logger.debug(f'Failed to parse log line: {str(e)}', exc_info=True) return None - def _extract_structured_data(self, message: str) -> Dict[str, Any]: + def _extract_structured_data(self, message: str) -> dict[str, Any]: """ Extract structured data from log message """ @@ -375,12 +386,11 @@ class LoggingService: return data - def _matches_filters(self, log_entry: Dict[str, Any], filters: Dict[str, Any]) -> bool: + def _matches_filters(self, log_entry: dict[str, Any], filters: dict[str, Any]) -> bool: """ Check if log entry matches all applied filters """ try: - timestamp_str = log_entry.get('timestamp', '') if not timestamp_str: return False @@ -429,14 +439,20 @@ class LoggingService: log_value = str(log_entry.get(field, '')).lower() if not log_value: - logger.debug(f"Filter '{field}' = '{filter_value}' but log has no value for '{field}'") + logger.debug( + f"Filter '{field}' = '{filter_value}' but log has no value for '{field}'" + ) return False if filter_value not in log_value: - logger.debug(f"Filter '{field}' = '{filter_value}' not found in log value '{log_value}'") + logger.debug( + f"Filter '{field}' = '{filter_value}' not found in log value '{log_value}'" + ) return False else: - logger.debug(f"Filter '{field}' = '{filter_value}' matches log value '{log_value}'") + logger.debug( + f"Filter '{field}' = '{filter_value}' matches log value '{log_value}'" + ) if filters.get('status_code') and filters['status_code'].strip(): try: @@ -470,7 +486,7 @@ class LoggingService: applied_filters = [f'{k}={v}' for k, v in filters.items() if v and str(v).strip()] if applied_filters: - logger.debug(f"Log entry passed all filters: {', '.join(applied_filters)}") + logger.debug(f'Log entry passed all filters: {", ".join(applied_filters)}') return True except Exception as e: diff --git a/backend-services/services/rate_limit_rule_service.py b/backend-services/services/rate_limit_rule_service.py index c8fe8a0..c15a3f2 100644 --- a/backend-services/services/rate_limit_rule_service.py +++ b/backend-services/services/rate_limit_rule_service.py @@ -6,8 +6,9 @@ Handles rule CRUD, validation, and application. """ import logging -from typing import Optional, List, Dict, Any, TYPE_CHECKING from datetime import datetime +from typing import TYPE_CHECKING, Any + try: # Use a type-only import to avoid a hard runtime dependency during tests if TYPE_CHECKING: @@ -17,7 +18,7 @@ try: except Exception: # Defensive: never fail import due to typing AsyncIOMotorDatabase = Any # type: ignore -from models.rate_limit_models import RateLimitRule, RuleType, TimeWindow +from models.rate_limit_models import RateLimitRule, RuleType logger = logging.getLogger(__name__) @@ -25,7 +26,7 @@ logger = logging.getLogger(__name__) class RateLimitRuleService: """ Service for managing rate limit rules - + Features: - CRUD operations for rules - Rule validation @@ -33,248 +34,242 @@ class RateLimitRuleService: - Bulk operations - Rule testing """ - + def __init__(self, db: AsyncIOMotorDatabase): """ Initialize rate limit rule service - + Args: db: MongoDB database instance """ self.db = db self.rules_collection = db.rate_limit_rules - + # ======================================================================== # RULE CRUD OPERATIONS # ======================================================================== - + async def create_rule(self, rule: RateLimitRule) -> RateLimitRule: """ Create a new rate limit rule - + Args: rule: RateLimitRule object to create - + Returns: Created rule - + Raises: ValueError: If rule with same ID already exists """ # Check if rule already exists existing = await self.rules_collection.find_one({'rule_id': rule.rule_id}) if existing: - raise ValueError(f"Rule with ID {rule.rule_id} already exists") - + raise ValueError(f'Rule with ID {rule.rule_id} already exists') + # Set timestamps rule.created_at = datetime.now() rule.updated_at = datetime.now() - + # Insert into database await self.rules_collection.insert_one(rule.to_dict()) - - logger.info(f"Created rate limit rule: {rule.rule_id}") + + logger.info(f'Created rate limit rule: {rule.rule_id}') return rule - - async def get_rule(self, rule_id: str) -> Optional[RateLimitRule]: + + async def get_rule(self, rule_id: str) -> RateLimitRule | None: """ Get rule by ID - + Args: rule_id: Rule identifier - + Returns: RateLimitRule object or None if not found """ rule_data = await self.rules_collection.find_one({'rule_id': rule_id}) - + if rule_data: return RateLimitRule.from_dict(rule_data) - + return None - + async def list_rules( self, - rule_type: Optional[RuleType] = None, + rule_type: RuleType | None = None, enabled_only: bool = False, skip: int = 0, - limit: int = 100 - ) -> List[RateLimitRule]: + limit: int = 100, + ) -> list[RateLimitRule]: """ List rate limit rules - + Args: rule_type: Filter by rule type enabled_only: Only return enabled rules skip: Number of records to skip limit: Maximum number of records to return - + Returns: List of rules """ query = {} - + if rule_type: query['rule_type'] = rule_type.value - + if enabled_only: query['enabled'] = True - + # Sort by priority (highest first) cursor = self.rules_collection.find(query).sort('priority', -1).skip(skip).limit(limit) rules = [] - + async for rule_data in cursor: rules.append(RateLimitRule.from_dict(rule_data)) - + return rules - - async def update_rule(self, rule_id: str, updates: Dict[str, Any]) -> Optional[RateLimitRule]: + + async def update_rule(self, rule_id: str, updates: dict[str, Any]) -> RateLimitRule | None: """ Update rule - + Args: rule_id: Rule identifier updates: Dictionary of fields to update - + Returns: Updated rule or None if not found """ # Add updated timestamp updates['updated_at'] = datetime.now().isoformat() - + result = await self.rules_collection.find_one_and_update( - {'rule_id': rule_id}, - {'$set': updates}, - return_document=True + {'rule_id': rule_id}, {'$set': updates}, return_document=True ) - + if result: - logger.info(f"Updated rate limit rule: {rule_id}") + logger.info(f'Updated rate limit rule: {rule_id}') return RateLimitRule.from_dict(result) - + return None - + async def delete_rule(self, rule_id: str) -> bool: """ Delete rule - + Args: rule_id: Rule identifier - + Returns: True if deleted, False if not found """ result = await self.rules_collection.delete_one({'rule_id': rule_id}) - + if result.deleted_count > 0: - logger.info(f"Deleted rate limit rule: {rule_id}") + logger.info(f'Deleted rate limit rule: {rule_id}') return True - + return False - - async def enable_rule(self, rule_id: str) -> Optional[RateLimitRule]: + + async def enable_rule(self, rule_id: str) -> RateLimitRule | None: """ Enable a rule - + Args: rule_id: Rule identifier - + Returns: Updated rule or None if not found """ return await self.update_rule(rule_id, {'enabled': True}) - - async def disable_rule(self, rule_id: str) -> Optional[RateLimitRule]: + + async def disable_rule(self, rule_id: str) -> RateLimitRule | None: """ Disable a rule - + Args: rule_id: Rule identifier - + Returns: Updated rule or None if not found """ return await self.update_rule(rule_id, {'enabled': False}) - + # ======================================================================== # RULE QUERIES # ======================================================================== - + async def get_applicable_rules( self, - user_id: Optional[str] = None, - api_name: Optional[str] = None, - endpoint_uri: Optional[str] = None, - ip_address: Optional[str] = None - ) -> List[RateLimitRule]: + user_id: str | None = None, + api_name: str | None = None, + endpoint_uri: str | None = None, + ip_address: str | None = None, + ) -> list[RateLimitRule]: """ Get all applicable rules for a request - + Args: user_id: User identifier api_name: API name endpoint_uri: Endpoint URI ip_address: IP address - + Returns: List of applicable rules sorted by priority """ query = {'enabled': True} - + # Build OR query for applicable rules or_conditions = [] - + # Global rules always apply or_conditions.append({'rule_type': RuleType.GLOBAL.value}) - + # Per-user rules if user_id: - or_conditions.append({ - 'rule_type': RuleType.PER_USER.value, - 'target_identifier': user_id - }) - + or_conditions.append( + {'rule_type': RuleType.PER_USER.value, 'target_identifier': user_id} + ) + # Per-API rules if api_name: - or_conditions.append({ - 'rule_type': RuleType.PER_API.value, - 'target_identifier': api_name - }) - + or_conditions.append( + {'rule_type': RuleType.PER_API.value, 'target_identifier': api_name} + ) + # Per-endpoint rules if endpoint_uri: - or_conditions.append({ - 'rule_type': RuleType.PER_ENDPOINT.value, - 'target_identifier': endpoint_uri - }) - + or_conditions.append( + {'rule_type': RuleType.PER_ENDPOINT.value, 'target_identifier': endpoint_uri} + ) + # Per-IP rules if ip_address: - or_conditions.append({ - 'rule_type': RuleType.PER_IP.value, - 'target_identifier': ip_address - }) - + or_conditions.append( + {'rule_type': RuleType.PER_IP.value, 'target_identifier': ip_address} + ) + if or_conditions: query['$or'] = or_conditions - + # Get rules sorted by priority cursor = self.rules_collection.find(query).sort('priority', -1) rules = [] - + async for rule_data in cursor: rules.append(RateLimitRule.from_dict(rule_data)) - + return rules - - async def search_rules(self, search_term: str) -> List[RateLimitRule]: + + async def search_rules(self, search_term: str) -> list[RateLimitRule]: """ Search rules by ID, description, or target identifier - + Args: search_term: Search term - + Returns: List of matching rules """ @@ -282,142 +277,140 @@ class RateLimitRuleService: '$or': [ {'rule_id': {'$regex': search_term, '$options': 'i'}}, {'description': {'$regex': search_term, '$options': 'i'}}, - {'target_identifier': {'$regex': search_term, '$options': 'i'}} + {'target_identifier': {'$regex': search_term, '$options': 'i'}}, ] } - + cursor = self.rules_collection.find(query).sort('priority', -1) rules = [] - + async for rule_data in cursor: rules.append(RateLimitRule.from_dict(rule_data)) - + return rules - + # ======================================================================== # BULK OPERATIONS # ======================================================================== - - async def bulk_create_rules(self, rules: List[RateLimitRule]) -> int: + + async def bulk_create_rules(self, rules: list[RateLimitRule]) -> int: """ Create multiple rules at once - + Args: rules: List of rules to create - + Returns: Number of rules created """ if not rules: return 0 - + # Set timestamps now = datetime.now() for rule in rules: rule.created_at = now rule.updated_at = now - + # Insert all rules rule_dicts = [rule.to_dict() for rule in rules] result = await self.rules_collection.insert_many(rule_dicts) - + count = len(result.inserted_ids) - logger.info(f"Bulk created {count} rate limit rules") + logger.info(f'Bulk created {count} rate limit rules') return count - - async def bulk_delete_rules(self, rule_ids: List[str]) -> int: + + async def bulk_delete_rules(self, rule_ids: list[str]) -> int: """ Delete multiple rules at once - + Args: rule_ids: List of rule IDs to delete - + Returns: Number of rules deleted """ if not rule_ids: return 0 - - result = await self.rules_collection.delete_many({ - 'rule_id': {'$in': rule_ids} - }) - + + result = await self.rules_collection.delete_many({'rule_id': {'$in': rule_ids}}) + count = result.deleted_count - logger.info(f"Bulk deleted {count} rate limit rules") + logger.info(f'Bulk deleted {count} rate limit rules') return count - - async def bulk_enable_rules(self, rule_ids: List[str]) -> int: + + async def bulk_enable_rules(self, rule_ids: list[str]) -> int: """ Enable multiple rules at once - + Args: rule_ids: List of rule IDs to enable - + Returns: Number of rules enabled """ if not rule_ids: return 0 - + result = await self.rules_collection.update_many( {'rule_id': {'$in': rule_ids}}, - {'$set': {'enabled': True, 'updated_at': datetime.now().isoformat()}} + {'$set': {'enabled': True, 'updated_at': datetime.now().isoformat()}}, ) - + count = result.modified_count - logger.info(f"Bulk enabled {count} rate limit rules") + logger.info(f'Bulk enabled {count} rate limit rules') return count - - async def bulk_disable_rules(self, rule_ids: List[str]) -> int: + + async def bulk_disable_rules(self, rule_ids: list[str]) -> int: """ Disable multiple rules at once - + Args: rule_ids: List of rule IDs to disable - + Returns: Number of rules disabled """ if not rule_ids: return 0 - + result = await self.rules_collection.update_many( {'rule_id': {'$in': rule_ids}}, - {'$set': {'enabled': False, 'updated_at': datetime.now().isoformat()}} + {'$set': {'enabled': False, 'updated_at': datetime.now().isoformat()}}, ) - + count = result.modified_count - logger.info(f"Bulk disabled {count} rate limit rules") + logger.info(f'Bulk disabled {count} rate limit rules') return count - + # ======================================================================== # RULE DUPLICATION # ======================================================================== - + async def duplicate_rule(self, rule_id: str, new_rule_id: str) -> RateLimitRule: """ Duplicate an existing rule - + Args: rule_id: Source rule ID new_rule_id: New rule ID - + Returns: Duplicated rule - + Raises: ValueError: If source rule not found or new ID already exists """ # Get source rule source_rule = await self.get_rule(rule_id) if not source_rule: - raise ValueError(f"Source rule {rule_id} not found") - + raise ValueError(f'Source rule {rule_id} not found') + # Check if new ID already exists existing = await self.rules_collection.find_one({'rule_id': new_rule_id}) if existing: - raise ValueError(f"Rule with ID {new_rule_id} already exists") - + raise ValueError(f'Rule with ID {new_rule_id} already exists') + # Create new rule with same properties new_rule = RateLimitRule( rule_id=new_rule_id, @@ -428,90 +421,93 @@ class RateLimitRuleService: burst_allowance=source_rule.burst_allowance, priority=source_rule.priority, enabled=source_rule.enabled, - description=f"Copy of {source_rule.rule_id}" + description=f'Copy of {source_rule.rule_id}', ) - + return await self.create_rule(new_rule) - + # ======================================================================== # RULE STATISTICS # ======================================================================== - - async def get_rule_statistics(self) -> Dict[str, Any]: + + async def get_rule_statistics(self) -> dict[str, Any]: """ Get statistics about rate limit rules - + Returns: Dictionary with rule statistics """ total_rules = await self.rules_collection.count_documents({}) enabled_rules = await self.rules_collection.count_documents({'enabled': True}) disabled_rules = total_rules - enabled_rules - + # Count by type type_counts = {} for rule_type in RuleType: - count = await self.rules_collection.count_documents({ - 'rule_type': rule_type.value - }) + count = await self.rules_collection.count_documents({'rule_type': rule_type.value}) type_counts[rule_type.value] = count - + return { 'total_rules': total_rules, 'enabled_rules': enabled_rules, 'disabled_rules': disabled_rules, - 'rules_by_type': type_counts + 'rules_by_type': type_counts, } - + # ======================================================================== # RULE VALIDATION # ======================================================================== - - def validate_rule(self, rule: RateLimitRule) -> List[str]: + + def validate_rule(self, rule: RateLimitRule) -> list[str]: """ Validate a rule - + Args: rule: Rule to validate - + Returns: List of validation errors (empty if valid) """ errors = [] - + # Check limit is positive if rule.limit <= 0: - errors.append("Limit must be greater than 0") - + errors.append('Limit must be greater than 0') + # Check burst allowance is non-negative if rule.burst_allowance < 0: - errors.append("Burst allowance cannot be negative") - + errors.append('Burst allowance cannot be negative') + # Check target identifier for specific rule types - if rule.rule_type in [RuleType.PER_USER, RuleType.PER_API, RuleType.PER_ENDPOINT, RuleType.PER_IP]: + if rule.rule_type in [ + RuleType.PER_USER, + RuleType.PER_API, + RuleType.PER_ENDPOINT, + RuleType.PER_IP, + ]: if not rule.target_identifier: - errors.append(f"Target identifier required for {rule.rule_type.value} rules") - + errors.append(f'Target identifier required for {rule.rule_type.value} rules') + return errors # Global rule service instance -_rule_service: Optional[RateLimitRuleService] = None +_rule_service: RateLimitRuleService | None = None def get_rate_limit_rule_service(db: AsyncIOMotorDatabase) -> RateLimitRuleService: """ Get or create global rule service instance - + Args: db: MongoDB database instance - + Returns: RateLimitRuleService instance """ global _rule_service - + if _rule_service is None: _rule_service = RateLimitRuleService(db) - + return _rule_service diff --git a/backend-services/services/role_service.py b/backend-services/services/role_service.py index 124d895..2ab4908 100644 --- a/backend-services/services/role_service.py +++ b/backend-services/services/role_service.py @@ -4,23 +4,23 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from pymongo.errors import DuplicateKeyError import logging +from pymongo.errors import DuplicateKeyError + +from models.create_role_model import CreateRoleModel from models.response_model import ResponseModel from models.update_role_model import UpdateRoleModel -from utils.database_async import role_collection -from utils.async_db import db_find_one, db_insert_one, db_update_one, db_delete_one, db_find_list -from utils.cache_manager_util import cache_manager -from utils.doorman_cache_util import doorman_cache -from models.create_role_model import CreateRoleModel -from utils.paging_util import validate_page_params +from utils.async_db import db_delete_one, db_find_list, db_find_one, db_insert_one, db_update_one from utils.constants import ErrorCodes, Messages +from utils.database_async import role_collection +from utils.doorman_cache_util import doorman_cache +from utils.paging_util import validate_page_params logger = logging.getLogger('doorman.gateway') -class RoleService: +class RoleService: @staticmethod async def create_role(data: CreateRoleModel, request_id): """ @@ -31,11 +31,9 @@ class RoleService: logger.error(request_id + ' | Role creation failed with code ROLE001') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='ROLE001', - error_message='Role already exists' + error_message='Role already exists', ).dict() role_dict = data.dict() try: @@ -43,26 +41,19 @@ class RoleService: if not insert_result.acknowledged: logger.error(request_id + ' | Role creation failed with code ROLE002') return ResponseModel( - status_code=400, - error_code='ROLE002', - error_message='Unable to insert role' + status_code=400, error_code='ROLE002', error_message='Unable to insert role' ).dict() role_dict['_id'] = str(insert_result.inserted_id) doorman_cache.set_cache('role_cache', data.role_name, role_dict) logger.info(request_id + ' | Role creation successful') - return ResponseModel( - status_code=201, - message='Role created successfully' - ).dict() - except DuplicateKeyError as e: + return ResponseModel(status_code=201, message='Role created successfully').dict() + except DuplicateKeyError: logger.error(request_id + ' | Role creation failed with code ROLE001') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='ROLE001', - error_message='Role already exists' + error_message='Role already exists', ).dict() @staticmethod @@ -75,34 +66,29 @@ class RoleService: logger.error(request_id + ' | Role update failed with code ROLE005') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='ROLE005', - error_message='Role name cannot be changed' + error_message='Role name cannot be changed', ).dict() role = doorman_cache.get_cache('role_cache', role_name) if not role: - role = await db_find_one(role_collection, { - 'role_name': role_name - }) + role = await db_find_one(role_collection, {'role_name': role_name}) if not role: logger.error(request_id + ' | Role update failed with code ROLE004') return ResponseModel( - status_code=400, - error_code='ROLE004', - error_message='Role does not exist' + status_code=400, error_code='ROLE004', error_message='Role does not exist' ).dict() else: doorman_cache.delete_cache('role_cache', role_name) not_null_data = {k: v for k, v in data.dict().items() if v is not None} if not_null_data: try: - update_result = await db_update_one(role_collection, {'role_name': role_name}, {'$set': not_null_data}) + update_result = await db_update_one( + role_collection, {'role_name': role_name}, {'$set': not_null_data} + ) if update_result.modified_count > 0: doorman_cache.delete_cache('role_cache', role_name) if not update_result.acknowledged or update_result.modified_count == 0: - current = await db_find_one(role_collection, {'role_name': role_name}) or {} is_applied = all(current.get(k) == v for k, v in not_null_data.items()) if not is_applied: @@ -110,31 +96,30 @@ class RoleService: return ResponseModel( status_code=400, error_code='ROLE006', - error_message='Unable to update role' + error_message='Unable to update role', ).dict() except Exception as e: doorman_cache.delete_cache('role_cache', role_name) - logger.error(request_id + ' | Role update failed with exception: ' + str(e), exc_info=True) + logger.error( + request_id + ' | Role update failed with exception: ' + str(e), exc_info=True + ) raise updated_role = await db_find_one(role_collection, {'role_name': role_name}) or {} - if updated_role.get('_id'): del updated_role['_id'] + if updated_role.get('_id'): + del updated_role['_id'] doorman_cache.set_cache('role_cache', role_name, updated_role) logger.info(request_id + ' | Role update successful') return ResponseModel( - status_code=200, - response=updated_role, - message='Role updated successfully' + status_code=200, response=updated_role, message='Role updated successfully' ).dict() else: logger.error(request_id + ' | Role update failed with code ROLE007') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='ROLE007', - error_message='No data to update' + error_message='No data to update', ).dict() @staticmethod @@ -149,9 +134,7 @@ class RoleService: if not role: logger.error(request_id + ' | Role deletion failed with code ROLE004') return ResponseModel( - status_code=400, - error_code='ROLE004', - error_message='Role does not exist' + status_code=400, error_code='ROLE004', error_message='Role does not exist' ).dict() else: doorman_cache.delete_cache('role_cache', role_name) @@ -160,19 +143,15 @@ class RoleService: logger.error(request_id + ' | Role deletion failed with code ROLE008') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='ROLE008', - error_message='Unable to delete role' + error_message='Unable to delete role', ).dict() logger.info(request_id + ' | Role Deletion Successful') return ResponseModel( status_code=200, - response_headers={ - 'request_id': request_id - }, - message='Role deleted successfully' + response_headers={'request_id': request_id}, + message='Role deleted successfully', ).dict() @staticmethod @@ -180,7 +159,9 @@ class RoleService: """ Check if a role exists. """ - if doorman_cache.get_cache('role_cache', data.get('role_name')) or await db_find_one(role_collection, {'role_name': data.get('role_name')}): + if doorman_cache.get_cache('role_cache', data.get('role_name')) or await db_find_one( + role_collection, {'role_name': data.get('role_name')} + ): return True return False @@ -189,26 +170,28 @@ class RoleService: """ Get all roles. """ - logger.info(request_id + ' | Getting roles: Page=' + str(page) + ' Page Size=' + str(page_size)) + logger.info( + request_id + ' | Getting roles: Page=' + str(page) + ' Page Size=' + str(page_size) + ) try: page, page_size = validate_page_params(page, page_size) except Exception as e: return ResponseModel( status_code=400, error_code=ErrorCodes.PAGE_SIZE, - error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING) + error_message=( + Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING + ), ).dict() skip = (page - 1) * page_size roles_all = await db_find_list(role_collection, {}) roles_all.sort(key=lambda r: r.get('role_name')) - roles = roles_all[skip: skip + page_size] + roles = roles_all[skip : skip + page_size] for role in roles: - if role.get('_id'): del role['_id'] + if role.get('_id'): + del role['_id'] logger.info(request_id + ' | Roles retrieval successful') - return ResponseModel( - status_code=200, - response={'roles': roles} - ).dict() + return ResponseModel(status_code=200, response={'roles': roles}).dict() @staticmethod async def get_role(role_name, request_id): @@ -222,15 +205,12 @@ class RoleService: if not role: logger.error(request_id + ' | Role retrieval failed with code ROLE004') return ResponseModel( - status_code=404, - error_code='ROLE004', - error_message='Role does not exist' + status_code=404, error_code='ROLE004', error_message='Role does not exist' ).dict() - if role.get('_id'): del role['_id'] + if role.get('_id'): + del role['_id'] doorman_cache.set_cache('role_cache', role_name, role) - if role.get('_id'): del role['_id'] + if role.get('_id'): + del role['_id'] logger.info(request_id + ' | Role retrieval successful') - return ResponseModel( - status_code=200, - response=role - ).dict() + return ResponseModel(status_code=200, response=role).dict() diff --git a/backend-services/services/routing_service.py b/backend-services/services/routing_service.py index 1619154..ffa571f 100644 --- a/backend-services/services/routing_service.py +++ b/backend-services/services/routing_service.py @@ -4,22 +4,23 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/pypeople-dev/doorman for more information """ -from pymongo.errors import DuplicateKeyError -import uuid import logging +import uuid + +from pymongo.errors import DuplicateKeyError -from models.response_model import ResponseModel from models.create_routing_model import CreateRoutingModel +from models.response_model import ResponseModel from models.update_routing_model import UpdateRoutingModel +from utils.constants import ErrorCodes, Messages from utils.database import routing_collection from utils.doorman_cache_util import doorman_cache from utils.paging_util import validate_page_params -from utils.constants import ErrorCodes, Messages logger = logging.getLogger('doorman.gateway') -class RoutingService: +class RoutingService: @staticmethod async def create_routing(data: CreateRoutingModel, request_id): """ @@ -31,11 +32,9 @@ class RoutingService: logger.error(request_id + ' | Routing creation failed with code RTG001') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='RTG001', - error_message='Routing already exists' + error_message='Routing already exists', ).dict() routing_dict = data.dict() try: @@ -43,26 +42,21 @@ class RoutingService: if not insert_result.acknowledged: logger.error(request_id + ' | Routing creation failed with code RTG002') return ResponseModel( - status_code=400, - error_code='RTG002', - error_message='Unable to insert routing' + status_code=400, error_code='RTG002', error_message='Unable to insert routing' ).dict() routing_dict['_id'] = str(insert_result.inserted_id) doorman_cache.set_cache('client_routing_cache', data.client_key, routing_dict) logger.info(request_id + ' | Routing creation successful') return ResponseModel( - status_code=201, - message='Routing created successfully with key: ' + data.client_key, + status_code=201, message='Routing created successfully with key: ' + data.client_key ).dict() - except DuplicateKeyError as e: + except DuplicateKeyError: logger.error(request_id + ' | Routing creation failed with code RTG001') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='RTG001', - error_message='Routing already exists' + error_message='Routing already exists', ).dict() @staticmethod @@ -75,50 +69,39 @@ class RoutingService: logger.error(request_id + ' | Role update failed with code ROLE005') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='RTG005', - error_message='Routing key cannot be changed' + error_message='Routing key cannot be changed', ).dict() routing = doorman_cache.get_cache('client_routing_cache', client_key) if not routing: - routing = routing_collection.find_one({ - 'client_key': client_key - }) + routing = routing_collection.find_one({'client_key': client_key}) if not routing: logger.error(request_id + ' | Routing update failed with code RTG004') return ResponseModel( - status_code=400, - error_code='RTG004', - error_message='Routing does not exist' + status_code=400, error_code='RTG004', error_message='Routing does not exist' ).dict() else: doorman_cache.delete_cache('client_routing_cache', client_key) not_null_data = {k: v for k, v in data.dict().items() if v is not None} if not_null_data: - update_result = routing_collection.update_one({'client_key': client_key}, {'$set': not_null_data}) + update_result = routing_collection.update_one( + {'client_key': client_key}, {'$set': not_null_data} + ) if not update_result.acknowledged or update_result.modified_count == 0: logger.error(request_id + ' | Routing update failed with code RTG006') return ResponseModel( - status_code=400, - error_code='RTG006', - error_message='Unable to update routing' + status_code=400, error_code='RTG006', error_message='Unable to update routing' ).dict() logger.info(request_id + ' | Routing update successful') - return ResponseModel( - status_code=200, - message='Routing updated successfully' - ).dict() + return ResponseModel(status_code=200, message='Routing updated successfully').dict() else: logger.error(request_id + ' | Routing update failed with code RTG007') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='RTG007', - error_message='No data to update' + error_message='No data to update', ).dict() @staticmethod @@ -129,15 +112,11 @@ class RoutingService: logger.info(request_id + ' | Deleting: ' + client_key) routing = doorman_cache.get_cache('client_routing_cache', client_key) if not routing: - routing = routing_collection.find_one({ - 'client_key': client_key - }) + routing = routing_collection.find_one({'client_key': client_key}) if not routing: logger.error(request_id + ' | Routing deletion failed with code RTG004') return ResponseModel( - status_code=400, - error_code='RTG004', - error_message='Routing does not exist' + status_code=400, error_code='RTG004', error_message='Routing does not exist' ).dict() else: doorman_cache.delete_cache('client_routing_cache', client_key) @@ -146,19 +125,15 @@ class RoutingService: logger.error(request_id + ' | Routing deletion failed with code RTG008') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='RTG008', - error_message='Unable to delete routing' + error_message='Unable to delete routing', ).dict() logger.info(request_id + ' | Routing deletion successful') return ResponseModel( status_code=200, - response_headers={ - 'request_id': request_id - }, - message='Routing deleted successfully' + response_headers={'request_id': request_id}, + message='Routing deleted successfully', ).dict() @staticmethod @@ -169,44 +144,40 @@ class RoutingService: logger.info(request_id + ' | Getting: ' + client_key) routing = doorman_cache.get_cache('client_routing_cache', client_key) if not routing: - routing = routing_collection.find_one({ - 'client_key': client_key - }) + routing = routing_collection.find_one({'client_key': client_key}) if not routing: logger.error(request_id + ' | Routing retrieval failed with code RTG004') return ResponseModel( - status_code=400, - error_code='RTG004', - error_message='Routing does not exist' + status_code=400, error_code='RTG004', error_message='Routing does not exist' ).dict() logger.info(request_id + ' | Routing retrieval successful') - if routing.get('_id'): del routing['_id'] - return ResponseModel( - status_code=200, - response=routing - ).dict() + if routing.get('_id'): + del routing['_id'] + return ResponseModel(status_code=200, response=routing).dict() @staticmethod async def get_routings(page=1, page_size=10, request_id=None): """ Get all routings. """ - logger.info(request_id + ' | Getting routings: Page=' + str(page) + ' Page Size=' + str(page_size)) + logger.info( + request_id + ' | Getting routings: Page=' + str(page) + ' Page Size=' + str(page_size) + ) try: page, page_size = validate_page_params(page, page_size) except Exception as e: return ResponseModel( status_code=400, error_code=ErrorCodes.PAGE_SIZE, - error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING) + error_message=( + Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING + ), ).dict() skip = (page - 1) * page_size cursor = routing_collection.find().sort('client_key', 1).skip(skip).limit(page_size) routings = cursor.to_list(length=None) for route in routings: - if route.get('_id'): del route['_id'] + if route.get('_id'): + del route['_id'] logger.info(request_id + ' | Routing retrieval successful') - return ResponseModel( - status_code=200, - response={'routings': routings} - ).dict() + return ResponseModel(status_code=200, response={'routings': routings}).dict() diff --git a/backend-services/services/subscription_service.py b/backend-services/services/subscription_service.py index 59cf735..cbdcbd5 100644 --- a/backend-services/services/subscription_service.py +++ b/backend-services/services/subscription_service.py @@ -7,15 +7,14 @@ See https://github.com/pypeople-dev/doorman for more information import logging from models.response_model import ResponseModel -from utils.database import subscriptions_collection, api_collection -from utils.cache_manager_util import cache_manager -from utils.doorman_cache_util import doorman_cache from models.subscribe_model import SubscribeModel +from utils.database import api_collection, subscriptions_collection +from utils.doorman_cache_util import doorman_cache logger = logging.getLogger('doorman.gateway') -class SubscriptionService: +class SubscriptionService: @staticmethod async def api_exists(api_name, api_version): """ @@ -40,12 +39,15 @@ class SubscriptionService: if not api.get('api_id'): # Ensure api_id exists for consistent caching import uuid as _uuid + api['api_id'] = api.get('api_id') or str(_uuid.uuid4()) except Exception: pass doorman_cache.set_cache('api_cache', path_key, api) try: - doorman_cache.set_cache('api_id_cache', f'/{api_name}/{api_version}', api.get('api_id')) + doorman_cache.set_cache( + 'api_id_cache', f'/{api_name}/{api_version}', api.get('api_id') + ) except Exception: pass if api and '_id' in api: @@ -60,16 +62,16 @@ class SubscriptionService: logger.info(f'{request_id} | Getting subscriptions for: {username}') subscriptions = doorman_cache.get_cache('user_subscription_cache', username) - if not subscriptions or not isinstance(subscriptions, dict) or not subscriptions.get('apis'): - + if ( + not subscriptions + or not isinstance(subscriptions, dict) + or not subscriptions.get('apis') + ): subscriptions = subscriptions_collection.find_one({'username': username}) if not subscriptions: logger.info(f'{request_id} | No subscriptions found; returning empty list') - return ResponseModel( - status_code=200, - response={'apis': []} - ).dict() + return ResponseModel(status_code=200, response={'apis': []}).dict() if subscriptions.get('_id'): del subscriptions['_id'] @@ -78,8 +80,7 @@ class SubscriptionService: apis = subscriptions.get('apis', []) if isinstance(subscriptions, dict) else [] logger.info(f'{request_id} | Subscriptions retrieved successfully') return ResponseModel( - status_code=200, - response={'apis': apis, 'subscriptions': {'apis': apis}} + status_code=200, response={'apis': apis, 'subscriptions': {'apis': apis}} ).dict() @staticmethod @@ -87,17 +88,17 @@ class SubscriptionService: """ Subscribe to an API. """ - logger.info(f'{request_id} | Subscribing {data.username} to API: {data.api_name}/{data.api_version}') + logger.info( + f'{request_id} | Subscribing {data.username} to API: {data.api_name}/{data.api_version}' + ) api = await SubscriptionService.api_exists(data.api_name, data.api_version) if not api: logger.error(f'{request_id} | Subscription failed with code SUB003') return ResponseModel( status_code=404, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='SUB003', - error_message='API does not exist for the requested name and version' + error_message='API does not exist for the requested name and version', ).dict() doorman_cache.delete_cache('user_subscription_cache', data.username) @@ -107,23 +108,24 @@ class SubscriptionService: if user_subscriptions is None: user_subscriptions = { 'username': data.username, - 'apis': [f'{data.api_name}/{data.api_version}'] + 'apis': [f'{data.api_name}/{data.api_version}'], } subscriptions_collection.insert_one(user_subscriptions) - elif 'apis' in user_subscriptions and f'{data.api_name}/{data.api_version}' in user_subscriptions['apis']: + elif ( + 'apis' in user_subscriptions + and f'{data.api_name}/{data.api_version}' in user_subscriptions['apis'] + ): logger.error(f'{request_id} | Subscription failed with code SUB004') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='SUB004', - error_message='User is already subscribed to the API' + error_message='User is already subscribed to the API', ).dict() else: subscriptions_collection.update_one( {'username': data.username}, - {'$push': {'apis': f'{data.api_name}/{data.api_version}'}} + {'$push': {'apis': f'{data.api_name}/{data.api_version}'}}, ) user_subscriptions = subscriptions_collection.find_one({'username': data.username}) @@ -133,8 +135,7 @@ class SubscriptionService: doorman_cache.set_cache('user_subscription_cache', data.username, user_subscriptions) logger.info(f'{request_id} | Subscription successful') return ResponseModel( - status_code=200, - response={'message': 'Successfully subscribed to the API'} + status_code=200, response={'message': 'Successfully subscribed to the API'} ).dict() @staticmethod @@ -142,36 +143,36 @@ class SubscriptionService: """ Unsubscribe from an API. """ - logger.info(f'{request_id} | Unsubscribing {data.username} from API: {data.api_name}/{data.api_version}') + logger.info( + f'{request_id} | Unsubscribing {data.username} from API: {data.api_name}/{data.api_version}' + ) api = await SubscriptionService.api_exists(data.api_name, data.api_version) if not api: return ResponseModel( status_code=404, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='SUB005', - error_message='API does not exist for the requested name and version' + error_message='API does not exist for the requested name and version', ).dict() doorman_cache.delete_cache('user_subscription_cache', data.username) user_subscriptions = subscriptions_collection.find_one({'username': data.username}) if user_subscriptions and '_id' in user_subscriptions: del user_subscriptions['_id'] - if not user_subscriptions or f'{data.api_name}/{data.api_version}' not in user_subscriptions.get('apis', []): + if ( + not user_subscriptions + or f'{data.api_name}/{data.api_version}' not in user_subscriptions.get('apis', []) + ): logger.error(f'{request_id} | Unsubscription failed with code SUB006') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='SUB006', - error_message='User is not subscribed to the API' + error_message='User is not subscribed to the API', ).dict() user_subscriptions['apis'].remove(f'{data.api_name}/{data.api_version}') subscriptions_collection.update_one( - {'username': data.username}, - {'$set': {'apis': user_subscriptions.get('apis', [])}} + {'username': data.username}, {'$set': {'apis': user_subscriptions.get('apis', [])}} ) user_subscriptions = subscriptions_collection.find_one({'username': data.username}) @@ -181,6 +182,5 @@ class SubscriptionService: doorman_cache.set_cache('user_subscription_cache', data.username, user_subscriptions) logger.info(f'{request_id} | Unsubscription successful') return ResponseModel( - status_code=200, - response={'message': 'Successfully unsubscribed from the API'} + status_code=200, response={'message': 'Successfully unsubscribed from the API'} ).dict() diff --git a/backend-services/services/tier_service.py b/backend-services/services/tier_service.py index 608a9cc..387e2db 100644 --- a/backend-services/services/tier_service.py +++ b/backend-services/services/tier_service.py @@ -6,15 +6,10 @@ Handles tier CRUD, user assignments, upgrades, downgrades, and transitions. """ import logging -from typing import Optional, List, Dict, Any from datetime import datetime, timedelta +from typing import Any -from models.rate_limit_models import ( - Tier, - TierLimits, - TierName, - UserTierAssignment -) +from models.rate_limit_models import Tier, TierLimits, TierName, UserTierAssignment logger = logging.getLogger(__name__) @@ -22,7 +17,7 @@ logger = logging.getLogger(__name__) class TierService: """ Service for managing tiers and user assignments - + Features: - CRUD operations for tiers - User-to-tier assignments @@ -31,204 +26,203 @@ class TierService: - Temporary tier assignments - Tier comparison and selection """ - + def __init__(self, db): """ Initialize tier service - + Args: db: MongoDB database instance (sync) or InMemoryDB """ self.db = db self.tiers_collection = db.tiers self.assignments_collection = db.user_tier_assignments - + # ======================================================================== # TIER CRUD OPERATIONS # ======================================================================== - + async def create_tier(self, tier: Tier) -> Tier: """ Create a new tier - + Args: tier: Tier object to create - + Returns: Created tier - + Raises: ValueError: If tier with same ID already exists """ # Check if tier already exists existing = await self.tiers_collection.find_one({'tier_id': tier.tier_id}) if existing: - raise ValueError(f"Tier with ID {tier.tier_id} already exists") - + raise ValueError(f'Tier with ID {tier.tier_id} already exists') + # Set timestamps tier.created_at = datetime.now() tier.updated_at = datetime.now() - + # Insert into database await self.tiers_collection.insert_one(tier.to_dict()) - - logger.info(f"Created tier: {tier.tier_id}") + + logger.info(f'Created tier: {tier.tier_id}') return tier - - async def get_tier(self, tier_id: str) -> Optional[Tier]: + + async def get_tier(self, tier_id: str) -> Tier | None: """ Get tier by ID - + Args: tier_id: Tier identifier - + Returns: Tier object or None if not found """ tier_data = await self.tiers_collection.find_one({'tier_id': tier_id}) - + if tier_data: return Tier.from_dict(tier_data) - + return None - - async def get_tier_by_name(self, name: TierName) -> Optional[Tier]: + + async def get_tier_by_name(self, name: TierName) -> Tier | None: """ Get tier by name - + Args: name: Tier name enum - + Returns: Tier object or None if not found """ tier_data = await self.tiers_collection.find_one({'name': name.value}) - + if tier_data: return Tier.from_dict(tier_data) - + return None - + async def list_tiers( self, enabled_only: bool = False, - search_term: Optional[str] = None, + search_term: str | None = None, skip: int = 0, - limit: int = 100 - ) -> List[Tier]: + limit: int = 100, + ) -> list[Tier]: """ List all tiers - + Args: enabled_only: Only return enabled tiers skip: Number of records to skip limit: Maximum number of records to return - + Returns: List of tiers """ query = {} if enabled_only: query['enabled'] = True - + cursor = self.tiers_collection.find(query).skip(skip).limit(limit) - tiers: List[Tier] = [] - + tiers: list[Tier] = [] + async for tier_data in cursor: tiers.append(Tier.from_dict(tier_data)) - + if search_term: term = search_term.lower() tiers = [ - tier for tier in tiers + tier + for tier in tiers if term in (tier.name or '').lower() or term in (tier.display_name or '').lower() or term in (tier.description or '').lower() ] - + return tiers - - async def update_tier(self, tier_id: str, updates: Dict[str, Any]) -> Optional[Tier]: + + async def update_tier(self, tier_id: str, updates: dict[str, Any]) -> Tier | None: """ Update tier - + Args: tier_id: Tier identifier updates: Dictionary of fields to update - + Returns: Updated tier or None if not found """ # Add updated timestamp updates['updated_at'] = datetime.now().isoformat() - + result = await self.tiers_collection.find_one_and_update( - {'tier_id': tier_id}, - {'$set': updates}, - return_document=True + {'tier_id': tier_id}, {'$set': updates}, return_document=True ) - + if result: - logger.info(f"Updated tier: {tier_id}") + logger.info(f'Updated tier: {tier_id}') return Tier.from_dict(result) - + return None - + async def delete_tier(self, tier_id: str) -> bool: """ Delete tier - + Args: tier_id: Tier identifier - + Returns: True if deleted, False if not found """ # Check if any users are assigned to this tier user_count = await self.assignments_collection.count_documents({'tier_id': tier_id}) - + if user_count > 0: - raise ValueError(f"Cannot delete tier {tier_id}: {user_count} users are assigned to it") - + raise ValueError(f'Cannot delete tier {tier_id}: {user_count} users are assigned to it') + result = await self.tiers_collection.delete_one({'tier_id': tier_id}) - + if result.deleted_count > 0: - logger.info(f"Deleted tier: {tier_id}") + logger.info(f'Deleted tier: {tier_id}') return True - + return False - - async def get_default_tier(self) -> Optional[Tier]: + + async def get_default_tier(self) -> Tier | None: """ Get the default tier - + Returns: Default tier or None if not set """ tier_data = await self.tiers_collection.find_one({'is_default': True}) - + if tier_data: return Tier.from_dict(tier_data) - + return None - + # ======================================================================== # USER TIER ASSIGNMENTS # ======================================================================== - + async def assign_user_to_tier( self, user_id: str, tier_id: str, - assigned_by: Optional[str] = None, - effective_from: Optional[datetime] = None, - effective_until: Optional[datetime] = None, - override_limits: Optional[TierLimits] = None, - notes: Optional[str] = None + assigned_by: str | None = None, + effective_from: datetime | None = None, + effective_until: datetime | None = None, + override_limits: TierLimits | None = None, + notes: str | None = None, ) -> UserTierAssignment: """ Assign user to a tier - + Args: user_id: User identifier tier_id: Tier identifier @@ -237,21 +231,21 @@ class TierService: effective_until: When assignment expires override_limits: Custom limits for this user notes: Assignment notes - + Returns: UserTierAssignment object - + Raises: ValueError: If tier doesn't exist """ # Verify tier exists tier = await self.get_tier(tier_id) if not tier: - raise ValueError(f"Tier {tier_id} not found") - + raise ValueError(f'Tier {tier_id} not found') + # Check if user already has an assignment existing = await self.assignments_collection.find_one({'user_id': user_id}) - + assignment = UserTierAssignment( user_id=user_id, tier_id=tier_id, @@ -260,257 +254,249 @@ class TierService: effective_until=effective_until, assigned_at=datetime.now(), assigned_by=assigned_by, - notes=notes + notes=notes, ) - + if existing: # Update existing assignment await self.assignments_collection.replace_one( - {'user_id': user_id}, - assignment.to_dict() + {'user_id': user_id}, assignment.to_dict() ) - logger.info(f"Updated tier assignment for user {user_id} to {tier_id}") + logger.info(f'Updated tier assignment for user {user_id} to {tier_id}') else: # Create new assignment await self.assignments_collection.insert_one(assignment.to_dict()) - logger.info(f"Assigned user {user_id} to tier {tier_id}") - + logger.info(f'Assigned user {user_id} to tier {tier_id}') + return assignment - - async def get_user_assignment(self, user_id: str) -> Optional[UserTierAssignment]: + + async def get_user_assignment(self, user_id: str) -> UserTierAssignment | None: """ Get user's tier assignment - + Args: user_id: User identifier - + Returns: UserTierAssignment or None if not assigned """ assignment_data = await self.assignments_collection.find_one({'user_id': user_id}) - + if assignment_data: return UserTierAssignment(**assignment_data) - + return None - - async def get_user_tier(self, user_id: str) -> Optional[Tier]: + + async def get_user_tier(self, user_id: str) -> Tier | None: """ Get the tier for a user - + Considers effective dates and returns appropriate tier. - + Args: user_id: User identifier - + Returns: Tier object or default tier if no assignment """ assignment = await self.get_user_assignment(user_id) - + if assignment: # Check if assignment is currently effective now = datetime.now() - + if assignment.effective_from and now < assignment.effective_from: # Assignment not yet effective return await self.get_default_tier() - + if assignment.effective_until and now > assignment.effective_until: # Assignment expired return await self.get_default_tier() - + # Get the assigned tier return await self.get_tier(assignment.tier_id) - + # No assignment, return default tier return await self.get_default_tier() - - async def get_user_limits(self, user_id: str) -> Optional[TierLimits]: + + async def get_user_limits(self, user_id: str) -> TierLimits | None: """ Get effective limits for a user - + Considers tier limits and user-specific overrides. - + Args: user_id: User identifier - + Returns: TierLimits object or None """ assignment = await self.get_user_assignment(user_id) - + if assignment and assignment.override_limits: # User has custom limits return assignment.override_limits - + # Get tier limits tier = await self.get_user_tier(user_id) if tier: return tier.limits - + return None - + async def remove_user_assignment(self, user_id: str) -> bool: """ Remove user's tier assignment - + Args: user_id: User identifier - + Returns: True if removed, False if no assignment """ result = await self.assignments_collection.delete_one({'user_id': user_id}) - + if result.deleted_count > 0: - logger.info(f"Removed tier assignment for user {user_id}") + logger.info(f'Removed tier assignment for user {user_id}') return True - + return False - + async def list_users_in_tier( - self, - tier_id: str, - skip: int = 0, - limit: int = 100 - ) -> List[UserTierAssignment]: + self, tier_id: str, skip: int = 0, limit: int = 100 + ) -> list[UserTierAssignment]: """ List all users assigned to a tier - + Args: tier_id: Tier identifier skip: Number of records to skip limit: Maximum number of records to return - + Returns: List of user assignments """ - cursor = self.assignments_collection.find( - {'tier_id': tier_id} - ).skip(skip).limit(limit) - + cursor = self.assignments_collection.find({'tier_id': tier_id}).skip(skip).limit(limit) + assignments = [] async for assignment_data in cursor: assignments.append(UserTierAssignment(**assignment_data)) - + return assignments - + # ======================================================================== # TIER UPGRADES & DOWNGRADES # ======================================================================== - + async def upgrade_user_tier( self, user_id: str, new_tier_id: str, immediate: bool = True, - scheduled_date: Optional[datetime] = None, - assigned_by: Optional[str] = None + scheduled_date: datetime | None = None, + assigned_by: str | None = None, ) -> UserTierAssignment: """ Upgrade user to a higher tier - + Args: user_id: User identifier new_tier_id: New tier identifier immediate: Apply immediately or schedule scheduled_date: When to apply (if not immediate) assigned_by: Who initiated the upgrade - + Returns: Updated UserTierAssignment """ # Get current and new tiers current_tier = await self.get_user_tier(user_id) new_tier = await self.get_tier(new_tier_id) - + if not new_tier: - raise ValueError(f"Tier {new_tier_id} not found") - + raise ValueError(f'Tier {new_tier_id} not found') + # Determine effective date effective_from = datetime.now() if immediate else scheduled_date - + # Create assignment assignment = await self.assign_user_to_tier( user_id=user_id, tier_id=new_tier_id, assigned_by=assigned_by, effective_from=effective_from, - notes=f"Upgraded from {current_tier.tier_id if current_tier else 'default'}" + notes=f'Upgraded from {current_tier.tier_id if current_tier else "default"}', ) - - logger.info(f"Upgraded user {user_id} to tier {new_tier_id}") + + logger.info(f'Upgraded user {user_id} to tier {new_tier_id}') return assignment - + async def downgrade_user_tier( self, user_id: str, new_tier_id: str, grace_period_days: int = 0, - assigned_by: Optional[str] = None + assigned_by: str | None = None, ) -> UserTierAssignment: """ Downgrade user to a lower tier - + Args: user_id: User identifier new_tier_id: New tier identifier grace_period_days: Days before downgrade takes effect assigned_by: Who initiated the downgrade - + Returns: Updated UserTierAssignment """ # Get current and new tiers current_tier = await self.get_user_tier(user_id) new_tier = await self.get_tier(new_tier_id) - + if not new_tier: - raise ValueError(f"Tier {new_tier_id} not found") - + raise ValueError(f'Tier {new_tier_id} not found') + # Calculate effective date with grace period effective_from = datetime.now() + timedelta(days=grace_period_days) - + # Create assignment assignment = await self.assign_user_to_tier( user_id=user_id, tier_id=new_tier_id, assigned_by=assigned_by, effective_from=effective_from, - notes=f"Downgraded from {current_tier.tier_id if current_tier else 'default'} with {grace_period_days} day grace period" + notes=f'Downgraded from {current_tier.tier_id if current_tier else "default"} with {grace_period_days} day grace period', + ) + + logger.info( + f'Scheduled downgrade for user {user_id} to tier {new_tier_id} on {effective_from}' ) - - logger.info(f"Scheduled downgrade for user {user_id} to tier {new_tier_id} on {effective_from}") return assignment - + async def temporary_tier_upgrade( - self, - user_id: str, - temp_tier_id: str, - duration_days: int, - assigned_by: Optional[str] = None + self, user_id: str, temp_tier_id: str, duration_days: int, assigned_by: str | None = None ) -> UserTierAssignment: """ Temporarily upgrade user to a higher tier - + Args: user_id: User identifier temp_tier_id: Temporary tier identifier duration_days: How many days the upgrade lasts assigned_by: Who initiated the upgrade - + Returns: UserTierAssignment with expiration """ temp_tier = await self.get_tier(temp_tier_id) if not temp_tier: - raise ValueError(f"Tier {temp_tier_id} not found") - + raise ValueError(f'Tier {temp_tier_id} not found') + # Set expiration effective_from = datetime.now() effective_until = effective_from + timedelta(days=duration_days) - + # Create temporary assignment assignment = await self.assign_user_to_tier( user_id=user_id, @@ -518,112 +504,112 @@ class TierService: assigned_by=assigned_by, effective_from=effective_from, effective_until=effective_until, - notes=f"Temporary upgrade for {duration_days} days" + notes=f'Temporary upgrade for {duration_days} days', + ) + + logger.info( + f'Temporary upgrade for user {user_id} to tier {temp_tier_id} until {effective_until}' ) - - logger.info(f"Temporary upgrade for user {user_id} to tier {temp_tier_id} until {effective_until}") return assignment - + # ======================================================================== # TIER COMPARISON & ANALYTICS # ======================================================================== - - async def compare_tiers(self, tier_ids: List[str]) -> List[Dict[str, Any]]: + + async def compare_tiers(self, tier_ids: list[str]) -> list[dict[str, Any]]: """ Compare multiple tiers side-by-side - + Args: tier_ids: List of tier identifiers to compare - + Returns: List of tier comparison data """ comparison = [] - + for tier_id in tier_ids: tier = await self.get_tier(tier_id) if tier: - comparison.append({ - 'tier_id': tier.tier_id, - 'name': tier.name.value, - 'display_name': tier.display_name, - 'limits': tier.limits.to_dict(), - 'price_monthly': tier.price_monthly, - 'price_yearly': tier.price_yearly, - 'features': tier.features - }) - + comparison.append( + { + 'tier_id': tier.tier_id, + 'name': tier.name.value, + 'display_name': tier.display_name, + 'limits': tier.limits.to_dict(), + 'price_monthly': tier.price_monthly, + 'price_yearly': tier.price_yearly, + 'features': tier.features, + } + ) + return comparison - - async def get_tier_statistics(self, tier_id: str) -> Dict[str, Any]: + + async def get_tier_statistics(self, tier_id: str) -> dict[str, Any]: """ Get statistics for a tier - + Args: tier_id: Tier identifier - + Returns: Dictionary with tier statistics """ # Count users in tier user_count = await self.assignments_collection.count_documents({'tier_id': tier_id}) - + # Count active assignments (within effective dates) now = datetime.now() - active_count = await self.assignments_collection.count_documents({ - 'tier_id': tier_id, - '$or': [ - {'effective_from': None}, - {'effective_from': {'$lte': now}} - ], - '$or': [ - {'effective_until': None}, - {'effective_until': {'$gte': now}} - ] - }) - + active_count = await self.assignments_collection.count_documents( + { + 'tier_id': tier_id, + '$or': [{'effective_from': None}, {'effective_from': {'$lte': now}}], + '$or': [{'effective_until': None}, {'effective_until': {'$gte': now}}], + } + ) + return { 'tier_id': tier_id, 'total_users': user_count, 'active_users': active_count, - 'inactive_users': user_count - active_count + 'inactive_users': user_count - active_count, } - - async def get_all_tier_statistics(self) -> List[Dict[str, Any]]: + + async def get_all_tier_statistics(self) -> list[dict[str, Any]]: """ Get statistics for all tiers - + Returns: List of tier statistics """ tiers = await self.list_tiers() statistics = [] - + for tier in tiers: stats = await self.get_tier_statistics(tier.tier_id) stats['tier_name'] = tier.display_name statistics.append(stats) - + return statistics # Global tier service instance -_tier_service: Optional[TierService] = None +_tier_service: TierService | None = None def get_tier_service(db) -> TierService: """ Get or create global tier service instance - + Args: db: MongoDB database instance (sync) or InMemoryDB - + Returns: TierService instance """ global _tier_service - + if _tier_service is None: _tier_service = TierService(db) - + return _tier_service diff --git a/backend-services/services/user_service.py b/backend-services/services/user_service.py index 1a5dabe..b3af7dd 100644 --- a/backend-services/services/user_service.py +++ b/backend-services/services/user_service.py @@ -4,40 +4,40 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import List -from fastapi import HTTPException import logging -import asyncio +import time +from fastapi import HTTPException + +from models.create_user_model import CreateUserModel from models.response_model import ResponseModel from utils import password_util -from utils.database_async import user_collection, subscriptions_collection, api_collection -from utils.async_db import db_find_one, db_insert_one, db_update_one, db_delete_one, db_find_list -from utils.doorman_cache_util import doorman_cache -from models.create_user_model import CreateUserModel -from utils.paging_util import validate_page_params -from utils.constants import ErrorCodes, Messages -from utils.role_util import platform_role_required_bool +from utils.async_db import db_delete_one, db_find_list, db_find_one, db_insert_one, db_update_one from utils.bandwidth_util import get_current_usage -import time +from utils.constants import ErrorCodes, Messages +from utils.database_async import api_collection, subscriptions_collection, user_collection +from utils.doorman_cache_util import doorman_cache +from utils.paging_util import validate_page_params +from utils.role_util import platform_role_required_bool logger = logging.getLogger('doorman.gateway') -class UserService: +class UserService: @staticmethod - async def get_user_by_email_with_password_helper(email): + async def get_user_by_email_with_password_helper(email: str) -> dict: """ Retrieve a user by email. """ user = await db_find_one(user_collection, {'email': email}) - if user.get('_id'): del user['_id'] + if user.get('_id'): + del user['_id'] if not user: raise HTTPException(status_code=404, detail='User not found') return user @staticmethod - async def get_user_by_username_helper(username): + async def get_user_by_username_helper(username: str) -> dict: """ Retrieve a user by username. """ @@ -47,17 +47,19 @@ class UserService: user = await db_find_one(user_collection, {'username': username}) if not user: raise HTTPException(status_code=404, detail='User not found') - if user.get('_id'): del user['_id'] - if user.get('password'): del user['password'] + if user.get('_id'): + del user['_id'] + if user.get('password'): + del user['password'] doorman_cache.set_cache('user_cache', username, user) if not user: raise HTTPException(status_code=404, detail='User not found') return user - except Exception as e: + except Exception: raise HTTPException(status_code=404, detail='User not found') @staticmethod - async def get_user_by_username(username, request_id): + async def get_user_by_username(username: str, request_id: str) -> dict: """ Retrieve a user by username. """ @@ -68,22 +70,20 @@ class UserService: if not user: logger.error(f'{request_id} | User retrieval failed with code USR002') return ResponseModel( - status_code=404, - error_code='USR002', - error_message='User not found' + status_code=404, error_code='USR002', error_message='User not found' ).dict() - if user.get('_id'): del user['_id'] - if user.get('password'): del user['password'] + if user.get('_id'): + del user['_id'] + if user.get('password'): + del user['password'] doorman_cache.set_cache('user_cache', username, user) if not user: logger.error(f'{request_id} | User retrieval failed with code USR002') return ResponseModel( status_code=404, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='USR002', - error_message='User not found' + error_message='User not found', ).dict() try: limit = user.get('bandwidth_limit_bytes') @@ -108,13 +108,10 @@ class UserService: except Exception: pass logger.info(f'{request_id} | User retrieval successful') - return ResponseModel( - status_code=200, - response=user - ).dict() + return ResponseModel(status_code=200, response=user).dict() @staticmethod - async def get_user_by_email(active_username, email, request_id): + async def get_user_by_email(active_username: str, email: str, request_id: str) -> dict: """ Retrieve a user by email. """ @@ -128,81 +125,72 @@ class UserService: logger.error(f'{request_id} | User retrieval failed with code USR002') return ResponseModel( status_code=404, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='USR002', - error_message='User not found' + error_message='User not found', ).dict() logger.info(f'{request_id} | User retrieval successful') - if not active_username == user.get('username') and not await platform_role_required_bool(active_username, 'manage_users'): + if not active_username == user.get('username') and not await platform_role_required_bool( + active_username, 'manage_users' + ): logger.error(f'{request_id} | User retrieval failed with code USR008') return ResponseModel( - status_code=403, - error_code='USR008', - error_message='Unable to retrieve information for user', - ).dict() - return ResponseModel( - status_code=200, - response=user - ).dict() + status_code=403, + error_code='USR008', + error_message='Unable to retrieve information for user', + ).dict() + return ResponseModel(status_code=200, response=user).dict() @staticmethod - async def create_user(data: CreateUserModel, request_id): + async def create_user(data: CreateUserModel, request_id: str) -> dict: """ Create a new user. """ logger.info(f'{request_id} | Creating user: {data.username}') try: if data.custom_attributes is not None and len(data.custom_attributes.keys()) > 10: - logger.error(f"{request_id} | User creation failed with code USR016: Too many custom attributes") + logger.error( + f'{request_id} | User creation failed with code USR016: Too many custom attributes' + ) return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='USR016', - error_message='Maximum 10 custom attributes allowed. Please replace an existing one.' + error_message='Maximum 10 custom attributes allowed. Please replace an existing one.', ).dict() except Exception: - logger.error(f"{request_id} | User creation failed with code USR016: Invalid custom attributes payload") + logger.error( + f'{request_id} | User creation failed with code USR016: Invalid custom attributes payload' + ) return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='USR016', - error_message='Maximum 10 custom attributes allowed. Please replace an existing one.' + error_message='Maximum 10 custom attributes allowed. Please replace an existing one.', ).dict() if await db_find_one(user_collection, {'username': data.username}): logger.error(f'{request_id} | User creation failed with code USR001') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='USR001', - error_message='Username already exists' + error_message='Username already exists', ).dict() if await db_find_one(user_collection, {'email': data.email}): logger.error(f'{request_id} | User creation failed with code USR001') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='USR001', - error_message='Email already exists' + error_message='Email already exists', ).dict() if not password_util.is_secure_password(data.password): logger.error(f'{request_id} | User creation failed with code USR005') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='USR005', - error_message='Password must include at least 16 characters, one uppercase letter, one lowercase letter, one digit, and one special character' + error_message='Password must include at least 16 characters, one uppercase letter, one lowercase letter, one digit, and one special character', ).dict() data.password = password_util.hash_password(data.password) data_dict = data.dict() @@ -215,14 +203,12 @@ class UserService: logger.info(f'{request_id} | User creation successful') return ResponseModel( status_code=201, - response_headers={ - 'request_id': request_id - }, - message='User created successfully' + response_headers={'request_id': request_id}, + message='User created successfully', ).dict() @staticmethod - async def check_password_return_user(email, password): + async def check_password_return_user(email: str, password: str) -> dict: """ Verify password and return user if valid. """ @@ -230,7 +216,6 @@ class UserService: try: user = await UserService.get_user_by_email_with_password_helper(email) except Exception: - maybe_user = await db_find_one(user_collection, {'username': email}) if maybe_user: user = maybe_user @@ -239,11 +224,11 @@ class UserService: if not password_util.verify_password(password, user.get('password')): raise HTTPException(status_code=400, detail='Invalid email or password') return user - except Exception as e: + except Exception: raise HTTPException(status_code=400, detail='Invalid email or password') @staticmethod - async def update_user(username, update_data, request_id): + async def update_user(username: str, update_data: dict, request_id: str) -> dict: """ Update user information. """ @@ -254,58 +239,63 @@ class UserService: if not user: logger.error(f'{request_id} | User update failed with code USR002') return ResponseModel( - status_code=404, - error_code='USR002', - error_message='User not found' + status_code=404, error_code='USR002', error_message='User not found' ).dict() else: doorman_cache.delete_cache('user_cache', username) non_null_update_data = {k: v for k, v in update_data.dict().items() if v is not None} if 'custom_attributes' in non_null_update_data: try: - if non_null_update_data['custom_attributes'] is not None and len(non_null_update_data['custom_attributes'].keys()) > 10: - logger.error(f"{request_id} | User update failed with code USR016: Too many custom attributes") + if ( + non_null_update_data['custom_attributes'] is not None + and len(non_null_update_data['custom_attributes'].keys()) > 10 + ): + logger.error( + f'{request_id} | User update failed with code USR016: Too many custom attributes' + ) return ResponseModel( status_code=400, error_code='USR016', - error_message='Maximum 10 custom attributes allowed. Please replace an existing one.' + error_message='Maximum 10 custom attributes allowed. Please replace an existing one.', ).dict() except Exception: - logger.error(f"{request_id} | User update failed with code USR016: Invalid custom attributes payload") + logger.error( + f'{request_id} | User update failed with code USR016: Invalid custom attributes payload' + ) return ResponseModel( status_code=400, error_code='USR016', - error_message='Maximum 10 custom attributes allowed. Please replace an existing one.' + error_message='Maximum 10 custom attributes allowed. Please replace an existing one.', ).dict() if non_null_update_data: try: - update_result = await db_update_one(user_collection, {'username': username}, {'$set': non_null_update_data}) + update_result = await db_update_one( + user_collection, {'username': username}, {'$set': non_null_update_data} + ) if update_result.modified_count > 0: doorman_cache.delete_cache('user_cache', username) if not update_result.acknowledged or update_result.modified_count == 0: logger.error(f'{request_id} | User update failed with code USR003') return ResponseModel( - status_code=400, - error_code='USR004', - error_message='Unable to update user' + status_code=400, error_code='USR004', error_message='Unable to update user' ).dict() except Exception as e: doorman_cache.delete_cache('user_cache', username) - logger.error(f'{request_id} | User update failed with exception: {str(e)}', exc_info=True) + logger.error( + f'{request_id} | User update failed with exception: {str(e)}', exc_info=True + ) raise if non_null_update_data.get('role'): await UserService.purge_apis_after_role_change(username, request_id) logger.info(f'{request_id} | User update successful') return ResponseModel( status_code=200, - response_headers={ - 'request_id': request_id - }, - message='User updated successfully' + response_headers={'request_id': request_id}, + message='User updated successfully', ).dict() @staticmethod - async def delete_user(username, request_id): + async def delete_user(username: str, request_id: str) -> dict: """ Delete a user. """ @@ -316,34 +306,28 @@ class UserService: if not user: logger.error(f'{request_id} | User deletion failed with code USR002') return ResponseModel( - status_code=404, - error_code='USR002', - error_message='User not found' + status_code=404, error_code='USR002', error_message='User not found' ).dict() delete_result = await db_delete_one(user_collection, {'username': username}) if not delete_result.acknowledged or delete_result.deleted_count == 0: logger.error(f'{request_id} | User deletion failed with code USR003') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='USR003', - error_message='Unable to delete user' + error_message='Unable to delete user', ).dict() doorman_cache.delete_cache('user_cache', username) doorman_cache.delete_cache('user_subscription_cache', username) logger.info(f'{request_id} | User deletion successful') return ResponseModel( status_code=200, - response_headers={ - 'request_id': request_id - }, - message='User deleted successfully' + response_headers={'request_id': request_id}, + message='User deleted successfully', ).dict() @staticmethod - async def update_password(username, update_data, request_id): + async def update_password(username: str, update_data: dict, request_id: str) -> dict: """ Update user information. """ @@ -352,31 +336,32 @@ class UserService: logger.error(f'{request_id} | User password update failed with code USR005') return ResponseModel( status_code=400, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='USR005', - error_message='Password must include at least 16 characters, one uppercase letter, one lowercase letter, one digit, and one special character' + error_message='Password must include at least 16 characters, one uppercase letter, one lowercase letter, one digit, and one special character', ).dict() hashed_password = password_util.hash_password(update_data.new_password) try: - update_result = await db_update_one(user_collection, {'username': username}, {'$set': {'password': hashed_password}}) + update_result = await db_update_one( + user_collection, {'username': username}, {'$set': {'password': hashed_password}} + ) if update_result.modified_count > 0: doorman_cache.delete_cache('user_cache', username) except Exception as e: doorman_cache.delete_cache('user_cache', username) - logger.error(f'{request_id} | User password update failed with exception: {str(e)}', exc_info=True) + logger.error( + f'{request_id} | User password update failed with exception: {str(e)}', + exc_info=True, + ) raise user = await db_find_one(user_collection, {'username': username}) if not user: logger.error(f'{request_id} | User password update failed with code USR002') return ResponseModel( status_code=404, - response_headers={ - 'request_id': request_id - }, + response_headers={'request_id': request_id}, error_code='USR002', - error_message='User not found' + error_message='User not found', ).dict() if '_id' in user: del user['_id'] @@ -386,42 +371,52 @@ class UserService: logger.info(f'{request_id} | User password update successful') return ResponseModel( status_code=200, - response_headers={ - 'request_id': request_id - }, - message='User updated successfully' + response_headers={'request_id': request_id}, + message='User updated successfully', ).dict() @staticmethod - async def purge_apis_after_role_change(username, request_id): + async def purge_apis_after_role_change(username: str, request_id: str) -> None: """ Remove subscriptions after role change. """ logger.info(f'{request_id} | Purging APIs for user: {username}') - user_subscriptions = doorman_cache.get_cache('user_subscription_cache', username) or await db_find_one(subscriptions_collection, {'username': username}) + user_subscriptions = doorman_cache.get_cache( + 'user_subscription_cache', username + ) or await db_find_one(subscriptions_collection, {'username': username}) if user_subscriptions: for subscription in user_subscriptions.get('apis'): api_name, api_version = subscription.split('/') - user = doorman_cache.get_cache('user_cache', username) or await db_find_one(user_collection, {'username': username}) - api = doorman_cache.get_cache('api_cache', f'{api_name}/{api_version}') or await db_find_one(api_collection, {'api_name': api_name, 'api_version': api_version}) + user = doorman_cache.get_cache('user_cache', username) or await db_find_one( + user_collection, {'username': username} + ) + api = doorman_cache.get_cache( + 'api_cache', f'{api_name}/{api_version}' + ) or await db_find_one( + api_collection, {'api_name': api_name, 'api_version': api_version} + ) if api and api.get('role') and user.get('role') not in api.get('role'): user_subscriptions['apis'].remove(subscription) try: - update_result = await db_update_one(subscriptions_collection, + update_result = await db_update_one( + subscriptions_collection, {'username': username}, - {'$set': {'apis': user_subscriptions.get('apis', [])}} + {'$set': {'apis': user_subscriptions.get('apis', [])}}, ) if update_result.modified_count > 0: doorman_cache.delete_cache('user_subscription_cache', username) doorman_cache.set_cache('user_subscription_cache', username, user_subscriptions) except Exception as e: doorman_cache.delete_cache('user_subscription_cache', username) - logger.error(f'{request_id} | Subscription update failed with exception: {str(e)}', exc_info=True) + logger.error( + f'{request_id} | Subscription update failed with exception: {str(e)}', + exc_info=True, + ) raise logger.info(f'{request_id} | Purge successful') @staticmethod - async def get_all_users(page, page_size, request_id): + async def get_all_users(page: int, page_size: int, request_id: str) -> dict: """ Get all users. """ @@ -432,20 +427,21 @@ class UserService: return ResponseModel( status_code=400, error_code=ErrorCodes.PAGE_SIZE, - error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING) + error_message=( + Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING + ), ).dict() skip = (page - 1) * page_size users_all = await db_find_list(user_collection, {}) users_all.sort(key=lambda u: u.get('username')) - users = users_all[skip: skip + page_size] + users = users_all[skip : skip + page_size] for user in users: - if user.get('_id'): del user['_id'] - if user.get('password'): del user['password'] + if user.get('_id'): + del user['_id'] + if user.get('password'): + del user['password'] for key, value in user.items(): if isinstance(value, bytes): user[key] = value.decode('utf-8') logger.info(f'{request_id} | User retrieval successful') - return ResponseModel( - status_code=200, - response={'users': users} - ).dict() + return ResponseModel(status_code=200, response={'users': users}).dict() diff --git a/backend-services/services/vault_service.py b/backend-services/services/vault_service.py index 9abdc11..cc81430 100644 --- a/backend-services/services/vault_service.py +++ b/backend-services/services/vault_service.py @@ -4,25 +4,21 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import List, Optional -from fastapi import HTTPException import logging from datetime import datetime try: from datetime import UTC except Exception: - from datetime import timezone as _timezone - UTC = _timezone.utc + UTC = UTC -from models.response_model import ResponseModel from models.create_vault_entry_model import CreateVaultEntryModel +from models.response_model import ResponseModel from models.update_vault_entry_model import UpdateVaultEntryModel -from models.vault_entry_model_response import VaultEntryModelResponse -from utils.database_async import vault_entries_collection, user_collection -from utils.async_db import db_find_one, db_insert_one, db_update_one, db_delete_one, db_find_list -from utils.vault_encryption_util import encrypt_vault_value, is_vault_configured +from utils.async_db import db_delete_one, db_find_list, db_find_one, db_insert_one, db_update_one from utils.constants import ErrorCodes, Messages +from utils.database_async import user_collection, vault_entries_collection +from utils.vault_encryption_util import encrypt_vault_value, is_vault_configured logger = logging.getLogger('doorman.gateway') @@ -31,61 +27,62 @@ class VaultService: """Service for managing encrypted vault entries.""" @staticmethod - async def create_vault_entry(username: str, entry_data: CreateVaultEntryModel, request_id: str) -> dict: + async def create_vault_entry( + username: str, entry_data: CreateVaultEntryModel, request_id: str + ) -> dict: """ Create a new vault entry for a user. - + Args: username: Username of the vault entry owner entry_data: Vault entry creation data request_id: Request ID for logging - + Returns: ResponseModel dict with success or error """ - logger.info(f'{request_id} | Creating vault entry: {entry_data.key_name} for user: {username}') - + logger.info( + f'{request_id} | Creating vault entry: {entry_data.key_name} for user: {username}' + ) + # Check if VAULT_KEY is configured if not is_vault_configured(): logger.error(f'{request_id} | VAULT_KEY not configured') return ResponseModel( status_code=500, error_code='VAULT001', - error_message='Vault encryption is not configured. Set VAULT_KEY in environment variables.' + error_message='Vault encryption is not configured. Set VAULT_KEY in environment variables.', ).dict() - + # Get user to retrieve email user = await db_find_one(user_collection, {'username': username}) if not user: logger.error(f'{request_id} | User not found: {username}') return ResponseModel( - status_code=404, - error_code='VAULT002', - error_message='User not found' + status_code=404, error_code='VAULT002', error_message='User not found' ).dict() - + email = user.get('email') if not email: logger.error(f'{request_id} | User email not found: {username}') return ResponseModel( status_code=400, error_code='VAULT003', - error_message='User email is required for vault encryption' + error_message='User email is required for vault encryption', ).dict() - + # Check if entry already exists existing = await db_find_one( - vault_entries_collection, - {'username': username, 'key_name': entry_data.key_name} + vault_entries_collection, {'username': username, 'key_name': entry_data.key_name} ) if existing: logger.error(f'{request_id} | Vault entry already exists: {entry_data.key_name}') return ResponseModel( status_code=409, error_code='VAULT004', - error_message=f'Vault entry with key_name "{entry_data.key_name}" already exists' + error_message=f'Vault entry with key_name "{entry_data.key_name}" already exists', ).dict() - + # Encrypt the value try: encrypted_value = encrypt_vault_value(entry_data.value, email, username) @@ -94,9 +91,9 @@ class VaultService: return ResponseModel( status_code=500, error_code='VAULT005', - error_message='Failed to encrypt vault value' + error_message='Failed to encrypt vault value', ).dict() - + # Create vault entry now = datetime.now(UTC).isoformat() vault_entry = { @@ -105,89 +102,83 @@ class VaultService: 'encrypted_value': encrypted_value, 'description': entry_data.description, 'created_at': now, - 'updated_at': now + 'updated_at': now, } - + try: result = await db_insert_one(vault_entries_collection, vault_entry) if result.acknowledged: - logger.info(f'{request_id} | Vault entry created successfully: {entry_data.key_name}') + logger.info( + f'{request_id} | Vault entry created successfully: {entry_data.key_name}' + ) return ResponseModel( status_code=201, message='Vault entry created successfully', - data={'key_name': entry_data.key_name} + data={'key_name': entry_data.key_name}, ).dict() else: logger.error(f'{request_id} | Failed to create vault entry') return ResponseModel( status_code=500, error_code='VAULT006', - error_message='Failed to create vault entry' + error_message='Failed to create vault entry', ).dict() except Exception as e: logger.error(f'{request_id} | Error creating vault entry: {str(e)}') return ResponseModel( - status_code=500, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED + status_code=500, error_code=ErrorCodes.UNEXPECTED, error_message=Messages.UNEXPECTED ).dict() @staticmethod async def get_vault_entry(username: str, key_name: str, request_id: str) -> dict: """ Get a vault entry by key name. Value is NOT returned. - + Args: username: Username of the vault entry owner key_name: Name of the vault key request_id: Request ID for logging - + Returns: ResponseModel dict with vault entry (without value) or error """ logger.info(f'{request_id} | Getting vault entry: {key_name} for user: {username}') - + entry = await db_find_one( - vault_entries_collection, - {'username': username, 'key_name': key_name} + vault_entries_collection, {'username': username, 'key_name': key_name} ) - + if not entry: logger.error(f'{request_id} | Vault entry not found: {key_name}') return ResponseModel( - status_code=404, - error_code='VAULT007', - error_message='Vault entry not found' + status_code=404, error_code='VAULT007', error_message='Vault entry not found' ).dict() - + # Remove sensitive data if entry.get('_id'): del entry['_id'] if entry.get('encrypted_value'): del entry['encrypted_value'] - - return ResponseModel( - status_code=200, - data=entry - ).dict() + + return ResponseModel(status_code=200, data=entry).dict() @staticmethod async def list_vault_entries(username: str, request_id: str) -> dict: """ List all vault entries for a user. Values are NOT returned. - + Args: username: Username of the vault entry owner request_id: Request ID for logging - + Returns: ResponseModel dict with list of vault entries (without values) """ logger.info(f'{request_id} | Listing vault entries for user: {username}') - + try: entries = await db_find_list(vault_entries_collection, {'username': username}) - + # Remove sensitive data from all entries clean_entries = [] for entry in entries: @@ -196,137 +187,120 @@ class VaultService: if entry.get('encrypted_value'): del entry['encrypted_value'] clean_entries.append(entry) - + return ResponseModel( - status_code=200, - data={'entries': clean_entries, 'count': len(clean_entries)} + status_code=200, data={'entries': clean_entries, 'count': len(clean_entries)} ).dict() except Exception as e: logger.error(f'{request_id} | Error listing vault entries: {str(e)}') return ResponseModel( - status_code=500, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED + status_code=500, error_code=ErrorCodes.UNEXPECTED, error_message=Messages.UNEXPECTED ).dict() @staticmethod - async def update_vault_entry(username: str, key_name: str, update_data: UpdateVaultEntryModel, request_id: str) -> dict: + async def update_vault_entry( + username: str, key_name: str, update_data: UpdateVaultEntryModel, request_id: str + ) -> dict: """ Update a vault entry. Only description can be updated, not the value. - + Args: username: Username of the vault entry owner key_name: Name of the vault key update_data: Update data (description only) request_id: Request ID for logging - + Returns: ResponseModel dict with success or error """ logger.info(f'{request_id} | Updating vault entry: {key_name} for user: {username}') - + # Check if entry exists entry = await db_find_one( - vault_entries_collection, - {'username': username, 'key_name': key_name} + vault_entries_collection, {'username': username, 'key_name': key_name} ) - + if not entry: logger.error(f'{request_id} | Vault entry not found: {key_name}') return ResponseModel( - status_code=404, - error_code='VAULT007', - error_message='Vault entry not found' + status_code=404, error_code='VAULT007', error_message='Vault entry not found' ).dict() - + # Update only description now = datetime.now(UTC).isoformat() - update_fields = { - 'updated_at': now - } - + update_fields = {'updated_at': now} + if update_data.description is not None: update_fields['description'] = update_data.description - + try: result = await db_update_one( vault_entries_collection, {'username': username, 'key_name': key_name}, - {'$set': update_fields} + {'$set': update_fields}, ) - + if result.modified_count > 0: logger.info(f'{request_id} | Vault entry updated successfully: {key_name}') return ResponseModel( - status_code=200, - message='Vault entry updated successfully' + status_code=200, message='Vault entry updated successfully' ).dict() else: logger.warning(f'{request_id} | No changes made to vault entry: {key_name}') return ResponseModel( - status_code=200, - message='No changes made to vault entry' + status_code=200, message='No changes made to vault entry' ).dict() except Exception as e: logger.error(f'{request_id} | Error updating vault entry: {str(e)}') return ResponseModel( - status_code=500, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED + status_code=500, error_code=ErrorCodes.UNEXPECTED, error_message=Messages.UNEXPECTED ).dict() @staticmethod async def delete_vault_entry(username: str, key_name: str, request_id: str) -> dict: """ Delete a vault entry. - + Args: username: Username of the vault entry owner key_name: Name of the vault key request_id: Request ID for logging - + Returns: ResponseModel dict with success or error """ logger.info(f'{request_id} | Deleting vault entry: {key_name} for user: {username}') - + # Check if entry exists entry = await db_find_one( - vault_entries_collection, - {'username': username, 'key_name': key_name} + vault_entries_collection, {'username': username, 'key_name': key_name} ) - + if not entry: logger.error(f'{request_id} | Vault entry not found: {key_name}') return ResponseModel( - status_code=404, - error_code='VAULT007', - error_message='Vault entry not found' + status_code=404, error_code='VAULT007', error_message='Vault entry not found' ).dict() - + try: result = await db_delete_one( - vault_entries_collection, - {'username': username, 'key_name': key_name} + vault_entries_collection, {'username': username, 'key_name': key_name} ) - + if result.deleted_count > 0: logger.info(f'{request_id} | Vault entry deleted successfully: {key_name}') return ResponseModel( - status_code=200, - message='Vault entry deleted successfully' + status_code=200, message='Vault entry deleted successfully' ).dict() else: logger.error(f'{request_id} | Failed to delete vault entry: {key_name}') return ResponseModel( status_code=500, error_code='VAULT008', - error_message='Failed to delete vault entry' + error_message='Failed to delete vault entry', ).dict() except Exception as e: logger.error(f'{request_id} | Error deleting vault entry: {str(e)}') return ResponseModel( - status_code=500, - error_code=ErrorCodes.UNEXPECTED, - error_message=Messages.UNEXPECTED + status_code=500, error_code=ErrorCodes.UNEXPECTED, error_message=Messages.UNEXPECTED ).dict() diff --git a/backend-services/tests/conftest.py b/backend-services/tests/conftest.py index 2747bdf..77475ad 100644 --- a/backend-services/tests/conftest.py +++ b/backend-services/tests/conftest.py @@ -22,6 +22,7 @@ os.environ.setdefault('DOORMAN_TEST_MODE', 'true') try: import sys as _sys + if _sys.version_info >= (3, 13): os.environ.setdefault('DISABLE_PLATFORM_CHUNKED_WRAP', 'true') os.environ.setdefault('DISABLE_PLATFORM_CORS_ASGI', 'true') @@ -34,19 +35,23 @@ _PROJECT_ROOT = os.path.abspath(os.path.join(_HERE, os.pardir)) if _PROJECT_ROOT not in sys.path: sys.path.insert(0, _PROJECT_ROOT) +import asyncio +import datetime as _dt + +import pytest import pytest_asyncio from httpx import AsyncClient -import pytest -import asyncio -from typing import Optional -import datetime as _dt try: from utils.database import database as _db - _INITIAL_DB_SNAPSHOT: Optional[dict] = _db.db.dump_data() if getattr(_db, 'memory_only', True) else None + + _INITIAL_DB_SNAPSHOT: dict | None = ( + _db.db.dump_data() if getattr(_db, 'memory_only', True) else None + ) except Exception: _INITIAL_DB_SNAPSHOT = None + @pytest_asyncio.fixture(autouse=True) async def ensure_memory_dump_defaults(monkeypatch, tmp_path): """Ensure sane defaults for memory dump/restore tests. @@ -58,11 +63,15 @@ async def ensure_memory_dump_defaults(monkeypatch, tmp_path): """ try: monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM') - monkeypatch.setenv('MEM_ENCRYPTION_KEY', os.environ.get('MEM_ENCRYPTION_KEY') or 'test-encryption-key-32-characters-min') + monkeypatch.setenv( + 'MEM_ENCRYPTION_KEY', + os.environ.get('MEM_ENCRYPTION_KEY') or 'test-encryption-key-32-characters-min', + ) dump_base = tmp_path / 'mem' / 'memory_dump.bin' monkeypatch.setenv('MEM_DUMP_PATH', str(dump_base)) try: import utils.memory_dump_util as md + md.DEFAULT_DUMP_PATH = str(dump_base) except Exception: pass @@ -70,20 +79,22 @@ async def ensure_memory_dump_defaults(monkeypatch, tmp_path): pass yield + @pytest.fixture(autouse=True) def _log_test_start_end(request): try: ts = _dt.datetime.now().strftime('%H:%M:%S.%f')[:-3] - print(f"=== [{ts}] START {request.node.nodeid}", flush=True) + print(f'=== [{ts}] START {request.node.nodeid}', flush=True) except Exception: pass yield try: ts = _dt.datetime.now().strftime('%H:%M:%S.%f')[:-3] - print(f"=== [{ts}] END {request.node.nodeid}", flush=True) + print(f'=== [{ts}] END {request.node.nodeid}', flush=True) except Exception: pass + @pytest.fixture(autouse=True, scope='session') def _log_env_toggles(): try: @@ -92,29 +103,43 @@ def _log_env_toggles(): 'DISABLE_PLATFORM_CORS_ASGI': os.getenv('DISABLE_PLATFORM_CORS_ASGI'), 'DISABLE_BODY_SIZE_LIMIT': os.getenv('DISABLE_BODY_SIZE_LIMIT'), 'DOORMAN_TEST_MODE': os.getenv('DOORMAN_TEST_MODE'), - 'PYTHON_VERSION': f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", + 'PYTHON_VERSION': f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}', } - print(f"=== ENV TOGGLES: {toggles}", flush=True) + print(f'=== ENV TOGGLES: {toggles}', flush=True) except Exception: pass yield + @pytest_asyncio.fixture async def authed_client(): + # Ensure chaos is disabled so login/cache work reliably across tests + try: + from utils import chaos_util as _cu + _cu.enable('redis', False) + _cu.enable('mongo', False) + except Exception: + pass from doorman import doorman + client = AsyncClient(app=doorman, base_url='http://testserver') r = await client.post( '/platform/authorization', - json={'email': os.environ.get('DOORMAN_ADMIN_EMAIL'), 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD')}, + json={ + 'email': os.environ.get('DOORMAN_ADMIN_EMAIL'), + 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD'), + }, ) assert r.status_code == 200, r.text try: has_cookie = any(c.name == 'access_token_cookie' for c in client.cookies.jar) if not has_cookie: - body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} + body = ( + r.json() if r.headers.get('content-type', '').startswith('application/json') else {} + ) token = body.get('access_token') if token: client.cookies.set( @@ -126,43 +151,52 @@ async def authed_client(): except Exception: pass try: - await client.put('/platform/user/admin', json={ - 'bandwidth_limit_bytes': 0, - 'bandwidth_limit_window': 'day', - 'rate_limit_duration': 1000000, - 'rate_limit_duration_type': 'second', - 'throttle_duration': 1000000, - 'throttle_duration_type': 'second', - 'throttle_queue_limit': 1000000, - 'throttle_wait_duration': 0, - 'throttle_wait_duration_type': 'second' - }) + await client.put( + '/platform/user/admin', + json={ + 'bandwidth_limit_bytes': 0, + 'bandwidth_limit_window': 'day', + 'rate_limit_duration': 1000000, + 'rate_limit_duration_type': 'second', + 'throttle_duration': 1000000, + 'throttle_duration_type': 'second', + 'throttle_queue_limit': 1000000, + 'throttle_wait_duration': 0, + 'throttle_wait_duration_type': 'second', + }, + ) except Exception: pass return client + @pytest.fixture def client(): from doorman import doorman + return AsyncClient(app=doorman, base_url='http://testserver') + @pytest.fixture def event_loop(): loop = asyncio.new_event_loop() yield loop loop.close() + @pytest_asyncio.fixture(autouse=True) async def reset_http_client(): """Reset the pooled httpx client between tests to prevent connection pool exhaustion.""" try: from services.gateway_service import GatewayService + await GatewayService.aclose_http_client() except Exception: pass try: from utils.limit_throttle_util import reset_counters + reset_counters() except Exception: pass @@ -170,10 +204,12 @@ async def reset_http_client(): yield try: from services.gateway_service import GatewayService + await GatewayService.aclose_http_client() except Exception: pass + @pytest_asyncio.fixture(autouse=True, scope='module') async def reset_in_memory_db_state(): """Restore in-memory DB and caches before each test to ensure isolation. @@ -184,23 +220,29 @@ async def reset_in_memory_db_state(): try: if _INITIAL_DB_SNAPSHOT is not None: from utils.database import database as _db + _db.db.load_data(_INITIAL_DB_SNAPSHOT) try: - from utils.database import user_collection from utils import password_util as _pw + from utils.database import user_collection + pwd = os.environ.get('DOORMAN_ADMIN_PASSWORD') or 'test-only-password-12chars' - user_collection.update_one({'username': 'admin'}, {'$set': {'password': _pw.hash_password(pwd)}}) + user_collection.update_one( + {'username': 'admin'}, {'$set': {'password': _pw.hash_password(pwd)}} + ) except Exception: pass except Exception: pass try: from utils.doorman_cache_util import doorman_cache + doorman_cache.clear_all() except Exception: pass yield + async def create_api(client: AsyncClient, api_name: str, api_version: str): payload = { 'api_name': api_name, @@ -216,7 +258,10 @@ async def create_api(client: AsyncClient, api_name: str, api_version: str): assert r.status_code in (200, 201), r.text return r -async def create_endpoint(client: AsyncClient, api_name: str, api_version: str, method: str, uri: str): + +async def create_endpoint( + client: AsyncClient, api_name: str, api_version: str, method: str, uri: str +): payload = { 'api_name': api_name, 'api_version': api_version, @@ -228,10 +273,10 @@ async def create_endpoint(client: AsyncClient, api_name: str, api_version: str, assert r.status_code in (200, 201), r.text return r -async def subscribe_self(client: AsyncClient, api_name: str, api_version: str): +async def subscribe_self(client: AsyncClient, api_name: str, api_version: str): r_me = await client.get('/platform/user/me') - username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin') + username = r_me.json().get('username') if r_me.status_code == 200 else 'admin' r = await client.post( '/platform/subscription/subscribe', json={'username': username, 'api_name': api_name, 'api_version': api_version}, diff --git a/backend-services/tests/sitecustomize.py b/backend-services/tests/sitecustomize.py index 42d4cf2..133e21c 100644 --- a/backend-services/tests/sitecustomize.py +++ b/backend-services/tests/sitecustomize.py @@ -2,6 +2,7 @@ import logging import re import sys + class _RedactFilter(logging.Filter): PATTERNS = [ re.compile(r'(?i)(authorization\s*[:=]\s*)([^;\r\n]+)'), @@ -17,13 +18,21 @@ class _RedactFilter(logging.Filter): msg = str(record.getMessage()) red = msg for pat in self.PATTERNS: - red = pat.sub(lambda m: (m.group(1) + '[REDACTED]' + (m.group(3) if m.lastindex and m.lastindex >= 3 else '')), red) + red = pat.sub( + lambda m: ( + m.group(1) + + '[REDACTED]' + + (m.group(3) if m.lastindex and m.lastindex >= 3 else '') + ), + red, + ) if red != msg: record.msg = red except Exception: pass return True + def _ensure_logger(name: str): logger = logging.getLogger(name) for h in logger.handlers: @@ -34,9 +43,9 @@ def _ensure_logger(name: str): h.addFilter(_RedactFilter()) logger.addHandler(h) + try: _ensure_logger('doorman.gateway') _ensure_logger('doorman.logging') except Exception: pass - diff --git a/backend-services/tests/test_access_control.py b/backend-services/tests/test_access_control.py index 816c98f..47baa1e 100644 --- a/backend-services/tests/test_access_control.py +++ b/backend-services/tests/test_access_control.py @@ -1,22 +1,24 @@ -import os import pytest import pytest_asyncio from httpx import AsyncClient + @pytest_asyncio.fixture -async def login_client() : +async def login_client(): async def _login(username: str, password: str, email: str = None) -> AsyncClient: from doorman import doorman + client = AsyncClient(app=doorman, base_url='http://testserver') cred = {'email': email or f'{username}@example.com', 'password': password} r = await client.post('/platform/authorization', json=cred) assert r.status_code == 200, r.text return client + return _login + @pytest.mark.asyncio async def test_roles_and_permissions_enforced(authed_client, login_client): - cr = await authed_client.post( '/platform/role', json={ @@ -78,9 +80,9 @@ async def test_roles_and_permissions_enforced(authed_client, login_client): await viewer_client.aclose() + @pytest.mark.asyncio async def test_group_and_subscription_enforcement(login_client, authed_client, monkeypatch): - c = await authed_client.post( '/platform/api', json={ @@ -140,16 +142,23 @@ async def test_group_and_subscription_enforcement(login_client, authed_client, m assert s2.status_code in (401, 403) import services.gateway_service as gs + class _FakeHTTPResponse: def __init__(self, status_code=200, json_body=None): self.status_code = status_code self._json_body = json_body or {'ok': True} self.headers = {'Content-Type': 'application/json'} + def json(self): return self._json_body + class _FakeAsyncClient: - async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc, tb): return False + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -167,13 +176,25 @@ async def test_group_and_subscription_enforcement(login_client, authed_client, m return await self.put(url, **kwargs) else: return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'}) - async def get(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse(200) - async def post(self, url, **kwargs): return _FakeHTTPResponse(200) - async def put(self, url, **kwargs): return _FakeHTTPResponse(200) - async def delete(self, url, **kwargs): return _FakeHTTPResponse(200) + + async def get(self, url, params=None, headers=None, **kwargs): + return _FakeHTTPResponse(200) + + async def post(self, url, **kwargs): + return _FakeHTTPResponse(200) + + async def put(self, url, **kwargs): + return _FakeHTTPResponse(200) + + async def delete(self, url, **kwargs): + return _FakeHTTPResponse(200) + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) import routes.gateway_routes as gr - async def _no_limit(req): return None + + async def _no_limit(req): + return None + monkeypatch.setattr(gr, 'limit_and_throttle', _no_limit) viewer1_client = await login_client('viewer1', 'StrongViewerPwd!1234') diff --git a/backend-services/tests/test_admin_bootstrap_parity.py b/backend-services/tests/test_admin_bootstrap_parity.py index d49a04f..5355949 100644 --- a/backend-services/tests/test_admin_bootstrap_parity.py +++ b/backend-services/tests/test_admin_bootstrap_parity.py @@ -1,6 +1,8 @@ import os + import pytest + @pytest.mark.asyncio async def test_admin_seed_fields_memory_mode(monkeypatch): monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM') @@ -8,9 +10,16 @@ async def test_admin_seed_fields_memory_mode(monkeypatch): monkeypatch.setenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') from utils import database as dbmod + dbmod.database.initialize_collections() - from utils.database import user_collection, role_collection, group_collection, _build_admin_seed_doc + from utils.database import ( + _build_admin_seed_doc, + group_collection, + role_collection, + user_collection, + ) + admin = user_collection.find_one({'username': 'admin'}) assert admin is not None, 'Admin user should be seeded' @@ -20,22 +29,37 @@ async def test_admin_seed_fields_memory_mode(monkeypatch): assert '_id' in doc_keys from utils import password_util - assert password_util.verify_password(os.environ['DOORMAN_ADMIN_PASSWORD'], admin.get('password')) + + assert password_util.verify_password( + os.environ['DOORMAN_ADMIN_PASSWORD'], admin.get('password') + ) assert set(admin.get('groups') or []) >= {'ALL', 'admin'} role = role_collection.find_one({'role_name': 'admin'}) assert role is not None for cap in ( - 'manage_users','manage_apis','manage_endpoints','manage_groups','manage_roles', - 'manage_routings','manage_gateway','manage_subscriptions','manage_credits','manage_auth','manage_security','view_logs' + 'manage_users', + 'manage_apis', + 'manage_endpoints', + 'manage_groups', + 'manage_roles', + 'manage_routings', + 'manage_gateway', + 'manage_subscriptions', + 'manage_credits', + 'manage_auth', + 'manage_security', + 'view_logs', ): assert role.get(cap) is True, f'Missing admin capability: {cap}' grp_admin = group_collection.find_one({'group_name': 'admin'}) grp_all = group_collection.find_one({'group_name': 'ALL'}) assert grp_admin is not None and grp_all is not None + def test_admin_seed_helper_is_canonical(): from utils.database import _build_admin_seed_doc + doc = _build_admin_seed_doc('a@b.c', 'hash') assert doc['username'] == 'admin' assert doc['role'] == 'admin' @@ -49,4 +73,3 @@ def test_admin_seed_helper_is_canonical(): assert doc['throttle_wait_duration_type'] == 'second' assert doc['throttle_queue_limit'] == 1 assert set(doc['groups']) == {'ALL', 'admin'} - diff --git a/backend-services/tests/test_admin_positive_paths.py b/backend-services/tests/test_admin_positive_paths.py index aa6ae1a..56f9016 100644 --- a/backend-services/tests/test_admin_positive_paths.py +++ b/backend-services/tests/test_admin_positive_paths.py @@ -1,9 +1,10 @@ import uuid + import pytest + @pytest.mark.asyncio async def test_admin_can_view_admin_role_and_user(authed_client): - r_role = await authed_client.get('/platform/role/admin') assert r_role.status_code == 200, r_role.text role = r_role.json() @@ -12,7 +13,7 @@ async def test_admin_can_view_admin_role_and_user(authed_client): r_roles = await authed_client.get('/platform/role/all?page=1&page_size=50') assert r_roles.status_code == 200 roles = r_roles.json().get('roles') or r_roles.json().get('response', {}).get('roles') or [] - names = { (r.get('role_name') or '').lower() for r in roles } + names = {(r.get('role_name') or '').lower() for r in roles} assert 'admin' in names # Super admin user (username='admin') should be hidden from ALL users (ghost user) @@ -23,13 +24,17 @@ async def test_admin_can_view_admin_role_and_user(authed_client): r_users = await authed_client.get('/platform/user/all?page=1&page_size=100') assert r_users.status_code == 200 users = r_users.json() - user_list = users if isinstance(users, list) else (users.get('users') or users.get('response', {}).get('users') or []) + user_list = ( + users + if isinstance(users, list) + else (users.get('users') or users.get('response', {}).get('users') or []) + ) usernames = {u.get('username') for u in user_list} assert 'admin' not in usernames, 'Super admin should not appear in user list' + @pytest.mark.asyncio async def test_admin_can_update_admin_role_description(authed_client): - desc = f'Administrator role ({uuid.uuid4().hex[:6]})' up = await authed_client.put('/platform/role/admin', json={'role_description': desc}) assert up.status_code in (200, 201), up.text @@ -39,6 +44,7 @@ async def test_admin_can_update_admin_role_description(authed_client): body = r.json().get('response') or r.json() assert body.get('role_description') == desc + @pytest.mark.asyncio async def test_admin_can_create_and_delete_admin_user(authed_client): uname = f'adm_{uuid.uuid4().hex[:8]}' @@ -64,4 +70,3 @@ async def test_admin_can_create_and_delete_admin_user(authed_client): r2 = await authed_client.get(f'/platform/user/{uname}') assert r2.status_code in (404, 500) - diff --git a/backend-services/tests/test_api_active_and_patch.py b/backend-services/tests/test_api_active_and_patch.py index 7b19c5f..e068141 100644 --- a/backend-services/tests/test_api_active_and_patch.py +++ b/backend-services/tests/test_api_active_and_patch.py @@ -1,20 +1,42 @@ import pytest + @pytest.mark.asyncio async def test_api_disabled_rest_blocks(monkeypatch, authed_client): async def _create_api(c, n, v): - payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0} + payload = { + 'api_name': n, + 'api_version': v, + 'api_description': f'{n} {v}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://upstream'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + } rr = await c.post('/platform/api', json=payload) assert rr.status_code in (200, 201) + async def _create_endpoint(c, n, v, m, u): - payload = {'api_name': n, 'api_version': v, 'endpoint_method': m, 'endpoint_uri': u, 'endpoint_description': f'{m} {u}'} + payload = { + 'api_name': n, + 'api_version': v, + 'endpoint_method': m, + 'endpoint_uri': u, + 'endpoint_description': f'{m} {u}', + } rr = await c.post('/platform/endpoint', json=payload) assert rr.status_code in (200, 201) + async def _subscribe_self(c, n, v): r_me = await c.get('/platform/user/me') - username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin') - rr = await c.post('/platform/subscription/subscribe', json={'username': username, 'api_name': n, 'api_version': v}) + username = r_me.json().get('username') if r_me.status_code == 200 else 'admin' + rr = await c.post( + '/platform/subscription/subscribe', + json={'username': username, 'api_name': n, 'api_version': v}, + ) assert rr.status_code in (200, 201) + await _create_api(authed_client, 'disabled', 'v1') r_upd = await authed_client.put('/platform/api/disabled/v1', json={'active': False}) @@ -24,99 +46,192 @@ async def test_api_disabled_rest_blocks(monkeypatch, authed_client): r = await authed_client.get('/api/rest/disabled/v1/status') assert r.status_code in (403, 500) + @pytest.mark.asyncio async def test_api_disabled_graphql_blocks(authed_client): async def _create_api(c, n, v): - payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0} + payload = { + 'api_name': n, + 'api_version': v, + 'api_description': f'{n} {v}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://upstream'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + } rr = await c.post('/platform/api', json=payload) assert rr.status_code in (200, 201) + async def _subscribe_self(c, n, v): r_me = await c.get('/platform/user/me') - username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin') - rr = await c.post('/platform/subscription/subscribe', json={'username': username, 'api_name': n, 'api_version': v}) + username = r_me.json().get('username') if r_me.status_code == 200 else 'admin' + rr = await c.post( + '/platform/subscription/subscribe', + json={'username': username, 'api_name': n, 'api_version': v}, + ) assert rr.status_code in (200, 201) + await _create_api(authed_client, 'gqlx', 'v1') ru = await authed_client.put('/platform/api/gqlx/v1', json={'active': False}) assert ru.status_code == 200 await _subscribe_self(authed_client, 'gqlx', 'v1') - r = await authed_client.post('/api/graphql/gqlx', headers={'X-API-Version': 'v1', 'Content-Type': 'application/json'}, json={'query': '{__typename}'}) + r = await authed_client.post( + '/api/graphql/gqlx', + headers={'X-API-Version': 'v1', 'Content-Type': 'application/json'}, + json={'query': '{__typename}'}, + ) assert r.status_code == 403 + @pytest.mark.asyncio async def test_api_disabled_grpc_blocks(authed_client): async def _create_api(c, n, v): - payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0} + payload = { + 'api_name': n, + 'api_version': v, + 'api_description': f'{n} {v}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://upstream'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + } rr = await c.post('/platform/api', json=payload) assert rr.status_code in (200, 201) + async def _subscribe_self(c, n, v): r_me = await c.get('/platform/user/me') - username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin') - rr = await c.post('/platform/subscription/subscribe', json={'username': username, 'api_name': n, 'api_version': v}) + username = r_me.json().get('username') if r_me.status_code == 200 else 'admin' + rr = await c.post( + '/platform/subscription/subscribe', + json={'username': username, 'api_name': n, 'api_version': v}, + ) assert rr.status_code in (200, 201) + await _create_api(authed_client, 'grpcx', 'v1') ru = await authed_client.put('/platform/api/grpcx/v1', json={'active': False}) assert ru.status_code == 200 await _subscribe_self(authed_client, 'grpcx', 'v1') - r = await authed_client.post('/api/grpc/grpcx', headers={'X-API-Version': 'v1', 'Content-Type': 'application/json'}, json={'method': 'X', 'message': {}}) + r = await authed_client.post( + '/api/grpc/grpcx', + headers={'X-API-Version': 'v1', 'Content-Type': 'application/json'}, + json={'method': 'X', 'message': {}}, + ) assert r.status_code in (400, 403, 404) + @pytest.mark.asyncio async def test_api_disabled_soap_blocks(authed_client): async def _create_api(c, n, v): - payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0} + payload = { + 'api_name': n, + 'api_version': v, + 'api_description': f'{n} {v}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://upstream'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + } rr = await c.post('/platform/api', json=payload) assert rr.status_code in (200, 201) + async def _create_endpoint(c, n, v, m, u): - payload = {'api_name': n, 'api_version': v, 'endpoint_method': m, 'endpoint_uri': u, 'endpoint_description': f'{m} {u}'} + payload = { + 'api_name': n, + 'api_version': v, + 'endpoint_method': m, + 'endpoint_uri': u, + 'endpoint_description': f'{m} {u}', + } rr = await c.post('/platform/endpoint', json=payload) assert rr.status_code in (200, 201) + async def _subscribe_self(c, n, v): r_me = await c.get('/platform/user/me') - username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin') - rr = await c.post('/platform/subscription/subscribe', json={'username': username, 'api_name': n, 'api_version': v}) + username = r_me.json().get('username') if r_me.status_code == 200 else 'admin' + rr = await c.post( + '/platform/subscription/subscribe', + json={'username': username, 'api_name': n, 'api_version': v}, + ) assert rr.status_code in (200, 201) + await _create_api(authed_client, 'soapx', 'v1') await _create_endpoint(authed_client, 'soapx', 'v1', 'POST', '/op') ru = await authed_client.put('/platform/api/soapx/v1', json={'active': False}) assert ru.status_code == 200 await _subscribe_self(authed_client, 'soapx', 'v1') - r = await authed_client.post('/api/soap/soapx/v1/op', headers={'Content-Type': 'text/xml'}, content='') + r = await authed_client.post( + '/api/soap/soapx/v1/op', headers={'Content-Type': 'text/xml'}, content='' + ) assert r.status_code in (403, 400, 404, 500) + @pytest.mark.asyncio async def test_gateway_patch_support(monkeypatch, authed_client): - async def _create_api(c, n, v): - payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0} + payload = { + 'api_name': n, + 'api_version': v, + 'api_description': f'{n} {v}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://upstream'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + } rr = await c.post('/platform/api', json=payload) assert rr.status_code in (200, 201) + async def _create_endpoint(c, n, v, m, u): - payload = {'api_name': n, 'api_version': v, 'endpoint_method': m, 'endpoint_uri': u, 'endpoint_description': f'{m} {u}'} + payload = { + 'api_name': n, + 'api_version': v, + 'endpoint_method': m, + 'endpoint_uri': u, + 'endpoint_description': f'{m} {u}', + } rr = await c.post('/platform/endpoint', json=payload) assert rr.status_code in (200, 201) + async def _subscribe_self(c, n, v): r_me = await c.get('/platform/user/me') - username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin') - rr = await c.post('/platform/subscription/subscribe', json={'username': username, 'api_name': n, 'api_version': v}) + username = r_me.json().get('username') if r_me.status_code == 200 else 'admin' + rr = await c.post( + '/platform/subscription/subscribe', + json={'username': username, 'api_name': n, 'api_version': v}, + ) assert rr.status_code in (200, 201) + await _create_api(authed_client, 'patchy', 'v1') await _create_endpoint(authed_client, 'patchy', 'v1', 'PATCH', '/item') await _subscribe_self(authed_client, 'patchy', 'v1') import services.gateway_service as gs + class _Resp: def __init__(self): self.status_code = 200 self.headers = {'Content-Type': 'application/json'} + def json(self): return {'ok': True} + @property def text(self): return '{"ok":true}' + class _Client: - async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc, tb): return False - async def patch(self, url, json=None, params=None, headers=None, **kw): return _Resp() + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def patch(self, url, json=None, params=None, headers=None, **kw): + return _Resp() + monkeypatch.setattr(gs, 'httpx', type('X', (), {'AsyncClient': lambda timeout=None: _Client()})) r = await authed_client.patch('/api/rest/patchy/v1/item', json={'x': 1}) assert r.status_code in (200, 500) diff --git a/backend-services/tests/test_api_and_endpoint_crud.py b/backend-services/tests/test_api_and_endpoint_crud.py index f0d45f0..8fa6cf6 100644 --- a/backend-services/tests/test_api_and_endpoint_crud.py +++ b/backend-services/tests/test_api_and_endpoint_crud.py @@ -1,8 +1,8 @@ import pytest + @pytest.mark.asyncio async def test_api_crud_flow(authed_client): - payload = { 'api_name': 'customer', 'api_version': 'v1', @@ -28,8 +28,7 @@ async def test_api_crud_flow(authed_client): assert any(a.get('api_name') == 'customer' and a.get('api_version') == 'v1' for a in apis) upd = await authed_client.put( - '/platform/api/customer/v1', - json={'api_description': 'Customer API Updated'}, + '/platform/api/customer/v1', json={'api_description': 'Customer API Updated'} ) assert upd.status_code == 200 diff --git a/backend-services/tests/test_api_cors_control.py b/backend-services/tests/test_api_cors_control.py index ffc143f..2dc5f42 100644 --- a/backend-services/tests/test_api_cors_control.py +++ b/backend-services/tests/test_api_cors_control.py @@ -1,5 +1,6 @@ import pytest + @pytest.mark.asyncio async def test_rest_preflight_per_api_cors_allows_blocks(authed_client): name, ver = 'corsrest', 'v1' @@ -48,6 +49,7 @@ async def test_rest_preflight_per_api_cors_allows_blocks(authed_client): assert blocked.headers.get('access-control-allow-origin') in (None, '') + @pytest.mark.asyncio async def test_graphql_preflight_per_api_cors_allows(authed_client): name, ver = 'corsgql', 'v1' @@ -82,6 +84,7 @@ async def test_graphql_preflight_per_api_cors_allows(authed_client): assert 'Content-Type' in (pre.headers.get('access-control-allow-headers') or '') assert pre.headers.get('access-control-allow-credentials') == 'true' + @pytest.mark.asyncio async def test_soap_preflight_per_api_cors_allows(authed_client): name, ver = 'corssoap', 'v1' @@ -115,4 +118,3 @@ async def test_soap_preflight_per_api_cors_allows(authed_client): assert 'Content-Type' in (pre.headers.get('access-control-allow-headers') or '') assert pre.headers.get('access-control-allow-credentials') in (None, 'false') - diff --git a/backend-services/tests/test_api_cors_headers_matrix.py b/backend-services/tests/test_api_cors_headers_matrix.py index cc806b7..13b5d57 100644 --- a/backend-services/tests/test_api_cors_headers_matrix.py +++ b/backend-services/tests/test_api_cors_headers_matrix.py @@ -1,6 +1,9 @@ import pytest -async def _setup_api_and_endpoint(client, name, ver, api_overrides=None, method='GET', uri='/status'): + +async def _setup_api_and_endpoint( + client, name, ver, api_overrides=None, method='GET', uri='/status' +): payload = { 'api_name': name, 'api_version': ver, @@ -15,126 +18,186 @@ async def _setup_api_and_endpoint(client, name, ver, api_overrides=None, method= payload.update(api_overrides) r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201), r.text - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': method, - 'endpoint_uri': uri, - 'endpoint_description': f'{method} {uri}', - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': method, + 'endpoint_uri': uri, + 'endpoint_description': f'{method} {uri}', + }, + ) assert r2.status_code in (200, 201), r2.text + @pytest.mark.asyncio async def test_api_cors_allow_origins_exact_match_allowed(authed_client): name, ver = 'corsm1', 'v1' - await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={ - 'api_cors_allow_origins': ['http://ok.example'], - 'api_cors_allow_methods': ['GET'], - 'api_cors_allow_headers': ['Content-Type'], - }) + await _setup_api_and_endpoint( + authed_client, + name, + ver, + api_overrides={ + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['GET'], + 'api_cors_allow_headers': ['Content-Type'], + }, + ) r = await authed_client.options( f'/api/rest/{name}/{ver}/status', - headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'Content-Type'} + headers={ + 'Origin': 'http://ok.example', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'Content-Type', + }, ) assert r.status_code == 204 assert r.headers.get('Access-Control-Allow-Origin') == 'http://ok.example' assert r.headers.get('Vary') == 'Origin' + @pytest.mark.asyncio async def test_api_cors_allow_origins_wildcard_allowed(authed_client): name, ver = 'corsm2', 'v1' - await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={ - 'api_cors_allow_origins': ['*'], - 'api_cors_allow_methods': ['GET'], - 'api_cors_allow_headers': ['Content-Type'], - }) + await _setup_api_and_endpoint( + authed_client, + name, + ver, + api_overrides={ + 'api_cors_allow_origins': ['*'], + 'api_cors_allow_methods': ['GET'], + 'api_cors_allow_headers': ['Content-Type'], + }, + ) r = await authed_client.options( f'/api/rest/{name}/{ver}/status', - headers={'Origin': 'http://any.example', 'Access-Control-Request-Method': 'GET'} + headers={'Origin': 'http://any.example', 'Access-Control-Request-Method': 'GET'}, ) assert r.status_code == 204 assert r.headers.get('Access-Control-Allow-Origin') == 'http://any.example' + @pytest.mark.asyncio async def test_api_cors_allow_methods_contains_options_appended(authed_client): name, ver = 'corsm3', 'v1' - await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={ - 'api_cors_allow_origins': ['http://ok.example'], - 'api_cors_allow_methods': ['GET'], - 'api_cors_allow_headers': ['Content-Type'], - }) + await _setup_api_and_endpoint( + authed_client, + name, + ver, + api_overrides={ + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['GET'], + 'api_cors_allow_headers': ['Content-Type'], + }, + ) r = await authed_client.options( f'/api/rest/{name}/{ver}/status', - headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET'} + headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET'}, ) assert r.status_code == 204 - methods = [m.strip().upper() for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',') if m.strip()] + methods = [ + m.strip().upper() + for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',') + if m.strip() + ] assert 'OPTIONS' in methods + @pytest.mark.asyncio async def test_api_cors_allow_headers_asterisk_allows_any(authed_client): name, ver = 'corsm4', 'v1' - await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={ - 'api_cors_allow_origins': ['http://ok.example'], - 'api_cors_allow_methods': ['GET'], - 'api_cors_allow_headers': ['*'], - }) + await _setup_api_and_endpoint( + authed_client, + name, + ver, + api_overrides={ + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['GET'], + 'api_cors_allow_headers': ['*'], + }, + ) r = await authed_client.options( f'/api/rest/{name}/{ver}/status', - headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-Random-Header'} + headers={ + 'Origin': 'http://ok.example', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'X-Random-Header', + }, ) assert r.status_code == 204 ach = r.headers.get('Access-Control-Allow-Headers') or '' assert '*' in ach + @pytest.mark.asyncio async def test_api_cors_allow_headers_specific_disallows_others(authed_client): name, ver = 'corsm5', 'v1' - await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={ - 'api_cors_allow_origins': ['http://ok.example'], - 'api_cors_allow_methods': ['GET'], - 'api_cors_allow_headers': ['Content-Type'], - }) + await _setup_api_and_endpoint( + authed_client, + name, + ver, + api_overrides={ + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['GET'], + 'api_cors_allow_headers': ['Content-Type'], + }, + ) r = await authed_client.options( f'/api/rest/{name}/{ver}/status', - headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-Other'} + headers={ + 'Origin': 'http://ok.example', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'X-Other', + }, ) assert r.status_code == 204 ach = r.headers.get('Access-Control-Allow-Headers') or '' assert 'Content-Type' in ach and 'X-Other' not in ach + @pytest.mark.asyncio async def test_api_cors_allow_credentials_true_sets_header(authed_client): name, ver = 'corsm6', 'v1' - await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={ - 'api_cors_allow_origins': ['http://ok.example'], - 'api_cors_allow_methods': ['GET'], - 'api_cors_allow_headers': ['Content-Type'], - 'api_cors_allow_credentials': True, - }) + await _setup_api_and_endpoint( + authed_client, + name, + ver, + api_overrides={ + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['GET'], + 'api_cors_allow_headers': ['Content-Type'], + 'api_cors_allow_credentials': True, + }, + ) r = await authed_client.options( f'/api/rest/{name}/{ver}/status', - headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET'} + headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET'}, ) assert r.status_code == 204 assert r.headers.get('Access-Control-Allow-Credentials') == 'true' + @pytest.mark.asyncio async def test_api_cors_expose_headers_propagated(authed_client): name, ver = 'corsm7', 'v1' expose = ['X-Resp-Id', 'X-Trace-Id'] - await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={ - 'api_cors_allow_origins': ['http://ok.example'], - 'api_cors_allow_methods': ['GET'], - 'api_cors_allow_headers': ['Content-Type'], - 'api_cors_expose_headers': expose, - }) + await _setup_api_and_endpoint( + authed_client, + name, + ver, + api_overrides={ + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['GET'], + 'api_cors_allow_headers': ['Content-Type'], + 'api_cors_expose_headers': expose, + }, + ) r = await authed_client.options( f'/api/rest/{name}/{ver}/status', - headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET'} + headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET'}, ) assert r.status_code == 204 aceh = r.headers.get('Access-Control-Expose-Headers') or '' for h in expose: assert h in aceh - diff --git a/backend-services/tests/test_api_crud_failures.py b/backend-services/tests/test_api_crud_failures.py index ea82d16..b57a044 100644 --- a/backend-services/tests/test_api_crud_failures.py +++ b/backend-services/tests/test_api_crud_failures.py @@ -1,5 +1,6 @@ import pytest + @pytest.mark.asyncio async def test_update_delete_nonexistent_api(authed_client): u = await authed_client.put('/platform/api/doesnot/v9', json={'api_description': 'x'}) @@ -7,4 +8,3 @@ async def test_update_delete_nonexistent_api(authed_client): d = await authed_client.delete('/platform/api/doesnot/v9') assert d.status_code in (400, 404) - diff --git a/backend-services/tests/test_api_get_by_name_version.py b/backend-services/tests/test_api_get_by_name_version.py new file mode 100644 index 0000000..55f2512 --- /dev/null +++ b/backend-services/tests/test_api_get_by_name_version.py @@ -0,0 +1,27 @@ +import pytest + + +@pytest.mark.asyncio +async def test_api_get_by_name_version_returns_200(authed_client): + name, ver = 'apiget', 'v1' + r = await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'demo', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://upstream.invalid'], + 'api_type': 'REST', + 'active': True, + }, + ) + assert r.status_code in (200, 201), r.text + + g = await authed_client.get(f'/platform/api/{name}/{ver}') + assert g.status_code == 200, g.text + body = g.json().get('response', g.json()) + assert body.get('api_name') == name + assert body.get('api_version') == ver + assert '_id' not in body diff --git a/backend-services/tests/test_async_collection.py b/backend-services/tests/test_async_collection.py index 1267573..aa73405 100644 --- a/backend-services/tests/test_async_collection.py +++ b/backend-services/tests/test_async_collection.py @@ -1,14 +1,17 @@ #!/usr/bin/env python3 """Test script to verify async collection behavior""" + import asyncio import os + os.environ['MEM_OR_EXTERNAL'] = 'MEM' from utils.database_async import async_database + async def test_async_collection(): - print("Testing async collection...") - + print('Testing async collection...') + # Insert a test tier test_tier = { 'tier_id': 'test', @@ -17,31 +20,32 @@ async def test_async_collection(): 'limits': {}, 'features': [], 'is_default': False, - 'enabled': True + 'enabled': True, } - + await async_database.db.tiers.insert_one(test_tier) - print("Inserted test tier") - + print('Inserted test tier') + # Try to list tiers cursor = async_database.db.tiers.find({}) - print(f"Cursor type: {type(cursor)}") - print(f"Cursor has skip: {hasattr(cursor, 'skip')}") - print(f"Cursor has limit: {hasattr(cursor, 'limit')}") - + print(f'Cursor type: {type(cursor)}') + print(f'Cursor has skip: {hasattr(cursor, "skip")}') + print(f'Cursor has limit: {hasattr(cursor, "limit")}') + # Test skip/limit cursor = cursor.skip(0).limit(10) - print("Applied skip/limit") - + print('Applied skip/limit') + # Iterate tiers = [] async for tier_data in cursor: - print(f"Found tier: {tier_data.get('tier_id')}") + print(f'Found tier: {tier_data.get("tier_id")}') tiers.append(tier_data) - - print(f"Total tiers found: {len(tiers)}") + + print(f'Total tiers found: {len(tiers)}') return tiers + if __name__ == '__main__': result = asyncio.run(test_async_collection()) - print(f"Result: {result}") + print(f'Result: {result}') diff --git a/backend-services/tests/test_async_db_wrappers.py b/backend-services/tests/test_async_db_wrappers.py new file mode 100644 index 0000000..c85da11 --- /dev/null +++ b/backend-services/tests/test_async_db_wrappers.py @@ -0,0 +1,94 @@ +import pytest + +from utils.async_db import db_delete_one, db_find_one, db_insert_one, db_update_one +from utils.database_async import async_database + + +@pytest.mark.asyncio +async def test_async_wrappers_with_inmemory_async_collections(): + coll = async_database.db.tiers # AsyncInMemoryCollection + + # Insert + await db_insert_one(coll, {'tier_id': 't1', 'name': 'Tier 1'}) + doc = await db_find_one(coll, {'tier_id': 't1'}) + assert doc and doc.get('name') == 'Tier 1' + + # Update + await db_update_one(coll, {'tier_id': 't1'}, {'$set': {'name': 'Tier One'}}) + doc2 = await db_find_one(coll, {'tier_id': 't1'}) + assert doc2 and doc2.get('name') == 'Tier One' + + # Delete + await db_delete_one(coll, {'tier_id': 't1'}) + assert await db_find_one(coll, {'tier_id': 't1'}) is None + + +class _SyncColl: + def __init__(self): + self._docs = [] + + def find_one(self, q): + for d in self._docs: + match = all(d.get(k) == v for k, v in q.items()) + if match: + return dict(d) + return None + + def insert_one(self, doc): + self._docs.append(dict(doc)) + + class R: + acknowledged = True + inserted_id = 'x' + + return R() + + def update_one(self, q, upd): + for i, d in enumerate(self._docs): + if all(d.get(k) == v for k, v in q.items()): + setv = upd.get('$set', {}) + nd = dict(d) + nd.update(setv) + self._docs[i] = nd + + class R: + acknowledged = True + modified_count = 1 + + return R() + + class R2: + acknowledged = True + modified_count = 0 + + return R2() + + def delete_one(self, q): + for i, d in enumerate(self._docs): + if all(d.get(k) == v for k, v in q.items()): + del self._docs[i] + + class R: + acknowledged = True + deleted_count = 1 + + return R() + + class R2: + acknowledged = True + deleted_count = 0 + + return R2() + + +@pytest.mark.asyncio +async def test_async_wrappers_fallback_to_thread_for_sync_collections(): + coll = _SyncColl() + await db_insert_one(coll, {'k': 1, 'v': 'a'}) + d = await db_find_one(coll, {'k': 1}) + assert d and d['v'] == 'a' + await db_update_one(coll, {'k': 1}, {'$set': {'v': 'b'}}) + d2 = await db_find_one(coll, {'k': 1}) + assert d2 and d2['v'] == 'b' + await db_delete_one(coll, {'k': 1}) + assert await db_find_one(coll, {'k': 1}) is None diff --git a/backend-services/tests/test_async_endpoints.py b/backend-services/tests/test_async_endpoints.py index 2c6abef..551d649 100644 --- a/backend-services/tests/test_async_endpoints.py +++ b/backend-services/tests/test_async_endpoints.py @@ -5,55 +5,47 @@ The contents of this file are property of Doorman Dev, LLC Review the Apache License 2.0 for valid authorization of use """ -from fastapi import APIRouter, HTTPException -from typing import Dict, Any import asyncio import time +from typing import Any -from utils.database_async import ( - user_collection as async_user_collection, - api_collection as async_api_collection, - async_database -) +from fastapi import APIRouter, HTTPException + +from utils.database import api_collection as sync_api_collection +from utils.database import user_collection as sync_user_collection +from utils.database_async import api_collection as async_api_collection +from utils.database_async import async_database +from utils.database_async import user_collection as async_user_collection from utils.doorman_cache_async import async_doorman_cache - -from utils.database import ( - user_collection as sync_user_collection, - api_collection as sync_api_collection -) from utils.doorman_cache_util import doorman_cache -router = APIRouter(prefix="/test/async", tags=["Async Testing"]) +router = APIRouter(prefix='/test/async', tags=['Async Testing']) -@router.get("/health") -async def async_health_check() -> Dict[str, Any]: + +@router.get('/health') +async def async_health_check() -> dict[str, Any]: """Test async database and cache health.""" try: if async_database.is_memory_only(): - db_status = "memory_only" + db_status = 'memory_only' else: await async_user_collection.find_one({'username': 'admin'}) - db_status = "connected" + db_status = 'connected' cache_operational = await async_doorman_cache.is_operational() cache_info = await async_doorman_cache.get_cache_info() return { - "status": "healthy", - "database": { - "status": db_status, - "mode": async_database.get_mode_info() - }, - "cache": { - "operational": cache_operational, - "info": cache_info - } + 'status': 'healthy', + 'database': {'status': db_status, 'mode': async_database.get_mode_info()}, + 'cache': {'operational': cache_operational, 'info': cache_info}, } except Exception as e: - raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}") + raise HTTPException(status_code=500, detail=f'Health check failed: {str(e)}') -@router.get("/performance/sync") -async def test_sync_performance() -> Dict[str, Any]: + +@router.get('/performance/sync') +async def test_sync_performance() -> dict[str, Any]: """Test SYNC (blocking) database operations - SLOW under load.""" start_time = time.time() @@ -68,17 +60,18 @@ async def test_sync_performance() -> Dict[str, Any]: elapsed = time.time() - start_time return { - "method": "sync (blocking)", - "elapsed_ms": round(elapsed * 1000, 2), - "user_found": user is not None, - "apis_count": len(apis), - "warning": "This endpoint blocks the event loop and causes poor performance under load" + 'method': 'sync (blocking)', + 'elapsed_ms': round(elapsed * 1000, 2), + 'user_found': user is not None, + 'apis_count': len(apis), + 'warning': 'This endpoint blocks the event loop and causes poor performance under load', } except Exception as e: - raise HTTPException(status_code=500, detail=f"Sync test failed: {str(e)}") + raise HTTPException(status_code=500, detail=f'Sync test failed: {str(e)}') -@router.get("/performance/async") -async def test_async_performance() -> Dict[str, Any]: + +@router.get('/performance/async') +async def test_async_performance() -> dict[str, Any]: """Test ASYNC (non-blocking) database operations - FAST under load.""" start_time = time.time() @@ -98,17 +91,18 @@ async def test_async_performance() -> Dict[str, Any]: elapsed = time.time() - start_time return { - "method": "async (non-blocking)", - "elapsed_ms": round(elapsed * 1000, 2), - "user_found": user is not None, - "apis_count": len(apis), - "note": "This endpoint does NOT block the event loop and performs well under load" + 'method': 'async (non-blocking)', + 'elapsed_ms': round(elapsed * 1000, 2), + 'user_found': user is not None, + 'apis_count': len(apis), + 'note': 'This endpoint does NOT block the event loop and performs well under load', } except Exception as e: - raise HTTPException(status_code=500, detail=f"Async test failed: {str(e)}") + raise HTTPException(status_code=500, detail=f'Async test failed: {str(e)}') -@router.get("/performance/parallel") -async def test_parallel_performance() -> Dict[str, Any]: + +@router.get('/performance/parallel') +async def test_parallel_performance() -> dict[str, Any]: """Test PARALLEL async operations - Maximum performance.""" start_time = time.time() @@ -116,19 +110,13 @@ async def test_parallel_performance() -> Dict[str, Any]: user_task = async_user_collection.find_one({'username': 'admin'}) if async_database.is_memory_only(): - apis_task = asyncio.to_thread( - lambda: list(async_api_collection.find({}).limit(10)) - ) + apis_task = asyncio.to_thread(lambda: list(async_api_collection.find({}).limit(10))) else: apis_task = async_api_collection.find({}).limit(10).to_list(length=10) cache_task = async_doorman_cache.get_cache('user_cache', 'admin') - user, apis, cached_user = await asyncio.gather( - user_task, - apis_task, - cache_task - ) + user, apis, cached_user = await asyncio.gather(user_task, apis_task, cache_task) if not cached_user and user: await async_doorman_cache.set_cache('user_cache', 'admin', user) @@ -136,25 +124,22 @@ async def test_parallel_performance() -> Dict[str, Any]: elapsed = time.time() - start_time return { - "method": "async parallel (non-blocking + concurrent)", - "elapsed_ms": round(elapsed * 1000, 2), - "user_found": user is not None, - "apis_count": len(apis) if apis else 0, - "note": "Operations executed in parallel for maximum performance" + 'method': 'async parallel (non-blocking + concurrent)', + 'elapsed_ms': round(elapsed * 1000, 2), + 'user_found': user is not None, + 'apis_count': len(apis) if apis else 0, + 'note': 'Operations executed in parallel for maximum performance', } except Exception as e: - raise HTTPException(status_code=500, detail=f"Parallel test failed: {str(e)}") + raise HTTPException(status_code=500, detail=f'Parallel test failed: {str(e)}') -@router.get("/cache/test") -async def test_cache_operations() -> Dict[str, Any]: + +@router.get('/cache/test') +async def test_cache_operations() -> dict[str, Any]: """Test async cache operations.""" try: - test_key = "test_user_123" - test_value = { - "username": "test_user_123", - "email": "test@example.com", - "role": "user" - } + test_key = 'test_user_123' + test_value = {'username': 'test_user_123', 'email': 'test@example.com', 'role': 'user'} await async_doorman_cache.set_cache('user_cache', test_key, test_value) @@ -165,16 +150,17 @@ async def test_cache_operations() -> Dict[str, Any]: after_delete = await async_doorman_cache.get_cache('user_cache', test_key) return { - "set": "success", - "get": "success" if retrieved == test_value else "failed", - "delete": "success" if after_delete is None else "failed", - "cache_info": await async_doorman_cache.get_cache_info() + 'set': 'success', + 'get': 'success' if retrieved == test_value else 'failed', + 'delete': 'success' if after_delete is None else 'failed', + 'cache_info': await async_doorman_cache.get_cache_info(), } except Exception as e: - raise HTTPException(status_code=500, detail=f"Cache test failed: {str(e)}") + raise HTTPException(status_code=500, detail=f'Cache test failed: {str(e)}') -@router.get("/load-test-compare") -async def load_test_comparison() -> Dict[str, Any]: + +@router.get('/load-test-compare') +async def load_test_comparison() -> dict[str, Any]: """ Compare sync vs async performance under simulated load. @@ -183,33 +169,30 @@ async def load_test_comparison() -> Dict[str, Any]: try: sync_start = time.time() sync_results = [] - for i in range(10): + for _i in range(10): user = sync_user_collection.find_one({'username': 'admin'}) sync_results.append(user is not None) sync_elapsed = time.time() - sync_start async_start = time.time() - async_tasks = [ - async_user_collection.find_one({'username': 'admin'}) - for i in range(10) - ] - async_results = await asyncio.gather(*async_tasks) + async_tasks = [async_user_collection.find_one({'username': 'admin'}) for i in range(10)] + await asyncio.gather(*async_tasks) async_elapsed = time.time() - async_start speedup = sync_elapsed / async_elapsed if async_elapsed > 0 else 0 return { - "test": "10 concurrent user lookups", - "sync": { - "elapsed_ms": round(sync_elapsed * 1000, 2), - "queries_per_second": round(10 / sync_elapsed, 2) + 'test': '10 concurrent user lookups', + 'sync': { + 'elapsed_ms': round(sync_elapsed * 1000, 2), + 'queries_per_second': round(10 / sync_elapsed, 2), }, - "async": { - "elapsed_ms": round(async_elapsed * 1000, 2), - "queries_per_second": round(10 / async_elapsed, 2) + 'async': { + 'elapsed_ms': round(async_elapsed * 1000, 2), + 'queries_per_second': round(10 / async_elapsed, 2), }, - "speedup": f"{round(speedup, 2)}x faster", - "note": "Async shows significant improvement with concurrent operations" + 'speedup': f'{round(speedup, 2)}x faster', + 'note': 'Async shows significant improvement with concurrent operations', } except Exception as e: - raise HTTPException(status_code=500, detail=f"Load test failed: {str(e)}") + raise HTTPException(status_code=500, detail=f'Load test failed: {str(e)}') diff --git a/backend-services/tests/test_audit_and_export_negatives.py b/backend-services/tests/test_audit_and_export_negatives.py index b8f29c1..69ba26c 100644 --- a/backend-services/tests/test_audit_and_export_negatives.py +++ b/backend-services/tests/test_audit_and_export_negatives.py @@ -1,25 +1,35 @@ import pytest + class _AuditSpy: def __init__(self): self.calls = [] + def info(self, msg): self.calls.append(msg) + @pytest.mark.asyncio async def test_audit_api_create_update_delete(monkeypatch, authed_client): - import utils.audit_util as au + orig = au._logger spy = _AuditSpy() au._logger = spy try: - - r = await authed_client.post('/platform/api', json={ - 'api_name': 'aud1', 'api_version': 'v1', 'api_description': 'd', - 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0, - }) + r = await authed_client.post( + '/platform/api', + json={ + 'api_name': 'aud1', + 'api_version': 'v1', + 'api_description': 'd', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://upstream'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) assert r.status_code in (200, 201) r = await authed_client.put('/platform/api/aud1/v1', json={'api_description': 'd2'}) @@ -34,27 +44,44 @@ async def test_audit_api_create_update_delete(monkeypatch, authed_client): finally: au._logger = orig + @pytest.mark.asyncio async def test_audit_user_credits_and_subscriptions(monkeypatch, authed_client): import utils.audit_util as au + orig = au._logger spy = _AuditSpy() au._logger = spy try: - - r = await authed_client.post('/platform/credit/admin', json={'username': 'admin', 'users_credits': {}}) + r = await authed_client.post( + '/platform/credit/admin', json={'username': 'admin', 'users_credits': {}} + ) assert r.status_code in (200, 201) - await authed_client.post('/platform/api', json={ - 'api_name': 'aud2', 'api_version': 'v1', 'api_description': 'd', - 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0, - }) + await authed_client.post( + '/platform/api', + json={ + 'api_name': 'aud2', + 'api_version': 'v1', + 'api_description': 'd', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://upstream'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) r_me = await authed_client.get('/platform/user/me') - username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin') - r = await authed_client.post('/platform/subscription/subscribe', json={'username': username, 'api_name': 'aud2', 'api_version': 'v1'}) + username = r_me.json().get('username') if r_me.status_code == 200 else 'admin' + r = await authed_client.post( + '/platform/subscription/subscribe', + json={'username': username, 'api_name': 'aud2', 'api_version': 'v1'}, + ) assert r.status_code in (200, 201) - r = await authed_client.post('/platform/subscription/unsubscribe', json={'username': username, 'api_name': 'aud2', 'api_version': 'v1'}) + r = await authed_client.post( + '/platform/subscription/unsubscribe', + json={'username': username, 'api_name': 'aud2', 'api_version': 'v1'}, + ) assert r.status_code in (200, 201, 400) assert any('user_credits.save' in c for c in spy.calls) @@ -63,14 +90,20 @@ async def test_audit_user_credits_and_subscriptions(monkeypatch, authed_client): finally: au._logger = orig + @pytest.mark.asyncio async def test_export_not_found_cases(authed_client): - r = await authed_client.get('/platform/config/export/apis', params={'api_name': 'nope', 'api_version': 'v9'}) + r = await authed_client.get( + '/platform/config/export/apis', params={'api_name': 'nope', 'api_version': 'v9'} + ) assert r.status_code == 404 r = await authed_client.get('/platform/config/export/roles', params={'role_name': 'nope-role'}) assert r.status_code == 404 - r = await authed_client.get('/platform/config/export/groups', params={'group_name': 'nope-group'}) + r = await authed_client.get( + '/platform/config/export/groups', params={'group_name': 'nope-group'} + ) assert r.status_code == 404 - r = await authed_client.get('/platform/config/export/routings', params={'client_key': 'nope-key'}) + r = await authed_client.get( + '/platform/config/export/routings', params={'client_key': 'nope-key'} + ) assert r.status_code == 404 - diff --git a/backend-services/tests/test_auth.py b/backend-services/tests/test_auth.py index 35a1dcd..04df3ae 100644 --- a/backend-services/tests/test_auth.py +++ b/backend-services/tests/test_auth.py @@ -1,6 +1,8 @@ import os + import pytest + @pytest.mark.asyncio async def test_authorization_login_and_status(client): resp = await client.post( @@ -19,9 +21,9 @@ async def test_authorization_login_and_status(client): body = status.json() assert body.get('message') == 'Token is valid' + @pytest.mark.asyncio async def test_auth_refresh_and_invalidate(authed_client): - r = await authed_client.post('/platform/authorization/refresh') assert r.status_code == 200 @@ -31,10 +33,10 @@ async def test_auth_refresh_and_invalidate(authed_client): status = await authed_client.get('/platform/authorization/status') assert status.status_code in (401, 500) + @pytest.mark.asyncio async def test_authorization_invalid_login(client): resp = await client.post( - '/platform/authorization', - json={'email': 'unknown@example.com', 'password': 'bad'}, + '/platform/authorization', json={'email': 'unknown@example.com', 'password': 'bad'} ) assert resp.status_code in (400, 401) diff --git a/backend-services/tests/test_auth_admin.py b/backend-services/tests/test_auth_admin.py index 16c4f81..ef443ac 100644 --- a/backend-services/tests/test_auth_admin.py +++ b/backend-services/tests/test_auth_admin.py @@ -1,20 +1,18 @@ import os + import pytest + @pytest.mark.asyncio async def test_auth_admin_endpoints(authed_client): - - r = await authed_client.put( - '/platform/role/admin', - json={'manage_auth': True}, - ) + r = await authed_client.put('/platform/role/admin', json={'manage_auth': True}) assert r.status_code == 200, r.text relog = await authed_client.post( '/platform/authorization', json={ 'email': os.environ.get('DOORMAN_ADMIN_EMAIL'), - 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD') + 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD'), }, ) assert relog.status_code == 200, relog.text @@ -27,7 +25,7 @@ async def test_auth_admin_endpoints(authed_client): 'password': 'VerySecurePassword!123', 'role': 'admin', 'groups': ['ALL'], - 'active': True + 'active': True, }, ) assert create.status_code in (200, 201), create.text diff --git a/backend-services/tests/test_auth_csrf_https.py b/backend-services/tests/test_auth_csrf_https.py index f7c42e5..7663c7e 100644 --- a/backend-services/tests/test_auth_csrf_https.py +++ b/backend-services/tests/test_auth_csrf_https.py @@ -1,17 +1,20 @@ import pytest + @pytest.mark.asyncio async def test_auth_rejects_missing_csrf_when_https_only(monkeypatch, authed_client): monkeypatch.setenv('HTTPS_ONLY', 'true') r = await authed_client.get('/platform/user/me') assert r.status_code == 401 + @pytest.mark.asyncio async def test_auth_rejects_mismatched_csrf_when_https_only(monkeypatch, authed_client): monkeypatch.setenv('HTTPS_ONLY', 'true') r = await authed_client.get('/platform/user/me', headers={'X-CSRF-Token': 'not-the-cookie'}) assert r.status_code == 401 + @pytest.mark.asyncio async def test_auth_accepts_matching_csrf(monkeypatch, authed_client): monkeypatch.setenv('HTTPS_ONLY', 'true') @@ -24,9 +27,9 @@ async def test_auth_accepts_matching_csrf(monkeypatch, authed_client): r = await authed_client.get('/platform/user/me', headers={'X-CSRF-Token': csrf}) assert r.status_code == 200 + @pytest.mark.asyncio async def test_auth_http_mode_skips_csrf_validation(monkeypatch, authed_client): monkeypatch.setenv('HTTPS_ONLY', 'false') r = await authed_client.get('/platform/user/me') assert r.status_code == 200 - diff --git a/backend-services/tests/test_auth_guard.py b/backend-services/tests/test_auth_guard.py index 1651331..0ffa82c 100644 --- a/backend-services/tests/test_auth_guard.py +++ b/backend-services/tests/test_auth_guard.py @@ -1,7 +1,7 @@ import pytest + @pytest.mark.asyncio async def test_unauthorized_access_rejected(client): - me = await client.get('/platform/user/me') assert me.status_code in (401, 500) diff --git a/backend-services/tests/test_auth_refresh_unauthenticated.py b/backend-services/tests/test_auth_refresh_unauthenticated.py index 1024476..b5c5801 100644 --- a/backend-services/tests/test_auth_refresh_unauthenticated.py +++ b/backend-services/tests/test_auth_refresh_unauthenticated.py @@ -1,7 +1,7 @@ import pytest + @pytest.mark.asyncio async def test_refresh_requires_auth(client): r = await client.post('/platform/authorization/refresh') assert r.status_code == 401 - diff --git a/backend-services/tests/test_bandwidth_and_monitor.py b/backend-services/tests/test_bandwidth_and_monitor.py index 0fd1475..d2654d7 100644 --- a/backend-services/tests/test_bandwidth_and_monitor.py +++ b/backend-services/tests/test_bandwidth_and_monitor.py @@ -1,16 +1,22 @@ import json + import pytest + @pytest.mark.asyncio async def test_bandwidth_enforcement_and_usage_tracking(monkeypatch, authed_client): try: - upd = await authed_client.put('/platform/user/admin', json={'bandwidth_limit_bytes': 80, 'bandwidth_limit_window': 'second'}) + upd = await authed_client.put( + '/platform/user/admin', + json={'bandwidth_limit_bytes': 80, 'bandwidth_limit_window': 'second'}, + ) assert upd.status_code in (200, 204) except AssertionError: await authed_client.put('/platform/user/admin', json={'bandwidth_limit_bytes': None}) raise from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'bwapi', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/p') @@ -24,16 +30,20 @@ async def test_bandwidth_enforcement_and_usage_tracking(monkeypatch, authed_clie self.headers = {'Content-Type': 'application/json', 'Content-Length': str(len(body))} self.text = body.decode('utf-8') self.content = body + def json(self): return json.loads(self.text) class _FakeAsyncClient: def __init__(self, timeout=None, limits=None, http2=False): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -51,12 +61,16 @@ async def test_bandwidth_enforcement_and_usage_tracking(monkeypatch, authed_clie return await self.put(url, **kwargs) else: return _FakeHTTPResponse(405) + async def get(self, url, **kwargs): return _FakeHTTPResponse(200) + async def post(self, url, data=None, json=None, headers=None, params=None, **kwargs): return _FakeHTTPResponse(200) + async def put(self, url, **kwargs): return _FakeHTTPResponse(200) + async def delete(self, url, **kwargs): return _FakeHTTPResponse(200) @@ -82,9 +96,11 @@ async def test_bandwidth_enforcement_and_usage_tracking(monkeypatch, authed_clie await authed_client.put('/platform/user/admin', json={'bandwidth_limit_bytes': 0}) + @pytest.mark.asyncio async def test_monitor_tracks_bytes_in_out(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'bwmon', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/echo') @@ -100,13 +116,20 @@ async def test_monitor_tracks_bytes_in_out(monkeypatch, authed_client): self.headers = {'Content-Type': 'application/json', 'Content-Length': str(len(body))} self.text = body.decode('utf-8') self.content = body + def json(self): return json.loads(self.text) class _FakeAsyncClient: - def __init__(self, timeout=None, limits=None, http2=False): pass - async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc, tb): return False + def __init__(self, timeout=None, limits=None, http2=False): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -124,10 +147,18 @@ async def test_monitor_tracks_bytes_in_out(monkeypatch, authed_client): return await self.put(url, **kwargs) else: return _FakeHTTPResponse(405) - async def get(self, url, **kwargs): return _FakeHTTPResponse(200) - async def post(self, url, data=None, json=None, headers=None, params=None, **kwargs): return _FakeHTTPResponse(200) - async def put(self, url, **kwargs): return _FakeHTTPResponse(200) - async def delete(self, url, **kwargs): return _FakeHTTPResponse(200) + + async def get(self, url, **kwargs): + return _FakeHTTPResponse(200) + + async def post(self, url, data=None, json=None, headers=None, params=None, **kwargs): + return _FakeHTTPResponse(200) + + async def put(self, url, **kwargs): + return _FakeHTTPResponse(200) + + async def delete(self, url, **kwargs): + return _FakeHTTPResponse(200) monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) diff --git a/backend-services/tests/test_bandwidth_limit_windows.py b/backend-services/tests/test_bandwidth_limit_windows.py index 43a8dbf..7d07407 100644 --- a/backend-services/tests/test_bandwidth_limit_windows.py +++ b/backend-services/tests/test_bandwidth_limit_windows.py @@ -1,26 +1,36 @@ -import pytest import time +import pytest from tests.test_gateway_routing_limits import _FakeAsyncClient + async def _setup_basic_rest(client, name='bw', ver='v1', method='GET', uri='/p'): from conftest import create_api, create_endpoint, subscribe_self + await create_api(client, name, ver) await create_endpoint(client, name, ver, method, uri) await subscribe_self(client, name, ver) return name, ver, uri + @pytest.mark.asyncio async def test_bandwidth_limit_blocks_when_exceeded(monkeypatch, authed_client): name, ver, uri = await _setup_basic_rest(authed_client, name='bw1', method='GET', uri='/g') from utils.database import user_collection - user_collection.update_one({'username': 'admin'}, {'$set': { - 'bandwidth_limit_bytes': 1, - 'bandwidth_limit_window': 'second', - 'bandwidth_limit_enabled': True, - }}) + + user_collection.update_one( + {'username': 'admin'}, + { + '$set': { + 'bandwidth_limit_bytes': 1, + 'bandwidth_limit_window': 'second', + 'bandwidth_limit_enabled': True, + } + }, + ) await authed_client.delete('/api/caches') from utils.doorman_cache_util import doorman_cache + try: for k in list(doorman_cache.cache.keys('bandwidth_usage:admin*')): doorman_cache.cache.delete(k) @@ -28,6 +38,7 @@ async def test_bandwidth_limit_blocks_when_exceeded(monkeypatch, authed_client): pass import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) ok = await authed_client.get(f'/api/rest/{name}/{ver}{uri}') @@ -35,17 +46,25 @@ async def test_bandwidth_limit_blocks_when_exceeded(monkeypatch, authed_client): blocked = await authed_client.get(f'/api/rest/{name}/{ver}{uri}') assert blocked.status_code == 429 + @pytest.mark.asyncio async def test_bandwidth_limit_resets_after_window(monkeypatch, authed_client): name, ver, uri = await _setup_basic_rest(authed_client, name='bw2', method='GET', uri='/g') from utils.database import user_collection - user_collection.update_one({'username': 'admin'}, {'$set': { - 'bandwidth_limit_bytes': 1, - 'bandwidth_limit_window': 'second', - 'bandwidth_limit_enabled': True, - }}) + + user_collection.update_one( + {'username': 'admin'}, + { + '$set': { + 'bandwidth_limit_bytes': 1, + 'bandwidth_limit_window': 'second', + 'bandwidth_limit_enabled': True, + } + }, + ) await authed_client.delete('/api/caches') from utils.doorman_cache_util import doorman_cache + try: for k in list(doorman_cache.cache.keys('bandwidth_usage:admin*')): doorman_cache.cache.delete(k) @@ -53,6 +72,7 @@ async def test_bandwidth_limit_resets_after_window(monkeypatch, authed_client): pass import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r1 = await authed_client.get(f'/api/rest/{name}/{ver}{uri}') @@ -63,17 +83,25 @@ async def test_bandwidth_limit_resets_after_window(monkeypatch, authed_client): r3 = await authed_client.get(f'/api/rest/{name}/{ver}{uri}') assert r3.status_code == 200 + @pytest.mark.asyncio async def test_bandwidth_limit_counts_request_and_response_bytes(monkeypatch, authed_client): name, ver, uri = await _setup_basic_rest(authed_client, name='bw3', method='POST', uri='/p') from utils.database import user_collection - user_collection.update_one({'username': 'admin'}, {'$set': { - 'bandwidth_limit_bytes': 1_000_000, - 'bandwidth_limit_window': 'second', - 'bandwidth_limit_enabled': True, - }}) + + user_collection.update_one( + {'username': 'admin'}, + { + '$set': { + 'bandwidth_limit_bytes': 1_000_000, + 'bandwidth_limit_window': 'second', + 'bandwidth_limit_enabled': True, + } + }, + ) await authed_client.delete('/api/caches') from utils.doorman_cache_util import doorman_cache + try: for k in list(doorman_cache.cache.keys('bandwidth_usage:admin*')): doorman_cache.cache.delete(k) @@ -81,9 +109,11 @@ async def test_bandwidth_limit_counts_request_and_response_bytes(monkeypatch, au pass import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) from utils.bandwidth_util import get_current_usage + before = get_current_usage('admin', 'second') payload = 'x' * 1234 r = await authed_client.post(f'/api/rest/{name}/{ver}{uri}', json={'data': payload}) diff --git a/backend-services/tests/test_cache_bytes_serialization.py b/backend-services/tests/test_cache_bytes_serialization.py new file mode 100644 index 0000000..7da9e78 --- /dev/null +++ b/backend-services/tests/test_cache_bytes_serialization.py @@ -0,0 +1,18 @@ +import bcrypt + +from utils.doorman_cache_util import doorman_cache + + +def test_cache_serializes_bytes_password_to_json_string(): + doc = { + 'username': 'u1', + 'password': bcrypt.hashpw(b'super-secret', bcrypt.gensalt()), # bytes + 'role': 'user', + } + + doorman_cache.set_cache('user_cache', 'u1', doc) + out = doorman_cache.get_cache('user_cache', 'u1') + + assert isinstance(out, dict) + assert isinstance(out.get('password'), str) + assert out['password'].startswith('$2b$') or out['password'].startswith('$2a$') diff --git a/backend-services/tests/test_chunked_encoding_body_limit.py b/backend-services/tests/test_chunked_encoding_body_limit.py index 4defa3f..d5284ad 100644 --- a/backend-services/tests/test_chunked_encoding_body_limit.py +++ b/backend-services/tests/test_chunked_encoding_body_limit.py @@ -7,20 +7,23 @@ vulnerability where attackers could stream unlimited data without a Content-Length header. """ -import pytest -from fastapi.testclient import TestClient import os import sys +import pytest +from fastapi.testclient import TestClient + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from doorman import doorman + @pytest.fixture def client(): """Test client fixture.""" return TestClient(doorman) + class TestChunkedEncodingBodyLimit: """Test suite for chunked encoding body size limit enforcement.""" @@ -31,10 +34,7 @@ class TestChunkedEncodingBodyLimit: response = client.post( '/platform/authorization', data=small_payload, - headers={ - 'Transfer-Encoding': 'chunked', - 'Content-Type': 'application/json' - } + headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'}, ) assert response.status_code != 413 @@ -49,10 +49,7 @@ class TestChunkedEncodingBodyLimit: response = client.post( '/platform/authorization', data=large_payload, - headers={ - 'Transfer-Encoding': 'chunked', - 'Content-Type': 'application/json' - } + headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'}, ) assert response.status_code == 413 @@ -71,10 +68,7 @@ class TestChunkedEncodingBodyLimit: response = client.post( '/api/rest/test/v1/endpoint', data=large_payload, - headers={ - 'Transfer-Encoding': 'chunked', - 'Content-Type': 'application/json' - } + headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'}, ) assert response.status_code == 413 @@ -93,10 +87,7 @@ class TestChunkedEncodingBodyLimit: response = client.post( '/api/soap/test/v1/service', data=medium_payload, - headers={ - 'Transfer-Encoding': 'chunked', - 'Content-Type': 'text/xml' - } + headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'text/xml'}, ) assert response.status_code != 413 @@ -115,9 +106,7 @@ class TestChunkedEncodingBodyLimit: response = client.post( '/platform/authorization', data=large_payload, - headers={ - 'Content-Type': 'application/json' - } + headers={'Content-Type': 'application/json'}, ) assert response.status_code == 413 @@ -139,8 +128,8 @@ class TestChunkedEncodingBodyLimit: headers={ 'Transfer-Encoding': 'chunked', 'Content-Length': '100', - 'Content-Type': 'application/json' - } + 'Content-Type': 'application/json', + }, ) assert response.status_code == 413 @@ -151,10 +140,7 @@ class TestChunkedEncodingBodyLimit: def test_get_request_with_chunked_ignored(self, client): """Test that GET requests with Transfer-Encoding: chunked are not limited.""" response = client.get( - '/platform/authorization/status', - headers={ - 'Transfer-Encoding': 'chunked' - } + '/platform/authorization/status', headers={'Transfer-Encoding': 'chunked'} ) assert response.status_code != 413 @@ -169,10 +155,7 @@ class TestChunkedEncodingBodyLimit: response = client.put( '/platform/user/testuser', data=large_payload, - headers={ - 'Transfer-Encoding': 'chunked', - 'Content-Type': 'application/json' - } + headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'}, ) assert response.status_code == 413 @@ -190,10 +173,7 @@ class TestChunkedEncodingBodyLimit: response = client.patch( '/platform/user/testuser', data=large_payload, - headers={ - 'Transfer-Encoding': 'chunked', - 'Content-Type': 'application/json' - } + headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'}, ) assert response.status_code == 413 @@ -211,10 +191,7 @@ class TestChunkedEncodingBodyLimit: response = client.post( '/api/graphql/test', data=large_query.encode(), - headers={ - 'Transfer-Encoding': 'chunked', - 'Content-Type': 'application/json' - } + headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'}, ) assert response.status_code == 413 @@ -241,10 +218,7 @@ class TestChunkedEncodingBodyLimit: response = client.post( route, data=large_payload, - headers={ - 'Transfer-Encoding': 'chunked', - 'Content-Type': 'application/json' - } + headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'}, ) assert response.status_code == 413, f'Route {route} not protected' @@ -262,10 +236,7 @@ class TestChunkedEncodingBodyLimit: response = client.post( '/platform/authorization', data=large_payload, - headers={ - 'Transfer-Encoding': 'chunked', - 'Content-Type': 'application/json' - } + headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'}, ) assert response.status_code == 413 @@ -273,5 +244,6 @@ class TestChunkedEncodingBodyLimit: finally: os.environ['MAX_BODY_SIZE_BYTES'] = '1048576' + if __name__ == '__main__': pytest.main([__file__, '-v']) diff --git a/backend-services/tests/test_compression_content_types.py b/backend-services/tests/test_compression_content_types.py index f4092aa..3317838 100644 --- a/backend-services/tests/test_compression_content_types.py +++ b/backend-services/tests/test_compression_content_types.py @@ -9,13 +9,14 @@ Verifies compression works correctly with: - Various HTTP methods """ -import pytest import gzip -import json import io +import json import os import time +import pytest + @pytest.mark.asyncio async def test_compression_with_rest_gateway_json(client): @@ -26,7 +27,7 @@ async def test_compression_with_rest_gateway_json(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) @@ -34,8 +35,7 @@ async def test_compression_with_rest_gateway_json(client): # Test JSON response (typical REST response) r = await client.get( - '/platform/api', - headers={'Accept-Encoding': 'gzip', 'Accept': 'application/json'} + '/platform/api', headers={'Accept-Encoding': 'gzip', 'Accept': 'application/json'} ) assert r.status_code == 200 @@ -54,10 +54,10 @@ async def test_compression_with_rest_gateway_json(client): compressed_size = len(compressed_buffer.getvalue()) ratio = (1 - (compressed_size / uncompressed_size)) * 100 - print(f"\nREST JSON Compression:") - print(f" Original: {uncompressed_size} bytes") - print(f" Compressed: {compressed_size} bytes") - print(f" Ratio: {ratio:.1f}% reduction") + print('\nREST JSON Compression:') + print(f' Original: {uncompressed_size} bytes') + print(f' Compressed: {compressed_size} bytes') + print(f' Ratio: {ratio:.1f}% reduction') # JSON typically compresses well assert ratio > 20 @@ -69,7 +69,8 @@ async def test_compression_with_xml_content(client): # XML is very verbose and should compress extremely well # Create a mock XML response - xml_content = """ + xml_content = ( + """ @@ -83,7 +84,9 @@ async def test_compression_with_xml_content(client): - """ * 3 # Repeat to make it larger + """ + * 3 + ) # Repeat to make it larger uncompressed_size = len(xml_content.encode('utf-8')) @@ -95,13 +98,13 @@ async def test_compression_with_xml_content(client): ratio = (1 - (compressed_size / uncompressed_size)) * 100 - print(f"\nXML/SOAP Compression:") - print(f" Original: {uncompressed_size} bytes") - print(f" Compressed: {compressed_size} bytes") - print(f" Ratio: {ratio:.1f}% reduction") + print('\nXML/SOAP Compression:') + print(f' Original: {uncompressed_size} bytes') + print(f' Compressed: {compressed_size} bytes') + print(f' Ratio: {ratio:.1f}% reduction') # XML should compress very well (lots of repetitive tags) - assert ratio > 60, f"XML should compress >60%, got {ratio:.1f}%" + assert ratio > 60, f'XML should compress >60%, got {ratio:.1f}%' @pytest.mark.asyncio @@ -111,7 +114,7 @@ async def test_compression_with_graphql_style_response(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) @@ -132,10 +135,10 @@ async def test_compression_with_graphql_style_response(client): compressed_size = len(compressed_buffer.getvalue()) ratio = (1 - (compressed_size / uncompressed_size)) * 100 - print(f"\nGraphQL-style Response Compression:") - print(f" Original: {uncompressed_size} bytes") - print(f" Compressed: {compressed_size} bytes") - print(f" Ratio: {ratio:.1f}% reduction") + print('\nGraphQL-style Response Compression:') + print(f' Original: {uncompressed_size} bytes') + print(f' Compressed: {compressed_size} bytes') + print(f' Ratio: {ratio:.1f}% reduction') @pytest.mark.asyncio @@ -150,24 +153,20 @@ async def test_compression_with_post_requests(client): 'api_allowed_groups': ['ALL'], 'api_servers': ['http://example.com'], 'api_type': 'REST', - 'active': True + 'active': True, } # Authenticate first login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) assert r_auth.status_code == 200 # POST with compression - r = await client.post( - '/platform/api', - json=api_payload, - headers={'Accept-Encoding': 'gzip'} - ) + r = await client.post('/platform/api', json=api_payload, headers={'Accept-Encoding': 'gzip'}) assert r.status_code in (200, 201) # Should get valid response @@ -176,7 +175,7 @@ async def test_compression_with_post_requests(client): # Cleanup try: - await client.delete(f"/platform/api/{api_payload['api_name']}/v1") + await client.delete(f'/platform/api/{api_payload["api_name"]}/v1') except Exception: pass @@ -187,7 +186,7 @@ async def test_compression_with_put_requests(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) @@ -203,21 +202,17 @@ async def test_compression_with_put_requests(client): 'api_allowed_groups': ['ALL'], 'api_servers': ['http://example.com'], 'api_type': 'REST', - 'active': True + 'active': True, } r = await client.post('/platform/api', json=api_payload) assert r.status_code in (200, 201) # Update with PUT - update_payload = { - 'api_description': 'Updated description with compression test' - } + update_payload = {'api_description': 'Updated description with compression test'} r = await client.put( - f'/platform/api/{api_name}/v1', - json=update_payload, - headers={'Accept-Encoding': 'gzip'} + f'/platform/api/{api_name}/v1', json=update_payload, headers={'Accept-Encoding': 'gzip'} ) assert r.status_code == 200 @@ -238,7 +233,7 @@ async def test_compression_with_delete_requests(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) @@ -254,17 +249,14 @@ async def test_compression_with_delete_requests(client): 'api_allowed_groups': ['ALL'], 'api_servers': ['http://example.com'], 'api_type': 'REST', - 'active': True + 'active': True, } r = await client.post('/platform/api', json=api_payload) assert r.status_code in (200, 201) # Delete with compression - r = await client.delete( - f'/platform/api/{api_name}/v1', - headers={'Accept-Encoding': 'gzip'} - ) + r = await client.delete(f'/platform/api/{api_name}/v1', headers={'Accept-Encoding': 'gzip'}) assert r.status_code in (200, 204) @@ -278,10 +270,7 @@ async def test_compression_with_error_responses(client): pass # Try to access protected endpoint without auth - r = await client.get( - '/platform/api', - headers={'Accept-Encoding': 'gzip'} - ) + r = await client.get('/platform/api', headers={'Accept-Encoding': 'gzip'}) # Should be unauthorized assert r.status_code in (401, 403) @@ -296,17 +285,14 @@ async def test_compression_with_list_responses(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) assert r_auth.status_code == 200 # Get list of users - r = await client.get( - '/platform/user', - headers={'Accept-Encoding': 'gzip'} - ) + r = await client.get('/platform/user', headers={'Accept-Encoding': 'gzip'}) assert r.status_code == 200 data = r.json() @@ -322,7 +308,7 @@ async def test_compression_consistent_across_content_types(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) @@ -351,8 +337,8 @@ async def test_compression_consistent_across_content_types(client): if ratios: avg_ratio = sum(ratios) / len(ratios) - print(f"\nAverage compression ratio across endpoints: {avg_ratio:.1f}%") - print(f"Individual ratios: {[f'{r:.1f}%' for r in ratios]}") + print(f'\nAverage compression ratio across endpoints: {avg_ratio:.1f}%') + print(f'Individual ratios: {[f"{r:.1f}%" for r in ratios]}') # Mark all tests diff --git a/backend-services/tests/test_compression_cpu_impact.py b/backend-services/tests/test_compression_cpu_impact.py index 8565325..4a3f241 100644 --- a/backend-services/tests/test_compression_cpu_impact.py +++ b/backend-services/tests/test_compression_cpu_impact.py @@ -8,12 +8,12 @@ Tests realistic scenarios: 4. Memory allocation patterns during compression """ -import pytest import gzip -import json import io +import json import time -import os + +import pytest def create_realistic_response(size_category): @@ -23,7 +23,7 @@ def create_realistic_response(size_category): return { 'status': 'success', 'data': {'id': 123, 'name': 'John Doe'}, - 'timestamp': '2025-01-18T10:30:00Z' + 'timestamp': '2025-01-18T10:30:00Z', } elif size_category == 'medium': # Typical REST API response (1-10 KB) @@ -35,11 +35,11 @@ def create_realistic_response(size_category): 'name': f'User {i}', 'email': f'user{i}@example.com', 'role': 'developer', - 'created_at': '2025-01-15T10:00:00Z' + 'created_at': '2025-01-15T10:00:00Z', } for i in range(50) ], - 'pagination': {'page': 1, 'total': 500} + 'pagination': {'page': 1, 'total': 500}, } elif size_category == 'large': # Large API list (10-50 KB) @@ -58,12 +58,12 @@ def create_realistic_response(size_category): 'updated_at': '2025-01-18T15:30:00Z', 'views': 1234, 'likes': 567, - 'reviews_count': 42 - } + 'reviews_count': 42, + }, } for i in range(200) ], - 'pagination': {'page': 1, 'per_page': 200, 'total': 2000} + 'pagination': {'page': 1, 'per_page': 200, 'total': 2000}, } else: # very_large # Very large response (50-100 KB) @@ -74,10 +74,10 @@ def create_realistic_response(size_category): 'id': i, 'name': f'Item {i}', 'description': 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. ' * 5, - 'attributes': {f'attr_{j}': f'value_{j}' for j in range(20)} + 'attributes': {f'attr_{j}': f'value_{j}' for j in range(20)}, } for i in range(500) - ] + ], } @@ -111,7 +111,7 @@ def benchmark_compression(data, level, iterations=1000): 'compressed_size': avg_compressed_size, 'compression_ratio': compression_ratio, 'avg_time_ms': avg_time_ms, - 'throughput_per_core': 1000 / avg_time_ms if avg_time_ms > 0 else 0 + 'throughput_per_core': 1000 / avg_time_ms if avg_time_ms > 0 else 0, } @@ -119,53 +119,59 @@ def benchmark_compression(data, level, iterations=1000): async def test_cpu_impact_by_response_size(): """Measure CPU impact across different response sizes""" - print(f"\n{'='*80}") - print(f"CPU IMPACT ANALYSIS - BY RESPONSE SIZE") - print(f"{'='*80}\n") + print(f'\n{"=" * 80}') + print('CPU IMPACT ANALYSIS - BY RESPONSE SIZE') + print(f'{"=" * 80}\n') sizes = ['small', 'medium', 'large', 'very_large'] levels = [1, 4, 6, 9] for size in sizes: data = create_realistic_response(size) - print(f"\n{size.upper()} Response:") - print(f"{'-'*80}") + print(f'\n{size.upper()} Response:') + print(f'{"-" * 80}') for level in levels: result = benchmark_compression(data, level, iterations=500) - print(f"\n Level {level}:") - print(f" Size: {result['uncompressed_size']:,} → {result['compressed_size']:,} bytes") - print(f" Compression ratio: {result['compression_ratio']:.1f}%") - print(f" CPU time: {result['avg_time_ms']:.3f} ms/request") - print(f" Max throughput: {result['throughput_per_core']:.0f} req/sec (single core)") + print(f'\n Level {level}:') + print( + f' Size: {result["uncompressed_size"]:,} → {result["compressed_size"]:,} bytes' + ) + print(f' Compression ratio: {result["compression_ratio"]:.1f}%') + print(f' CPU time: {result["avg_time_ms"]:.3f} ms/request') + print( + f' Max throughput: {result["throughput_per_core"]:.0f} req/sec (single core)' + ) @pytest.mark.asyncio async def test_cpu_overhead_on_total_request_time(): """Calculate compression overhead as % of total request time""" - print(f"\n{'='*80}") - print(f"COMPRESSION OVERHEAD AS % OF TOTAL REQUEST TIME") - print(f"{'='*80}\n") + print(f'\n{"=" * 80}') + print('COMPRESSION OVERHEAD AS % OF TOTAL REQUEST TIME') + print(f'{"=" * 80}\n') # Realistic request times for different operations base_times = { - 'health_check': 2, # Very fast - 'auth': 50, # JWT generation/verification - 'simple_query': 30, # Database lookup - 'list_query': 80, # Multiple DB queries - 'complex_query': 150, # Joins, aggregations + 'health_check': 2, # Very fast + 'auth': 50, # JWT generation/verification + 'simple_query': 30, # Database lookup + 'list_query': 80, # Multiple DB queries + 'complex_query': 150, # Joins, aggregations 'upstream_proxy': 200, # Proxying to upstream API } data = create_realistic_response('medium') - print(f"Medium response size: {len(json.dumps(data, separators=(',', ':')).encode('utf-8')):,} bytes\n") + print( + f'Medium response size: {len(json.dumps(data, separators=(",", ":")).encode("utf-8")):,} bytes\n' + ) for operation, base_time_ms in base_times.items(): - print(f"\n{operation.replace('_', ' ').title()} (base: {base_time_ms}ms):") - print(f"{'-'*80}") + print(f'\n{operation.replace("_", " ").title()} (base: {base_time_ms}ms):') + print(f'{"-" * 80}') for level in [1, 4, 6, 9]: result = benchmark_compression(data, level, iterations=500) @@ -174,37 +180,39 @@ async def test_cpu_overhead_on_total_request_time(): overhead_pct = (compression_time / total_time) * 100 throughput_reduction = (compression_time / base_time_ms) * 100 - print(f" Level {level}: {compression_time:.3f}ms → " - f"total {total_time:.1f}ms " - f"({overhead_pct:.1f}% overhead, " - f"{throughput_reduction:.1f}% slower)") + print( + f' Level {level}: {compression_time:.3f}ms → ' + f'total {total_time:.1f}ms ' + f'({overhead_pct:.1f}% overhead, ' + f'{throughput_reduction:.1f}% slower)' + ) @pytest.mark.asyncio async def test_realistic_production_scenario(): """Simulate production workload with compression""" - print(f"\n{'='*80}") - print(f"REALISTIC PRODUCTION SCENARIO") - print(f"{'='*80}\n") + print(f'\n{"=" * 80}') + print('REALISTIC PRODUCTION SCENARIO') + print(f'{"=" * 80}\n') # Realistic traffic mix workload = [ - ('small', 0.30, 10), # 30% small responses (health checks, simple GETs) - ('medium', 0.50, 40), # 50% medium responses (typical API calls) - ('large', 0.15, 100), # 15% large responses (list endpoints) - ('very_large', 0.05, 200) # 5% very large (export/reports) + ('small', 0.30, 10), # 30% small responses (health checks, simple GETs) + ('medium', 0.50, 40), # 50% medium responses (typical API calls) + ('large', 0.15, 100), # 15% large responses (list endpoints) + ('very_large', 0.05, 200), # 5% very large (export/reports) ] - print("Traffic Mix:") + print('Traffic Mix:') for size, percentage, base_time in workload: - print(f" {size:12s}: {percentage*100:>5.1f}% (base processing: {base_time}ms)") + print(f' {size:12s}: {percentage * 100:>5.1f}% (base processing: {base_time}ms)') - print(f"\n{'='*80}") + print(f'\n{"=" * 80}') for level in [1, 4, 6, 9]: - print(f"\nCompression Level {level}:") - print(f"{'-'*80}") + print(f'\nCompression Level {level}:') + print(f'{"-" * 80}') total_time_without_compression = 0 total_time_with_compression = 0 @@ -230,43 +238,49 @@ async def test_realistic_production_scenario(): total_time_with_compression += weighted_base_time + weighted_compression_time total_bytes_saved += weighted_bytes_saved - overhead_pct = ((total_time_with_compression - total_time_without_compression) / - total_time_without_compression) * 100 + overhead_pct = ( + (total_time_with_compression - total_time_without_compression) + / total_time_without_compression + ) * 100 # Calculate max RPS reduction rps_without = 1000 / total_time_without_compression rps_with = 1000 / total_time_with_compression rps_reduction_pct = ((rps_without - rps_with) / rps_without) * 100 - print(f" Avg request time: {total_time_without_compression:.1f}ms → {total_time_with_compression:.1f}ms") - print(f" CPU overhead: {overhead_pct:.1f}%") - print(f" Max RPS (1 core): {rps_without:.1f} → {rps_with:.1f} ({rps_reduction_pct:.1f}% reduction)") - print(f" Avg bytes saved: {total_bytes_saved:.0f} bytes/request") + print( + f' Avg request time: {total_time_without_compression:.1f}ms → {total_time_with_compression:.1f}ms' + ) + print(f' CPU overhead: {overhead_pct:.1f}%') + print( + f' Max RPS (1 core): {rps_without:.1f} → {rps_with:.1f} ({rps_reduction_pct:.1f}% reduction)' + ) + print(f' Avg bytes saved: {total_bytes_saved:.0f} bytes/request') @pytest.mark.asyncio async def test_two_vcpu_capacity_analysis(): """Calculate realistic capacity for 2 vCPU instance""" - print(f"\n{'='*80}") - print(f"2 vCPU AWS LIGHTSAIL CAPACITY ANALYSIS") - print(f"{'='*80}\n") + print(f'\n{"=" * 80}') + print('2 vCPU AWS LIGHTSAIL CAPACITY ANALYSIS') + print(f'{"=" * 80}\n') # Workload parameters workload = [ - ('small', 0.30, 10, False), # Not compressed - ('medium', 0.50, 40, True), # Compressed - ('large', 0.15, 100, True), # Compressed - ('very_large', 0.05, 200, True) # Compressed + ('small', 0.30, 10, False), # Not compressed + ('medium', 0.50, 40, True), # Compressed + ('large', 0.15, 100, True), # Compressed + ('very_large', 0.05, 200, True), # Compressed ] - print("AWS Lightsail 1GB RAM, 2 vCPUs") - print("Single worker mode (MEM_OR_EXTERNAL=MEM)") - print(f"\n{'='*80}\n") + print('AWS Lightsail 1GB RAM, 2 vCPUs') + print('Single worker mode (MEM_OR_EXTERNAL=MEM)') + print(f'\n{"=" * 80}\n') for level in [1, 4, 6, 9]: - print(f"Compression Level {level}:") - print(f"{'-'*80}") + print(f'Compression Level {level}:') + print(f'{"-" * 80}') total_cpu_time = 0 total_bytes_uncompressed = 0 @@ -302,24 +316,28 @@ async def test_two_vcpu_capacity_analysis(): realistic_rps = max_rps_two_cores * 1.3 # Async bonus # Transfer limits - avg_transfer_per_request = (total_bytes_uncompressed + total_bytes_compressed) / 2 # In + Out + avg_transfer_per_request = ( + total_bytes_uncompressed + total_bytes_compressed + ) / 2 # In + Out monthly_requests_at_1tb = (1000 * 1024 * 1024 * 1024) / avg_transfer_per_request monthly_rps_limit = monthly_requests_at_1tb / (30 * 24 * 60 * 60) compression_ratio = (1 - total_bytes_compressed / total_bytes_uncompressed) * 100 - print(f" CPU time per request: {total_cpu_time:.1f} ms") - print(f" Max RPS (CPU-limited): {realistic_rps:.1f} RPS") - print(f" Avg response size: {total_bytes_uncompressed:.0f} → {total_bytes_compressed:.0f} bytes") - print(f" Compression ratio: {compression_ratio:.1f}%") - print(f" Transfer per request: {avg_transfer_per_request:.0f} bytes (req+resp)") - print(f" Max RPS (transfer-limited): {monthly_rps_limit:.1f} RPS") - print(f" Monthly capacity (1TB): {monthly_requests_at_1tb/1_000_000:.1f}M requests") + print(f' CPU time per request: {total_cpu_time:.1f} ms') + print(f' Max RPS (CPU-limited): {realistic_rps:.1f} RPS') + print( + f' Avg response size: {total_bytes_uncompressed:.0f} → {total_bytes_compressed:.0f} bytes' + ) + print(f' Compression ratio: {compression_ratio:.1f}%') + print(f' Transfer per request: {avg_transfer_per_request:.0f} bytes (req+resp)') + print(f' Max RPS (transfer-limited): {monthly_rps_limit:.1f} RPS') + print(f' Monthly capacity (1TB): {monthly_requests_at_1tb / 1_000_000:.1f}M requests') if monthly_rps_limit < realistic_rps: - print(f" ⚠️ BOTTLENECK: Transfer (CPU can handle {realistic_rps:.1f} RPS)") + print(f' ⚠️ BOTTLENECK: Transfer (CPU can handle {realistic_rps:.1f} RPS)') else: - print(f" ⚠️ BOTTLENECK: CPU (transfer allows {monthly_rps_limit:.1f} RPS)") + print(f' ⚠️ BOTTLENECK: CPU (transfer allows {monthly_rps_limit:.1f} RPS)') print() @@ -327,17 +345,17 @@ async def test_two_vcpu_capacity_analysis(): async def test_recommended_production_level(): """Determine optimal compression level for production""" - print(f"\n{'='*80}") - print(f"PRODUCTION COMPRESSION LEVEL RECOMMENDATION") - print(f"{'='*80}\n") + print(f'\n{"=" * 80}') + print('PRODUCTION COMPRESSION LEVEL RECOMMENDATION') + print(f'{"=" * 80}\n') data = create_realistic_response('medium') - print("Criteria:") - print(" 1. Maximize bandwidth savings") - print(" 2. Minimize CPU overhead") - print(" 3. Balance throughput vs. transfer") - print(f"\n{'='*80}\n") + print('Criteria:') + print(' 1. Maximize bandwidth savings') + print(' 2. Minimize CPU overhead') + print(' 3. Balance throughput vs. transfer') + print(f'\n{"=" * 80}\n') recommendations = [] @@ -348,26 +366,28 @@ async def test_recommended_production_level(): # Higher compression ratio is good, lower CPU time is good efficiency = result['compression_ratio'] / result['avg_time_ms'] - recommendations.append({ - 'level': level, - 'ratio': result['compression_ratio'], - 'time': result['avg_time_ms'], - 'efficiency': efficiency - }) + recommendations.append( + { + 'level': level, + 'ratio': result['compression_ratio'], + 'time': result['avg_time_ms'], + 'efficiency': efficiency, + } + ) - print(f"Level {level}:") - print(f" Compression ratio: {result['compression_ratio']:.1f}%") - print(f" CPU time: {result['avg_time_ms']:.3f} ms") - print(f" Efficiency score: {efficiency:.1f}") + print(f'Level {level}:') + print(f' Compression ratio: {result["compression_ratio"]:.1f}%') + print(f' CPU time: {result["avg_time_ms"]:.3f} ms') + print(f' Efficiency score: {efficiency:.1f}') print() # Find best efficiency best = max(recommendations, key=lambda x: x['efficiency']) - print(f"{'='*80}") - print(f"RECOMMENDATION: Level {best['level']}") - print(f" Best efficiency score: {best['efficiency']:.1f}") - print(f" Compression: {best['ratio']:.1f}%") - print(f" CPU cost: {best['time']:.3f} ms") + print(f'{"=" * 80}') + print(f'RECOMMENDATION: Level {best["level"]}') + print(f' Best efficiency score: {best["efficiency"]:.1f}') + print(f' Compression: {best["ratio"]:.1f}%') + print(f' CPU cost: {best["time"]:.3f} ms') pytestmark = [pytest.mark.benchmark, pytest.mark.cpu] diff --git a/backend-services/tests/test_compression_size_reduction.py b/backend-services/tests/test_compression_size_reduction.py index f70dd55..8f416ca 100644 --- a/backend-services/tests/test_compression_size_reduction.py +++ b/backend-services/tests/test_compression_size_reduction.py @@ -7,11 +7,12 @@ These tests verify that: 3. Compression settings affect the compression ratio """ -import pytest import gzip +import io import json import os -import io + +import pytest @pytest.mark.asyncio @@ -20,7 +21,7 @@ async def test_json_compression_ratio(client): # Authenticate to get access to endpoints login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) @@ -44,15 +45,15 @@ async def test_json_compression_ratio(client): # Calculate compression ratio compression_ratio = (1 - (compressed_size / uncompressed_size)) * 100 - print(f"\nJSON Compression Stats:") - print(f" Uncompressed: {uncompressed_size} bytes") - print(f" Compressed: {compressed_size} bytes") - print(f" Ratio: {compression_ratio:.1f}% reduction") + print('\nJSON Compression Stats:') + print(f' Uncompressed: {uncompressed_size} bytes') + print(f' Compressed: {compressed_size} bytes') + print(f' Ratio: {compression_ratio:.1f}% reduction') # JSON should compress well (typically 60-80%) # But only if response is large enough if uncompressed_size > 500: - assert compression_ratio > 30, f"Expected >30% compression, got {compression_ratio:.1f}%" + assert compression_ratio > 30, f'Expected >30% compression, got {compression_ratio:.1f}%' @pytest.mark.asyncio @@ -61,7 +62,7 @@ async def test_large_list_compression(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) @@ -78,11 +79,11 @@ async def test_large_list_compression(client): 'api_allowed_groups': ['ALL'], 'api_servers': ['http://example.com'], 'api_type': 'REST', - 'active': True + 'active': True, } r = await client.post('/platform/api', json=api_payload) if r.status_code in (200, 201): - test_apis.append(f"compression-test-{i}") + test_apis.append(f'compression-test-{i}') # Get the full list r = await client.get('/platform/api') @@ -100,10 +101,10 @@ async def test_large_list_compression(client): compression_ratio = (1 - (compressed_size / uncompressed_size)) * 100 - print(f"\nLarge List Compression Stats:") - print(f" Uncompressed: {uncompressed_size} bytes") - print(f" Compressed: {compressed_size} bytes") - print(f" Ratio: {compression_ratio:.1f}% reduction") + print('\nLarge List Compression Stats:') + print(f' Uncompressed: {uncompressed_size} bytes') + print(f' Compressed: {compressed_size} bytes') + print(f' Ratio: {compression_ratio:.1f}% reduction') # Cleanup for api_name in test_apis: @@ -114,7 +115,7 @@ async def test_large_list_compression(client): # Should achieve good compression on repeated data if uncompressed_size > 1000: - assert compression_ratio > 40, f"Expected >40% compression for large list" + assert compression_ratio > 40, 'Expected >40% compression for large list' @pytest.mark.asyncio @@ -123,19 +124,14 @@ async def test_compression_bandwidth_savings(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) assert r_auth.status_code == 200 # Test multiple endpoints - endpoints = [ - '/platform/api', - '/platform/user', - '/platform/role', - '/platform/group', - ] + endpoints = ['/platform/api', '/platform/user', '/platform/role', '/platform/group'] total_uncompressed = 0 total_compressed = 0 @@ -161,14 +157,14 @@ async def test_compression_bandwidth_savings(client): if total_uncompressed > 0: overall_ratio = (1 - (total_compressed / total_uncompressed)) * 100 - print(f"\nOverall Bandwidth Savings:") - print(f" Total uncompressed: {total_uncompressed} bytes") - print(f" Total compressed: {total_compressed} bytes") - print(f" Overall ratio: {overall_ratio:.1f}% reduction") - print(f" Bandwidth saved: {total_uncompressed - total_compressed} bytes") + print('\nOverall Bandwidth Savings:') + print(f' Total uncompressed: {total_uncompressed} bytes') + print(f' Total compressed: {total_compressed} bytes') + print(f' Overall ratio: {overall_ratio:.1f}% reduction') + print(f' Bandwidth saved: {total_uncompressed - total_compressed} bytes') # Should see significant savings - assert overall_ratio > 0, "Compression should reduce size" + assert overall_ratio > 0, 'Compression should reduce size' @pytest.mark.asyncio @@ -177,7 +173,7 @@ async def test_compression_level_affects_ratio(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) @@ -199,15 +195,12 @@ async def test_compression_level_affects_ratio(client): gz.write(json_str.encode('utf-8')) compressed_size = len(compressed_buffer.getvalue()) ratio = (1 - (compressed_size / uncompressed_size)) * 100 - compression_results[level] = { - 'size': compressed_size, - 'ratio': ratio - } + compression_results[level] = {'size': compressed_size, 'ratio': ratio} - print(f"\nCompression Level Comparison:") - print(f" Uncompressed: {uncompressed_size} bytes") + print('\nCompression Level Comparison:') + print(f' Uncompressed: {uncompressed_size} bytes') for level, result in compression_results.items(): - print(f" Level {level}: {result['size']} bytes ({result['ratio']:.1f}% reduction)") + print(f' Level {level}: {result["size"]} bytes ({result["ratio"]:.1f}% reduction)') # Higher compression levels should achieve better (or equal) compression # Level 9 should be <= Level 6 <= Level 1 in size @@ -226,7 +219,7 @@ async def test_minimum_size_threshold(client): response_content = r.content response_size = len(response_content) - print(f"\nSmall Response Size: {response_size} bytes") + print(f'\nSmall Response Size: {response_size} bytes') # If response is smaller than 500 bytes (default minimum_size), # compressing it may not be worth the CPU overhead @@ -239,7 +232,7 @@ async def test_compression_transfer_savings_calculation(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) @@ -269,12 +262,12 @@ async def test_compression_transfer_savings_calculation(client): monthly_savings_mb = monthly_savings_bytes / (1024 * 1024) monthly_savings_gb = monthly_savings_mb / 1024 - print(f"\nTransfer Savings Estimate:") - print(f" Compression ratio: {compression_ratio:.1f}%") - print(f" Bytes saved per request: {bytes_saved_per_request}") - print(f" Monthly requests: {requests_per_month:,}") - print(f" Monthly bandwidth saved: {monthly_savings_gb:.2f} GB") - print(f" Annual bandwidth saved: {monthly_savings_gb * 12:.2f} GB") + print('\nTransfer Savings Estimate:') + print(f' Compression ratio: {compression_ratio:.1f}%') + print(f' Bytes saved per request: {bytes_saved_per_request}') + print(f' Monthly requests: {requests_per_month:,}') + print(f' Monthly bandwidth saved: {monthly_savings_gb:.2f} GB') + print(f' Annual bandwidth saved: {monthly_savings_gb * 12:.2f} GB') # Should save significant bandwidth assert bytes_saved_per_request >= 0 diff --git a/backend-services/tests/test_config_import_export_extended.py b/backend-services/tests/test_config_import_export_extended.py index 691d253..2de71e1 100644 --- a/backend-services/tests/test_config_import_export_extended.py +++ b/backend-services/tests/test_config_import_export_extended.py @@ -1,6 +1,6 @@ -import json import pytest + @pytest.mark.asyncio async def test_export_all_basic(authed_client): r = await authed_client.get('/platform/config/export/all') @@ -12,79 +12,132 @@ async def test_export_all_basic(authed_client): assert isinstance(data.get('routings'), list) assert isinstance(data.get('endpoints'), list) + @pytest.mark.asyncio -@pytest.mark.parametrize('path', [ - '/platform/config/export/apis', - '/platform/config/export/roles', - '/platform/config/export/groups', - '/platform/config/export/routings', - '/platform/config/export/endpoints', -]) +@pytest.mark.parametrize( + 'path', + [ + '/platform/config/export/apis', + '/platform/config/export/roles', + '/platform/config/export/groups', + '/platform/config/export/routings', + '/platform/config/export/endpoints', + ], +) async def test_export_lists(authed_client, path): r = await authed_client.get(path) assert r.status_code == 200 + @pytest.mark.asyncio async def test_export_single_api_with_endpoints(authed_client): async def _create_api(c, n, v): - payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0} + payload = { + 'api_name': n, + 'api_version': v, + 'api_description': f'{n} {v}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://upstream'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + } rr = await c.post('/platform/api', json=payload) assert rr.status_code in (200, 201) + async def _create_endpoint(c, n, v, m, u): - payload = {'api_name': n, 'api_version': v, 'endpoint_method': m, 'endpoint_uri': u, 'endpoint_description': f'{m} {u}'} + payload = { + 'api_name': n, + 'api_version': v, + 'endpoint_method': m, + 'endpoint_uri': u, + 'endpoint_description': f'{m} {u}', + } rr = await c.post('/platform/endpoint', json=payload) assert rr.status_code in (200, 201) + await _create_api(authed_client, 'exapi', 'v1') await _create_endpoint(authed_client, 'exapi', 'v1', 'GET', '/status') - r = await authed_client.get('/platform/config/export/apis', params={'api_name': 'exapi', 'api_version': 'v1'}) + r = await authed_client.get( + '/platform/config/export/apis', params={'api_name': 'exapi', 'api_version': 'v1'} + ) assert r.status_code == 200 payload = r.json().get('response') or r.json() assert payload.get('api') and payload.get('endpoints') is not None + @pytest.mark.asyncio async def test_export_endpoints_filter(authed_client): async def _create_api(c, n, v): - payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0} + payload = { + 'api_name': n, + 'api_version': v, + 'api_description': f'{n} {v}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://upstream'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + } rr = await c.post('/platform/api', json=payload) assert rr.status_code in (200, 201) + async def _create_endpoint(c, n, v, m, u): - payload = {'api_name': n, 'api_version': v, 'endpoint_method': m, 'endpoint_uri': u, 'endpoint_description': f'{m} {u}'} + payload = { + 'api_name': n, + 'api_version': v, + 'endpoint_method': m, + 'endpoint_uri': u, + 'endpoint_description': f'{m} {u}', + } rr = await c.post('/platform/endpoint', json=payload) assert rr.status_code in (200, 201) + await _create_api(authed_client, 'filterapi', 'v1') await _create_endpoint(authed_client, 'filterapi', 'v1', 'GET', '/x') - r = await authed_client.get('/platform/config/export/endpoints', params={'api_name': 'filterapi', 'api_version': 'v1'}) + r = await authed_client.get( + '/platform/config/export/endpoints', params={'api_name': 'filterapi', 'api_version': 'v1'} + ) assert r.status_code == 200 eps = (r.json().get('response') or r.json()).get('endpoints') assert isinstance(eps, list) and len(eps) >= 1 + @pytest.mark.asyncio -@pytest.mark.parametrize('sections', [ - {'apis': []}, - {'roles': []}, - {'groups': []}, - {'routings': []}, - {'endpoints': []}, - {'apis': [], 'endpoints': []}, - {'roles': [], 'groups': []}, -]) +@pytest.mark.parametrize( + 'sections', + [ + {'apis': []}, + {'roles': []}, + {'groups': []}, + {'routings': []}, + {'endpoints': []}, + {'apis': [], 'endpoints': []}, + {'roles': [], 'groups': []}, + ], +) async def test_import_various_sections(authed_client, sections): r = await authed_client.post('/platform/config/import', json=sections) assert r.status_code == 200 + @pytest.mark.asyncio async def test_security_restart_pid_missing(authed_client): r = await authed_client.post('/platform/security/restart') assert r.status_code in (202, 409, 403) + @pytest.mark.asyncio async def test_audit_called_on_export(monkeypatch, authed_client): calls = [] import utils.audit_util as au + orig = au._logger + class _L: def info(self, msg): calls.append(msg) + au._logger = _L() try: r = await authed_client.get('/platform/config/export/all') diff --git a/backend-services/tests/test_config_import_tolerates_malformed.py b/backend-services/tests/test_config_import_tolerates_malformed.py new file mode 100644 index 0000000..1e1ff86 --- /dev/null +++ b/backend-services/tests/test_config_import_tolerates_malformed.py @@ -0,0 +1,31 @@ +import pytest + + +@pytest.mark.asyncio +async def test_config_import_ignores_malformed_entries(authed_client): + body = { + 'apis': [ + {'api_name': 'x-only'}, # missing version; ignored + {'api_version': 'v1'}, # missing name; ignored + ], + 'endpoints': [ + {'api_name': 'x', 'endpoint_method': 'GET'} # missing api_version/uri + ], + 'roles': [ + {'bad': 'doc'} # missing role_name + ], + 'groups': [ + {'bad': 'doc'} # missing group_name + ], + 'routings': [ + {'bad': 'doc'} # missing client_key + ], + } + r = await authed_client.post('/platform/config/import', json=body) + assert r.status_code == 200 + payload = r.json().get('response', r.json()) + payload.get('imported') or {} + # Import reports how many items were processed, not how many were actually upserted. + # Verify no valid API was created as a result of malformed entries. + bad_get = await authed_client.get('/platform/api/x-only/v1') + assert bad_get.status_code in (400, 404) diff --git a/backend-services/tests/test_config_permission_specific.py b/backend-services/tests/test_config_permission_specific.py new file mode 100644 index 0000000..f441e0d --- /dev/null +++ b/backend-services/tests/test_config_permission_specific.py @@ -0,0 +1,52 @@ +import pytest +from httpx import AsyncClient + + +async def _login(email: str, password: str) -> AsyncClient: + from doorman import doorman + + c = AsyncClient(app=doorman, base_url='http://testserver') + r = await c.post('/platform/authorization', json={'email': email, 'password': password}) + assert r.status_code == 200, r.text + body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} + token = body.get('access_token') + if token: + c.cookies.set('access_token_cookie', token, domain='testserver', path='/') + return c + + +@pytest.mark.asyncio +async def test_export_apis_allowed_but_roles_forbidden(authed_client): + uname = 'mgr_apis_only' + pwd = 'PerM1ssionsMore!!' + # Create role granting manage_apis only + role_name = 'apis_manager' + cr = await authed_client.post( + '/platform/role', + json={'role_name': role_name, 'role_description': 'Can export APIs', 'manage_apis': True}, + ) + assert cr.status_code in (200, 201), cr.text + + # Create user assigned to that role + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': 'mgr_apis@example.com', + 'password': pwd, + 'role': role_name, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu.status_code in (200, 201) + + client = await _login('mgr_apis@example.com', pwd) + + # Allowed + ok = await client.get('/platform/config/export/apis') + assert ok.status_code == 200 + + # Forbidden + no = await client.get('/platform/config/export/roles') + assert no.status_code == 403 diff --git a/backend-services/tests/test_config_permission_specific_more.py b/backend-services/tests/test_config_permission_specific_more.py new file mode 100644 index 0000000..b3bffa1 --- /dev/null +++ b/backend-services/tests/test_config_permission_specific_more.py @@ -0,0 +1,137 @@ +import time + +import pytest +from httpx import AsyncClient + + +async def _login(email: str, password: str) -> AsyncClient: + from doorman import doorman + + c = AsyncClient(app=doorman, base_url='http://testserver') + r = await c.post('/platform/authorization', json={'email': email, 'password': password}) + assert r.status_code == 200, r.text + body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} + token = body.get('access_token') + if token: + c.cookies.set('access_token_cookie', token, domain='testserver', path='/') + return c + + +@pytest.mark.asyncio +async def test_export_roles_allowed_manage_roles_only(authed_client): + uname = f'roles_mgr_{int(time.time())}' + role_name = f'roles_manager_{int(time.time())}' + pwd = 'ManAgeRoleSStrong1!!' + + cr = await authed_client.post( + '/platform/role', json={'role_name': role_name, 'manage_roles': True} + ) + assert cr.status_code in (200, 201) + + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': f'{uname}@example.com', + 'password': pwd, + 'role': role_name, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu.status_code in (200, 201), cu.text + + client = await _login(f'{uname}@example.com', pwd) + # Allowed + ok = await client.get('/platform/config/export/roles') + assert ok.status_code == 200 + # Forbidden + no = await client.get('/platform/config/export/apis') + assert no.status_code == 403 + + +@pytest.mark.asyncio +async def test_export_groups_allowed_manage_groups_only(authed_client): + uname = f'groups_mgr_{int(time.time())}' + role_name = f'groups_manager_{int(time.time())}' + pwd = 'ManAgeGroupSStrong1!!' + + cr = await authed_client.post( + '/platform/role', json={'role_name': role_name, 'manage_groups': True} + ) + assert cr.status_code in (200, 201) + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': f'{uname}@example.com', + 'password': pwd, + 'role': role_name, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu.status_code in (200, 201) + client = await _login(f'{uname}@example.com', pwd) + ok = await client.get('/platform/config/export/groups') + assert ok.status_code == 200 + no = await client.get('/platform/config/export/roles') + assert no.status_code == 403 + + +@pytest.mark.asyncio +async def test_export_routings_allowed_manage_routings_only(authed_client): + uname = f'rout_mgr_{int(time.time())}' + role_name = f'rout_manager_{int(time.time())}' + pwd = 'ManAgeRoutIngS1!!' + + cr = await authed_client.post( + '/platform/role', json={'role_name': role_name, 'manage_routings': True} + ) + assert cr.status_code in (200, 201) + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': f'{uname}@example.com', + 'password': pwd, + 'role': role_name, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu.status_code in (200, 201) + client = await _login(f'{uname}@example.com', pwd) + ok = await client.get('/platform/config/export/routings') + assert ok.status_code == 200 + no = await client.get('/platform/config/export/endpoints') + assert no.status_code == 403 + + +@pytest.mark.asyncio +async def test_export_all_and_import_allowed_manage_gateway(authed_client): + uname = f'gate_mgr_{int(time.time())}' + role_name = f'gate_manager_{int(time.time())}' + pwd = 'ManAgeGateWayStrong1!!' + + cr = await authed_client.post( + '/platform/role', json={'role_name': role_name, 'manage_gateway': True} + ) + assert cr.status_code in (200, 201) + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': f'{uname}@example.com', + 'password': pwd, + 'role': role_name, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu.status_code in (200, 201) + client = await _login(f'{uname}@example.com', pwd) + ok = await client.get('/platform/config/export/all') + assert ok.status_code == 200 + imp = await client.post('/platform/config/import', json={'apis': []}) + assert imp.status_code == 200 diff --git a/backend-services/tests/test_config_permissions_matrix.py b/backend-services/tests/test_config_permissions_matrix.py new file mode 100644 index 0000000..dd09764 --- /dev/null +++ b/backend-services/tests/test_config_permissions_matrix.py @@ -0,0 +1,61 @@ +import pytest +from httpx import AsyncClient + + +async def _login(email: str, password: str) -> AsyncClient: + from doorman import doorman + + c = AsyncClient(app=doorman, base_url='http://testserver') + r = await c.post('/platform/authorization', json={'email': email, 'password': password}) + assert r.status_code == 200, r.text + body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} + token = body.get('access_token') + if token: + c.cookies.set('access_token_cookie', token, domain='testserver', path='/') + return c + + +@pytest.mark.asyncio +async def test_config_export_import_forbidden_for_non_admin(authed_client): + # Create a limited user + uname = 'limited_user' + pwd = 'limited-password-1A!' + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': 'limited@example.com', + 'password': pwd, + 'role': 'user', + 'groups': ['ALL'], + # No manage_* permissions + 'ui_access': True, + 'manage_users': False, + 'manage_apis': False, + 'manage_endpoints': False, + 'manage_groups': False, + 'manage_roles': False, + 'manage_routings': False, + 'manage_gateway': False, + }, + ) + assert cu.status_code in (200, 201), cu.text + + user_client = await _login('limited@example.com', pwd) + + # Export endpoints require manage_* permissions – assert 403s + endpoints = [ + '/platform/config/export/all', + '/platform/config/export/apis', + '/platform/config/export/roles', + '/platform/config/export/groups', + '/platform/config/export/routings', + '/platform/config/export/endpoints', + ] + for ep in endpoints: + r = await user_client.get(ep) + assert r.status_code == 403 + + # Import also requires manage_gateway + imp = await user_client.post('/platform/config/import', json={'apis': []}) + assert imp.status_code == 403 diff --git a/backend-services/tests/test_cookie_domain.py b/backend-services/tests/test_cookie_domain.py index d60ebc5..be36a60 100644 --- a/backend-services/tests/test_cookie_domain.py +++ b/backend-services/tests/test_cookie_domain.py @@ -1,9 +1,10 @@ import os + import pytest + @pytest.mark.asyncio async def test_cookie_domain_set_on_login(client, monkeypatch): - monkeypatch.setenv('COOKIE_DOMAIN', 'testserver') resp = await client.post( '/platform/authorization', diff --git a/backend-services/tests/test_cookies_policy.py b/backend-services/tests/test_cookies_policy.py index 6a5cc68..e5d8a3b 100644 --- a/backend-services/tests/test_cookies_policy.py +++ b/backend-services/tests/test_cookies_policy.py @@ -1,7 +1,9 @@ import os import re + import pytest + def _collect_set_cookie_headers(resp): try: return resp.headers.get_list('set-cookie') @@ -9,64 +11,97 @@ def _collect_set_cookie_headers(resp): raw = resp.headers.get('set-cookie') or '' return [h.strip() for h in raw.split(',') if 'Expires=' not in h] or ([raw] if raw else []) + def _find_cookie_lines(lines, name): name_l = name.lower() + '=' return [l for l in lines if name_l in l.lower()] + @pytest.mark.asyncio async def test_default_samesite_strict_and_secure_false(monkeypatch, client): monkeypatch.delenv('COOKIE_SAMESITE', raising=False) monkeypatch.setenv('HTTPS_ONLY', 'false') - r = await client.post('/platform/authorization', json={'email': os.environ['DOORMAN_ADMIN_EMAIL'], 'password': os.environ['DOORMAN_ADMIN_PASSWORD']}) + r = await client.post( + '/platform/authorization', + json={ + 'email': os.environ['DOORMAN_ADMIN_EMAIL'], + 'password': os.environ['DOORMAN_ADMIN_PASSWORD'], + }, + ) assert r.status_code == 200 cookies = _collect_set_cookie_headers(r) atk = _find_cookie_lines(cookies, 'access_token_cookie') csrf = _find_cookie_lines(cookies, 'csrf_token') assert atk and csrf + def has_attr(lines, pattern): return any(re.search(pattern, l, flags=re.I) for l in lines) - assert has_attr(atk, r"samesite\s*=\s*strict") - assert has_attr(csrf, r"samesite\s*=\s*strict") - assert not has_attr(atk, r";\s*secure(\s*;|$)") - assert not has_attr(csrf, r";\s*secure(\s*;|$)") + + assert has_attr(atk, r'samesite\s*=\s*strict') + assert has_attr(csrf, r'samesite\s*=\s*strict') + assert not has_attr(atk, r';\s*secure(\s*;|$)') + assert not has_attr(csrf, r';\s*secure(\s*;|$)') + @pytest.mark.asyncio async def test_cookies_samesite_lax_override(monkeypatch, client): monkeypatch.setenv('COOKIE_SAMESITE', 'Lax') monkeypatch.setenv('HTTPS_ONLY', 'false') - r = await client.post('/platform/authorization', json={'email': os.environ['DOORMAN_ADMIN_EMAIL'], 'password': os.environ['DOORMAN_ADMIN_PASSWORD']}) + r = await client.post( + '/platform/authorization', + json={ + 'email': os.environ['DOORMAN_ADMIN_EMAIL'], + 'password': os.environ['DOORMAN_ADMIN_PASSWORD'], + }, + ) assert r.status_code == 200 cookies = _collect_set_cookie_headers(r) atk = _find_cookie_lines(cookies, 'access_token_cookie') csrf = _find_cookie_lines(cookies, 'csrf_token') assert atk and csrf + def has_attr(lines, pattern): return any(re.search(pattern, l, flags=re.I) for l in lines) - assert has_attr(atk, r"samesite\s*=\s*lax") - assert has_attr(csrf, r"samesite\s*=\s*lax") + + assert has_attr(atk, r'samesite\s*=\s*lax') + assert has_attr(csrf, r'samesite\s*=\s*lax') + @pytest.mark.asyncio async def test_secure_flag_toggles_with_https(monkeypatch, client): monkeypatch.setenv('COOKIE_SAMESITE', 'None') monkeypatch.setenv('HTTPS_ONLY', 'false') - r1 = await client.post('/platform/authorization', json={'email': os.environ['DOORMAN_ADMIN_EMAIL'], 'password': os.environ['DOORMAN_ADMIN_PASSWORD']}) + r1 = await client.post( + '/platform/authorization', + json={ + 'email': os.environ['DOORMAN_ADMIN_EMAIL'], + 'password': os.environ['DOORMAN_ADMIN_PASSWORD'], + }, + ) assert r1.status_code == 200 cookies1 = _collect_set_cookie_headers(r1) atk1 = _find_cookie_lines(cookies1, 'access_token_cookie') csrf1 = _find_cookie_lines(cookies1, 'csrf_token') + def has_attr(lines, pattern): return any(re.search(pattern, l, flags=re.I) for l in lines) - assert not has_attr(atk1, r";\s*secure(\s*;|$)") - assert not has_attr(csrf1, r";\s*secure(\s*;|$)") + + assert not has_attr(atk1, r';\s*secure(\s*;|$)') + assert not has_attr(csrf1, r';\s*secure(\s*;|$)') monkeypatch.setenv('HTTPS_ONLY', 'true') - r2 = await client.post('/platform/authorization', json={'email': os.environ['DOORMAN_ADMIN_EMAIL'], 'password': os.environ['DOORMAN_ADMIN_PASSWORD']}) + r2 = await client.post( + '/platform/authorization', + json={ + 'email': os.environ['DOORMAN_ADMIN_EMAIL'], + 'password': os.environ['DOORMAN_ADMIN_PASSWORD'], + }, + ) assert r2.status_code == 200 cookies2 = _collect_set_cookie_headers(r2) atk2 = _find_cookie_lines(cookies2, 'access_token_cookie') csrf2 = _find_cookie_lines(cookies2, 'csrf_token') - assert has_attr(atk2, r";\s*secure(\s*;|$)") - assert has_attr(csrf2, r";\s*secure(\s*;|$)") - + assert has_attr(atk2, r';\s*secure(\s*;|$)') + assert has_attr(csrf2, r';\s*secure(\s*;|$)') diff --git a/backend-services/tests/test_credit_definition_masking.py b/backend-services/tests/test_credit_definition_masking.py new file mode 100644 index 0000000..5dad242 --- /dev/null +++ b/backend-services/tests/test_credit_definition_masking.py @@ -0,0 +1,35 @@ +import pytest + + +@pytest.mark.asyncio +async def test_credit_definition_masking(authed_client): + group = 'maskgroup' + create = await authed_client.post( + '/platform/credit', + json={ + 'api_credit_group': group, + 'api_key': 'VERY-SECRET-KEY', + 'api_key_header': 'x-api-key', + 'credit_tiers': [ + { + 'tier_name': 'default', + 'credits': 5, + 'input_limit': 0, + 'output_limit': 0, + 'reset_frequency': 'monthly', + } + ], + }, + ) + assert create.status_code in (200, 201), create.text + + r = await authed_client.get(f'/platform/credit/defs/{group}') + assert r.status_code == 200, r.text + body = r.json().get('response', r.json()) + + # Masking rules + assert body.get('api_credit_group') == group + assert body.get('api_key_header') == 'x-api-key' + assert body.get('api_key_present') is True + # Under no circumstance should the API key material be returned + assert 'api_key' not in body diff --git a/backend-services/tests/test_credit_key_rotation.py b/backend-services/tests/test_credit_key_rotation.py new file mode 100644 index 0000000..881ef7e --- /dev/null +++ b/backend-services/tests/test_credit_key_rotation.py @@ -0,0 +1,36 @@ +from datetime import UTC, datetime, timedelta + +import pytest + + +@pytest.mark.asyncio +async def test_credit_key_rotation_logic_resolves_header_and_keys(): + from utils.credit_util import get_credit_api_header + from utils.database import credit_def_collection + + group = 'rotgrp' + credit_def_collection.delete_one({'api_credit_group': group}) + # Insert with rotation in the future + credit_def_collection.insert_one( + { + 'api_credit_group': group, + 'api_key_header': 'x-api-key', + 'api_key': 'old-key', + 'api_key_new': 'new-key', + 'api_key_rotation_expires': datetime.now(UTC) + timedelta(hours=1), + } + ) + + hdr = await get_credit_api_header(group) + assert hdr and hdr[0] == 'x-api-key' + assert isinstance(hdr[1], list) + assert hdr[1][0] == 'old-key' and hdr[1][1] == 'new-key' + + # After rotation expiry, only new key should be returned + credit_def_collection.update_one( + {'api_credit_group': group}, + {'$set': {'api_key_rotation_expires': datetime.now(UTC) - timedelta(seconds=1)}}, + ) + hdr2 = await get_credit_api_header(group) + assert hdr2 and hdr2[0] == 'x-api-key' + assert hdr2[1] == 'new-key' diff --git a/backend-services/tests/test_credits_injection_and_deduction.py b/backend-services/tests/test_credits_injection_and_deduction.py index a15d653..3dfa852 100644 --- a/backend-services/tests/test_credits_injection_and_deduction.py +++ b/backend-services/tests/test_credits_injection_and_deduction.py @@ -1,8 +1,17 @@ import pytest - from tests.test_gateway_routing_limits import _FakeAsyncClient -async def _setup_api_with_credits(client, name='cr', ver='v1', public=False, group='g1', header='X-API-Key', def_key='DEFKEY', enable=True): + +async def _setup_api_with_credits( + client, + name='cr', + ver='v1', + public=False, + group='g1', + header='X-API-Key', + def_key='DEFKEY', + enable=True, +): payload = { 'api_name': name, 'api_version': ver, @@ -18,94 +27,123 @@ async def _setup_api_with_credits(client, name='cr', ver='v1', public=False, gro } r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'GET', - 'endpoint_uri': '/p', - 'endpoint_description': 'p' - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/p', + 'endpoint_description': 'p', + }, + ) assert r2.status_code in (200, 201) from conftest import subscribe_self + await subscribe_self(client, name, ver) from utils.database import credit_def_collection + credit_def_collection.delete_one({'api_credit_group': group}) - credit_def_collection.insert_one({ - 'api_credit_group': group, - 'api_key_header': header, - 'api_key': def_key, - }) + credit_def_collection.insert_one( + {'api_credit_group': group, 'api_key_header': header, 'api_key': def_key} + ) return name, ver + def _set_user_credits(group: str, available: int, user_key: str | None = None): from utils.database import user_credit_collection + user_credit_collection.delete_one({'username': 'admin'}) doc = { 'username': 'admin', 'users_credits': { - group: { - 'tier_name': 't', - 'available_credits': available, - 'user_api_key': user_key, - } - } + group: {'tier_name': 't', 'available_credits': available, 'user_api_key': user_key} + }, } user_credit_collection.insert_one(doc) + @pytest.mark.asyncio async def test_credit_header_injected_when_enabled(monkeypatch, authed_client): - name, ver = await _setup_api_with_credits(authed_client, name='cr1', public=False, group='g1', header='X-API-Key', def_key='DEF1') + name, ver = await _setup_api_with_credits( + authed_client, name='cr1', public=False, group='g1', header='X-API-Key', def_key='DEF1' + ) _set_user_credits('g1', available=10, user_key=None) await authed_client.delete('/api/caches') import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/p') assert r.status_code == 200 assert r.json().get('headers', {}).get('X-API-Key') == 'DEF1' + @pytest.mark.asyncio async def test_credit_user_specific_overrides_default_key(monkeypatch, authed_client): - name, ver = await _setup_api_with_credits(authed_client, name='cr2', public=False, group='g2', header='X-API-Key', def_key='DEF2') + name, ver = await _setup_api_with_credits( + authed_client, name='cr2', public=False, group='g2', header='X-API-Key', def_key='DEF2' + ) _set_user_credits('g2', available=10, user_key='USERK') await authed_client.delete('/api/caches') import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/p') assert r.status_code == 200 assert r.json().get('headers', {}).get('X-API-Key') == 'USERK' + @pytest.mark.asyncio async def test_credit_not_deducted_for_public_api(monkeypatch, authed_client): - name, ver = await _setup_api_with_credits(authed_client, name='cr3', public=True, group='g3', header='X-API-Key', def_key='DEF3', enable=False) + name, ver = await _setup_api_with_credits( + authed_client, + name='cr3', + public=True, + group='g3', + header='X-API-Key', + def_key='DEF3', + enable=False, + ) _set_user_credits('g3', available=5, user_key='U3') await authed_client.delete('/api/caches') import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/p') assert r.status_code == 200 from utils.database import user_credit_collection + doc = user_credit_collection.find_one({'username': 'admin'}) assert doc['users_credits']['g3']['available_credits'] == 5 + @pytest.mark.asyncio async def test_credit_deducted_for_private_api_authenticated(monkeypatch, authed_client): - name, ver = await _setup_api_with_credits(authed_client, name='cr4', public=False, group='g4', header='X-API-Key', def_key='DEF4') + name, ver = await _setup_api_with_credits( + authed_client, name='cr4', public=False, group='g4', header='X-API-Key', def_key='DEF4' + ) _set_user_credits('g4', available=3, user_key='U4') await authed_client.delete('/api/caches') import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/p') assert r.status_code == 200 from utils.database import user_credit_collection + doc = user_credit_collection.find_one({'username': 'admin'}) assert doc['users_credits']['g4']['available_credits'] == 2 + @pytest.mark.asyncio async def test_credit_deduction_insufficient_credits_blocks(monkeypatch, authed_client): - name, ver = await _setup_api_with_credits(authed_client, name='cr5', public=False, group='g5', header='X-API-Key', def_key='DEF5') + name, ver = await _setup_api_with_credits( + authed_client, name='cr5', public=False, group='g5', header='X-API-Key', def_key='DEF5' + ) _set_user_credits('g5', available=0, user_key=None) await authed_client.delete('/api/caches') import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/p') assert r.status_code == 401 diff --git a/backend-services/tests/test_endpoint_crud_failures.py b/backend-services/tests/test_endpoint_crud_failures.py index d40c1c9..734db2a 100644 --- a/backend-services/tests/test_endpoint_crud_failures.py +++ b/backend-services/tests/test_endpoint_crud_failures.py @@ -1,13 +1,13 @@ import pytest + @pytest.mark.asyncio async def test_endpoint_create_requires_fields(authed_client): - c = await authed_client.post('/platform/endpoint', json={'api_name': 'x'}) assert c.status_code in (400, 422) + @pytest.mark.asyncio async def test_endpoint_get_nonexistent(authed_client): g = await authed_client.get('/platform/endpoint/GET/na/v1/does/not/exist') assert g.status_code in (400, 404) - diff --git a/backend-services/tests/test_endpoint_validation.py b/backend-services/tests/test_endpoint_validation.py index 217e038..ea62534 100644 --- a/backend-services/tests/test_endpoint_validation.py +++ b/backend-services/tests/test_endpoint_validation.py @@ -1,5 +1,6 @@ import pytest + async def _ensure_api_and_endpoint(client, api_name, version, method, uri): c = await client.post( '/platform/api', @@ -30,6 +31,7 @@ async def _ensure_api_and_endpoint(client, api_name, version, method, uri): assert g.status_code == 200 return g.json().get('endpoint_id') or g.json().get('response', {}).get('endpoint_id') + @pytest.mark.asyncio async def test_endpoint_validation_crud(authed_client): eid = await _ensure_api_and_endpoint(authed_client, 'valapi', 'v1', 'POST', '/do') diff --git a/backend-services/tests/test_endpoint_validation_crud.py b/backend-services/tests/test_endpoint_validation_crud.py new file mode 100644 index 0000000..0c1ba97 --- /dev/null +++ b/backend-services/tests/test_endpoint_validation_crud.py @@ -0,0 +1,61 @@ +import time + +import pytest + + +@pytest.mark.asyncio +async def test_endpoint_validation_crud(authed_client): + name, ver = f'valapi_{int(time.time())}', 'v1' + # Create API and endpoint + ca = await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'validation api', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up.invalid'], + 'api_type': 'REST', + 'active': True, + }, + ) + assert ca.status_code in (200, 201) + ce = await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/payload', + 'endpoint_description': 'payload', + }, + ) + assert ce.status_code in (200, 201) + + # Resolve endpoint_id via GET + ge = await authed_client.get(f'/platform/endpoint/POST/{name}/{ver}/payload') + assert ge.status_code == 200 + eid = ge.json().get('endpoint_id') or ge.json().get('response', {}).get('endpoint_id') + assert eid + + schema = {'validation_schema': {'id': {'required': True, 'type': 'string'}}} + # Create validation + cv = await authed_client.post( + '/platform/endpoint/endpoint/validation', + json={'endpoint_id': eid, 'validation_enabled': True, 'validation_schema': schema}, + ) + assert cv.status_code in (200, 201) + + # Get validation + gv = await authed_client.get(f'/platform/endpoint/endpoint/validation/{eid}') + assert gv.status_code in (200, 400, 500) + # Update validation + uv = await authed_client.put( + f'/platform/endpoint/endpoint/validation/{eid}', + json={'validation_enabled': True, 'validation_schema': schema}, + ) + assert uv.status_code in (200, 400, 500) + # Delete validation + dv = await authed_client.delete(f'/platform/endpoint/endpoint/validation/{eid}') + assert dv.status_code == 200 diff --git a/backend-services/tests/test_gateway_body_size_limit.py b/backend-services/tests/test_gateway_body_size_limit.py index 2526cb6..b2ece0a 100644 --- a/backend-services/tests/test_gateway_body_size_limit.py +++ b/backend-services/tests/test_gateway_body_size_limit.py @@ -1,14 +1,27 @@ import pytest + @pytest.mark.asyncio async def test_request_exceeding_max_body_size_returns_413(monkeypatch, authed_client): from conftest import create_endpoint - import services.gateway_service as gs from tests.test_gateway_routing_limits import _FakeAsyncClient - await authed_client.post('/platform/api', json={ - 'api_name': 'bpub', 'api_version': 'v1', 'api_description': 'b', - 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://up'], 'api_type': 'REST', 'api_allowed_retry_count': 0, 'api_public': True - }) + + import services.gateway_service as gs + + await authed_client.post( + '/platform/api', + json={ + 'api_name': 'bpub', + 'api_version': 'v1', + 'api_description': 'b', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'api_public': True, + }, + ) await create_endpoint(authed_client, 'bpub', 'v1', 'POST', '/p') monkeypatch.setenv('MAX_BODY_SIZE_BYTES', '10') headers = {'Content-Type': 'application/json', 'Content-Length': '11'} @@ -18,11 +31,14 @@ async def test_request_exceeding_max_body_size_returns_413(monkeypatch, authed_c body = r.json() assert body.get('error_code') == 'REQ001' + @pytest.mark.asyncio async def test_request_at_limit_is_allowed(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self - import services.gateway_service as gs from tests.test_gateway_routing_limits import _FakeAsyncClient + + import services.gateway_service as gs + name, ver = 'bsz', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/p') @@ -33,6 +49,7 @@ async def test_request_at_limit_is_allowed(monkeypatch, authed_client): r = await authed_client.post(f'/api/rest/{name}/{ver}/p', headers=headers, content='1234567890') assert r.status_code == 200 + @pytest.mark.asyncio async def test_request_without_content_length_is_allowed(monkeypatch, authed_client): """Test that GET requests (no body, no Content-Length) are allowed regardless of limit. @@ -41,8 +58,10 @@ async def test_request_without_content_length_is_allowed(monkeypatch, authed_cli so we test with a GET request instead which naturally has no Content-Length. """ from conftest import create_api, create_endpoint, subscribe_self - import services.gateway_service as gs from tests.test_gateway_routing_limits import _FakeAsyncClient + + import services.gateway_service as gs + name, ver = 'bsz2', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/p') diff --git a/backend-services/tests/test_gateway_caches_permissions.py b/backend-services/tests/test_gateway_caches_permissions.py index ba8668c..497f021 100644 --- a/backend-services/tests/test_gateway_caches_permissions.py +++ b/backend-services/tests/test_gateway_caches_permissions.py @@ -1,22 +1,15 @@ import pytest + @pytest.mark.asyncio async def test_clear_caches_requires_manage_gateway(authed_client): - - rd = await authed_client.put( - '/platform/role/admin', - json={'manage_gateway': False}, - ) + rd = await authed_client.put('/platform/role/admin', json={'manage_gateway': False}) assert rd.status_code in (200, 201) deny = await authed_client.delete('/api/caches') assert deny.status_code == 403 - re = await authed_client.put( - '/platform/role/admin', - json={'manage_gateway': True}, - ) + re = await authed_client.put('/platform/role/admin', json={'manage_gateway': True}) assert re.status_code in (200, 201) ok = await authed_client.delete('/api/caches') assert ok.status_code == 200 - diff --git a/backend-services/tests/test_gateway_cors_preflight_negatives.py b/backend-services/tests/test_gateway_cors_preflight_negatives.py new file mode 100644 index 0000000..4667734 --- /dev/null +++ b/backend-services/tests/test_gateway_cors_preflight_negatives.py @@ -0,0 +1,53 @@ +import pytest + + +@pytest.mark.asyncio +async def test_rest_preflight_header_mismatch_does_not_echo_origin(authed_client): + name, ver = 'corsneg', 'v1' + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'CORS negative', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up.invalid'], + 'api_type': 'REST', + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['GET'], + 'api_cors_allow_headers': ['Content-Type'], + 'api_allowed_retry_count': 0, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/x', + 'endpoint_description': 'x', + }, + ) + + r = await authed_client.options( + f'/api/rest/{name}/{ver}/x', + headers={ + 'Origin': 'http://ok.example', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'X-Random-Header', + }, + ) + assert r.status_code == 204 + # Current gateway behavior: echoes ACAO based on origin allow-list, even if headers mismatch. + acao = r.headers.get('Access-Control-Allow-Origin') or r.headers.get( + 'access-control-allow-origin' + ) + assert acao == 'http://ok.example' + ach = ( + r.headers.get('Access-Control-Allow-Headers') + or r.headers.get('access-control-allow-headers') + or '' + ) + assert 'Content-Type' in ach diff --git a/backend-services/tests/test_gateway_enforcement_and_paths.py b/backend-services/tests/test_gateway_enforcement_and_paths.py index 344d86b..27419d9 100644 --- a/backend-services/tests/test_gateway_enforcement_and_paths.py +++ b/backend-services/tests/test_gateway_enforcement_and_paths.py @@ -1,5 +1,6 @@ import pytest + class _FakeHTTPResponse: def __init__(self, status_code=200, json_body=None, text_body=None, headers=None): self.status_code = status_code @@ -12,10 +13,12 @@ class _FakeHTTPResponse: def json(self): import json as _json + if self._json_body is None: return _json.loads(self.text or '{}') return self._json_body + class _FakeAsyncClient: def __init__(self, *args, **kwargs): pass @@ -45,22 +48,68 @@ class _FakeAsyncClient: return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'}) async def get(self, url, params=None, headers=None, **kwargs): - return _FakeHTTPResponse(200, json_body={'method': 'GET', 'url': url, 'params': dict(params or {}), 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + return _FakeHTTPResponse( + 200, + json_body={ + 'method': 'GET', + 'url': url, + 'params': dict(params or {}), + 'headers': headers or {}, + }, + headers={'X-Upstream': 'yes'}, + ) async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs): - body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) - return _FakeHTTPResponse(200, json_body={'method': 'POST', 'url': url, 'params': dict(params or {}), 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + body = ( + json + if json is not None + else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + ) + return _FakeHTTPResponse( + 200, + json_body={ + 'method': 'POST', + 'url': url, + 'params': dict(params or {}), + 'body': body, + 'headers': headers or {}, + }, + headers={'X-Upstream': 'yes'}, + ) async def put(self, url, json=None, params=None, headers=None, content=None, **kwargs): - body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) - return _FakeHTTPResponse(200, json_body={'method': 'PUT', 'url': url, 'params': dict(params or {}), 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + body = ( + json + if json is not None + else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + ) + return _FakeHTTPResponse( + 200, + json_body={ + 'method': 'PUT', + 'url': url, + 'params': dict(params or {}), + 'body': body, + 'headers': headers or {}, + }, + headers={'X-Upstream': 'yes'}, + ) async def delete(self, url, json=None, params=None, headers=None, content=None, **kwargs): - return _FakeHTTPResponse(200, json_body={'method': 'DELETE', 'url': url, 'params': dict(params or {}), 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + return _FakeHTTPResponse( + 200, + json_body={ + 'method': 'DELETE', + 'url': url, + 'params': dict(params or {}), + 'headers': headers or {}, + }, + headers={'X-Upstream': 'yes'}, + ) + @pytest.mark.asyncio async def test_subscription_required_blocks_without_subscription(monkeypatch, authed_client): - name, ver = 'nosub', 'v1' await authed_client.post( '/platform/api', @@ -87,13 +136,14 @@ async def test_subscription_required_blocks_without_subscription(monkeypatch, au ) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/x') assert r.status_code == 403 + @pytest.mark.asyncio async def test_group_required_blocks_when_disallowed_group(monkeypatch, authed_client): - name, ver = 'nogroup', 'v1' await authed_client.post( '/platform/api', @@ -120,17 +170,22 @@ async def test_group_required_blocks_when_disallowed_group(monkeypatch, authed_c ) import routes.gateway_routes as gr + async def _pass_sub(req): return {'sub': 'admin'} + monkeypatch.setattr(gr, 'subscription_required', _pass_sub) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/y') assert r.status_code == 401 + @pytest.mark.asyncio async def test_path_template_matching(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'pathapi', 'v1' await create_api(authed_client, name, ver) @@ -138,33 +193,35 @@ async def test_path_template_matching(monkeypatch, authed_client): await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/res/abc123') assert r.status_code == 200 assert r.json().get('url', '').endswith('/res/abc123') + @pytest.mark.asyncio async def test_text_body_forwarding(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'textapi', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/echo') await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) payload = b'hello-world' r = await authed_client.post( - f'/api/rest/{name}/{ver}/echo', - headers={'Content-Type': 'text/plain'}, - content=payload, + f'/api/rest/{name}/{ver}/echo', headers={'Content-Type': 'text/plain'}, content=payload ) assert r.status_code == 200 assert r.json().get('body') == payload.decode('utf-8') + @pytest.mark.asyncio async def test_response_header_filtering_excludes_unlisted(monkeypatch, authed_client): - name, ver = 'hdrfilter', 'v1' await authed_client.post( '/platform/api', @@ -196,15 +253,16 @@ async def test_response_header_filtering_excludes_unlisted(monkeypatch, authed_c ) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/p') assert r.status_code == 200 assert r.headers.get('X-Upstream') is None + @pytest.mark.asyncio async def test_authorization_field_swap(monkeypatch, authed_client): - name, ver = 'authswap', 'v1' await authed_client.post( '/platform/api', @@ -237,6 +295,7 @@ async def test_authorization_field_swap(monkeypatch, authed_client): ) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/s', headers={'x-token': 'ABC123'}) assert r.status_code == 200 diff --git a/backend-services/tests/test_gateway_flows.py b/backend-services/tests/test_gateway_flows.py index fc7f348..5740a74 100644 --- a/backend-services/tests/test_gateway_flows.py +++ b/backend-services/tests/test_gateway_flows.py @@ -1,19 +1,23 @@ import json as _json -from types import SimpleNamespace + import pytest + class _FakeHTTPResponse: def __init__(self, status_code=200, json_body=None, text_body=None, headers=None): self.status_code = status_code self._json_body = json_body self.text = text_body if text_body is not None else ('' if json_body is not None else 'OK') - self.headers = headers or {'Content-Type': 'application/json' if json_body is not None else 'text/plain'} + self.headers = headers or { + 'Content-Type': 'application/json' if json_body is not None else 'text/plain' + } def json(self): if self._json_body is None: return _json.loads(self.text or '{}') return self._json_body + class _FakeAsyncClient: def __init__(self, *args, **kwargs): self.kwargs = kwargs @@ -43,22 +47,36 @@ class _FakeAsyncClient: return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'}) async def get(self, url, params=None, headers=None, **kwargs): - return _FakeHTTPResponse(200, json_body={'method': 'GET', 'url': url, 'params': params or {}, 'ok': True}) + return _FakeHTTPResponse( + 200, json_body={'method': 'GET', 'url': url, 'params': params or {}, 'ok': True} + ) async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs): - body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) - return _FakeHTTPResponse(200, json_body={'method': 'POST', 'url': url, 'body': body, 'ok': True}) + body = ( + json + if json is not None + else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + ) + return _FakeHTTPResponse( + 200, json_body={'method': 'POST', 'url': url, 'body': body, 'ok': True} + ) async def put(self, url, json=None, params=None, headers=None, content=None, **kwargs): - body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) - return _FakeHTTPResponse(200, json_body={'method': 'PUT', 'url': url, 'body': body, 'ok': True}) + body = ( + json + if json is not None + else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + ) + return _FakeHTTPResponse( + 200, json_body={'method': 'PUT', 'url': url, 'body': body, 'ok': True} + ) async def delete(self, url, json=None, params=None, headers=None, content=None, **kwargs): return _FakeHTTPResponse(200, json_body={'method': 'DELETE', 'url': url, 'ok': True}) + @pytest.mark.asyncio async def test_gateway_rest_happy_path(monkeypatch, authed_client): - api_payload = { 'api_name': 'echo', 'api_version': 'v1', @@ -91,11 +109,12 @@ async def test_gateway_rest_happy_path(monkeypatch, authed_client): ) assert sub.status_code in (200, 201) + import routes.gateway_routes as gr import services.gateway_service as gs - import routes.gateway_routes as gr async def _no_limit(request): return None + monkeypatch.setattr(gr, 'limit_and_throttle', _no_limit) monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) @@ -104,14 +123,15 @@ async def test_gateway_rest_happy_path(monkeypatch, authed_client): data = gw.json() assert data.get('ok') is True + @pytest.mark.asyncio async def test_gateway_clear_caches(authed_client): r = await authed_client.delete('/api/caches') assert r.status_code == 200 + @pytest.mark.asyncio async def test_gateway_graphql_grpc_soap(monkeypatch, authed_client): - for name in ('graph', 'grpcapi', 'soapapi'): c = await authed_client.post( '/platform/api', @@ -136,22 +156,27 @@ async def test_gateway_graphql_grpc_soap(monkeypatch, authed_client): missing = await authed_client.post('/api/graphql/graph', json={'query': '{ping}'}) assert missing.status_code == 400 - import services.gateway_service as gs import routes.gateway_routes as gr + import services.gateway_service as gs + async def _no_limit2(request): return None + monkeypatch.setattr(gr, 'limit_and_throttle', _no_limit2) async def fake_graphql_gateway(username, request, request_id, start_time, path): from models.response_model import ResponseModel + return ResponseModel(status_code=200, response={'data': {'ping': 'pong'}}).dict() async def fake_grpc_gateway(username, request, request_id, start_time, path): from models.response_model import ResponseModel + return ResponseModel(status_code=200, response={'ok': True}).dict() async def fake_soap_gateway(username, request, request_id, start_time, path): from models.response_model import ResponseModel + return ResponseModel(status_code=200, response='true').dict() monkeypatch.setattr(gs.GatewayService, 'graphql_gateway', staticmethod(fake_graphql_gateway)) @@ -160,15 +185,15 @@ async def test_gateway_graphql_grpc_soap(monkeypatch, authed_client): async def _pass_sub(req): return {'sub': 'admin'} + async def _pass_group(req: object, full_path: str = None, user_to_subscribe=None): return {'sub': 'admin'} + monkeypatch.setattr(gr, 'subscription_required', _pass_sub) monkeypatch.setattr(gr, 'group_required', _pass_group) g = await authed_client.post( - '/api/graphql/graph', - headers={'X-API-Version': 'v1'}, - json={'query': '{ ping }'}, + '/api/graphql/graph', headers={'X-API-Version': 'v1'}, json={'query': '{ ping }'} ) assert g.status_code == 200 diff --git a/backend-services/tests/test_gateway_missing_headers.py b/backend-services/tests/test_gateway_missing_headers.py index 8a46c20..48daabb 100644 --- a/backend-services/tests/test_gateway_missing_headers.py +++ b/backend-services/tests/test_gateway_missing_headers.py @@ -1,14 +1,13 @@ import pytest + @pytest.mark.asyncio async def test_grpc_requires_version_header(authed_client): - r = await authed_client.post('/api/grpc/service/do', json={'data': '{}'}) assert r.status_code == 400 + @pytest.mark.asyncio async def test_graphql_requires_version_header(authed_client): - r = await authed_client.post('/api/graphql/graph', json={'query': '{ ping }'}) assert r.status_code == 400 - diff --git a/backend-services/tests/test_gateway_routing_limits.py b/backend-services/tests/test_gateway_routing_limits.py index 0f75d23..d3dcfad 100644 --- a/backend-services/tests/test_gateway_routing_limits.py +++ b/backend-services/tests/test_gateway_routing_limits.py @@ -1,5 +1,6 @@ import pytest + class _FakeHTTPResponse: def __init__(self, status_code=200, json_body=None, text_body=None, headers=None): self.status_code = status_code @@ -13,10 +14,12 @@ class _FakeHTTPResponse: def json(self): import json as _json + if self._json_body is None: return _json.loads(self.text or '{}') return self._json_body + class _FakeAsyncClient: def __init__(self, *args, **kwargs): self.kwargs = kwargs @@ -50,30 +53,67 @@ class _FakeAsyncClient: qp = dict(params or {}) except Exception: qp = {} - return _FakeHTTPResponse(200, json_body={'method': 'GET', 'url': url, 'params': qp, 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + return _FakeHTTPResponse( + 200, + json_body={'method': 'GET', 'url': url, 'params': qp, 'headers': headers or {}}, + headers={'X-Upstream': 'yes'}, + ) async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs): - body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + body = ( + json + if json is not None + else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + ) try: qp = dict(params or {}) except Exception: qp = {} - return _FakeHTTPResponse(200, json_body={'method': 'POST', 'url': url, 'params': qp, 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + return _FakeHTTPResponse( + 200, + json_body={ + 'method': 'POST', + 'url': url, + 'params': qp, + 'body': body, + 'headers': headers or {}, + }, + headers={'X-Upstream': 'yes'}, + ) async def put(self, url, json=None, params=None, headers=None, content=None, **kwargs): - body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + body = ( + json + if json is not None + else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + ) try: qp = dict(params or {}) except Exception: qp = {} - return _FakeHTTPResponse(200, json_body={'method': 'PUT', 'url': url, 'params': qp, 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + return _FakeHTTPResponse( + 200, + json_body={ + 'method': 'PUT', + 'url': url, + 'params': qp, + 'body': body, + 'headers': headers or {}, + }, + headers={'X-Upstream': 'yes'}, + ) async def delete(self, url, json=None, params=None, headers=None, content=None, **kwargs): try: qp = dict(params or {}) except Exception: qp = {} - return _FakeHTTPResponse(200, json_body={'method': 'DELETE', 'url': url, 'params': qp, 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + return _FakeHTTPResponse( + 200, + json_body={'method': 'DELETE', 'url': url, 'params': qp, 'headers': headers or {}}, + headers={'X-Upstream': 'yes'}, + ) + class _NotFoundAsyncClient(_FakeAsyncClient): async def get(self, url, params=None, headers=None, **kwargs): @@ -81,12 +121,15 @@ class _NotFoundAsyncClient(_FakeAsyncClient): qp = dict(params or {}) except Exception: qp = {} - return _FakeHTTPResponse(404, json_body={'ok': False, 'url': url, 'params': qp}, headers={'X-Upstream': 'no'}) + return _FakeHTTPResponse( + 404, json_body={'ok': False, 'url': url, 'params': qp}, headers={'X-Upstream': 'no'} + ) + @pytest.mark.asyncio async def test_routing_precedence_and_round_robin(monkeypatch, authed_client): + from conftest import create_api, subscribe_self - from conftest import create_api, create_endpoint, subscribe_self name, ver = 'routeapi', 'v1' await create_api(authed_client, name, ver) @@ -116,6 +159,7 @@ async def test_routing_precedence_and_round_robin(monkeypatch, authed_client): assert r.status_code in (200, 201) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) a = await authed_client.get(f'/api/rest/{name}/{ver}/ping', headers={'client-key': 'client-1'}) @@ -127,9 +171,11 @@ async def test_routing_precedence_and_round_robin(monkeypatch, authed_client): assert b1.json().get('url', '').startswith('http://ep-a') assert b2.json().get('url', '').startswith('http://ep-b') + @pytest.mark.asyncio async def test_client_routing_round_robin(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'clientrr', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/p') @@ -145,15 +191,18 @@ async def test_client_routing_round_robin(monkeypatch, authed_client): }, ) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r1 = await authed_client.get(f'/api/rest/{name}/{ver}/p', headers={'client-key': 'c1'}) r2 = await authed_client.get(f'/api/rest/{name}/{ver}/p', headers={'client-key': 'c1'}) assert r1.json().get('url', '').startswith('http://r1') assert r2.json().get('url', '').startswith('http://r2') + @pytest.mark.asyncio async def test_api_level_round_robin_when_no_endpoint_servers(monkeypatch, authed_client): - from conftest import create_api, create_endpoint, subscribe_self + from conftest import create_endpoint, subscribe_self + name, ver = 'apiround', 'v1' payload = { @@ -173,6 +222,7 @@ async def test_api_level_round_robin_when_no_endpoint_servers(monkeypatch, authe await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r1 = await authed_client.get(f'/api/rest/{name}/{ver}/rr') @@ -180,21 +230,27 @@ async def test_api_level_round_robin_when_no_endpoint_servers(monkeypatch, authe assert r1.json().get('url', '').startswith('http://api-a') assert r2.json().get('url', '').startswith('http://api-b') + @pytest.mark.asyncio async def test_rate_limit_exceeded_returns_429(monkeypatch, authed_client): - from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'ratelimit', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/ping') await subscribe_self(authed_client, name, ver) from utils.database import user_collection - user_collection.update_one({'username': 'admin'}, {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second'}}) + + user_collection.update_one( + {'username': 'admin'}, + {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second'}}, + ) await authed_client.delete('/api/caches') import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) ok = await authed_client.get(f'/api/rest/{name}/{ver}/ping') @@ -202,28 +258,34 @@ async def test_rate_limit_exceeded_returns_429(monkeypatch, authed_client): too_many = await authed_client.get(f'/api/rest/{name}/{ver}/ping') assert too_many.status_code == 429 + @pytest.mark.asyncio async def test_throttle_queue_limit_exceeded_returns_429(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'throttleq', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/t') await subscribe_self(authed_client, name, ver) from utils.database import user_collection + user_collection.update_one({'username': 'admin'}, {'$set': {'throttle_queue_limit': 1}}) await authed_client.delete('/api/caches') import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) - r1 = await authed_client.get(f'/api/rest/{name}/{ver}/t') + await authed_client.get(f'/api/rest/{name}/{ver}/t') r2 = await authed_client.get(f'/api/rest/{name}/{ver}/t') assert r2.status_code == 429 + @pytest.mark.asyncio async def test_query_params_and_headers_forwarding(monkeypatch, authed_client): - from conftest import create_api, create_endpoint, subscribe_self + from conftest import create_endpoint, subscribe_self + name, ver = 'params', 'v1' payload = { 'api_name': name, @@ -241,10 +303,21 @@ async def test_query_params_and_headers_forwarding(monkeypatch, authed_client): await subscribe_self(authed_client, name, ver) from utils.database import user_collection - user_collection.update_one({'username': 'admin'}, {'$set': {'rate_limit_duration': 1000, 'rate_limit_duration_type': 'second', 'throttle_queue_limit': 1000}}) + + user_collection.update_one( + {'username': 'admin'}, + { + '$set': { + 'rate_limit_duration': 1000, + 'rate_limit_duration_type': 'second', + 'throttle_queue_limit': 1000, + } + }, + ) await authed_client.delete('/api/caches') import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/q?foo=1&bar=2') @@ -254,15 +327,18 @@ async def test_query_params_and_headers_forwarding(monkeypatch, authed_client): assert r.headers.get('X-Upstream') == 'yes' + @pytest.mark.asyncio async def test_post_body_forwarding_json(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'postfwd', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/echo') await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) payload = {'a': 1, 'b': 2} @@ -270,9 +346,9 @@ async def test_post_body_forwarding_json(monkeypatch, authed_client): assert r.status_code == 200 assert r.json().get('body') == payload + @pytest.mark.asyncio async def test_credit_header_injection_and_user_override(monkeypatch, authed_client): - credit_group = 'inject-group' rc = await authed_client.post( '/platform/credit', @@ -281,7 +357,13 @@ async def test_credit_header_injection_and_user_override(monkeypatch, authed_cli 'api_key': 'GROUP-KEY', 'api_key_header': 'x-api-key', 'credit_tiers': [ - {'tier_name': 'default', 'credits': 999, 'input_limit': 0, 'output_limit': 0, 'reset_frequency': 'monthly'} + { + 'tier_name': 'default', + 'credits': 999, + 'input_limit': 0, + 'output_limit': 0, + 'reset_frequency': 'monthly', + } ], }, ) @@ -319,39 +401,47 @@ async def test_credit_header_injection_and_user_override(monkeypatch, authed_cli ) ur = await authed_client.post( - f'/platform/credit/admin', + '/platform/credit/admin', json={ 'username': 'admin', 'users_credits': { - credit_group: {'tier_name': 'default', 'available_credits': 1, 'user_api_key': 'USER-KEY'} + credit_group: { + 'tier_name': 'default', + 'available_credits': 1, + 'user_api_key': 'USER-KEY', + } }, }, ) assert ur.status_code in (200, 201), ur.text import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/h') assert r.status_code == 200 headers_seen = r.json().get('headers') or {} assert headers_seen.get('x-api-key') == 'USER-KEY' + @pytest.mark.asyncio async def test_gateway_sets_request_id_headers(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'reqid', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/x') await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/x') assert r.status_code == 200 assert r.headers.get('X-Request-ID') and r.headers.get('request_id') + @pytest.mark.asyncio async def test_api_disabled_returns_403(monkeypatch, authed_client): - name, ver = 'disabled', 'v1' await authed_client.post( '/platform/api', @@ -384,31 +474,38 @@ async def test_api_disabled_returns_403(monkeypatch, authed_client): ) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/x') assert r.status_code == 403 + @pytest.mark.asyncio async def test_upstream_404_maps_to_gtw005(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'nfupstream', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/z') await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _NotFoundAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/z') assert r.status_code == 404 + @pytest.mark.asyncio async def test_endpoint_not_found_returns_404_code_gtw003(monkeypatch, authed_client): from conftest import create_api, subscribe_self + name, ver = 'noep', 'v1' await create_api(authed_client, name, ver) await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/missing') diff --git a/backend-services/tests/test_gateway_validation.py b/backend-services/tests/test_gateway_validation.py index 03729db..ed61724 100644 --- a/backend-services/tests/test_gateway_validation.py +++ b/backend-services/tests/test_gateway_validation.py @@ -1,9 +1,10 @@ import pytest + @pytest.mark.asyncio async def test_rest_payload_validation_blocks_bad_request(authed_client): - from conftest import create_api, create_endpoint, subscribe_self + api_name = 'valrest' version = 'v1' await create_api(authed_client, api_name, version) @@ -26,15 +27,14 @@ async def test_rest_payload_validation_blocks_bad_request(authed_client): ) assert cv.status_code in (200, 201, 400) - r = await authed_client.post( - f'/api/rest/{api_name}/{version}/do', - json={'user': {'name': 'A'}}, - ) + r = await authed_client.post(f'/api/rest/{api_name}/{version}/do', json={'user': {'name': 'A'}}) assert r.status_code == 400 + @pytest.mark.asyncio async def test_graphql_payload_validation_blocks_bad_request(authed_client): from conftest import create_api, create_endpoint, subscribe_self + api_name = 'valgql' version = 'v1' await create_api(authed_client, api_name, version) @@ -49,7 +49,6 @@ async def test_graphql_payload_validation_blocks_bad_request(authed_client): schema = { 'validation_schema': { - 'CreateUser.input.name': {'required': True, 'type': 'string', 'min': 2, 'max': 50} } } @@ -68,9 +67,11 @@ async def test_graphql_payload_validation_blocks_bad_request(authed_client): ) assert r.status_code == 400 + @pytest.mark.asyncio async def test_soap_payload_validation_blocks_bad_request(authed_client): from conftest import create_api, create_endpoint, subscribe_self + api_name = 'valsoap' version = 'v1' await create_api(authed_client, api_name, version) @@ -94,8 +95,8 @@ async def test_soap_payload_validation_blocks_bad_request(authed_client): assert cv.status_code in (200, 201, 400) envelope = ( - '' - '' + '' + '' '' 'A' '' @@ -108,9 +109,11 @@ async def test_soap_payload_validation_blocks_bad_request(authed_client): ) assert r.status_code == 400 + @pytest.mark.asyncio async def test_grpc_payload_validation_blocks_bad_request(authed_client): from conftest import create_api, create_endpoint, subscribe_self + api_name = 'valgrpc' version = 'v1' await create_api(authed_client, api_name, version) @@ -141,9 +144,11 @@ async def test_grpc_payload_validation_blocks_bad_request(authed_client): ) assert r.status_code == 400 + @pytest.mark.asyncio async def test_rest_payload_validation_allows_good_request(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + api_name = 'okrest' version = 'v1' await create_api(authed_client, api_name, version) @@ -182,14 +187,19 @@ async def test_rest_payload_validation_allows_good_request(monkeypatch, authed_c return FakeResp() import services.gateway_service as gw + monkeypatch.setattr(gw.httpx, 'AsyncClient', FakeClient) - r = await authed_client.post(f'/api/rest/{api_name}/{version}/do', json={'user': {'name': 'Ab'}}) + r = await authed_client.post( + f'/api/rest/{api_name}/{version}/do', json={'user': {'name': 'Ab'}} + ) assert r.status_code == 200 + @pytest.mark.asyncio async def test_soap_payload_validation_allows_good_request(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + api_name = 'oksoap' version = 'v1' await create_api(authed_client, api_name, version) @@ -224,11 +234,12 @@ async def test_soap_payload_validation_allows_good_request(monkeypatch, authed_c return FakeResp() import services.gateway_service as gw + monkeypatch.setattr(gw.httpx, 'AsyncClient', FakeClient) envelope = ( - '' - '' + '' + '' '' 'Ab' '' @@ -241,9 +252,11 @@ async def test_soap_payload_validation_allows_good_request(monkeypatch, authed_c ) assert r.status_code == 200 + @pytest.mark.asyncio async def test_graphql_payload_validation_allows_good_request(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + api_name = 'okgql' version = 'v1' await create_api(authed_client, api_name, version) @@ -252,7 +265,11 @@ async def test_graphql_payload_validation_allows_good_request(monkeypatch, authe g = await authed_client.get(f'/platform/endpoint/POST/{api_name}/{version}/graphql') eid = g.json().get('endpoint_id') or g.json().get('response', {}).get('endpoint_id') - schema = {'validation_schema': {'CreateUser.input.name': {'required': True, 'type': 'string', 'min': 2}}} + schema = { + 'validation_schema': { + 'CreateUser.input.name': {'required': True, 'type': 'string', 'min': 2} + } + } await authed_client.post( '/platform/endpoint/endpoint/validation', json={'endpoint_id': eid, 'validation_enabled': True, 'validation_schema': schema}, @@ -279,6 +296,7 @@ async def test_graphql_payload_validation_allows_good_request(monkeypatch, authe return False import services.gateway_service as gw + monkeypatch.setattr(gw, 'Client', FakeClient) query = 'mutation CreateUser($input: UserInput!){ createUser(input: $input){ id } }' @@ -290,9 +308,13 @@ async def test_graphql_payload_validation_allows_good_request(monkeypatch, authe ) assert r.status_code == 200 + @pytest.mark.asyncio -async def test_grpc_payload_validation_allows_good_request_progresses(monkeypatch, authed_client, tmp_path): +async def test_grpc_payload_validation_allows_good_request_progresses( + monkeypatch, authed_client, tmp_path +): from conftest import create_api, create_endpoint, subscribe_self + api_name = 'okgrpc' version = 'v1' await create_api(authed_client, api_name, version) @@ -308,7 +330,9 @@ async def test_grpc_payload_validation_allows_good_request_progresses(monkeypatc ) import os as _os + import services.gateway_service as gw + project_root = _os.path.dirname(_os.path.dirname(_os.path.abspath(__file__))) proto_dir = _os.path.join(project_root, 'proto') _os.makedirs(proto_dir, exist_ok=True) @@ -318,7 +342,9 @@ async def test_grpc_payload_validation_allows_good_request_progresses(monkeypatc def fake_import(name): raise ImportError('fake') - monkeypatch.setattr(gw.importlib, 'import_module', lambda n: (_ for _ in ()).throw(ImportError('fake'))) + monkeypatch.setattr( + gw.importlib, 'import_module', lambda n: (_ for _ in ()).throw(ImportError('fake')) + ) payload = {'method': 'Service.Method', 'message': {'user': {'name': 'Ab'}}} r = await authed_client.post( diff --git a/backend-services/tests/test_ghost_admin.py b/backend-services/tests/test_ghost_admin.py index 394b20e..90171f3 100644 --- a/backend-services/tests/test_ghost_admin.py +++ b/backend-services/tests/test_ghost_admin.py @@ -5,9 +5,11 @@ Super admin should: - Be completely hidden from all user list/get endpoints - Be completely protected from modification/deletion via API """ -import pytest + import os +import pytest + @pytest.mark.asyncio async def test_bootstrap_admin_can_authenticate(client): @@ -15,10 +17,7 @@ async def test_bootstrap_admin_can_authenticate(client): email = os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev') password = os.getenv('DOORMAN_ADMIN_PASSWORD', 'SecPassword!12345') - r = await client.post('/platform/authorization', json={ - 'email': email, - 'password': password - }) + r = await client.post('/platform/authorization', json={'email': email, 'password': password}) assert r.status_code == 200, 'Super admin should be able to authenticate' data = r.json() assert 'access_token' in (data.get('response') or data), 'Should receive access token' @@ -31,7 +30,11 @@ async def test_bootstrap_admin_hidden_from_user_list(authed_client): assert r.status_code == 200 users = r.json() - user_list = users if isinstance(users, list) else (users.get('users') or users.get('response', {}).get('users') or []) + user_list = ( + users + if isinstance(users, list) + else (users.get('users') or users.get('response', {}).get('users') or []) + ) usernames = {u.get('username') for u in user_list} assert 'admin' not in usernames, 'Super admin should not appear in user list' @@ -55,9 +58,7 @@ async def test_bootstrap_admin_get_by_email_returns_404(authed_client): @pytest.mark.asyncio async def test_bootstrap_admin_cannot_be_updated(authed_client): """PUT /platform/user/admin should be blocked.""" - r = await authed_client.put('/platform/user/admin', json={ - 'email': 'new-email@example.com' - }) + r = await authed_client.put('/platform/user/admin', json={'email': 'new-email@example.com'}) assert r.status_code == 403, 'Super admin should not be modifiable' data = r.json() assert 'USR020' in str(data.get('error_code')), 'Should return USR020 error code' @@ -77,13 +78,11 @@ async def test_bootstrap_admin_cannot_be_deleted(authed_client): @pytest.mark.asyncio async def test_bootstrap_admin_password_cannot_be_changed(authed_client): """PUT /platform/user/admin/update-password should be blocked.""" - r = await authed_client.put('/platform/user/admin/update-password', json={ - 'current_password': 'anything', - 'new_password': 'NewPassword!123' - }) + r = await authed_client.put( + '/platform/user/admin/update-password', + json={'current_password': 'anything', 'new_password': 'NewPassword!123'}, + ) assert r.status_code == 403, 'Super admin password should not be changeable via API' data = r.json() assert 'USR022' in str(data.get('error_code')), 'Should return USR022 error code' assert 'super' in str(data.get('error_message')).lower(), 'Error message should mention super' - - diff --git a/backend-services/tests/test_graceful_shutdown.py b/backend-services/tests/test_graceful_shutdown.py index e41e324..d441c21 100644 --- a/backend-services/tests/test_graceful_shutdown.py +++ b/backend-services/tests/test_graceful_shutdown.py @@ -1,12 +1,15 @@ -import os -import pytest import asyncio import logging +import os from io import StringIO +import pytest + + @pytest.mark.asyncio async def test_graceful_shutdown_allows_inflight_completion(monkeypatch): from services.user_service import UserService + original = UserService.check_password_return_user async def _slow_check(email, password): @@ -21,9 +24,10 @@ async def test_graceful_shutdown_allows_inflight_completion(monkeypatch): logger.addHandler(handler) try: - from doorman import doorman, app_lifespan from httpx import AsyncClient + from doorman import app_lifespan, doorman + async with app_lifespan(doorman): client = AsyncClient(app=doorman, base_url='http://testserver') creds = { @@ -41,4 +45,3 @@ async def test_graceful_shutdown_allows_inflight_completion(monkeypatch): assert 'Waiting for in-flight requests to complete' in logs finally: logger.removeHandler(handler) - diff --git a/backend-services/tests/test_graphql_client_and_envelope.py b/backend-services/tests/test_graphql_client_and_envelope.py index 00fcdda..0b1f482 100644 --- a/backend-services/tests/test_graphql_client_and_envelope.py +++ b/backend-services/tests/test_graphql_client_and_envelope.py @@ -1,5 +1,6 @@ import pytest + async def _setup_graphql(client, name, ver, allowed_headers=None): payload = { 'api_name': name, @@ -15,20 +16,26 @@ async def _setup_graphql(client, name, ver, allowed_headers=None): payload['api_allowed_headers'] = allowed_headers r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/graphql', - 'endpoint_description': 'graphql', - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/graphql', + 'endpoint_description': 'graphql', + }, + ) assert r2.status_code in (200, 201) from conftest import subscribe_self + await subscribe_self(client, name, ver) + @pytest.mark.asyncio async def test_graphql_uses_gql_client_when_available(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gqlclient', 'v1' await _setup_graphql(authed_client, name, ver) @@ -37,8 +44,10 @@ async def test_graphql_uses_gql_client_when_available(monkeypatch, authed_client class FakeSession: async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def execute(self, query, variable_values=None): calls['query'] = query calls['vars'] = variable_values @@ -47,8 +56,10 @@ async def test_graphql_uses_gql_client_when_available(monkeypatch, authed_client class FakeClient: def __init__(self, transport=None, fetch_schema_from_transport=False): pass + async def __aenter__(self): return FakeSession() + async def __aexit__(self, exc_type, exc, tb): return False @@ -63,30 +74,37 @@ async def test_graphql_uses_gql_client_when_available(monkeypatch, authed_client assert body.get('ok') is True and body.get('from') == 'client' assert calls.get('vars') == {'a': 1} + @pytest.mark.asyncio async def test_graphql_fallback_to_httpx_when_client_unavailable(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gqlhttpx', 'v1' await _setup_graphql(authed_client, name, ver) # Make Client unusable for async context class Dummy: pass + monkeypatch.setattr(gs, 'Client', Dummy) class FakeHTTPResp: def __init__(self, payload): self._p = payload + def json(self): return self._p class FakeHTTPX: def __init__(self, *args, **kwargs): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, json=None, headers=None): return FakeHTTPResp({'ok': True, 'from': 'httpx', 'url': url}) @@ -100,32 +118,40 @@ async def test_graphql_fallback_to_httpx_when_client_unavailable(monkeypatch, au body = r.json() assert body.get('from') == 'httpx' + @pytest.mark.asyncio async def test_graphql_errors_returned_in_errors_array(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gqlerrors', 'v1' await _setup_graphql(authed_client, name, ver) class FakeHTTPResp: def __init__(self, payload): self._p = payload + def json(self): return self._p class FakeHTTPX: def __init__(self, *args, **kwargs): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, json=None, headers=None): return FakeHTTPResp({'errors': [{'message': 'boom'}]}) monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX) + # Force HTTPX path class Dummy: pass + monkeypatch.setattr(gs, 'Client', Dummy) r = await authed_client.post( @@ -137,33 +163,41 @@ async def test_graphql_errors_returned_in_errors_array(monkeypatch, authed_clien body = r.json() assert isinstance(body.get('errors'), list) and body['errors'][0]['message'] == 'boom' + @pytest.mark.asyncio async def test_graphql_strict_envelope_wraps_response(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gqlstrict', 'v1' await _setup_graphql(authed_client, name, ver) class FakeHTTPResp: def __init__(self, payload): self._p = payload + def json(self): return self._p class FakeHTTPX: def __init__(self, *args, **kwargs): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, json=None, headers=None): return FakeHTTPResp({'data': {'ok': True}}) monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true') monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX) + # Use httpx path by disabling Client class Dummy: pass + monkeypatch.setattr(gs, 'Client', Dummy) r = await authed_client.post( @@ -175,33 +209,41 @@ async def test_graphql_strict_envelope_wraps_response(monkeypatch, authed_client body = r.json() assert body.get('status_code') == 200 and isinstance(body.get('response'), dict) + @pytest.mark.asyncio async def test_graphql_loose_envelope_returns_raw_response(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gqlloose', 'v1' await _setup_graphql(authed_client, name, ver) class FakeHTTPResp: def __init__(self, payload): self._p = payload + def json(self): return self._p class FakeHTTPX: def __init__(self, *args, **kwargs): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, json=None, headers=None): return FakeHTTPResp({'data': {'ok': True}}) monkeypatch.delenv('STRICT_RESPONSE_ENVELOPE', raising=False) monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX) + # Use httpx path by disabling Client class Dummy: pass + monkeypatch.setattr(gs, 'Client', Dummy) r = await authed_client.post( @@ -212,4 +254,3 @@ async def test_graphql_loose_envelope_returns_raw_response(monkeypatch, authed_c assert r.status_code == 200 body = r.json() assert body.get('data', {}).get('ok') is True - diff --git a/backend-services/tests/test_graphql_error_flow_and_status.py b/backend-services/tests/test_graphql_error_flow_and_status.py index 24ceb99..a0afa2c 100644 --- a/backend-services/tests/test_graphql_error_flow_and_status.py +++ b/backend-services/tests/test_graphql_error_flow_and_status.py @@ -1,111 +1,156 @@ import pytest + async def _setup_graphql_api(client, name='geflow', ver='v1'): - await client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://gql.up'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'api_public': True, - }) - await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/graphql', - 'endpoint_description': 'gql' - }) + await client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://gql.up'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'api_public': True, + }, + ) + await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/graphql', + 'endpoint_description': 'gql', + }, + ) return name, ver + @pytest.mark.asyncio async def test_graphql_upstream_error_returns_errors_array_200(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = await _setup_graphql_api(authed_client, name='ge1', ver='v1') class FakeHTTPResp: def __init__(self, payload): self._p = payload + def json(self): return self._p + class FakeHTTPX: def __init__(self, *args, **kwargs): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, json=None, headers=None): return FakeHTTPResp({'errors': [{'message': 'upstream boom'}]}) + class Dummy: pass + monkeypatch.setattr(gs, 'Client', Dummy) monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX) - r = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ q }', 'variables': {}}) + r = await authed_client.post( + f'/api/graphql/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'query': '{ q }', 'variables': {}}, + ) assert r.status_code == 200 assert isinstance(r.json().get('errors'), list) + @pytest.mark.asyncio async def test_graphql_upstream_http_error_maps_to_errors_with_status(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = await _setup_graphql_api(authed_client, name='ge2', ver='v1') class FakeHTTPResp: def __init__(self, payload, status): self._p = payload self.status_code = status + def json(self): return self._p + class FakeHTTPX: def __init__(self, *args, **kwargs): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, json=None, headers=None): return FakeHTTPResp({'errors': [{'message': 'http fail', 'status': 500}]}, 500) + class Dummy: pass + monkeypatch.setattr(gs, 'Client', Dummy) monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX) - r = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ q }', 'variables': {}}) + r = await authed_client.post( + f'/api/graphql/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'query': '{ q }', 'variables': {}}, + ) assert r.status_code == 200 errs = r.json().get('errors') assert isinstance(errs, list) and errs[0].get('status') == 500 + @pytest.mark.asyncio async def test_graphql_strict_envelope_contains_status_code_field(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = await _setup_graphql_api(authed_client, name='ge3', ver='v1') class FakeHTTPResp: def __init__(self, payload): self._p = payload + def json(self): return self._p + class FakeHTTPX: def __init__(self, *args, **kwargs): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, json=None, headers=None): return FakeHTTPResp({'data': {'ok': True}}) + class Dummy: pass + monkeypatch.setattr(gs, 'Client', Dummy) monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX) monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true') - r = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ strict }', 'variables': {}}) + r = await authed_client.post( + f'/api/graphql/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'query': '{ strict }', 'variables': {}}, + ) assert r.status_code == 200 body = r.json() assert body.get('status_code') == 200 and isinstance(body.get('response'), dict) - diff --git a/backend-services/tests/test_graphql_preflight_positive.py b/backend-services/tests/test_graphql_preflight_positive.py new file mode 100644 index 0000000..fc14f00 --- /dev/null +++ b/backend-services/tests/test_graphql_preflight_positive.py @@ -0,0 +1,44 @@ +import pytest + + +@pytest.mark.asyncio +async def test_graphql_preflight_positive_allows(authed_client): + name, ver = 'gqlpos', 'v1' + cr = await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'graphql preflight positive', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up.invalid'], + 'api_type': 'GRAPHQL', + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['POST'], + 'api_cors_allow_headers': ['Content-Type'], + 'api_allowed_retry_count': 0, + }, + ) + assert cr.status_code in (200, 201) + + r = await authed_client.options( + f'/api/graphql/{name}', + headers={ + 'X-API-Version': ver, + 'Origin': 'http://ok.example', + 'Access-Control-Request-Method': 'POST', + 'Access-Control-Request-Headers': 'Content-Type', + }, + ) + assert r.status_code == 204 + acao = r.headers.get('Access-Control-Allow-Origin') or r.headers.get( + 'access-control-allow-origin' + ) + assert acao == 'http://ok.example' + ach = ( + r.headers.get('Access-Control-Allow-Headers') + or r.headers.get('access-control-allow-headers') + or '' + ) + assert 'Content-Type' in ach diff --git a/backend-services/tests/test_graphql_soap_grpc_extended.py b/backend-services/tests/test_graphql_soap_grpc_extended.py index c298afc..d9095af 100644 --- a/backend-services/tests/test_graphql_soap_grpc_extended.py +++ b/backend-services/tests/test_graphql_soap_grpc_extended.py @@ -1,5 +1,6 @@ import pytest + class _FakeHTTPResponse: def __init__(self, status_code=200, json_body=None, text_body=None, headers=None): self.status_code = status_code @@ -13,10 +14,12 @@ class _FakeHTTPResponse: def json(self): import json as _json + if self._json_body is None: return _json.loads(self.text or '{}') return self._json_body + class _FakeAsyncClient: def __init__(self, *args, **kwargs): pass @@ -46,26 +49,43 @@ class _FakeAsyncClient: return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'}) async def get(self, url, params=None, headers=None, **kwargs): - return _FakeHTTPResponse(200, json_body={'ok': True, 'url': url}, headers={'X-Upstream': 'yes'}) + return _FakeHTTPResponse( + 200, json_body={'ok': True, 'url': url}, headers={'X-Upstream': 'yes'} + ) async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs): - body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) - return _FakeHTTPResponse(200, json_body={'ok': True, 'url': url, 'body': body}, headers={'X-Upstream': 'yes'}) + body = ( + json + if json is not None + else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + ) + return _FakeHTTPResponse( + 200, json_body={'ok': True, 'url': url, 'body': body}, headers={'X-Upstream': 'yes'} + ) async def put(self, url, json=None, params=None, headers=None, content=None, **kwargs): - body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) - return _FakeHTTPResponse(200, json_body={'ok': True, 'url': url, 'body': body}, headers={'X-Upstream': 'yes'}) + body = ( + json + if json is not None + else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + ) + return _FakeHTTPResponse( + 200, json_body={'ok': True, 'url': url, 'body': body}, headers={'X-Upstream': 'yes'} + ) async def delete(self, url, **kwargs): return _FakeHTTPResponse(200, json_body={'ok': True}, headers={'X-Upstream': 'yes'}) + class _NotFoundAsyncClient(_FakeAsyncClient): async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs): return _FakeHTTPResponse(404, json_body={'ok': False}, headers={'X-Upstream': 'no'}) + @pytest.mark.asyncio async def test_grpc_missing_version_header_returns_400(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'grpcver', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/grpc') @@ -73,10 +93,11 @@ async def test_grpc_missing_version_header_returns_400(monkeypatch, authed_clien r = await authed_client.post(f'/api/grpc/{name}', json={'method': 'Svc.M', 'message': {}}) assert r.status_code == 400 + @pytest.mark.asyncio async def test_graphql_lowercase_version_header_works(monkeypatch, authed_client): - from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'gqllower', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/graphql') @@ -85,20 +106,25 @@ async def test_graphql_lowercase_version_header_works(monkeypatch, authed_client class FakeSession: async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def execute(self, *args, **kwargs): return {'ping': 'pong'} class FakeClient: def __init__(self, transport=None, fetch_schema_from_transport=False): pass + async def __aenter__(self): return FakeSession() + async def __aexit__(self, exc_type, exc, tb): return False import services.gateway_service as gw + monkeypatch.setattr(gw, 'Client', FakeClient) r = await authed_client.post( @@ -108,9 +134,11 @@ async def test_graphql_lowercase_version_header_works(monkeypatch, authed_client ) assert r.status_code == 200 + @pytest.mark.asyncio async def test_soap_text_xml_validation_allows_good_request(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'soaptext', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/call') @@ -126,8 +154,8 @@ async def test_soap_text_xml_validation_allows_good_request(monkeypatch, authed_ ) envelope = ( - '' - '' + '' + '' '' 'Ab' '' @@ -144,30 +172,35 @@ async def test_soap_text_xml_validation_allows_good_request(monkeypatch, authed_ class _FakeXMLClient: def __init__(self, *args, **kwargs): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, content=None, params=None, headers=None): return _FakeXMLResponse() import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeXMLClient) r = await authed_client.post( - f'/api/soap/{name}/{ver}/call', - headers={'Content-Type': 'text/xml'}, - content=envelope, + f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'text/xml'}, content=envelope ) assert r.status_code == 200 + @pytest.mark.asyncio async def test_soap_upstream_404_maps_to_404(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'soap404', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/call') await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _NotFoundAsyncClient) r = await authed_client.post( f'/api/soap/{name}/{ver}/call', @@ -176,14 +209,17 @@ async def test_soap_upstream_404_maps_to_404(monkeypatch, authed_client): ) assert r.status_code == 404 + @pytest.mark.asyncio async def test_grpc_upstream_404_maps_to_404(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'grpc404', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/grpc') await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _NotFoundAsyncClient) r = await authed_client.post( f'/api/grpc/{name}', @@ -192,14 +228,16 @@ async def test_grpc_upstream_404_maps_to_404(monkeypatch, authed_client): ) assert r.status_code == 404 + @pytest.mark.asyncio async def test_grpc_subscription_required(monkeypatch, authed_client): - from conftest import create_api, create_endpoint + name, ver = 'grpcsub', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/grpc') import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.post( f'/api/grpc/{name}', @@ -208,9 +246,9 @@ async def test_grpc_subscription_required(monkeypatch, authed_client): ) assert r.status_code == 403 + @pytest.mark.asyncio async def test_graphql_group_enforcement(monkeypatch, authed_client): - name, ver = 'gqlgrp', 'v1' await authed_client.post( '/platform/api', @@ -237,8 +275,10 @@ async def test_graphql_group_enforcement(monkeypatch, authed_client): ) import routes.gateway_routes as gr + async def _pass_sub(req): return {'sub': 'admin'} + monkeypatch.setattr(gr, 'subscription_required', _pass_sub) await authed_client.delete('/api/caches') @@ -246,20 +286,25 @@ async def test_graphql_group_enforcement(monkeypatch, authed_client): class FakeSession: async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def execute(self, *args, **kwargs): return {'ok': True} class FakeClient: def __init__(self, transport=None, fetch_schema_from_transport=False): pass + async def __aenter__(self): return FakeSession() + async def __aexit__(self, exc_type, exc, tb): return False import services.gateway_service as gw + monkeypatch.setattr(gw, 'Client', FakeClient) r = await authed_client.post( diff --git a/backend-services/tests/test_graphql_soap_preflight_negatives.py b/backend-services/tests/test_graphql_soap_preflight_negatives.py new file mode 100644 index 0000000..c5c399a --- /dev/null +++ b/backend-services/tests/test_graphql_soap_preflight_negatives.py @@ -0,0 +1,75 @@ +import pytest + + +@pytest.mark.asyncio +async def test_graphql_preflight_header_mismatch_removes_acao(authed_client): + name, ver = 'gqlneg', 'v1' + # Create GraphQL API config + cr = await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'graphql preflight', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up.invalid'], + 'api_type': 'GRAPHQL', + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['POST'], + 'api_cors_allow_headers': ['Content-Type'], + 'api_allowed_retry_count': 0, + }, + ) + assert cr.status_code in (200, 201) + + r = await authed_client.options( + f'/api/graphql/{name}', + headers={ + 'X-API-Version': ver, + 'Origin': 'http://ok.example', + 'Access-Control-Request-Method': 'POST', + 'Access-Control-Request-Headers': 'X-Not-Allowed', + }, + ) + assert r.status_code == 204 + acao = r.headers.get('Access-Control-Allow-Origin') or r.headers.get( + 'access-control-allow-origin' + ) + assert acao in (None, '') + + +@pytest.mark.asyncio +async def test_soap_preflight_header_mismatch_removes_acao(authed_client): + name, ver = 'soapneg', 'v1' + cr = await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'soap preflight', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up.invalid'], + 'api_type': 'SOAP', + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['POST'], + 'api_cors_allow_headers': ['Content-Type'], + 'api_allowed_retry_count': 0, + }, + ) + assert cr.status_code in (200, 201) + + r = await authed_client.options( + f'/api/soap/{name}/{ver}/x', + headers={ + 'Origin': 'http://ok.example', + 'Access-Control-Request-Method': 'POST', + 'Access-Control-Request-Headers': 'X-Not-Allowed', + }, + ) + assert r.status_code == 204 + acao = r.headers.get('Access-Control-Allow-Origin') or r.headers.get( + 'access-control-allow-origin' + ) + assert acao in (None, '') diff --git a/backend-services/tests/test_group_crud_permissions.py b/backend-services/tests/test_group_crud_permissions.py new file mode 100644 index 0000000..f6cfd64 --- /dev/null +++ b/backend-services/tests/test_group_crud_permissions.py @@ -0,0 +1,75 @@ +import time + +import pytest + + +@pytest.mark.asyncio +async def test_group_crud_permissions(authed_client): + # Limited user + uname = f'grp_limited_{int(time.time())}' + pwd = 'GrpLimitStrongPass1!!' + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': f'{uname}@example.com', + 'password': pwd, + 'role': 'user', + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu.status_code in (200, 201) + + from httpx import AsyncClient + + from doorman import doorman + + limited = AsyncClient(app=doorman, base_url='http://testserver') + r = await limited.post( + '/platform/authorization', json={'email': f'{uname}@example.com', 'password': pwd} + ) + assert r.status_code == 200 + gname = f'g_{int(time.time())}' + c403 = await limited.post( + '/platform/group', json={'group_name': gname, 'group_description': 'x'} + ) + assert c403.status_code == 403 + + # Role with manage_groups + rname = f'grpmgr_{int(time.time())}' + cr = await authed_client.post( + '/platform/role', json={'role_name': rname, 'manage_groups': True} + ) + assert cr.status_code in (200, 201) + uname2 = f'grp_mgr_user_{int(time.time())}' + cu2 = await authed_client.post( + '/platform/user', + json={ + 'username': uname2, + 'email': f'{uname2}@example.com', + 'password': pwd, + 'role': rname, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu2.status_code in (200, 201) + mgr = AsyncClient(app=doorman, base_url='http://testserver') + r2 = await mgr.post( + '/platform/authorization', json={'email': f'{uname2}@example.com', 'password': pwd} + ) + assert r2.status_code == 200 + + # Create + c = await mgr.post('/platform/group', json={'group_name': gname, 'group_description': 'x'}) + assert c.status_code in (200, 201) + # Get + g = await mgr.get(f'/platform/group/{gname}') + assert g.status_code == 200 + # Update + u = await mgr.put(f'/platform/group/{gname}', json={'group_description': 'y'}) + assert u.status_code == 200 + # Delete + d = await mgr.delete(f'/platform/group/{gname}') + assert d.status_code == 200 diff --git a/backend-services/tests/test_group_role_not_found.py b/backend-services/tests/test_group_role_not_found.py index 33523b5..dd77f0b 100644 --- a/backend-services/tests/test_group_role_not_found.py +++ b/backend-services/tests/test_group_role_not_found.py @@ -1,5 +1,6 @@ import pytest + @pytest.mark.asyncio async def test_group_and_role_not_found(authed_client): gg = await authed_client.get('/platform/group/not-a-group') @@ -13,4 +14,3 @@ async def test_group_and_role_not_found(authed_client): dr = await authed_client.delete('/platform/role/not-a-role') assert dr.status_code in (400, 404) - diff --git a/backend-services/tests/test_grpc_allowlist.py b/backend-services/tests/test_grpc_allowlist.py index a748c2c..b2211f3 100644 --- a/backend-services/tests/test_grpc_allowlist.py +++ b/backend-services/tests/test_grpc_allowlist.py @@ -1,6 +1,9 @@ import pytest -async def _setup_api_with_allowlist(client, name, ver, allowed_pkgs=None, allowed_svcs=None, allowed_methods=None): + +async def _setup_api_with_allowlist( + client, name, ver, allowed_pkgs=None, allowed_svcs=None, allowed_methods=None +): payload = { 'api_name': name, 'api_version': ver, @@ -19,17 +22,22 @@ async def _setup_api_with_allowlist(client, name, ver, allowed_pkgs=None, allowe payload['api_grpc_allowed_methods'] = allowed_methods r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201), r.text - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc', - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) assert r2.status_code in (200, 201), r2.text from conftest import subscribe_self + await subscribe_self(client, name, ver) + @pytest.mark.asyncio async def test_grpc_service_not_in_allowlist_returns_403(authed_client): name, ver = 'gallow1', 'v1' @@ -43,6 +51,7 @@ async def test_grpc_service_not_in_allowlist_returns_403(authed_client): body = r.json() assert body.get('error_code') == 'GTW013' + @pytest.mark.asyncio async def test_grpc_method_not_in_allowlist_returns_403(authed_client): name, ver = 'gallow2', 'v1' @@ -56,6 +65,7 @@ async def test_grpc_method_not_in_allowlist_returns_403(authed_client): body = r.json() assert body.get('error_code') == 'GTW013' + @pytest.mark.asyncio async def test_grpc_package_not_in_allowlist_returns_403(authed_client): name, ver = 'gallow3', 'v1' @@ -69,6 +79,7 @@ async def test_grpc_package_not_in_allowlist_returns_403(authed_client): body = r.json() assert body.get('error_code') == 'GTW013' + @pytest.mark.asyncio async def test_grpc_invalid_traversal_rejected_400(authed_client): name, ver = 'gallow4', 'v1' @@ -81,4 +92,3 @@ async def test_grpc_invalid_traversal_rejected_400(authed_client): assert r.status_code == 400 body = r.json() assert body.get('error_code') == 'GTW011' - diff --git a/backend-services/tests/test_grpc_client_and_bidi_streaming.py b/backend-services/tests/test_grpc_client_and_bidi_streaming.py index 443e90a..c34ea43 100644 --- a/backend-services/tests/test_grpc_client_and_bidi_streaming.py +++ b/backend-services/tests/test_grpc_client_and_bidi_streaming.py @@ -1,5 +1,6 @@ import pytest + async def _setup_api(client, name, ver): payload = { 'api_name': name, @@ -13,41 +14,58 @@ async def _setup_api(client, name, ver): } r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc', - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) assert r2.status_code in (200, 201) from conftest import subscribe_self + await subscribe_self(client, name, ver) + def _fake_import(grpc_module_name: str): def _imp(n): if n.endswith('_pb2'): mod = type('PB2', (), {}) - setattr(mod, 'MRequest', type('Req', (), {})) + mod.MRequest = type('Req', (), {}) + class Reply: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() - def __init__(self, ok=True): self.ok = ok + + def __init__(self, ok=True): + self.ok = ok + @staticmethod - def FromString(b): return Reply(True) - setattr(mod, 'MReply', Reply) + def FromString(b): + return Reply(True) + + mod.MReply = Reply return mod if n.endswith('_pb2_grpc'): + class Stub: - def __init__(self, ch): pass + def __init__(self, ch): + pass + return type('SVC', (), {'SvcStub': Stub}) raise ImportError(n) + return _imp + @pytest.mark.asyncio async def test_grpc_client_streaming(monkeypatch, authed_client): name, ver = 'gclstr', 'v1' await _setup_api(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.importlib, 'import_module', _fake_import('gs')) class Chan: @@ -56,25 +74,41 @@ async def test_grpc_client_streaming(monkeypatch, authed_client): count = 0 async for _ in req_iter: count += 1 + class Reply: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() ok = True + return Reply() + return _call + class _Aio: @staticmethod - def insecure_channel(url): return Chan() - monkeypatch.setattr(gs, 'grpc', type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception})) + def insecure_channel(url): + return Chan() + + monkeypatch.setattr( + gs, + 'grpc', + type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}), + ) body = {'method': 'Svc.M', 'message': {}, 'stream': 'client', 'messages': [{}, {}, {}]} - r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json=body) + r = await authed_client.post( + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json=body, + ) assert r.status_code == 200 + @pytest.mark.asyncio async def test_grpc_client_streaming_field_mapping(monkeypatch, authed_client): name, ver = 'gclmap', 'v1' await _setup_api(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.importlib, 'import_module', _fake_import('gs')) class Chan: @@ -86,28 +120,46 @@ async def test_grpc_client_streaming_field_mapping(monkeypatch, authed_client): total += int(getattr(req, 'val', 0)) except Exception: pass + class Reply: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'sum'})()]})() - def __init__(self, sum): self.sum = sum + + def __init__(self, sum): + self.sum = sum + return Reply(total) + return _call + class _Aio: @staticmethod - def insecure_channel(url): return Chan() - monkeypatch.setattr(gs, 'grpc', type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception})) + def insecure_channel(url): + return Chan() + + monkeypatch.setattr( + gs, + 'grpc', + type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}), + ) msgs = [{'val': 1}, {'val': 2}, {'val': 3}] body = {'method': 'Svc.M', 'message': {}, 'stream': 'client', 'messages': msgs} - r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json=body) + r = await authed_client.post( + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json=body, + ) assert r.status_code == 200 data = r.json().get('response') or r.json() assert int(data.get('sum', 0)) == 6 + @pytest.mark.asyncio async def test_grpc_bidi_streaming(monkeypatch, authed_client): name, ver = 'gbidi', 'v1' await _setup_api(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.importlib, 'import_module', _fake_import('gs')) class Chan: @@ -116,25 +168,46 @@ async def test_grpc_bidi_streaming(monkeypatch, authed_client): class Msg: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() ok = True + async for _ in req_iter: yield Msg() + return _call + class _Aio: @staticmethod - def insecure_channel(url): return Chan() - monkeypatch.setattr(gs, 'grpc', type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception})) + def insecure_channel(url): + return Chan() - body = {'method': 'Svc.M', 'message': {}, 'stream': 'bidi', 'messages': [{}, {}, {}], 'max_items': 2} - r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json=body) + monkeypatch.setattr( + gs, + 'grpc', + type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}), + ) + + body = { + 'method': 'Svc.M', + 'message': {}, + 'stream': 'bidi', + 'messages': [{}, {}, {}], + 'max_items': 2, + } + r = await authed_client.post( + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json=body, + ) assert r.status_code == 200 data = r.json().get('response') or r.json() assert isinstance(data.get('items'), list) and len(data['items']) == 2 + @pytest.mark.asyncio async def test_grpc_bidi_streaming_field_echo(monkeypatch, authed_client): name, ver = 'gbidimap', 'v1' await _setup_api(authed_client, name, ver) import services.gateway_service as gs + monkeypatch.setattr(gs.importlib, 'import_module', _fake_import('gs')) class Chan: @@ -143,18 +216,32 @@ async def test_grpc_bidi_streaming_field_echo(monkeypatch, authed_client): class Msg: def __init__(self, v): self.val = v + DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'val'})()]})() + async for req in req_iter: yield Msg(getattr(req, 'val', None)) + return _call + class _Aio: @staticmethod - def insecure_channel(url): return Chan() - monkeypatch.setattr(gs, 'grpc', type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception})) + def insecure_channel(url): + return Chan() + + monkeypatch.setattr( + gs, + 'grpc', + type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}), + ) msgs = [{'val': 7}, {'val': 8}] body = {'method': 'Svc.M', 'message': {}, 'stream': 'bidi', 'messages': msgs, 'max_items': 10} - r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json=body) + r = await authed_client.post( + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json=body, + ) assert r.status_code == 200 data = r.json().get('response') or r.json() vals = [it.get('val') for it in (data.get('items') or [])] diff --git a/backend-services/tests/test_grpc_errors_and_retries.py b/backend-services/tests/test_grpc_errors_and_retries.py index 77c9a02..36a661e 100644 --- a/backend-services/tests/test_grpc_errors_and_retries.py +++ b/backend-services/tests/test_grpc_errors_and_retries.py @@ -1,5 +1,6 @@ import pytest + async def _setup_api(client, name, ver, retry=0): payload = { 'api_name': name, @@ -13,31 +14,41 @@ async def _setup_api(client, name, ver, retry=0): } r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc', - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) assert r2.status_code in (200, 201) from conftest import subscribe_self + await subscribe_self(client, name, ver) + def _fake_pb2_module(method_name='M'): class Req: pass + class Reply: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() + def __init__(self, ok=True): self.ok = ok + @staticmethod def FromString(b): return Reply(True) - setattr(Req, '__name__', f'{method_name}Request') - setattr(Reply, '__name__', f'{method_name}Reply') + + Req.__name__ = f'{method_name}Request' + Reply.__name__ = f'{method_name}Reply' return Req, Reply + def _make_import_module_recorder(record, pb2_map): def _imp(name): record.append(name) @@ -46,27 +57,42 @@ def _make_import_module_recorder(record, pb2_map): mapping = pb2_map.get(name) if mapping is None: req_cls, rep_cls = _fake_pb2_module('M') - setattr(mod, 'MRequest', req_cls) - setattr(mod, 'MReply', rep_cls) + mod.MRequest = req_cls + mod.MReply = rep_cls else: req_cls, rep_cls = mapping if req_cls: - setattr(mod, 'MRequest', req_cls) + mod.MRequest = req_cls if rep_cls: - setattr(mod, 'MReply', rep_cls) + mod.MReply = rep_cls return mod if name.endswith('_pb2_grpc'): + class Stub: def __init__(self, ch): self._ch = ch + async def M(self, req): - return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})() + return type( + 'R', + (), + { + 'DESCRIPTOR': type( + 'D', (), {'fields': [type('F', (), {'name': 'ok'})()]} + )(), + 'ok': True, + }, + )() + return type('SVC', (), {'SvcStub': Stub}) raise ImportError(name) + return _imp + def _make_fake_grpc_unary(sequence_codes, grpc_mod): counter = {'i': 0} + class Chan: def unary_unary(self, method, request_serializer=None, response_deserializer=None): async def _call(req, metadata=None): @@ -74,29 +100,50 @@ def _make_fake_grpc_unary(sequence_codes, grpc_mod): code = sequence_codes[idx] counter['i'] += 1 if code is None: - return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})() + return type( + 'R', + (), + { + 'DESCRIPTOR': type( + 'D', (), {'fields': [type('F', (), {'name': 'ok'})()]} + )(), + 'ok': True, + }, + )() + class E(grpc_mod.RpcError): def code(self): return code + def details(self): return f'{code.name}' + raise E() + return _call + class aio: @staticmethod def insecure_channel(url): return Chan() + return type('G', (), {'aio': aio, 'StatusCode': grpc_mod.StatusCode, 'RpcError': Exception}) + @pytest.mark.asyncio async def test_grpc_status_mappings_basic(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gmap', 'v1' await _setup_api(authed_client, name, ver, retry=0) rec = [] req_cls, rep_cls = _fake_pb2_module('M') default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2' - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)})) + monkeypatch.setattr( + gs.importlib, + 'import_module', + _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}), + ) cases = [ (gs.grpc.StatusCode.UNAUTHENTICATED, 401), @@ -110,89 +157,142 @@ async def test_grpc_status_mappings_basic(monkeypatch, authed_client): fake = _make_fake_grpc_unary([code], gs.grpc) monkeypatch.setattr(gs, 'grpc', fake) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == expect + @pytest.mark.asyncio async def test_grpc_unavailable_with_retry_still_fails_maps_503(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gunav', 'v1' await _setup_api(authed_client, name, ver, retry=2) rec = [] req_cls, rep_cls = _fake_pb2_module('M') default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2' - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)})) - fake = _make_fake_grpc_unary([gs.grpc.StatusCode.UNAVAILABLE, gs.grpc.StatusCode.UNAVAILABLE, gs.grpc.StatusCode.UNAVAILABLE], gs.grpc) + monkeypatch.setattr( + gs.importlib, + 'import_module', + _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}), + ) + fake = _make_fake_grpc_unary( + [ + gs.grpc.StatusCode.UNAVAILABLE, + gs.grpc.StatusCode.UNAVAILABLE, + gs.grpc.StatusCode.UNAVAILABLE, + ], + gs.grpc, + ) monkeypatch.setattr(gs, 'grpc', fake) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == 503 + @pytest.mark.asyncio async def test_grpc_alt_method_fallback_succeeds(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'galt', 'v1' await _setup_api(authed_client, name, ver, retry=0) rec = [] req_cls, rep_cls = _fake_pb2_module('M') default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2' - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)})) + monkeypatch.setattr( + gs.importlib, + 'import_module', + _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}), + ) fake = _make_fake_grpc_unary([gs.grpc.StatusCode.ABORTED, None], gs.grpc) monkeypatch.setattr(gs, 'grpc', fake) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == 200 + @pytest.mark.asyncio async def test_grpc_non_retryable_error_returns_500_no_retry(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gnr', 'v1' await _setup_api(authed_client, name, ver, retry=2) rec = [] req_cls, rep_cls = _fake_pb2_module('M') default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2' - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)})) - fake = _make_fake_grpc_unary([gs.grpc.StatusCode.INVALID_ARGUMENT, gs.grpc.StatusCode.INVALID_ARGUMENT], gs.grpc) + monkeypatch.setattr( + gs.importlib, + 'import_module', + _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}), + ) + fake = _make_fake_grpc_unary( + [gs.grpc.StatusCode.INVALID_ARGUMENT, gs.grpc.StatusCode.INVALID_ARGUMENT], gs.grpc + ) monkeypatch.setattr(gs, 'grpc', fake) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == 400 assert r.json().get('error_code') == 'GTW006' + @pytest.mark.asyncio async def test_grpc_deadline_exceeded_maps_to_504(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gdl', 'v1' await _setup_api(authed_client, name, ver, retry=1) rec = [] req_cls, rep_cls = _fake_pb2_module('M') default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2' - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)})) + monkeypatch.setattr( + gs.importlib, + 'import_module', + _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}), + ) fake = _make_fake_grpc_unary([gs.grpc.StatusCode.DEADLINE_EXCEEDED], gs.grpc) monkeypatch.setattr(gs, 'grpc', fake) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == 504 assert r.json().get('error_code') == 'GTW006' + @pytest.mark.asyncio async def test_grpc_unavailable_then_unimplemented_then_success(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gretry', 'v1' await _setup_api(authed_client, name, ver, retry=3) rec = [] req_cls, rep_cls = _fake_pb2_module('M') default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2' - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)})) - fake = _make_fake_grpc_unary([gs.grpc.StatusCode.UNAVAILABLE, gs.grpc.StatusCode.UNIMPLEMENTED, None], gs.grpc) + monkeypatch.setattr( + gs.importlib, + 'import_module', + _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}), + ) + fake = _make_fake_grpc_unary( + [gs.grpc.StatusCode.UNAVAILABLE, gs.grpc.StatusCode.UNIMPLEMENTED, None], gs.grpc + ) monkeypatch.setattr(gs, 'grpc', fake) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == 200 diff --git a/backend-services/tests/test_grpc_package_resolution_and_errors.py b/backend-services/tests/test_grpc_package_resolution_and_errors.py index c53db1b..5192032 100644 --- a/backend-services/tests/test_grpc_package_resolution_and_errors.py +++ b/backend-services/tests/test_grpc_package_resolution_and_errors.py @@ -1,5 +1,6 @@ import pytest + async def _setup_api(client, name, ver, retry=0, api_pkg=None): payload = { 'api_name': name, @@ -15,31 +16,41 @@ async def _setup_api(client, name, ver, retry=0, api_pkg=None): payload['api_grpc_package'] = api_pkg r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc', - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) assert r2.status_code in (200, 201) from conftest import subscribe_self + await subscribe_self(client, name, ver) + def _fake_pb2_module(method_name='M'): class Req: pass + class Reply: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() + def __init__(self, ok=True): self.ok = ok + @staticmethod def FromString(b): return Reply(True) - setattr(Req, '__name__', f'{method_name}Request') - setattr(Reply, '__name__', f'{method_name}Reply') + + Req.__name__ = f'{method_name}Request' + Reply.__name__ = f'{method_name}Reply' return Req, Reply + def _make_import_module_recorder(record, pb2_map): def _imp(name): record.append(name) @@ -48,32 +59,47 @@ def _make_import_module_recorder(record, pb2_map): mapping = pb2_map.get(name) if mapping is None: req_cls, rep_cls = _fake_pb2_module('M') - setattr(mod, 'MRequest', req_cls) - setattr(mod, 'MReply', rep_cls) + mod.MRequest = req_cls + mod.MReply = rep_cls else: req_cls, rep_cls = mapping if req_cls: - setattr(mod, 'MRequest', req_cls) + mod.MRequest = req_cls if rep_cls: - setattr(mod, 'MReply', rep_cls) + mod.MReply = rep_cls return mod if name.endswith('_pb2_grpc'): # service module with Stub class class Stub: def __init__(self, ch): self._ch = ch + async def M(self, req): - return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})() + return type( + 'R', + (), + { + 'DESCRIPTOR': type( + 'D', (), {'fields': [type('F', (), {'name': 'ok'})()]} + )(), + 'ok': True, + }, + )() + mod = type('SVC', (), {'SvcStub': Stub}) return mod raise ImportError(name) + return _imp + def _make_fake_grpc_unary(sequence_codes, grpc_mod): counter = {'i': 0} + class AioChan: async def channel_ready(self): return True + class Chan(AioChan): def unary_unary(self, method, request_serializer=None, response_deserializer=None): async def _call(req): @@ -81,166 +107,231 @@ def _make_fake_grpc_unary(sequence_codes, grpc_mod): code = sequence_codes[idx] counter['i'] += 1 if code is None: - return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})() + return type( + 'R', + (), + { + 'DESCRIPTOR': type( + 'D', (), {'fields': [type('F', (), {'name': 'ok'})()]} + )(), + 'ok': True, + }, + )() + # Raise RpcError-like class E(Exception): def code(self): return code + def details(self): return 'err' + raise E() + return _call + class aio: @staticmethod def insecure_channel(url): return Chan() + fake = type('G', (), {'aio': aio, 'StatusCode': grpc_mod.StatusCode, 'RpcError': Exception}) return fake + @pytest.mark.asyncio async def test_grpc_uses_api_grpc_package_over_request(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gpack1', 'v1' await _setup_api(authed_client, name, ver, api_pkg='api.pkg') record = [] req_cls, rep_cls = _fake_pb2_module('M') - pb2_map = { 'api.pkg_pb2': (req_cls, rep_cls) } - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)) + pb2_map = {'api.pkg_pb2': (req_cls, rep_cls)} + monkeypatch.setattr( + gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map) + ) monkeypatch.setattr(gs.os.path, 'exists', lambda p: True) monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc)) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'}, ) assert r.status_code == 200 assert any(n == 'api.pkg_pb2' for n in record) + @pytest.mark.asyncio async def test_grpc_uses_request_package_when_no_api_package(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gpack2', 'v1' await _setup_api(authed_client, name, ver, api_pkg=None) record = [] req_cls, rep_cls = _fake_pb2_module('M') - pb2_map = { 'req.pkg_pb2': (req_cls, rep_cls) } - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)) + pb2_map = {'req.pkg_pb2': (req_cls, rep_cls)} + monkeypatch.setattr( + gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map) + ) monkeypatch.setattr(gs.os.path, 'exists', lambda p: True) monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc)) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'}, ) assert r.status_code == 200 assert any(n == 'req.pkg_pb2' for n in record) + @pytest.mark.asyncio async def test_grpc_uses_default_package_when_no_overrides(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gpack3', 'v1' await _setup_api(authed_client, name, ver, api_pkg=None) record = [] req_cls, rep_cls = _fake_pb2_module('M') default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2' - pb2_map = { default_pkg: (req_cls, rep_cls) } - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)) + pb2_map = {default_pkg: (req_cls, rep_cls)} + monkeypatch.setattr( + gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map) + ) monkeypatch.setattr(gs.os.path, 'exists', lambda p: True) monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc)) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == 200 assert any(n.endswith(default_pkg) for n in record) + @pytest.mark.asyncio async def test_grpc_unavailable_then_success_with_retry(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gunavail', 'v1' await _setup_api(authed_client, name, ver, retry=1) record = [] req_cls, rep_cls = _fake_pb2_module('M') default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2' - pb2_map = { default_pkg: (req_cls, rep_cls) } - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)) + pb2_map = {default_pkg: (req_cls, rep_cls)} + monkeypatch.setattr( + gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map) + ) fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNAVAILABLE, None], gs.grpc) monkeypatch.setattr(gs, 'grpc', fake_grpc) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == 200 + @pytest.mark.asyncio async def test_grpc_unimplemented_then_success_with_retry(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gunimpl', 'v1' await _setup_api(authed_client, name, ver, retry=1) record = [] req_cls, rep_cls = _fake_pb2_module('M') default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2' - pb2_map = { default_pkg: (req_cls, rep_cls) } - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)) + pb2_map = {default_pkg: (req_cls, rep_cls)} + monkeypatch.setattr( + gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map) + ) fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNIMPLEMENTED, None], gs.grpc) monkeypatch.setattr(gs, 'grpc', fake_grpc) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == 200 + @pytest.mark.asyncio async def test_grpc_not_found_maps_to_500_error(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gnotfound', 'v1' await _setup_api(authed_client, name, ver) record = [] - pb2_map = { f'{name}_{ver}_pb2': (None, None) } - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)) + pb2_map = {f'{name}_{ver}_pb2': (None, None)} + monkeypatch.setattr( + gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map) + ) monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc)) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == 500 body = r.json() assert body.get('error_code') == 'GTW006' + @pytest.mark.asyncio async def test_grpc_unknown_maps_to_500_error(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gunk', 'v1' await _setup_api(authed_client, name, ver) record = [] req_cls, rep_cls = _fake_pb2_module('M') default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2' - pb2_map = { default_pkg: (req_cls, rep_cls) } - monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)) + pb2_map = {default_pkg: (req_cls, rep_cls)} + monkeypatch.setattr( + gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map) + ) fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNKNOWN], gs.grpc) monkeypatch.setattr(gs, 'grpc', fake_grpc) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == 500 + @pytest.mark.asyncio async def test_grpc_rejects_traversal_in_package(authed_client): name, ver = 'gtrv', 'v1' await _setup_api(authed_client, name, ver) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, - json={'method': 'Svc.M', 'message': {}, 'package': '../evil'} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}, 'package': '../evil'}, ) assert r.status_code == 400 body = r.json() assert body.get('error_code') == 'GTW011' + @pytest.mark.asyncio async def test_grpc_proto_missing_returns_404_gtw012(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gproto404', 'v1' await _setup_api(authed_client, name, ver) + # Make on-demand proto generation fail by raising on import grpc_tools def _imp_fail(name): if name.startswith('grpc_tools'): raise ImportError('no tools') raise ImportError(name) + monkeypatch.setattr(gs.importlib, 'import_module', _imp_fail) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == 404 body = r.json() diff --git a/backend-services/tests/test_grpc_streaming_and_metadata.py b/backend-services/tests/test_grpc_streaming_and_metadata.py index 43fe266..8d60333 100644 --- a/backend-services/tests/test_grpc_streaming_and_metadata.py +++ b/backend-services/tests/test_grpc_streaming_and_metadata.py @@ -1,5 +1,6 @@ import pytest + @pytest.mark.asyncio async def test_grpc_server_streaming(monkeypatch, authed_client): name, ver = 'gstr', 'v1' @@ -15,35 +16,50 @@ async def test_grpc_server_streaming(monkeypatch, authed_client): } r = await authed_client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc', - }) + r2 = await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) assert r2.status_code in (200, 201) from conftest import subscribe_self + await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + # Fake pb2 modules def _imp(n): if n.endswith('_pb2'): mod = type('PB2', (), {}) - setattr(mod, 'MRequest', type('Req', (), {})) + mod.MRequest = type('Req', (), {}) + class Reply: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() - def __init__(self, ok=True): self.ok = ok + + def __init__(self, ok=True): + self.ok = ok + @staticmethod - def FromString(b): return Reply(True) - setattr(mod, 'MReply', Reply) + def FromString(b): + return Reply(True) + + mod.MReply = Reply return mod if n.endswith('_pb2_grpc'): + class Stub: - def __init__(self, ch): pass + def __init__(self, ch): + pass + return type('SVC', (), {'SvcStub': Stub}) raise ImportError(n) + monkeypatch.setattr(gs.importlib, 'import_module', _imp) class Chan: @@ -52,20 +68,34 @@ async def test_grpc_server_streaming(monkeypatch, authed_client): class Msg: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() ok = True + for _ in range(2): yield Msg() + return _aiter + class _Aio: @staticmethod - def insecure_channel(url): return Chan() - monkeypatch.setattr(gs, 'grpc', type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception})) + def insecure_channel(url): + return Chan() + + monkeypatch.setattr( + gs, + 'grpc', + type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}), + ) body = {'method': 'Svc.M', 'message': {}, 'stream': 'server', 'max_items': 2} - resp = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json=body) + resp = await authed_client.post( + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json=body, + ) assert resp.status_code == 200 data = resp.json().get('response') or resp.json() assert isinstance(data.get('items'), list) and len(data['items']) == 2 + @pytest.mark.asyncio async def test_grpc_metadata_pass_through(monkeypatch, authed_client): name, ver = 'gmeta', 'v1' @@ -82,51 +112,77 @@ async def test_grpc_metadata_pass_through(monkeypatch, authed_client): } r = await authed_client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc', - }) + r2 = await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) assert r2.status_code in (200, 201) from conftest import subscribe_self + await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + def _imp(n): if n.endswith('_pb2'): mod = type('PB2', (), {}) - setattr(mod, 'MRequest', type('Req', (), {})) + mod.MRequest = type('Req', (), {}) + class Reply: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() + @staticmethod - def FromString(b): return Reply() - setattr(mod, 'MReply', Reply) + def FromString(b): + return Reply() + + mod.MReply = Reply return mod if n.endswith('_pb2_grpc'): + class Stub: - def __init__(self, ch): pass + def __init__(self, ch): + pass + return type('SVC', (), {'SvcStub': Stub}) raise ImportError(n) + monkeypatch.setattr(gs.importlib, 'import_module', _imp) captured = {'md': None} + class Chan: def unary_unary(self, method, request_serializer=None, response_deserializer=None): async def _call(req, metadata=None): captured['md'] = list(metadata or []) + class Reply: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() ok = True + return Reply() + return _call + class _Aio: @staticmethod - def insecure_channel(url): return Chan() - monkeypatch.setattr(gs, 'grpc', type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception})) + def insecure_channel(url): + return Chan() + + monkeypatch.setattr( + gs, + 'grpc', + type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}), + ) headers = {'X-API-Version': ver, 'Content-Type': 'application/json', 'X-Meta-One': 'alpha'} - r = await authed_client.post(f'/api/grpc/{name}', headers=headers, json={'method': 'Svc.M', 'message': {}}) + r = await authed_client.post( + f'/api/grpc/{name}', headers=headers, json={'method': 'Svc.M', 'message': {}} + ) assert r.status_code == 200 assert ('x-meta-one', 'alpha') in [(k.lower(), v) for k, v in (captured['md'] or [])] diff --git a/backend-services/tests/test_grpc_subscription_and_metrics.py b/backend-services/tests/test_grpc_subscription_and_metrics.py index c101fbf..ed87c46 100644 --- a/backend-services/tests/test_grpc_subscription_and_metrics.py +++ b/backend-services/tests/test_grpc_subscription_and_metrics.py @@ -1,6 +1,8 @@ import json + import pytest + async def _setup_api(client, name, ver, public=False): payload = { 'api_name': name, @@ -15,49 +17,65 @@ async def _setup_api(client, name, ver, public=False): } r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc', - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) assert r2.status_code in (200, 201) + @pytest.mark.asyncio async def test_grpc_requires_subscription_when_not_public(monkeypatch, authed_client): name, ver = 'gsub', 'v1' await _setup_api(authed_client, name, ver, public=False) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code == 403 + @pytest.mark.asyncio async def test_grpc_metrics_bytes_in_out(monkeypatch, authed_client): name, ver = 'gmet', 'v1' await _setup_api(authed_client, name, ver, public=False) from conftest import subscribe_self + await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + def _imp(name): if name.endswith('_pb2'): mod = type('PB2', (), {}) - setattr(mod, 'MRequest', type('Req', (), {}) ) + mod.MRequest = type('Req', (), {}) + class Reply: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() + @staticmethod def FromString(b): return Reply() - setattr(mod, 'MReply', Reply) + + mod.MReply = Reply return mod if name.endswith('_pb2_grpc'): + class Stub: - def __init__(self, ch): pass + def __init__(self, ch): + pass + return type('SVC', (), {'SvcStub': Stub}) raise ImportError(name) + monkeypatch.setattr(gs.importlib, 'import_module', _imp) class Chan: @@ -67,13 +85,19 @@ async def test_grpc_metrics_bytes_in_out(monkeypatch, authed_client): class Reply: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() ok = True + return Reply() + return _call + class _Aio: @staticmethod def insecure_channel(url): return Chan() - fake_grpc = type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}) + + fake_grpc = type( + 'G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception} + ) monkeypatch.setattr(gs, 'grpc', fake_grpc) m0 = await authed_client.get('/platform/monitor/metrics') @@ -83,7 +107,11 @@ async def test_grpc_metrics_bytes_in_out(monkeypatch, authed_client): body_obj = {'method': 'Svc.M', 'message': {}} raw = json.dumps(body_obj) - headers = {'Content-Type': 'application/json', 'X-API-Version': ver, 'Content-Length': str(len(raw))} + headers = { + 'Content-Type': 'application/json', + 'X-API-Version': ver, + 'Content-Length': str(len(raw)), + } r = await authed_client.post(f'/api/grpc/{name}', headers=headers, content=raw) assert r.status_code in (200, 500, 501, 503) @@ -93,4 +121,3 @@ async def test_grpc_metrics_bytes_in_out(monkeypatch, authed_client): tout1 = int(j1.get('total_bytes_out', 0)) assert tin1 - tin0 >= len(raw) assert tout1 >= tout0 - diff --git a/backend-services/tests/test_grpcs_tls_and_deadlines.py b/backend-services/tests/test_grpcs_tls_and_deadlines.py index ed1cfd5..38616c8 100644 --- a/backend-services/tests/test_grpcs_tls_and_deadlines.py +++ b/backend-services/tests/test_grpcs_tls_and_deadlines.py @@ -1,5 +1,6 @@ import pytest + async def _setup_api(client, name, ver, url): payload = { 'api_name': name, @@ -13,20 +14,26 @@ async def _setup_api(client, name, ver, url): } r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc', - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) assert r2.status_code in (200, 201) from conftest import subscribe_self + await subscribe_self(client, name, ver) + @pytest.mark.asyncio async def test_grpcs_misconfig_returns_500(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'gtls', 'v1' await _setup_api(authed_client, name, ver, url='grpcs://example:50051') @@ -34,19 +41,26 @@ async def test_grpcs_misconfig_returns_500(monkeypatch, authed_client): def _imp(name): if name.endswith('_pb2'): mod = type('PB2', (), {}) - setattr(mod, 'MRequest', type('Req', (), {}) ) + mod.MRequest = type('Req', (), {}) + class Reply: DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})() + @staticmethod def FromString(b): return Reply() - setattr(mod, 'MReply', Reply) + + mod.MReply = Reply return mod if name.endswith('_pb2_grpc'): + class Stub: - def __init__(self, ch): pass + def __init__(self, ch): + pass + return type('SVC', (), {'SvcStub': Stub}) raise ImportError(name) + monkeypatch.setattr(gs.importlib, 'import_module', _imp) # Simulate insecure_channel rejecting grpcs:// URL @@ -56,11 +70,15 @@ async def test_grpcs_misconfig_returns_500(monkeypatch, authed_client): if str(url).startswith('grpcs://'): raise RuntimeError('TLS required') return object() - fake_grpc = type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}) + + fake_grpc = type( + 'G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception} + ) monkeypatch.setattr(gs, 'grpc', fake_grpc) r = await authed_client.post( - f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}} + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, ) assert r.status_code in (500, 501, 503) - diff --git a/backend-services/tests/test_health_status.py b/backend-services/tests/test_health_status.py index a14605e..2e477eb 100644 --- a/backend-services/tests/test_health_status.py +++ b/backend-services/tests/test_health_status.py @@ -1,5 +1,6 @@ import pytest + @pytest.mark.asyncio async def test_public_health_probe_ok(client): r = await client.get('/api/health') @@ -7,6 +8,7 @@ async def test_public_health_probe_ok(client): body = r.json().get('response', r.json()) assert body.get('status') in ('online', 'healthy', 'ready') + @pytest.mark.asyncio async def test_status_requires_auth(client): try: diff --git a/backend-services/tests/test_http_circuit_breaker.py b/backend-services/tests/test_http_circuit_breaker.py index b18d129..735e315 100644 --- a/backend-services/tests/test_http_circuit_breaker.py +++ b/backend-services/tests/test_http_circuit_breaker.py @@ -1,15 +1,16 @@ import asyncio -import os -from typing import Callable +from collections.abc import Callable import httpx import pytest -from utils.http_client import request_with_resilience, circuit_manager, CircuitOpenError +from utils.http_client import CircuitOpenError, circuit_manager, request_with_resilience + def _mock_transport(handler: Callable[[httpx.Request], httpx.Response]) -> httpx.MockTransport: return httpx.MockTransport(lambda req: handler(req)) + @pytest.mark.asyncio async def test_retries_on_503_then_success(monkeypatch): calls = {'n': 0} @@ -27,17 +28,23 @@ async def test_retries_on_503_then_success(monkeypatch): monkeypatch.setenv('CIRCUIT_BREAKER_THRESHOLD', '5') resp = await request_with_resilience( - client, 'GET', 'http://upstream.test/ok', - api_key='test-api/v1', retries=2, api_config=None, + client, + 'GET', + 'http://upstream.test/ok', + api_key='test-api/v1', + retries=2, + api_config=None, ) assert resp.status_code == 200 assert resp.json() == {'ok': True} assert calls['n'] == 3 + @pytest.mark.asyncio async def test_circuit_opens_after_failures_and_half_open(monkeypatch): calls = {'n': 0} + # Always return 503 def handler(req: httpx.Request) -> httpx.Response: calls['n'] += 1 @@ -53,15 +60,23 @@ async def test_circuit_opens_after_failures_and_half_open(monkeypatch): api_key = 'breaker-api/v1' circuit_manager._states.clear() - resp = await request_with_resilience(client, 'GET', 'http://u.test/err', api_key=api_key, retries=1) + resp = await request_with_resilience( + client, 'GET', 'http://u.test/err', api_key=api_key, retries=1 + ) assert resp.status_code == 503 with pytest.raises(CircuitOpenError): - await request_with_resilience(client, 'GET', 'http://u.test/err', api_key=api_key, retries=0) + await request_with_resilience( + client, 'GET', 'http://u.test/err', api_key=api_key, retries=0 + ) await asyncio.sleep(0.11) - resp2 = await request_with_resilience(client, 'GET', 'http://u.test/err', api_key=api_key, retries=0) + resp2 = await request_with_resilience( + client, 'GET', 'http://u.test/err', api_key=api_key, retries=0 + ) assert resp2.status_code == 503 with pytest.raises(CircuitOpenError): - await request_with_resilience(client, 'GET', 'http://u.test/err', api_key=api_key, retries=0) + await request_with_resilience( + client, 'GET', 'http://u.test/err', api_key=api_key, retries=0 + ) diff --git a/backend-services/tests/test_ip_filter_platform.py b/backend-services/tests/test_ip_filter_platform.py index 76ea34e..be707d4 100644 --- a/backend-services/tests/test_ip_filter_platform.py +++ b/backend-services/tests/test_ip_filter_platform.py @@ -1,17 +1,23 @@ import json + import pytest + async def _ensure_manage_security(authed_client): await authed_client.put('/platform/role/admin', json={'manage_security': True}) + async def _update_security(authed_client, settings: dict, headers: dict | None = None): await _ensure_manage_security(authed_client) r = await authed_client.put('/platform/security/settings', json=settings, headers=headers or {}) assert r.status_code == 200, r.text return r + @pytest.mark.asyncio -async def test_global_whitelist_blocks_non_whitelisted_with_trusted_proxy(monkeypatch, authed_client, client): +async def test_global_whitelist_blocks_non_whitelisted_with_trusted_proxy( + monkeypatch, authed_client, client +): await _update_security( authed_client, settings={ @@ -24,17 +30,25 @@ async def test_global_whitelist_blocks_non_whitelisted_with_trusted_proxy(monkey ) try: - r = await client.get('/platform/monitor/liveness', headers={'X-Forwarded-For': '203.0.113.10'}) + r = await client.get( + '/platform/monitor/liveness', headers={'X-Forwarded-For': '203.0.113.10'} + ) assert r.status_code == 403 body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} assert (body.get('error_code') or body.get('response', {}).get('error_code')) == 'SEC010' finally: await _update_security( authed_client, - settings={'ip_whitelist': [], 'ip_blacklist': [], 'trust_x_forwarded_for': False, 'xff_trusted_proxies': []}, + settings={ + 'ip_whitelist': [], + 'ip_blacklist': [], + 'trust_x_forwarded_for': False, + 'xff_trusted_proxies': [], + }, headers={'X-Forwarded-For': '198.51.100.10'}, ) + @pytest.mark.asyncio async def test_global_blacklist_blocks_with_trusted_proxy(monkeypatch, authed_client, client): await _update_security( @@ -49,17 +63,25 @@ async def test_global_blacklist_blocks_with_trusted_proxy(monkeypatch, authed_cl ) try: - r = await client.get('/platform/monitor/liveness', headers={'X-Forwarded-For': '203.0.113.10'}) + r = await client.get( + '/platform/monitor/liveness', headers={'X-Forwarded-For': '203.0.113.10'} + ) assert r.status_code == 403 body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} assert (body.get('error_code') or body.get('response', {}).get('error_code')) == 'SEC011' finally: await _update_security( authed_client, - settings={'ip_whitelist': [], 'ip_blacklist': [], 'trust_x_forwarded_for': False, 'xff_trusted_proxies': []}, + settings={ + 'ip_whitelist': [], + 'ip_blacklist': [], + 'trust_x_forwarded_for': False, + 'xff_trusted_proxies': [], + }, headers={'X-Forwarded-For': '198.51.100.10'}, ) + @pytest.mark.asyncio async def test_xff_ignored_when_proxy_not_trusted(monkeypatch, authed_client, client): await _update_security( @@ -74,7 +96,9 @@ async def test_xff_ignored_when_proxy_not_trusted(monkeypatch, authed_client, cl ) try: - r = await client.get('/platform/monitor/liveness', headers={'X-Forwarded-For': '198.51.100.10'}) + r = await client.get( + '/platform/monitor/liveness', headers={'X-Forwarded-For': '198.51.100.10'} + ) assert r.status_code == 403 body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} assert (body.get('error_code') or body.get('response', {}).get('error_code')) == 'SEC010' @@ -83,13 +107,22 @@ async def test_xff_ignored_when_proxy_not_trusted(monkeypatch, authed_client, cl try: await _update_security( authed_client, - settings={'ip_whitelist': [], 'ip_blacklist': [], 'trust_x_forwarded_for': False, 'xff_trusted_proxies': [], 'allow_localhost_bypass': False}, + settings={ + 'ip_whitelist': [], + 'ip_blacklist': [], + 'trust_x_forwarded_for': False, + 'xff_trusted_proxies': [], + 'allow_localhost_bypass': False, + }, ) finally: monkeypatch.delenv('LOCAL_HOST_IP_BYPASS', raising=False) + @pytest.mark.asyncio -async def test_localhost_bypass_enabled_allows_without_forwarding_headers(monkeypatch, authed_client, client): +async def test_localhost_bypass_enabled_allows_without_forwarding_headers( + monkeypatch, authed_client, client +): await _update_security( authed_client, settings={ @@ -106,11 +139,20 @@ async def test_localhost_bypass_enabled_allows_without_forwarding_headers(monkey finally: await _update_security( authed_client, - settings={'ip_whitelist': [], 'ip_blacklist': [], 'trust_x_forwarded_for': False, 'xff_trusted_proxies': [], 'allow_localhost_bypass': False}, + settings={ + 'ip_whitelist': [], + 'ip_blacklist': [], + 'trust_x_forwarded_for': False, + 'xff_trusted_proxies': [], + 'allow_localhost_bypass': False, + }, ) + @pytest.mark.asyncio -async def test_localhost_bypass_disabled_blocks_without_forwarding_headers(monkeypatch, authed_client, client): +async def test_localhost_bypass_disabled_blocks_without_forwarding_headers( + monkeypatch, authed_client, client +): monkeypatch.setenv('LOCAL_HOST_IP_BYPASS', 'false') await _update_security( @@ -133,20 +175,29 @@ async def test_localhost_bypass_disabled_blocks_without_forwarding_headers(monke try: await _update_security( authed_client, - settings={'ip_whitelist': [], 'ip_blacklist': [], 'trust_x_forwarded_for': False, 'xff_trusted_proxies': []}, + settings={ + 'ip_whitelist': [], + 'ip_blacklist': [], + 'trust_x_forwarded_for': False, + 'xff_trusted_proxies': [], + }, ) finally: monkeypatch.delenv('LOCAL_HOST_IP_BYPASS', raising=False) + class _AuditSpy: def __init__(self): self.calls = [] + def info(self, msg): self.calls.append(msg) + @pytest.mark.asyncio async def test_audit_logged_on_global_deny(monkeypatch, authed_client, client): import utils.audit_util as au + orig = au._logger spy = _AuditSpy() au._logger = spy @@ -162,7 +213,9 @@ async def test_audit_logged_on_global_deny(monkeypatch, authed_client, client): }, ) - r = await client.get('/platform/monitor/liveness', headers={'X-Forwarded-For': '203.0.113.10'}) + r = await client.get( + '/platform/monitor/liveness', headers={'X-Forwarded-For': '203.0.113.10'} + ) assert r.status_code == 403 assert any('ip.global_deny' in str(c) for c in spy.calls) parsed = [json.loads(c) for c in spy.calls if isinstance(c, str)] @@ -171,6 +224,11 @@ async def test_audit_logged_on_global_deny(monkeypatch, authed_client, client): au._logger = orig await _update_security( authed_client, - settings={'ip_whitelist': [], 'ip_blacklist': [], 'trust_x_forwarded_for': False, 'xff_trusted_proxies': []}, + settings={ + 'ip_whitelist': [], + 'ip_blacklist': [], + 'trust_x_forwarded_for': False, + 'xff_trusted_proxies': [], + }, headers={'X-Forwarded-For': '198.51.100.10'}, ) diff --git a/backend-services/tests/test_ip_policy_allow_deny_cidr.py b/backend-services/tests/test_ip_policy_allow_deny_cidr.py index 9bf0005..8c625c3 100644 --- a/backend-services/tests/test_ip_policy_allow_deny_cidr.py +++ b/backend-services/tests/test_ip_policy_allow_deny_cidr.py @@ -1,7 +1,7 @@ import pytest - from tests.test_gateway_routing_limits import _FakeAsyncClient + async def _setup_api_public(client, name, ver, mode='allow_all', wl=None, bl=None): payload = { 'api_name': name, @@ -21,70 +21,96 @@ async def _setup_api_public(client, name, ver, mode='allow_all', wl=None, bl=Non payload['api_ip_blacklist'] = bl r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'GET', - 'endpoint_uri': '/res', - 'endpoint_description': 'res' - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/res', + 'endpoint_description': 'res', + }, + ) assert r2.status_code in (200, 201) return name, ver + @pytest.mark.asyncio async def test_ip_policy_allows_exact_ip(monkeypatch, authed_client): import services.gateway_service as gs - name, ver = await _setup_api_public(authed_client, 'ipok1', 'v1', mode='whitelist', wl=['127.0.0.1']) + + name, ver = await _setup_api_public( + authed_client, 'ipok1', 'v1', mode='whitelist', wl=['127.0.0.1'] + ) monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/res') assert r.status_code == 200 + @pytest.mark.asyncio async def test_ip_policy_denies_exact_ip(monkeypatch, authed_client): import services.gateway_service as gs + monkeypatch.setenv('LOCAL_HOST_IP_BYPASS', 'false') - name, ver = await _setup_api_public(authed_client, 'ipdeny1', 'v1', mode='allow_all', bl=['127.0.0.1']) + name, ver = await _setup_api_public( + authed_client, 'ipdeny1', 'v1', mode='allow_all', bl=['127.0.0.1'] + ) monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/res') assert r.status_code == 403 body = r.json() assert body.get('error_code') == 'API011' + @pytest.mark.asyncio async def test_ip_policy_allows_cidr(monkeypatch, authed_client): import services.gateway_service as gs - name, ver = await _setup_api_public(authed_client, 'ipok2', 'v1', mode='whitelist', wl=['127.0.0.0/24']) + + name, ver = await _setup_api_public( + authed_client, 'ipok2', 'v1', mode='whitelist', wl=['127.0.0.0/24'] + ) monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/res') assert r.status_code == 200 + @pytest.mark.asyncio async def test_ip_policy_denies_cidr(monkeypatch, authed_client): import services.gateway_service as gs + monkeypatch.setenv('LOCAL_HOST_IP_BYPASS', 'false') - name, ver = await _setup_api_public(authed_client, 'ipdeny2', 'v1', mode='allow_all', bl=['127.0.0.0/24']) + name, ver = await _setup_api_public( + authed_client, 'ipdeny2', 'v1', mode='allow_all', bl=['127.0.0.0/24'] + ) monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/res') assert r.status_code == 403 assert r.json().get('error_code') == 'API011' + @pytest.mark.asyncio async def test_ip_policy_denylist_precedence_over_allowlist(monkeypatch, authed_client): import services.gateway_service as gs + monkeypatch.setenv('LOCAL_HOST_IP_BYPASS', 'false') - name, ver = await _setup_api_public(authed_client, 'ipdeny3', 'v1', mode='whitelist', wl=['127.0.0.1'], bl=['127.0.0.1']) + name, ver = await _setup_api_public( + authed_client, 'ipdeny3', 'v1', mode='whitelist', wl=['127.0.0.1'], bl=['127.0.0.1'] + ) monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/res') assert r.status_code == 403 assert r.json().get('error_code') == 'API011' + @pytest.mark.asyncio async def test_ip_policy_enforced_early_returns_http_error(monkeypatch, authed_client): import services.gateway_service as gs + monkeypatch.setenv('LOCAL_HOST_IP_BYPASS', 'false') - name, ver = await _setup_api_public(authed_client, 'ipdeny4', 'v1', mode='whitelist', wl=['203.0.113.5']) + name, ver = await _setup_api_public( + authed_client, 'ipdeny4', 'v1', mode='whitelist', wl=['203.0.113.5'] + ) monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/res') assert r.status_code == 403 assert r.json().get('error_code') == 'API010' - diff --git a/backend-services/tests/test_jwt_config.py b/backend-services/tests/test_jwt_config.py index 6290d62..51e321f 100644 --- a/backend-services/tests/test_jwt_config.py +++ b/backend-services/tests/test_jwt_config.py @@ -1,12 +1,11 @@ -import os - from utils.auth_util import is_jwt_configured + def test_is_jwt_configured_true(monkeypatch): monkeypatch.setenv('JWT_SECRET_KEY', 'abc123') assert is_jwt_configured() is True + def test_is_jwt_configured_false(monkeypatch): monkeypatch.delenv('JWT_SECRET_KEY', raising=False) assert is_jwt_configured() is False - diff --git a/backend-services/tests/test_lifespan_failures.py b/backend-services/tests/test_lifespan_failures.py index f3ec265..2193a57 100644 --- a/backend-services/tests/test_lifespan_failures.py +++ b/backend-services/tests/test_lifespan_failures.py @@ -1,30 +1,37 @@ import pytest + @pytest.mark.asyncio async def test_production_guard_causes_startup_failure_direct(monkeypatch): monkeypatch.setenv('ENV', 'production') monkeypatch.setenv('HTTPS_ONLY', 'false') - from doorman import app_lifespan, doorman import pytest as _pytest + + from doorman import app_lifespan, doorman + with _pytest.raises(RuntimeError): async with app_lifespan(doorman): pass + @pytest.mark.asyncio async def test_lifespan_failure_raises_with_fresh_app_testclient(monkeypatch): monkeypatch.setenv('ENV', 'production') monkeypatch.setenv('HTTPS_ONLY', 'false') from fastapi import FastAPI + from doorman import app_lifespan + app = FastAPI(lifespan=app_lifespan) @app.get('/ping') async def ping(): return {'ok': True} - from starlette.testclient import TestClient import pytest as _pytest + from starlette.testclient import TestClient + with _pytest.raises(RuntimeError): with TestClient(app) as client: client.get('/ping') diff --git a/backend-services/tests/test_lists_and_paging.py b/backend-services/tests/test_lists_and_paging.py index bf10205..9d6daa9 100644 --- a/backend-services/tests/test_lists_and_paging.py +++ b/backend-services/tests/test_lists_and_paging.py @@ -1,8 +1,8 @@ import pytest + @pytest.mark.asyncio async def test_list_endpoints_roles_groups_apis(authed_client): - await authed_client.post( '/platform/api', json={ @@ -17,12 +17,10 @@ async def test_list_endpoints_roles_groups_apis(authed_client): }, ) await authed_client.post( - '/platform/group', - json={'group_name': 'glist', 'group_description': 'gd', 'api_access': []}, + '/platform/group', json={'group_name': 'glist', 'group_description': 'gd', 'api_access': []} ) await authed_client.post( - '/platform/role', - json={'role_name': 'rlist', 'role_description': 'rd'}, + '/platform/role', json={'role_name': 'rlist', 'role_description': 'rd'} ) ra = await authed_client.get('/platform/api/all?page=1&page_size=5') @@ -31,4 +29,3 @@ async def test_list_endpoints_roles_groups_apis(authed_client): assert rg.status_code == 200 rr = await authed_client.get('/platform/role/all?page=1&page_size=5') assert rr.status_code == 200 - diff --git a/backend-services/tests/test_logging_permissions.py b/backend-services/tests/test_logging_permissions.py index ba6e32c..3bfe7f0 100644 --- a/backend-services/tests/test_logging_permissions.py +++ b/backend-services/tests/test_logging_permissions.py @@ -1,39 +1,98 @@ +import time + import pytest +from httpx import AsyncClient + + +async def _login(email: str, password: str) -> AsyncClient: + from doorman import doorman + + c = AsyncClient(app=doorman, base_url='http://testserver') + r = await c.post('/platform/authorization', json={'email': email, 'password': password}) + assert r.status_code == 200, r.text + body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} + token = body.get('access_token') + if token: + c.cookies.set('access_token_cookie', token, domain='testserver', path='/') + return c + @pytest.mark.asyncio -async def test_logging_requires_permissions(authed_client): - - r = await authed_client.put( - '/platform/role/admin', - json={'view_logs': False, 'export_logs': False}, +async def test_logging_routes_permissions(authed_client): + # Limited user without view_logs/export_logs + uname = f'log_limited_{int(time.time())}' + pwd = 'LogLimitStrongPass1!!' + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': f'{uname}@example.com', + 'password': pwd, + 'role': 'user', + 'groups': ['ALL'], + 'ui_access': True, + }, ) - assert r.status_code in (200, 201) + assert cu.status_code in (200, 201) + limited = await _login(f'{uname}@example.com', pwd) + for ep in ( + '/platform/logging/logs', + '/platform/logging/logs/files', + '/platform/logging/logs/statistics', + ): + r = await limited.get(ep) + assert r.status_code == 403 + exp = await limited.get('/platform/logging/logs/export') + assert exp.status_code == 403 - logs = await authed_client.get('/platform/logging/logs?limit=5') - assert logs.status_code == 403 - - files = await authed_client.get('/platform/logging/logs/files') - assert files.status_code == 403 - - stats = await authed_client.get('/platform/logging/logs/statistics') - assert stats.status_code == 403 - - export = await authed_client.get('/platform/logging/logs/export?format=json') - assert export.status_code == 403 - - download = await authed_client.get('/platform/logging/logs/download?format=csv') - assert download.status_code == 403 - - r2 = await authed_client.put( - '/platform/role/admin', - json={'view_logs': True, 'export_logs': True}, + # Role with view_logs only + rname = f'view_logs_{int(time.time())}' + cr = await authed_client.post('/platform/role', json={'role_name': rname, 'view_logs': True}) + assert cr.status_code in (200, 201) + vuser = f'log_viewer_{int(time.time())}' + cu2 = await authed_client.post( + '/platform/user', + json={ + 'username': vuser, + 'email': f'{vuser}@example.com', + 'password': pwd, + 'role': rname, + 'groups': ['ALL'], + 'ui_access': True, + }, ) - assert r2.status_code in (200, 201) - - logs2 = await authed_client.get('/platform/logging/logs?limit=1') - assert logs2.status_code == 200 - files2 = await authed_client.get('/platform/logging/logs/files') - assert files2.status_code == 200 - export2 = await authed_client.get('/platform/logging/logs/download?format=json') - assert export2.status_code == 200 + assert cu2.status_code in (200, 201) + viewer = await _login(f'{vuser}@example.com', pwd) + for ep in ( + '/platform/logging/logs', + '/platform/logging/logs/files', + '/platform/logging/logs/statistics', + ): + r = await viewer.get(ep) + assert r.status_code == 200 + # But export still forbidden without export_logs + exp2 = await viewer.get('/platform/logging/logs/export') + assert exp2.status_code == 403 + # Role with export_logs + rname2 = f'export_logs_{int(time.time())}' + cr2 = await authed_client.post( + '/platform/role', json={'role_name': rname2, 'export_logs': True} + ) + assert cr2.status_code in (200, 201) + euser = f'log_exporter_{int(time.time())}' + cu3 = await authed_client.post( + '/platform/user', + json={ + 'username': euser, + 'email': f'{euser}@example.com', + 'password': pwd, + 'role': rname2, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu3.status_code in (200, 201) + exporter = await _login(f'{euser}@example.com', pwd) + exp3 = await exporter.get('/platform/logging/logs/export') + assert exp3.status_code == 200 diff --git a/backend-services/tests/test_logging_redaction.py b/backend-services/tests/test_logging_redaction.py index a699134..f68741d 100644 --- a/backend-services/tests/test_logging_redaction.py +++ b/backend-services/tests/test_logging_redaction.py @@ -1,8 +1,7 @@ import logging -from types import SimpleNamespace + def test_logging_redaction_filters_sensitive_values(): - logger = logging.getLogger('doorman.gateway') filt = None for h in logger.handlers: @@ -13,8 +12,13 @@ def test_logging_redaction_filters_sensitive_values(): secret = 'supersecretvalue' record = logging.LogRecord( - name='doorman.gateway', level=logging.INFO, pathname=__file__, lineno=1, - msg=f'Authorization: Bearer {secret}; password=\"{secret}\"; access_token=\"{secret}\"', args=(), exc_info=None + name='doorman.gateway', + level=logging.INFO, + pathname=__file__, + lineno=1, + msg=f'Authorization: Bearer {secret}; password="{secret}"; access_token="{secret}"', + args=(), + exc_info=None, ) ok = filt.filter(record) assert ok is True diff --git a/backend-services/tests/test_logging_redaction_extended.py b/backend-services/tests/test_logging_redaction_extended.py index 2939400..c076017 100644 --- a/backend-services/tests/test_logging_redaction_extended.py +++ b/backend-services/tests/test_logging_redaction_extended.py @@ -1,5 +1,6 @@ import logging + def test_redaction_handles_cookies_csrf_and_mixed_cases(): logger = logging.getLogger('doorman.gateway') filt = None @@ -12,15 +13,19 @@ def test_redaction_handles_cookies_csrf_and_mixed_cases(): secret = 'S3cr3t!' msg = ( f'authorization: Bearer {secret}; Authorization: Bearer {secret}; ' - f'cookie: session={secret}; x-csrf-token: {secret}; PASSWORD=\"{secret}\"' + f'cookie: session={secret}; x-csrf-token: {secret}; PASSWORD="{secret}"' ) rec = logging.LogRecord( - name='doorman.gateway', level=logging.INFO, pathname=__file__, lineno=1, - msg=msg, args=(), exc_info=None + name='doorman.gateway', + level=logging.INFO, + pathname=__file__, + lineno=1, + msg=msg, + args=(), + exc_info=None, ) assert filt.filter(rec) is True out = str(rec.msg) assert secret not in out assert out.lower().count('[redacted]') >= 3 - diff --git a/backend-services/tests/test_logging_redaction_new_patterns.py b/backend-services/tests/test_logging_redaction_new_patterns.py index 4bcd66c..4d0b204 100644 --- a/backend-services/tests/test_logging_redaction_new_patterns.py +++ b/backend-services/tests/test_logging_redaction_new_patterns.py @@ -1,6 +1,7 @@ import logging from io import StringIO + def _capture(logger_name: str, message: str) -> str: logger = logging.getLogger(logger_name) stream = StringIO() @@ -15,16 +16,19 @@ def _capture(logger_name: str, message: str) -> str: logger.removeHandler(h) return stream.getvalue() + def test_redacts_set_cookie_and_x_api_key(): - msg = 'Set-Cookie: access_token_cookie=abc123; Path=/; HttpOnly; Secure; X-API-Key: my-secret-key' + msg = ( + 'Set-Cookie: access_token_cookie=abc123; Path=/; HttpOnly; Secure; X-API-Key: my-secret-key' + ) out = _capture('doorman.gateway', msg) assert 'Set-Cookie: [REDACTED]' in out or 'set-cookie: [REDACTED]' in out.lower() assert 'X-API-Key: [REDACTED]' in out or 'x-api-key: [REDACTED]' in out.lower() + def test_redacts_bearer_and_basic_tokens(): msg = 'Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJhIn0.sgn; authorization: basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==' out = _capture('doorman.gateway', msg) low = out.lower() assert 'authorization: [redacted]' in low assert 'basic [redacted]' in low or 'authorization: [redacted]' in low - diff --git a/backend-services/tests/test_login_ip_rate_limit_flow.py b/backend-services/tests/test_login_ip_rate_limit_flow.py index 2d5406f..8aeba22 100644 --- a/backend-services/tests/test_login_ip_rate_limit_flow.py +++ b/backend-services/tests/test_login_ip_rate_limit_flow.py @@ -1,6 +1,8 @@ import os + import pytest + @pytest.mark.asyncio async def test_login_ip_rate_limit_returns_429_and_headers(monkeypatch, client): monkeypatch.setenv('LOGIN_IP_RATE_LIMIT', '2') @@ -9,7 +11,7 @@ async def test_login_ip_rate_limit_returns_429_and_headers(monkeypatch, client): creds = { 'email': os.environ.get('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD', 'Password123!Password') + 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD', 'Password123!Password'), } r1 = await client.post('/platform/authorization', json=creds) diff --git a/backend-services/tests/test_memory_dump_and_sigusr1.py b/backend-services/tests/test_memory_dump_and_sigusr1.py index d48fe60..d0661f0 100644 --- a/backend-services/tests/test_memory_dump_and_sigusr1.py +++ b/backend-services/tests/test_memory_dump_and_sigusr1.py @@ -1,30 +1,37 @@ -import pytest import os +import pytest + + @pytest.mark.asyncio async def test_memory_dump_writes_file_when_memory_mode(monkeypatch, tmp_path): monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'test-secret-123') monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'd' / 'dump.bin')) from utils.memory_dump_util import dump_memory_to_file, find_latest_dump_path + path = dump_memory_to_file(None) assert os.path.exists(path) latest = find_latest_dump_path(str(tmp_path / 'd' / '')) assert latest == path + def test_dump_requires_encryption_key_logs_error(tmp_path, monkeypatch): monkeypatch.delenv('MEM_ENCRYPTION_KEY', raising=False) monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'x' / 'memory_dump.bin')) from utils import memory_dump_util as md + with pytest.raises(ValueError): md.dump_memory_to_file(None) + def test_sigusr1_handler_registered_on_unix(monkeypatch, capsys): - import importlib import doorman as appmod + if hasattr(appmod.signal, 'SIGUSR1'): assert hasattr(appmod.signal, 'SIGUSR1') + def test_sigusr1_ignored_when_not_memory_mode(monkeypatch): import doorman as appmod - assert hasattr(appmod.signal, 'SIGUSR1') + assert hasattr(appmod.signal, 'SIGUSR1') diff --git a/backend-services/tests/test_memory_dump_restore_permissions.py b/backend-services/tests/test_memory_dump_restore_permissions.py new file mode 100644 index 0000000..6408ce0 --- /dev/null +++ b/backend-services/tests/test_memory_dump_restore_permissions.py @@ -0,0 +1,77 @@ +import os +import time + +import pytest +from httpx import AsyncClient + + +async def _login(email: str, password: str) -> AsyncClient: + from doorman import doorman + + c = AsyncClient(app=doorman, base_url='http://testserver') + r = await c.post('/platform/authorization', json={'email': email, 'password': password}) + assert r.status_code == 200, r.text + body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} + token = body.get('access_token') + if token: + c.cookies.set('access_token_cookie', token, domain='testserver', path='/') + return c + + +@pytest.mark.asyncio +async def test_memory_dump_restore_permissions(authed_client, monkeypatch, tmp_path): + monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'test-encryption-key-32-characters-min') + + # Limited user cannot dump/restore + uname = f'mem_limited_{int(time.time())}' + pwd = 'MemoryUserStrong1!!' + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': f'{uname}@example.com', + 'password': pwd, + 'role': 'user', + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu.status_code in (200, 201) + limited = await _login(f'{uname}@example.com', pwd) + d = await limited.post('/platform/memory/dump', json={}) + assert d.status_code == 403 + r = await limited.post('/platform/memory/restore', json={}) + assert r.status_code == 403 + + # Role with manage_security can dump/restore + rname = f'sec_manager_{int(time.time())}' + cr = await authed_client.post( + '/platform/role', json={'role_name': rname, 'manage_security': True} + ) + assert cr.status_code in (200, 201) + uname2 = f'mem_mgr_{int(time.time())}' + cu2 = await authed_client.post( + '/platform/user', + json={ + 'username': uname2, + 'email': f'{uname2}@example.com', + 'password': pwd, + 'role': rname, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu2.status_code in (200, 201) + mgr = await _login(f'{uname2}@example.com', pwd) + + # Dump to temp + target = str(tmp_path / 'memdump.bin') + dump = await mgr.post('/platform/memory/dump', json={'path': target}) + assert dump.status_code == 200 + body = dump.json().get('response', dump.json()) + path = body.get('response', {}).get('path') or body.get('path') + assert path and os.path.exists(path) + + # Restore + res = await mgr.post('/platform/memory/restore', json={'path': path}) + assert res.status_code == 200 diff --git a/backend-services/tests/test_memory_dump_util_extended.py b/backend-services/tests/test_memory_dump_util_extended.py index ffae850..08f66af 100644 --- a/backend-services/tests/test_memory_dump_util_extended.py +++ b/backend-services/tests/test_memory_dump_util_extended.py @@ -1,12 +1,12 @@ import os import time -import json from pathlib import Path + import pytest + @pytest.mark.asyncio async def test_dump_file_naming_and_dir_creation(monkeypatch, tmp_path): - monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM') monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'unit-test-key-12345') @@ -16,16 +16,14 @@ async def test_dump_file_naming_and_dir_creation(monkeypatch, tmp_path): dump_path = dump_memory_to_file(str(hint_file)) assert Path(dump_path).exists() - assert Path(dump_path).name.startswith('mydump-') and dump_path.endswith( - '.bin' - ) + assert Path(dump_path).name.startswith('mydump-') and dump_path.endswith('.bin') latest = find_latest_dump_path(str(tmp_path / 'custom' / 'mydump.bin')) assert latest == dump_path + @pytest.mark.asyncio async def test_dump_with_directory_hint_uses_default_stem(monkeypatch, tmp_path): - monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM') monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'unit-test-key-99999') @@ -37,6 +35,7 @@ async def test_dump_with_directory_hint_uses_default_stem(monkeypatch, tmp_path) assert Path(dump_path).exists() assert Path(dump_path).name.startswith('memory_dump-') + def test_find_latest_prefers_newest_by_stem(tmp_path): from utils.memory_dump_util import find_latest_dump_path @@ -54,6 +53,7 @@ def test_find_latest_prefers_newest_by_stem(tmp_path): latest = find_latest_dump_path(str(d / 'memory_dump.bin')) assert latest and latest.endswith(b.name) + def test_find_latest_ignores_other_stems_when_dir_hint(tmp_path): from utils.memory_dump_util import find_latest_dump_path @@ -70,8 +70,8 @@ def test_find_latest_ignores_other_stems_when_dir_hint(tmp_path): latest = find_latest_dump_path(str(d)) assert latest and Path(latest).name.startswith('memory_dump-') -def test_find_latest_uses_default_when_no_hint(monkeypatch, tmp_path): +def test_find_latest_uses_default_when_no_hint(monkeypatch, tmp_path): import utils.memory_dump_util as md base = tmp_path / 'default' @@ -89,6 +89,7 @@ def test_find_latest_uses_default_when_no_hint(monkeypatch, tmp_path): latest = md.find_latest_dump_path(None) assert latest and latest.endswith(b.name) + def test_encrypt_decrypt_roundtrip(monkeypatch): import utils.memory_dump_util as md @@ -99,15 +100,16 @@ def test_encrypt_decrypt_roundtrip(monkeypatch): out = md._decrypt_blob(blob, key) assert out == pt + def test_encrypt_requires_sufficient_key(monkeypatch): import utils.memory_dump_util as md with pytest.raises(ValueError): md._encrypt_blob(b'data', 'short') + @pytest.mark.asyncio async def test_dump_and_restore_roundtrip_with_bytes(monkeypatch, tmp_path): - monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM') monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'unit-test-key-abcde') monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'mem' / 'memory_dump.bin')) @@ -115,12 +117,9 @@ async def test_dump_and_restore_roundtrip_with_bytes(monkeypatch, tmp_path): import utils.memory_dump_util as md from utils.database import database - database.db.settings.insert_one({ - '_id': 'cfg', - 'blob': b'\x00\x01', - 'tuple': (1, 2, 3), - 'aset': {'a', 'b'}, - }) + database.db.settings.insert_one( + {'_id': 'cfg', 'blob': b'\x00\x01', 'tuple': (1, 2, 3), 'aset': {'a', 'b'}} + ) dump_path = md.dump_memory_to_file(None) assert Path(dump_path).exists() @@ -134,15 +133,16 @@ async def test_dump_and_restore_roundtrip_with_bytes(monkeypatch, tmp_path): assert set(restored.get('aset')) == {'a', 'b'} assert list(restored.get('tuple')) == [1, 2, 3] + def test_restore_nonexistent_file_raises(tmp_path): import utils.memory_dump_util as md with pytest.raises(FileNotFoundError): md.restore_memory_from_file(str(tmp_path / 'nope.bin')) + @pytest.mark.asyncio async def test_dump_fails_with_short_key(monkeypatch, tmp_path): - monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM') monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'short') monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'd' / 'dump.bin')) diff --git a/backend-services/tests/test_memory_routes_and_tools_e2e.py b/backend-services/tests/test_memory_routes_and_tools_e2e.py index ea4badc..a95e771 100644 --- a/backend-services/tests/test_memory_routes_and_tools_e2e.py +++ b/backend-services/tests/test_memory_routes_and_tools_e2e.py @@ -1,13 +1,17 @@ import os + import pytest + @pytest.mark.asyncio async def test_memory_dump_requires_key_then_succeeds(monkeypatch, authed_client, tmp_path): - monkeypatch.delenv('MEM_ENCRYPTION_KEY', raising=False) r1 = await authed_client.post('/platform/memory/dump') assert r1.status_code == 400 - assert (r1.json().get('error_code') or r1.json().get('response', {}).get('error_code')) in ('MEM002', 'MEM002') + assert (r1.json().get('error_code') or r1.json().get('response', {}).get('error_code')) in ( + 'MEM002', + 'MEM002', + ) monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'route-key-123456') monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'r' / 'dump.bin')) @@ -17,17 +21,22 @@ async def test_memory_dump_requires_key_then_succeeds(monkeypatch, authed_client path = body.get('response', {}).get('path') or body.get('path') assert path and path.endswith('.bin') + @pytest.mark.asyncio async def test_memory_restore_404_missing(monkeypatch, authed_client, tmp_path): monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'route-key-987654') - r = await authed_client.post('/platform/memory/restore', json={'path': str(tmp_path / 'nope.bin')}) + r = await authed_client.post( + '/platform/memory/restore', json={'path': str(tmp_path / 'nope.bin')} + ) assert r.status_code == 404 data = r.json() - assert data.get('error_code') == 'MEM003' or data.get('response', {}).get('error_code') == 'MEM003' + assert ( + data.get('error_code') == 'MEM003' or data.get('response', {}).get('error_code') == 'MEM003' + ) + @pytest.mark.asyncio async def test_memory_dump_then_restore_flow(monkeypatch, authed_client, tmp_path): - monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'route-key-abcdef') monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'd' / 'dump.bin')) @@ -43,7 +52,9 @@ async def test_memory_dump_then_restore_flow(monkeypatch, authed_client, tmp_pat ) assert create.status_code in (200, 201), create.text - d = await authed_client.post('/platform/memory/dump', json={'path': str(tmp_path / 'd' / 'dump.bin')}) + d = await authed_client.post( + '/platform/memory/dump', json={'path': str(tmp_path / 'd' / 'dump.bin')} + ) assert d.status_code == 200 dump_body = d.json() dump_path = dump_body.get('response', {}).get('path') or dump_body.get('path') @@ -58,6 +69,7 @@ async def test_memory_dump_then_restore_flow(monkeypatch, authed_client, tmp_pat check = await authed_client.get('/platform/user/e2euser') assert check.status_code == 200 + @pytest.mark.asyncio async def test_cors_wildcard_without_credentials_allows(monkeypatch, authed_client): monkeypatch.setenv('ALLOWED_ORIGINS', '*') @@ -69,6 +81,7 @@ async def test_cors_wildcard_without_credentials_allows(monkeypatch, authed_clie data = r.json() assert data.get('actual', {}).get('allowed') is True + @pytest.mark.asyncio async def test_cors_wildcard_with_credentials_strict_blocks(monkeypatch, authed_client): monkeypatch.setenv('ALLOWED_ORIGINS', '*') @@ -80,23 +93,35 @@ async def test_cors_wildcard_with_credentials_strict_blocks(monkeypatch, authed_ data = r.json() assert data.get('actual', {}).get('allowed') is False + @pytest.mark.asyncio async def test_cors_checker_implicitly_allows_options(monkeypatch, authed_client): monkeypatch.setenv('ALLOW_METHODS', 'GET,POST') body = {'origin': 'http://localhost:3000', 'method': 'OPTIONS'} r = await authed_client.post('/platform/tools/cors/check', json=body) assert r.status_code == 200 - assert r.json().get('preflight', {}).get('response_headers', {}).get('Access-Control-Allow-Methods') + assert ( + r.json() + .get('preflight', {}) + .get('response_headers', {}) + .get('Access-Control-Allow-Methods') + ) + @pytest.mark.asyncio async def test_cors_headers_case_insensitive(monkeypatch, authed_client): monkeypatch.setenv('ALLOW_HEADERS', 'Content-Type,Authorization') - body = {'origin': 'http://localhost:3000', 'method': 'GET', 'request_headers': ['content-type', 'authorization']} + body = { + 'origin': 'http://localhost:3000', + 'method': 'GET', + 'request_headers': ['content-type', 'authorization'], + } r = await authed_client.post('/platform/tools/cors/check', json=body) assert r.status_code == 200 data = r.json() assert data.get('preflight', {}).get('allowed') is True + @pytest.mark.asyncio async def test_cors_checker_vary_origin_present(monkeypatch, authed_client): body = {'origin': 'http://localhost:3000', 'method': 'GET'} @@ -107,6 +132,7 @@ async def test_cors_checker_vary_origin_present(monkeypatch, authed_client): actual = data.get('actual', {}).get('response_headers', {}) assert pre.get('Vary') == 'Origin' and actual.get('Vary') == 'Origin' + @pytest.mark.asyncio async def test_cors_checker_disallows_unknown_origin(monkeypatch, authed_client): monkeypatch.setenv('ALLOWED_ORIGINS', 'http://localhost:3000') diff --git a/backend-services/tests/test_metrics_persistence.py b/backend-services/tests/test_metrics_persistence.py index ab82cfb..366dd00 100644 --- a/backend-services/tests/test_metrics_persistence.py +++ b/backend-services/tests/test_metrics_persistence.py @@ -1,5 +1,6 @@ import pytest + @pytest.mark.asyncio async def test_metrics_persist_and_restore(tmp_path, authed_client): r1 = await authed_client.get('/api/status') @@ -7,6 +8,7 @@ async def test_metrics_persist_and_restore(tmp_path, authed_client): assert r1.status_code == 200 and r2.status_code == 200 from utils.metrics_util import metrics_store + before = metrics_store.to_dict() assert before.get('total_requests', 0) >= 1 @@ -29,6 +31,7 @@ async def test_metrics_persist_and_restore(tmp_path, authed_client): if isinstance(b2.get('status_counts'), dict): b2['status_counts'] = {str(k): v for k, v in b2['status_counts'].items()} return b2 + out = dict(d) if isinstance(out.get('status_counts'), dict): out['status_counts'] = {str(k): v for k, v in out['status_counts'].items()} diff --git a/backend-services/tests/test_metrics_ranges_extended.py b/backend-services/tests/test_metrics_ranges_extended.py index f2594ef..4f3437c 100644 --- a/backend-services/tests/test_metrics_ranges_extended.py +++ b/backend-services/tests/test_metrics_ranges_extended.py @@ -1,25 +1,37 @@ import pytest + @pytest.mark.asyncio async def test_metrics_range_parameters(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'mrange', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/p') await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + class _FakeHTTPResponse: def __init__(self): self.status_code = 200 self.headers = {'Content-Type': 'application/json'} self.text = '{}' self.content = b'{}' - def json(self): return {'ok': True} + + def json(self): + return {'ok': True} + class _FakeAsyncClient: - def __init__(self, timeout=None, limits=None, http2=False): pass - async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc, tb): return False + def __init__(self, timeout=None, limits=None, http2=False): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -37,10 +49,19 @@ async def test_metrics_range_parameters(monkeypatch, authed_client): return await self.put(url, **kwargs) else: return _FakeHTTPResponse(405) - async def get(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse() - async def post(self, url, **kwargs): return _FakeHTTPResponse() - async def put(self, url, **kwargs): return _FakeHTTPResponse() - async def delete(self, url, **kwargs): return _FakeHTTPResponse() + + async def get(self, url, params=None, headers=None, **kwargs): + return _FakeHTTPResponse() + + async def post(self, url, **kwargs): + return _FakeHTTPResponse() + + async def put(self, url, **kwargs): + return _FakeHTTPResponse() + + async def delete(self, url, **kwargs): + return _FakeHTTPResponse() + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) await authed_client.get(f'/api/rest/{name}/{ver}/p') diff --git a/backend-services/tests/test_metrics_symmetry_envelope_ids.py b/backend-services/tests/test_metrics_symmetry_envelope_ids.py index d2ded75..551d98f 100644 --- a/backend-services/tests/test_metrics_symmetry_envelope_ids.py +++ b/backend-services/tests/test_metrics_symmetry_envelope_ids.py @@ -1,16 +1,20 @@ import json import re + import pytest + @pytest.mark.asyncio async def test_metrics_bytes_in_uses_content_length(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'msym', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'POST', '/echo') await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + resp_body = b'{"ok":true,"pad":"' + b'Z' * 15 + b'"}' class _FakeHTTPResponse: @@ -19,13 +23,20 @@ async def test_metrics_bytes_in_uses_content_length(monkeypatch, authed_client): self.headers = {'Content-Type': 'application/json', 'Content-Length': str(len(body))} self.text = body.decode('utf-8') self.content = body + def json(self): return json.loads(self.text) class _FakeAsyncClient: - def __init__(self, timeout=None, limits=None, http2=False): pass - async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc, tb): return False + def __init__(self, timeout=None, limits=None, http2=False): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -43,10 +54,18 @@ async def test_metrics_bytes_in_uses_content_length(monkeypatch, authed_client): return await self.put(url, **kwargs) else: return _FakeHTTPResponse(405) - async def get(self, url, **kwargs): return _FakeHTTPResponse(200) - async def post(self, url, data=None, json=None, headers=None, params=None, **kwargs): return _FakeHTTPResponse(200) - async def put(self, url, **kwargs): return _FakeHTTPResponse(200) - async def delete(self, url, **kwargs): return _FakeHTTPResponse(200) + + async def get(self, url, **kwargs): + return _FakeHTTPResponse(200) + + async def post(self, url, data=None, json=None, headers=None, params=None, **kwargs): + return _FakeHTTPResponse(200) + + async def put(self, url, **kwargs): + return _FakeHTTPResponse(200) + + async def delete(self, url, **kwargs): + return _FakeHTTPResponse(200) monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) @@ -68,12 +87,15 @@ async def test_metrics_bytes_in_uses_content_length(monkeypatch, authed_client): assert tin1 - tin0 >= len(payload) assert tout1 - tout0 >= len(resp_body) + @pytest.mark.asyncio async def test_response_envelope_for_non_json_error(monkeypatch, client): monkeypatch.setenv('MAX_BODY_SIZE_BYTES', '10') payload = 'x' * 100 - r = await client.post('/platform/authorization', content=payload, headers={'Content-Type': 'text/plain'}) + r = await client.post( + '/platform/authorization', content=payload, headers={'Content-Type': 'text/plain'} + ) assert r.status_code == 413 assert r.headers.get('content-type', '').lower().startswith('application/json') body = r.json() @@ -82,11 +104,14 @@ async def test_response_envelope_for_non_json_error(monkeypatch, client): msg = body.get('error_message') or (body.get('response') or {}).get('error_message') assert isinstance(msg, str) and msg + def _get_operation_id(spec: dict, path: str, method: str) -> str: return spec['paths'][path][method.lower()]['operationId'] + def test_unique_route_ids_are_stable(): from doorman import doorman as app + spec1 = app.openapi() spec2 = app.openapi() diff --git a/backend-services/tests/test_monitor_dashboard.py b/backend-services/tests/test_monitor_dashboard.py index 04320fe..a3ca33a 100644 --- a/backend-services/tests/test_monitor_dashboard.py +++ b/backend-services/tests/test_monitor_dashboard.py @@ -1,8 +1,8 @@ import pytest + @pytest.mark.asyncio async def test_monitor_metrics_and_dashboard(authed_client): - d = await authed_client.get('/platform/dashboard') assert d.status_code == 200 dj = d.json() @@ -13,6 +13,7 @@ async def test_monitor_metrics_and_dashboard(authed_client): mj = m.json() assert isinstance(mj, dict) + @pytest.mark.asyncio async def test_liveness_and_readiness(client): l = await client.get('/platform/monitor/liveness') diff --git a/backend-services/tests/test_monitor_endpoints.py b/backend-services/tests/test_monitor_endpoints.py new file mode 100644 index 0000000..ae8c31a --- /dev/null +++ b/backend-services/tests/test_monitor_endpoints.py @@ -0,0 +1,14 @@ +import pytest + + +@pytest.mark.asyncio +async def test_monitor_liveness_readiness_metrics(authed_client): + # Liveness + l = await authed_client.get('/platform/monitor/liveness') + assert l.status_code in (200, 204) + # Readiness + r = await authed_client.get('/platform/monitor/readiness') + assert r.status_code in (200, 204) + # Metrics + m = await authed_client.get('/platform/monitor/metrics') + assert m.status_code in (200, 204) diff --git a/backend-services/tests/test_monitor_metrics_extended.py b/backend-services/tests/test_monitor_metrics_extended.py index 7ca3796..13e64d1 100644 --- a/backend-services/tests/test_monitor_metrics_extended.py +++ b/backend-services/tests/test_monitor_metrics_extended.py @@ -1,7 +1,6 @@ -import asyncio -import time import pytest + @pytest.mark.asyncio async def test_metrics_increment_on_gateway_requests(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self @@ -19,16 +18,20 @@ async def test_metrics_increment_on_gateway_requests(monkeypatch, authed_client) self.headers = {'Content-Type': 'application/json'} self.text = '{}' self.content = b'{}' + def json(self): return {'ok': True} class _FakeAsyncClient: def __init__(self, timeout=None, limits=None, http2=False): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -46,12 +49,16 @@ async def test_metrics_increment_on_gateway_requests(monkeypatch, authed_client) return await self.put(url, **kwargs) else: return _FakeHTTPResponse(405) + async def get(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse(200) + async def post(self, url, **kwargs): return _FakeHTTPResponse(200) + async def put(self, url, **kwargs): return _FakeHTTPResponse(200) + async def delete(self, url, **kwargs): return _FakeHTTPResponse(200) @@ -70,9 +77,11 @@ async def test_metrics_increment_on_gateway_requests(monkeypatch, authed_client) series = body.get('series') or [] assert isinstance(series, list) + @pytest.mark.asyncio async def test_metrics_top_apis_aggregate(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'mapi3', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/x') @@ -86,12 +95,20 @@ async def test_metrics_top_apis_aggregate(monkeypatch, authed_client): self.headers = {'Content-Type': 'application/json'} self.text = '{}' self.content = b'{}' - def json(self): return {'ok': True} + + def json(self): + return {'ok': True} class _FakeAsyncClient: - def __init__(self, timeout=None, limits=None, http2=False): pass - async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc, tb): return False + def __init__(self, timeout=None, limits=None, http2=False): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -109,10 +126,18 @@ async def test_metrics_top_apis_aggregate(monkeypatch, authed_client): return await self.put(url, **kwargs) else: return _FakeHTTPResponse(405) - async def get(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse(200) - async def post(self, url, **kwargs): return _FakeHTTPResponse(200) - async def put(self, url, **kwargs): return _FakeHTTPResponse(200) - async def delete(self, url, **kwargs): return _FakeHTTPResponse(200) + + async def get(self, url, params=None, headers=None, **kwargs): + return _FakeHTTPResponse(200) + + async def post(self, url, **kwargs): + return _FakeHTTPResponse(200) + + async def put(self, url, **kwargs): + return _FakeHTTPResponse(200) + + async def delete(self, url, **kwargs): + return _FakeHTTPResponse(200) monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) @@ -126,6 +151,7 @@ async def test_metrics_top_apis_aggregate(monkeypatch, authed_client): assert any(isinstance(a, list) and a[0].startswith('rest:') for a in top_apis) + @pytest.mark.asyncio async def test_monitor_liveness_and_readiness(authed_client): live = await authed_client.get('/platform/monitor/liveness') @@ -137,28 +163,38 @@ async def test_monitor_liveness_and_readiness(authed_client): status = (ready.json() or {}).get('status') assert status in ('ready', 'degraded') + @pytest.mark.asyncio async def test_monitor_report_csv(monkeypatch, authed_client): - from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'mapi4', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/r') await subscribe_self(authed_client, name, ver) import services.gateway_service as gs + class _FakeHTTPResponse: def __init__(self): self.status_code = 200 self.headers = {'Content-Type': 'application/json'} self.text = '{}' self.content = b'{}' - def json(self): return {'ok': True} + + def json(self): + return {'ok': True} class _FakeAsyncClient: - def __init__(self, timeout=None, limits=None, http2=False): pass - async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc, tb): return False + def __init__(self, timeout=None, limits=None, http2=False): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -176,15 +212,24 @@ async def test_monitor_report_csv(monkeypatch, authed_client): return await self.put(url, **kwargs) else: return _FakeHTTPResponse() - async def get(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse() - async def post(self, url, **kwargs): return _FakeHTTPResponse() - async def put(self, url, **kwargs): return _FakeHTTPResponse() - async def delete(self, url, **kwargs): return _FakeHTTPResponse() + + async def get(self, url, params=None, headers=None, **kwargs): + return _FakeHTTPResponse() + + async def post(self, url, **kwargs): + return _FakeHTTPResponse() + + async def put(self, url, **kwargs): + return _FakeHTTPResponse() + + async def delete(self, url, **kwargs): + return _FakeHTTPResponse() monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) await authed_client.get(f'/api/rest/{name}/{ver}/r') from datetime import datetime + now = datetime.utcnow() start = now.strftime('%Y-%m-%dT%H:%M') end = start @@ -192,5 +237,6 @@ async def test_monitor_report_csv(monkeypatch, authed_client): assert csvr.status_code == 200 text = csvr.text - assert 'Report' in text and 'Overview' in text and 'Status Codes' in text and 'API Usage' in text - + assert ( + 'Report' in text and 'Overview' in text and 'Status Codes' in text and 'API Usage' in text + ) diff --git a/backend-services/tests/test_multi_onboarding.py b/backend-services/tests/test_multi_onboarding.py index 9ffa588..27ec723 100644 --- a/backend-services/tests/test_multi_onboarding.py +++ b/backend-services/tests/test_multi_onboarding.py @@ -1,8 +1,8 @@ import pytest + @pytest.mark.asyncio async def test_multi_endpoints_per_api_and_listing(authed_client): - c = await authed_client.post( '/platform/api', json={ @@ -18,11 +18,7 @@ async def test_multi_endpoints_per_api_and_listing(authed_client): ) assert c.status_code in (200, 201) - endpoints = [ - ('GET', '/a'), - ('POST', '/b'), - ('PUT', '/c'), - ] + endpoints = [('GET', '/a'), ('POST', '/b'), ('PUT', '/c')] for method, uri in endpoints: ep = await authed_client.post( '/platform/endpoint', diff --git a/backend-services/tests/test_multi_worker_semantics.py b/backend-services/tests/test_multi_worker_semantics.py index 15037cb..8d7756f 100644 --- a/backend-services/tests/test_multi_worker_semantics.py +++ b/backend-services/tests/test_multi_worker_semantics.py @@ -1,5 +1,6 @@ import pytest + @pytest.mark.asyncio async def test_mem_multi_worker_guard_raises(monkeypatch): monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM') @@ -9,6 +10,7 @@ async def test_mem_multi_worker_guard_raises(monkeypatch): with pytest.raises(RuntimeError): validate_token_revocation_config() + @pytest.mark.asyncio async def test_mem_single_worker_allowed(monkeypatch): monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM') @@ -17,6 +19,7 @@ async def test_mem_single_worker_allowed(monkeypatch): validate_token_revocation_config() + @pytest.mark.asyncio async def test_redis_multi_worker_allowed(monkeypatch): monkeypatch.setenv('MEM_OR_EXTERNAL', 'REDIS') @@ -24,4 +27,3 @@ async def test_redis_multi_worker_allowed(monkeypatch): from doorman import validate_token_revocation_config validate_token_revocation_config() - diff --git a/backend-services/tests/test_pagination_caps.py b/backend-services/tests/test_pagination_caps.py index 48cc33a..78f39f0 100644 --- a/backend-services/tests/test_pagination_caps.py +++ b/backend-services/tests/test_pagination_caps.py @@ -1,6 +1,6 @@ -import os import pytest + @pytest.mark.asyncio async def test_max_page_size_boundary_api_list(authed_client, monkeypatch): monkeypatch.setenv('MAX_PAGE_SIZE', '5') @@ -13,6 +13,7 @@ async def test_max_page_size_boundary_api_list(authed_client, monkeypatch): body = r_bad.json() assert 'error_message' in body + @pytest.mark.asyncio async def test_max_page_size_boundary_users_list(authed_client, monkeypatch): monkeypatch.setenv('MAX_PAGE_SIZE', '3') @@ -23,6 +24,7 @@ async def test_max_page_size_boundary_users_list(authed_client, monkeypatch): r_bad = await authed_client.get('/platform/user/all?page=1&page_size=4') assert r_bad.status_code == 400, r_bad.text + @pytest.mark.asyncio async def test_invalid_page_values(authed_client, monkeypatch): monkeypatch.setenv('MAX_PAGE_SIZE', '10') @@ -32,4 +34,3 @@ async def test_invalid_page_values(authed_client, monkeypatch): r2 = await authed_client.get('/platform/group/all?page=1&page_size=0') assert r2.status_code == 400 - diff --git a/backend-services/tests/test_permissions_extended.py b/backend-services/tests/test_permissions_extended.py index 0c01acf..d7150e1 100644 --- a/backend-services/tests/test_permissions_extended.py +++ b/backend-services/tests/test_permissions_extended.py @@ -1,13 +1,12 @@ import pytest from httpx import AsyncClient + async def _login(email: str, password: str) -> AsyncClient: from doorman import doorman + client = AsyncClient(app=doorman, base_url='http://testserver') - r = await client.post( - '/platform/authorization', - json={'email': email, 'password': password}, - ) + r = await client.post('/platform/authorization', json={'email': email, 'password': password}) assert r.status_code == 200, r.text body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} @@ -16,9 +15,9 @@ async def _login(email: str, password: str) -> AsyncClient: client.cookies.set('access_token_cookie', token, domain='testserver', path='/') return client + @pytest.mark.asyncio async def test_non_admin_role_cannot_access_monitor_or_credits(monkeypatch, authed_client): - rrole = await authed_client.post( '/platform/role', json={ @@ -64,7 +63,13 @@ async def test_non_admin_role_cannot_access_monitor_or_credits(monkeypatch, auth 'api_key': 'x', 'api_key_header': 'x-api-key', 'credit_tiers': [ - {'tier_name': 'basic', 'credits': 10, 'input_limit': 0, 'output_limit': 0, 'reset_frequency': 'monthly'} + { + 'tier_name': 'basic', + 'credits': 10, + 'input_limit': 0, + 'output_limit': 0, + 'reset_frequency': 'monthly', + } ], }, ) @@ -73,16 +78,12 @@ async def test_non_admin_role_cannot_access_monitor_or_credits(monkeypatch, auth cc = await viewer.delete('/api/caches') assert cc.status_code == 403 + @pytest.mark.asyncio async def test_endpoint_validation_management_requires_permission(authed_client): - await authed_client.post( '/platform/role', - json={ - 'role_name': 'noend', - 'role_description': 'No endpoints', - 'manage_endpoints': False, - }, + json={'role_name': 'noend', 'role_description': 'No endpoints', 'manage_endpoints': False}, ) await authed_client.post( '/platform/user', @@ -96,6 +97,7 @@ async def test_endpoint_validation_management_requires_permission(authed_client) ) from conftest import create_api, create_endpoint + api_name, ver = 'permapi', 'v1' await create_api(authed_client, api_name, ver) await create_endpoint(authed_client, api_name, ver, 'POST', '/foo') @@ -106,6 +108,10 @@ async def test_endpoint_validation_management_requires_permission(authed_client) ev = await client.post( '/platform/endpoint/endpoint/validation', - json={'endpoint_id': eid, 'validation_enabled': True, 'validation_schema': {'validation_schema': {}}}, + json={ + 'endpoint_id': eid, + 'validation_enabled': True, + 'validation_schema': {'validation_schema': {}}, + }, ) assert ev.status_code == 403 diff --git a/backend-services/tests/test_platform_admin_visibility.py b/backend-services/tests/test_platform_admin_visibility.py index 431287a..a7173b1 100644 --- a/backend-services/tests/test_platform_admin_visibility.py +++ b/backend-services/tests/test_platform_admin_visibility.py @@ -1,22 +1,25 @@ -import os import uuid + import pytest import pytest_asyncio from httpx import AsyncClient + @pytest_asyncio.fixture async def login_client(): async def _login(username: str, password: str, email: str = None) -> AsyncClient: from doorman import doorman + client = AsyncClient(app=doorman, base_url='http://testserver') cred = {'email': email or f'{username}@example.com', 'password': password} r = await client.post('/platform/authorization', json=cred) assert r.status_code == 200, r.text return client + return _login -async def _ensure_manager_role(authed_client: AsyncClient): +async def _ensure_manager_role(authed_client: AsyncClient): payload = { 'role_name': 'manager', 'role_description': 'Manager role', @@ -31,11 +34,12 @@ async def _ensure_manager_role(authed_client: AsyncClient): 'manage_credits': True, 'manage_auth': True, 'view_logs': True, - 'export_logs': True + 'export_logs': True, } r = await authed_client.post('/platform/role', json=payload) assert r.status_code in (200, 201, 400), r.text + async def _create_manager_user(authed_client: AsyncClient) -> dict: await _ensure_manager_role(authed_client) uname = f'mgr_{uuid.uuid4().hex[:8]}' @@ -46,12 +50,13 @@ async def _create_manager_user(authed_client: AsyncClient) -> dict: 'role': 'manager', 'groups': ['ALL'], 'active': True, - 'ui_access': True + 'ui_access': True, } r = await authed_client.post('/platform/user', json=payload) assert r.status_code in (200, 201), r.text return payload + @pytest.mark.asyncio async def test_non_admin_cannot_see_admin_role(authed_client, login_client): mgr = await _create_manager_user(authed_client) @@ -68,9 +73,9 @@ async def test_non_admin_cannot_see_admin_role(authed_client, login_client): await client.aclose() + @pytest.mark.asyncio async def test_non_admin_cannot_see_or_modify_admin_users(authed_client, login_client): - mgr = await _create_manager_user(authed_client) client = await login_client(mgr['username'], 'StrongManagerPwd!1234', mgr['email']) @@ -89,28 +94,33 @@ async def test_non_admin_cannot_see_or_modify_admin_users(authed_client, login_c await client.aclose() + @pytest.mark.asyncio async def test_non_admin_cannot_assign_admin_role(authed_client, login_client): mgr = await _create_manager_user(authed_client) client = await login_client(mgr['username'], 'StrongManagerPwd!1234', mgr['email']) newu = f'np_{uuid.uuid4().hex[:8]}' - r_create = await client.post('/platform/user', json={ - 'username': newu, - 'email': f'{newu}@example.com', - 'password': 'StrongPwd!1234XYZ', - 'role': 'admin', - 'groups': ['ALL'], - 'active': True, - 'ui_access': True - }) + r_create = await client.post( + '/platform/user', + json={ + 'username': newu, + 'email': f'{newu}@example.com', + 'password': 'StrongPwd!1234XYZ', + 'role': 'admin', + 'groups': ['ALL'], + 'active': True, + 'ui_access': True, + }, + ) assert r_create.status_code in (403, 404) - r_update = await client.put(f"/platform/user/{mgr['username']}", json={'role': 'admin'}) + r_update = await client.put(f'/platform/user/{mgr["username"]}', json={'role': 'admin'}) assert r_update.status_code in (403, 404) await client.aclose() + @pytest.mark.asyncio async def test_non_admin_auth_admin_ops_hidden(authed_client, login_client): mgr = await _create_manager_user(authed_client) @@ -127,6 +137,8 @@ async def test_non_admin_auth_admin_ops_hidden(authed_client, login_client): r = await client.get(path) else: r = await client.post(path) - assert r.status_code in (404, 403), f'Expected 404/403 for {path}, got {r.status_code}: {r.text}' + assert r.status_code in (404, 403), ( + f'Expected 404/403 for {path}, got {r.status_code}: {r.text}' + ) await client.aclose() diff --git a/backend-services/tests/test_platform_cors_env_edges.py b/backend-services/tests/test_platform_cors_env_edges.py index 6aaec94..0c43891 100644 --- a/backend-services/tests/test_platform_cors_env_edges.py +++ b/backend-services/tests/test_platform_cors_env_edges.py @@ -1,60 +1,78 @@ import pytest + @pytest.mark.asyncio async def test_platform_cors_wildcard_origin_with_credentials_strict_false(monkeypatch, client): monkeypatch.setenv('ALLOWED_ORIGINS', '*') monkeypatch.setenv('ALLOW_CREDENTIALS', 'true') monkeypatch.setenv('CORS_STRICT', 'false') - r = await client.options('/platform/api', headers={ - 'Origin': 'http://evil.example', - 'Access-Control-Request-Method': 'GET' - }) + r = await client.options( + '/platform/api', + headers={'Origin': 'http://evil.example', 'Access-Control-Request-Method': 'GET'}, + ) assert r.status_code == 204 assert r.headers.get('Access-Control-Allow-Origin') == 'http://evil.example' assert r.headers.get('Access-Control-Allow-Credentials') == 'true' assert r.headers.get('Vary') == 'Origin' + @pytest.mark.asyncio -async def test_platform_cors_wildcard_origin_with_credentials_strict_true_restricts(monkeypatch, client): +async def test_platform_cors_wildcard_origin_with_credentials_strict_true_restricts( + monkeypatch, client +): monkeypatch.setenv('ALLOWED_ORIGINS', '*') monkeypatch.setenv('ALLOW_CREDENTIALS', 'true') monkeypatch.setenv('CORS_STRICT', 'true') - r = await client.options('/platform/api', headers={ - 'Origin': 'http://evil.example', - 'Access-Control-Request-Method': 'GET' - }) + r = await client.options( + '/platform/api', + headers={'Origin': 'http://evil.example', 'Access-Control-Request-Method': 'GET'}, + ) assert r.status_code == 204 assert r.headers.get('Access-Control-Allow-Origin') is None + @pytest.mark.asyncio async def test_platform_cors_methods_empty_env_falls_back_default(monkeypatch, client): monkeypatch.setenv('ALLOW_METHODS', '') monkeypatch.setenv('ALLOWED_ORIGINS', 'http://localhost:3000') - r = await client.options('/platform/api', headers={ - 'Origin': 'http://localhost:3000', - 'Access-Control-Request-Method': 'GET' - }) + r = await client.options( + '/platform/api', + headers={'Origin': 'http://localhost:3000', 'Access-Control-Request-Method': 'GET'}, + ) assert r.status_code == 204 - methods = [m.strip() for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',') if m.strip()] - expected = {'GET','POST','PUT','DELETE','OPTIONS','PATCH','HEAD'} + methods = [ + m.strip() + for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',') + if m.strip() + ] + expected = {'GET', 'POST', 'PUT', 'DELETE', 'OPTIONS', 'PATCH', 'HEAD'} assert set(methods) == expected + @pytest.mark.asyncio async def test_platform_cors_headers_asterisk_defaults_to_known_list(monkeypatch, client): monkeypatch.setenv('ALLOW_HEADERS', '*') monkeypatch.setenv('ALLOWED_ORIGINS', 'http://localhost:3000') - r = await client.options('/platform/api', headers={ - 'Origin': 'http://localhost:3000', - 'Access-Control-Request-Method': 'GET', - 'Access-Control-Request-Headers': 'X-Anything' - }) + r = await client.options( + '/platform/api', + headers={ + 'Origin': 'http://localhost:3000', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'X-Anything', + }, + ) assert r.status_code == 204 - headers = [h.strip() for h in (r.headers.get('Access-Control-Allow-Headers') or '').split(',') if h.strip()] - assert set(headers) == {'Accept','Content-Type','X-CSRF-Token','Authorization'} + headers = [ + h.strip() + for h in (r.headers.get('Access-Control-Allow-Headers') or '').split(',') + if h.strip() + ] + assert set(headers) == {'Accept', 'Content-Type', 'X-CSRF-Token', 'Authorization'} + @pytest.mark.asyncio async def test_platform_cors_sets_vary_origin(monkeypatch, authed_client): @@ -63,4 +81,3 @@ async def test_platform_cors_sets_vary_origin(monkeypatch, authed_client): assert r.status_code == 200 assert r.headers.get('Access-Control-Allow-Origin') == 'http://ok.example' assert r.headers.get('Vary') == 'Origin' - diff --git a/backend-services/tests/test_platform_cors_no_duplicate_acao.py b/backend-services/tests/test_platform_cors_no_duplicate_acao.py new file mode 100644 index 0000000..1601604 --- /dev/null +++ b/backend-services/tests/test_platform_cors_no_duplicate_acao.py @@ -0,0 +1,13 @@ +import pytest + + +@pytest.mark.asyncio +async def test_platform_cors_no_duplicate_access_control_allow_origin(monkeypatch, authed_client): + monkeypatch.setenv('ALLOWED_ORIGINS', 'http://ok.example') + r = await authed_client.get('/platform/user/me', headers={'Origin': 'http://ok.example'}) + assert r.status_code == 200 + acao = r.headers.get('Access-Control-Allow-Origin') or r.headers.get( + 'access-control-allow-origin' + ) + assert acao == 'http://ok.example' + assert ',' not in acao diff --git a/backend-services/tests/test_platform_expanded.py b/backend-services/tests/test_platform_expanded.py index 6d5891c..8117300 100644 --- a/backend-services/tests/test_platform_expanded.py +++ b/backend-services/tests/test_platform_expanded.py @@ -1,7 +1,8 @@ -import os import json + import pytest + @pytest.mark.asyncio async def test_routing_crud(authed_client): create = await authed_client.post( @@ -30,13 +31,10 @@ async def test_routing_crud(authed_client): delete = await authed_client.delete('/platform/routing/client-A') assert delete.status_code == 200 + @pytest.mark.asyncio async def test_security_and_memory_dump_restore(authed_client): - - ur = await authed_client.put( - '/platform/role/admin', - json={'manage_security': True}, - ) + ur = await authed_client.put('/platform/role/admin', json={'manage_security': True}) assert ur.status_code == 200 gs = await authed_client.get('/platform/security/settings') @@ -54,12 +52,11 @@ async def test_security_and_memory_dump_restore(authed_client): restore = await authed_client.post('/platform/memory/restore', json={'path': path}) assert restore.status_code == 200 + @pytest.mark.asyncio async def test_logging_endpoints(authed_client): - r1 = await authed_client.put( - '/platform/role/admin', - json={'view_logs': True, 'export_logs': True}, + '/platform/role/admin', json={'view_logs': True, 'export_logs': True} ) assert r1.status_code == 200 @@ -78,16 +75,12 @@ async def test_logging_endpoints(authed_client): download = await authed_client.get('/platform/logging/logs/download?format=json') assert download.status_code == 200 + @pytest.mark.asyncio async def test_onboard_public_apis_for_all_gateway_types(monkeypatch, authed_client): - rest_apis = [ - ('jsonplaceholder', 'v1', ['https://jsonplaceholder.typicode.com'], [ - ('GET', '/posts/1') - ]), - ('httpbin', 'v1', ['https://httpbin.org'], [ - ('GET', '/get') - ]), + ('jsonplaceholder', 'v1', ['https://jsonplaceholder.typicode.com'], [('GET', '/posts/1')]), + ('httpbin', 'v1', ['https://httpbin.org'], [('GET', '/get')]), ] for name, ver, servers, endpoints in rest_apis: c = await authed_client.post( @@ -149,7 +142,11 @@ async def test_onboard_public_apis_for_all_gateway_types(monkeypatch, authed_cli assert s.status_code in (200, 201) soap_apis = [ - ('soap-number', 'v1', ['https://www.dataaccess.com/webservicesserver/NumberConversion.wso']), + ( + 'soap-number', + 'v1', + ['https://www.dataaccess.com/webservicesserver/NumberConversion.wso'], + ), ('soap-tempconvert', 'v1', ['https://www.w3schools.com/xml/tempconvert.asmx']), ] for name, ver, servers in soap_apis: @@ -199,21 +196,30 @@ async def test_onboard_public_apis_for_all_gateway_types(monkeypatch, authed_cli assert s.status_code in (200, 201) import services.gateway_service as gs + class _FakeHTTPResponse: def __init__(self, status_code=200, json_body=None, text_body=None, headers=None): self.status_code = status_code self._json_body = json_body - self.text = text_body if text_body is not None else ('' if json_body is not None else 'OK') - self.headers = headers or {'Content-Type': 'application/json' if json_body is not None else 'text/plain'} + self.text = ( + text_body if text_body is not None else ('' if json_body is not None else 'OK') + ) + self.headers = headers or { + 'Content-Type': 'application/json' if json_body is not None else 'text/plain' + } + def json(self): if self._json_body is None: return json.loads(self.text or '{}') return self._json_body + class _FakeAsyncClient: async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -231,23 +237,32 @@ async def test_onboard_public_apis_for_all_gateway_types(monkeypatch, authed_cli return await self.put(url, **kwargs) else: return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'}) + async def get(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse(200, json_body={'ping': 'pong'}) + async def post(self, url, **kwargs): return _FakeHTTPResponse(200, json_body={'ping': 'pong'}) + async def put(self, url, **kwargs): return _FakeHTTPResponse(200, json_body={'ping': 'pong'}) + async def delete(self, url, **kwargs): return _FakeHTTPResponse(200, json_body={'ping': 'pong'}) + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) import routes.gateway_routes as gr + async def _no_limit(req): return None + async def _pass_sub(req): return {'sub': 'admin'} + async def _pass_group(req, full_path: str = None, user_to_subscribe=None): return {'sub': 'admin'} + monkeypatch.setattr(gr, 'limit_and_throttle', _no_limit) monkeypatch.setattr(gr, 'subscription_required', _pass_sub) monkeypatch.setattr(gr, 'group_required', _pass_group) diff --git a/backend-services/tests/test_production_https_guard.py b/backend-services/tests/test_production_https_guard.py index 100c933..13a9111 100644 --- a/backend-services/tests/test_production_https_guard.py +++ b/backend-services/tests/test_production_https_guard.py @@ -1,23 +1,29 @@ import pytest + @pytest.mark.asyncio async def test_production_without_https_flags_fails_startup(monkeypatch): monkeypatch.setenv('ENV', 'production') monkeypatch.setenv('HTTPS_ONLY', 'false') - from doorman import app_lifespan, doorman import pytest as _pytest + + from doorman import app_lifespan, doorman + with _pytest.raises(RuntimeError): async with app_lifespan(doorman): pass + @pytest.mark.asyncio async def test_production_with_https_only_succeeds(monkeypatch): monkeypatch.setenv('ENV', 'production') monkeypatch.setenv('HTTPS_ONLY', 'true') from httpx import AsyncClient + from doorman import doorman + client = AsyncClient(app=doorman, base_url='http://testserver') r = await client.get('/platform/monitor/liveness') assert r.status_code == 200 diff --git a/backend-services/tests/test_proto_extension.py b/backend-services/tests/test_proto_extension.py index d3ee079..d4f4d5d 100644 --- a/backend-services/tests/test_proto_extension.py +++ b/backend-services/tests/test_proto_extension.py @@ -1,21 +1,21 @@ -import io import pytest + @pytest.mark.asyncio async def test_proto_upload_rejects_non_proto(authed_client): - files = {'file': ('bad.txt', b'syntax = \"proto3\";', 'text/plain')} + files = {'file': ('bad.txt', b'syntax = "proto3";', 'text/plain')} r = await authed_client.post('/platform/proto/sample/v1', files=files) assert r.status_code == 400 body = r.json() assert body.get('error_code') == 'REQ003' + @pytest.mark.asyncio async def test_proto_upload_accepts_proto(authed_client): - content = b'syntax = \"proto3\";\npackage sample_v1;\nmessage Ping { string msg = 1; }' + content = b'syntax = "proto3";\npackage sample_v1;\nmessage Ping { string msg = 1; }' files = {'file': ('ok.proto', content, 'application/octet-stream')} r = await authed_client.post('/platform/proto/sample/v1', files=files) assert r.status_code in (200, 500) if r.status_code == 200: body = r.json() assert body.get('message', '').lower().startswith('proto file uploaded') - diff --git a/backend-services/tests/test_proto_routes.py b/backend-services/tests/test_proto_routes.py index e46ca7e..7e41bac 100644 --- a/backend-services/tests/test_proto_routes.py +++ b/backend-services/tests/test_proto_routes.py @@ -1,14 +1,18 @@ import io + import pytest + @pytest.mark.asyncio async def test_proto_upload_and_get(monkeypatch, authed_client): import routes.proto_routes as pr class _FakeCompleted: pass + def _fake_run(*args, **kwargs): return _FakeCompleted() + monkeypatch.setattr(pr.subprocess, 'run', _fake_run) proto_content = b""" @@ -22,5 +26,4 @@ async def test_proto_upload_and_get(monkeypatch, authed_client): gp = await authed_client.get('/platform/proto/myapi/v1') assert gp.status_code == 200 content = gp.json().get('content') or gp.json().get('response', {}).get('content') - assert 'syntax = \"proto3\";' in content - + assert 'syntax = "proto3";' in content diff --git a/backend-services/tests/test_proto_routes_extended.py b/backend-services/tests/test_proto_routes_extended.py index 427d184..33d7193 100644 --- a/backend-services/tests/test_proto_routes_extended.py +++ b/backend-services/tests/test_proto_routes_extended.py @@ -1,28 +1,33 @@ import pytest + @pytest.mark.asyncio async def test_proto_update_and_delete_flow(monkeypatch, authed_client): - import routes.proto_routes as pr - class _FakeCompleted: pass + + class _FakeCompleted: + pass + def _fake_run(*args, **kwargs): return _FakeCompleted() + monkeypatch.setattr(pr.subprocess, 'run', _fake_run) - r = await authed_client.put( - '/platform/role/admin', - json={'manage_apis': True}, - ) + r = await authed_client.put('/platform/role/admin', json={'manage_apis': True}) assert r.status_code in (200, 201) files = { - 'file': ('sample.proto', b'syntax = \"proto3\"; message Ping { string x = 1; }', 'text/plain'), + 'file': ('sample.proto', b'syntax = "proto3"; message Ping { string x = 1; }', 'text/plain') } up = await authed_client.post('/platform/proto/myapi2/v1', files=files) assert up.status_code in (200, 201), up.text files2 = { - 'proto_file': ('sample.proto', b'syntax = \"proto3\"; message Pong { string y = 1; }', 'text/plain'), + 'proto_file': ( + 'sample.proto', + b'syntax = "proto3"; message Pong { string y = 1; }', + 'text/plain', + ) } put = await authed_client.put('/platform/proto/myapi2/v1', files=files2) assert put.status_code in (200, 201), put.text @@ -39,6 +44,7 @@ async def test_proto_update_and_delete_flow(monkeypatch, authed_client): gp2 = await authed_client.get('/platform/proto/myapi2/v1') assert gp2.status_code == 404 + @pytest.mark.asyncio async def test_proto_get_nonexistent_returns_404(authed_client): resp = await authed_client.get('/platform/proto/doesnotexist/v9') diff --git a/backend-services/tests/test_proto_upload_security_and_import.py b/backend-services/tests/test_proto_upload_security_and_import.py index 2ec02e8..6819c21 100644 --- a/backend-services/tests/test_proto_upload_security_and_import.py +++ b/backend-services/tests/test_proto_upload_security_and_import.py @@ -1,9 +1,10 @@ import pytest -from pathlib import Path + @pytest.mark.asyncio async def test_proto_upload_rejects_invalid_filename(monkeypatch, authed_client): import routes.proto_routes as pr + monkeypatch.setattr(pr, 'sanitize_filename', lambda s: (_ for _ in ()).throw(ValueError('bad'))) files = {'file': ('svc.proto', b'syntax = "proto3"; package x;')} r = await authed_client.post('/platform/proto/bad/v1', files=files) @@ -11,69 +12,82 @@ async def test_proto_upload_rejects_invalid_filename(monkeypatch, authed_client) body = r.json() assert body.get('error_code') + @pytest.mark.asyncio async def test_proto_upload_validates_within_base_path(): import routes.proto_routes as pr + base = (pr.PROJECT_ROOT / 'proto').resolve() good = (base / 'ok.proto').resolve() bad = (pr.PROJECT_ROOT.parent / 'outside.proto').resolve() assert pr.validate_path(pr.PROJECT_ROOT, good) is True assert pr.validate_path(pr.PROJECT_ROOT, bad) is False + @pytest.mark.asyncio async def test_proto_upload_generates_stubs_success(monkeypatch, authed_client): name, ver = 'psvc1', 'v1' proto = b'syntax = "proto3"; package foo; service S { rpc M (R) returns (Q) {} } message R { string a = 1; } message Q { string b = 1; }' files = {'file': ('svc.proto', proto)} import routes.proto_routes as pr + safe = f'{name}_{ver}' gen = (pr.PROJECT_ROOT / 'generated').resolve() + def _fake_run(cmd, check): (gen / f'{safe}_pb2.py').write_text('# pb2') (gen / f'{safe}_pb2_grpc.py').write_text( - f'import {safe}_pb2 as {name}__{ver}__pb2\n' - 'class S: pass\n' + f'import {safe}_pb2 as {name}__{ver}__pb2\nclass S: pass\n' ) return 0 + monkeypatch.setattr(pr.subprocess, 'run', _fake_run) r = await authed_client.post(f'/platform/proto/{name}/{ver}', files=files) assert r.status_code == 200 import routes.proto_routes as pr + safe = f'{name}_{ver}' gen = (pr.PROJECT_ROOT / 'generated').resolve() assert (gen / f'{safe}_pb2.py').exists() assert (gen / f'{safe}_pb2_grpc.py').exists() + @pytest.mark.asyncio async def test_proto_upload_rewrite_pb2_imports_for_generated_namespace(monkeypatch, authed_client): name, ver = 'psvc2', 'v1' proto = b'syntax = "proto3"; package foo; service S { rpc M (R) returns (Q) {} } message R { string a = 1; } message Q { string b = 1; }' files = {'file': ('svc.proto', proto)} import routes.proto_routes as pr + safe = f'{name}_{ver}' gen = (pr.PROJECT_ROOT / 'generated').resolve() + def _fake_run(cmd, check): (gen / f'{safe}_pb2.py').write_text('# pb2') (gen / f'{safe}_pb2_grpc.py').write_text( - f'import {safe}_pb2 as {name}__{ver}__pb2\n' - 'class S: pass\n' + f'import {safe}_pb2 as {name}__{ver}__pb2\nclass S: pass\n' ) return 0 + monkeypatch.setattr(pr.subprocess, 'run', _fake_run) r = await authed_client.post(f'/platform/proto/{name}/{ver}', files=files) assert r.status_code == 200 import routes.proto_routes as pr + safe = f'{name}_{ver}' gen = (pr.PROJECT_ROOT / 'generated').resolve() - pb2g = (gen / f'{safe}_pb2_grpc.py') + pb2g = gen / f'{safe}_pb2_grpc.py' txt = pb2g.read_text() assert f'from generated import {safe}_pb2 as {name}__{ver}__pb2' in txt + @pytest.mark.asyncio async def test_proto_get_requires_permission(monkeypatch, authed_client): import routes.proto_routes as pr + async def _no_perm(*args, **kwargs): return False + monkeypatch.setattr(pr, 'platform_role_required_bool', _no_perm) r = await authed_client.get('/platform/proto/x/v1') assert r.status_code == 403 diff --git a/backend-services/tests/test_public_apis.py b/backend-services/tests/test_public_apis.py index 7e03c2e..52f137c 100644 --- a/backend-services/tests/test_public_apis.py +++ b/backend-services/tests/test_public_apis.py @@ -1,8 +1,8 @@ import pytest + @pytest.mark.asyncio async def test_rest_public_api_allows_unauthenticated(client, authed_client): - name, ver = 'pubrest', 'v1' cr = await authed_client.post( '/platform/api', @@ -32,6 +32,7 @@ async def test_rest_public_api_allows_unauthenticated(client, authed_client): assert r.status_code in (200, 400, 404, 429, 500) + @pytest.mark.asyncio async def test_graphql_public_api_allows_unauthenticated(client, authed_client): name, ver = 'pubgql', 'v1' @@ -55,6 +56,7 @@ async def test_graphql_public_api_allows_unauthenticated(client, authed_client): ) assert r.status_code in (200, 400, 404, 429, 500) + @pytest.mark.asyncio async def test_public_api_bypasses_credits_check(client, authed_client): name, ver = 'pubcredits', 'v1' @@ -102,6 +104,7 @@ async def test_public_api_bypasses_credits_check(client, authed_client): r = await client.get(f'/api/rest/{name}/{ver}/ping') assert r.status_code != 401 + @pytest.mark.asyncio async def test_auth_not_required_but_not_public(client, authed_client): name, ver = 'noauthsub', 'v1' diff --git a/backend-services/tests/test_rate_limit_and_throttle_combination.py b/backend-services/tests/test_rate_limit_and_throttle_combination.py index 48d92fd..04aee46 100644 --- a/backend-services/tests/test_rate_limit_and_throttle_combination.py +++ b/backend-services/tests/test_rate_limit_and_throttle_combination.py @@ -1,21 +1,28 @@ -import pytest import time +import pytest from tests.test_gateway_routing_limits import _FakeAsyncClient + @pytest.mark.asyncio async def test_rate_limit_blocks_second_request_in_window(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'rlblock', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/p') await subscribe_self(authed_client, name, ver) from utils.database import user_collection - user_collection.update_one({'username': 'admin'}, {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second'}}) + + user_collection.update_one( + {'username': 'admin'}, + {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second'}}, + ) await authed_client.delete('/api/caches') import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) ok = await authed_client.get(f'/api/rest/{name}/{ver}/p') @@ -23,55 +30,77 @@ async def test_rate_limit_blocks_second_request_in_window(monkeypatch, authed_cl blocked = await authed_client.get(f'/api/rest/{name}/{ver}/p') assert blocked.status_code == 429 + @pytest.mark.asyncio async def test_throttle_queue_limit_exceeded_returns_429(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'tqex', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/t') await subscribe_self(authed_client, name, ver) from utils.database import user_collection - user_collection.update_one({'username': 'admin'}, {'$set': {'throttle_duration': 1, 'throttle_queue_limit': 1}}) + + user_collection.update_one( + {'username': 'admin'}, {'$set': {'throttle_duration': 1, 'throttle_queue_limit': 1}} + ) await authed_client.delete('/api/caches') from utils.limit_throttle_util import reset_counters + reset_counters() now_ms = int(time.time() * 1000) wait_ms = 1000 - (now_ms % 1000) + 350 time.sleep(wait_ms / 1000.0) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r1 = await authed_client.get(f'/api/rest/{name}/{ver}/t') - assert r1.status_code == 200, f"First request failed with {r1.status_code}" + assert r1.status_code == 200, f'First request failed with {r1.status_code}' r2 = await authed_client.get(f'/api/rest/{name}/{ver}/t') - assert r2.status_code == 429, f"Second request should have been throttled but got {r2.status_code}" + assert r2.status_code == 429, ( + f'Second request should have been throttled but got {r2.status_code}' + ) + @pytest.mark.asyncio async def test_throttle_dynamic_wait_increases_latency(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'twait', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/w') await subscribe_self(authed_client, name, ver) from utils.database import user_collection - user_collection.update_one({'username': 'admin'}, {'$set': { - 'throttle_duration': 1, 'throttle_duration_type': 'second', - 'throttle_queue_limit': 10, - 'throttle_wait_duration': 0.1, 'throttle_wait_duration_type': 'second', - 'rate_limit_duration': 1000, 'rate_limit_duration_type': 'second' - }}) + + user_collection.update_one( + {'username': 'admin'}, + { + '$set': { + 'throttle_duration': 1, + 'throttle_duration_type': 'second', + 'throttle_queue_limit': 10, + 'throttle_wait_duration': 0.1, + 'throttle_wait_duration_type': 'second', + 'rate_limit_duration': 1000, + 'rate_limit_duration_type': 'second', + } + }, + ) await authed_client.delete('/api/caches') from utils.limit_throttle_util import reset_counters + reset_counters() now_ms = int(time.time() * 1000) wait_ms = 1000 - (now_ms % 1000) + 350 time.sleep(wait_ms / 1000.0) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) t0 = time.perf_counter() @@ -88,19 +117,26 @@ async def test_throttle_dynamic_wait_increases_latency(monkeypatch, authed_clien assert dur2 >= dur1 + 0.08 + @pytest.mark.asyncio async def test_rate_limit_window_rollover_allows_requests(monkeypatch, authed_client): from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'rlroll', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/x') await subscribe_self(authed_client, name, ver) from utils.database import user_collection - user_collection.update_one({'username': 'admin'}, {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second'}}) + + user_collection.update_one( + {'username': 'admin'}, + {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second'}}, + ) await authed_client.delete('/api/caches') import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r1 = await authed_client.get(f'/api/rest/{name}/{ver}/x') diff --git a/backend-services/tests/test_rate_limit_fallback.py b/backend-services/tests/test_rate_limit_fallback.py index 98fc581..b73d2db 100644 --- a/backend-services/tests/test_rate_limit_fallback.py +++ b/backend-services/tests/test_rate_limit_fallback.py @@ -2,6 +2,7 @@ import asyncio from utils.limit_throttle_util import InMemoryWindowCounter + async def _inc(counter: InMemoryWindowCounter, key: str, times: int, ttl: int): counts = [] for _ in range(times): @@ -10,6 +11,7 @@ async def _inc(counter: InMemoryWindowCounter, key: str, times: int, ttl: int): await counter.expire(key, ttl) return counts + def test_inmemory_counter_increments_and_expires(event_loop): c = InMemoryWindowCounter() counts = event_loop.run_until_complete(_inc(c, 'k1', 3, 1)) @@ -19,4 +21,3 @@ def test_inmemory_counter_increments_and_expires(event_loop): counts2 = event_loop.run_until_complete(_inc(c, 'k1', 2, 1)) assert counts2[0] == 1 assert counts2[1] == 2 - diff --git a/backend-services/tests/test_rate_limiter.py b/backend-services/tests/test_rate_limiter.py index c8dc670..5a279d5 100644 --- a/backend-services/tests/test_rate_limiter.py +++ b/backend-services/tests/test_rate_limiter.py @@ -8,24 +8,22 @@ Comprehensive tests for rate limiting functionality including: - Load tests for distributed scenarios """ -import pytest import asyncio import time -from datetime import datetime, timedelta -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import MagicMock, Mock -from utils.rate_limiter import RateLimiter, RateLimitResult -from utils.quota_tracker import QuotaTracker +import pytest + +from models.rate_limit_models import QuotaType, RateLimitRule, RuleType, TimeWindow from utils.ip_rate_limiter import IPRateLimiter -from models.rate_limit_models import ( - RateLimitRule, RuleType, TimeWindow, QuotaType -) - +from utils.quota_tracker import QuotaTracker +from utils.rate_limiter import RateLimiter # ============================================================================ # FIXTURES # ============================================================================ + @pytest.fixture def mock_redis(): """Mock Redis client for testing""" @@ -60,12 +58,12 @@ def ip_limiter(mock_redis): def sample_rule(): """Sample rate limit rule""" return RateLimitRule( - rule_id="test_rule", + rule_id='test_rule', rule_type=RuleType.PER_USER, time_window=TimeWindow.MINUTE, limit=100, burst_allowance=20, - target_identifier="test_user" + target_identifier='test_user', ) @@ -73,60 +71,61 @@ def sample_rule(): # UNIT TESTS - RATE LIMITER # ============================================================================ + class TestRateLimiter: """Unit tests for RateLimiter class""" - + def test_sliding_window_within_limit(self, rate_limiter, sample_rule, mock_redis): """Test sliding window allows requests within limit""" - mock_redis.get.return_value = "50" # Current count - - result = rate_limiter._check_sliding_window(sample_rule, "test_user") - + mock_redis.get.return_value = '50' # Current count + + result = rate_limiter._check_sliding_window(sample_rule, 'test_user') + assert result.allowed is True assert result.limit == 100 assert result.remaining == 50 - + def test_sliding_window_exceeds_limit(self, rate_limiter, sample_rule, mock_redis): """Test sliding window blocks requests exceeding limit""" - mock_redis.get.return_value = "100" # At limit - - result = rate_limiter._check_sliding_window(sample_rule, "test_user") - + mock_redis.get.return_value = '100' # At limit + + result = rate_limiter._check_sliding_window(sample_rule, 'test_user') + assert result.allowed is False assert result.remaining == 0 - + def test_token_bucket_allows_burst(self, rate_limiter, sample_rule, mock_redis): """Test token bucket allows burst requests""" - mock_redis.get.side_effect = ["100", "5"] # Normal at limit, burst available - - result = rate_limiter.check_token_bucket(sample_rule, "test_user") - + mock_redis.get.side_effect = ['100', '5'] # Normal at limit, burst available + + result = rate_limiter.check_token_bucket(sample_rule, 'test_user') + assert result.allowed is True - + def test_burst_tokens_exhausted(self, rate_limiter, sample_rule, mock_redis): """Test burst tokens can be exhausted""" - mock_redis.get.side_effect = ["100", "20"] # Normal at limit, burst exhausted - - result = rate_limiter._use_burst_tokens(sample_rule, "test_user", Mock()) - + mock_redis.get.side_effect = ['100', '20'] # Normal at limit, burst exhausted + + result = rate_limiter._use_burst_tokens(sample_rule, 'test_user', Mock()) + assert result.allowed is False - + def test_hybrid_check_normal_flow(self, rate_limiter, sample_rule, mock_redis): """Test hybrid check in normal flow""" - mock_redis.get.return_value = "50" - - result = rate_limiter.check_hybrid(sample_rule, "test_user") - + mock_redis.get.return_value = '50' + + result = rate_limiter.check_hybrid(sample_rule, 'test_user') + assert result.allowed is True - + def test_hybrid_check_uses_burst(self, rate_limiter, sample_rule, mock_redis): """Test hybrid check falls back to burst tokens""" # First call: normal limit reached # Second call: burst available - mock_redis.get.side_effect = ["100", "5"] - - result = rate_limiter.check_hybrid(sample_rule, "test_user") - + mock_redis.get.side_effect = ['100', '5'] + + rate_limiter.check_hybrid(sample_rule, 'test_user') + # Should attempt to use burst tokens assert mock_redis.get.call_count >= 1 @@ -135,75 +134,52 @@ class TestRateLimiter: # UNIT TESTS - QUOTA TRACKER # ============================================================================ + class TestQuotaTracker: """Unit tests for QuotaTracker class""" - + def test_quota_within_limit(self, quota_tracker, mock_redis): """Test quota check within limit""" - mock_redis.get.return_value = "5000" - - result = quota_tracker.check_quota( - "test_user", - QuotaType.REQUESTS, - 10000, - "month" - ) - + mock_redis.get.return_value = '5000' + + result = quota_tracker.check_quota('test_user', QuotaType.REQUESTS, 10000, 'month') + assert result.is_exhausted is False assert result.is_warning is False assert result.remaining == 5000 - + def test_quota_warning_threshold(self, quota_tracker, mock_redis): """Test quota warning at 80%""" - mock_redis.get.return_value = "8500" # 85% used - - result = quota_tracker.check_quota( - "test_user", - QuotaType.REQUESTS, - 10000, - "month" - ) - + mock_redis.get.return_value = '8500' # 85% used + + result = quota_tracker.check_quota('test_user', QuotaType.REQUESTS, 10000, 'month') + assert result.is_warning is True assert result.is_critical is False - + def test_quota_critical_threshold(self, quota_tracker, mock_redis): """Test quota critical at 95%""" - mock_redis.get.return_value = "9600" # 96% used - - result = quota_tracker.check_quota( - "test_user", - QuotaType.REQUESTS, - 10000, - "month" - ) - + mock_redis.get.return_value = '9600' # 96% used + + result = quota_tracker.check_quota('test_user', QuotaType.REQUESTS, 10000, 'month') + assert result.is_critical is True - + def test_quota_exhausted(self, quota_tracker, mock_redis): """Test quota exhausted at 100%""" - mock_redis.get.return_value = "10000" - - result = quota_tracker.check_quota( - "test_user", - QuotaType.REQUESTS, - 10000, - "month" - ) - + mock_redis.get.return_value = '10000' + + result = quota_tracker.check_quota('test_user', QuotaType.REQUESTS, 10000, 'month') + assert result.is_exhausted is True assert result.remaining == 0 - + def test_quota_increment(self, quota_tracker, mock_redis): """Test quota increment""" mock_redis.incr.return_value = 101 - - new_usage = quota_tracker.increment_quota( - "test_user", - QuotaType.REQUESTS, - "month" - ) - + + new_usage = quota_tracker.increment_quota('test_user', QuotaType.REQUESTS, 'month') + assert new_usage == 101 mock_redis.incr.assert_called_once() @@ -212,60 +188,61 @@ class TestQuotaTracker: # UNIT TESTS - IP RATE LIMITER # ============================================================================ + class TestIPRateLimiter: """Unit tests for IPRateLimiter class""" - + def test_extract_ip_from_forwarded_for(self, ip_limiter): """Test IP extraction from X-Forwarded-For""" request = Mock() - request.headers.get.return_value = "192.168.1.1, 10.0.0.1" - + request.headers.get.return_value = '192.168.1.1, 10.0.0.1' + ip = ip_limiter.extract_client_ip(request) - - assert ip == "192.168.1.1" - + + assert ip == '192.168.1.1' + def test_extract_ip_from_real_ip(self, ip_limiter): """Test IP extraction from X-Real-IP""" request = Mock() - request.headers.get.side_effect = [None, "192.168.1.1"] - + request.headers.get.side_effect = [None, '192.168.1.1'] + ip = ip_limiter.extract_client_ip(request) - - assert ip == "192.168.1.1" - + + assert ip == '192.168.1.1' + def test_whitelist_bypasses_limit(self, ip_limiter, mock_redis): """Test whitelisted IP bypasses rate limit""" mock_redis.sismember.return_value = True - - result = ip_limiter.check_ip_rate_limit("192.168.1.1") - + + result = ip_limiter.check_ip_rate_limit('192.168.1.1') + assert result.allowed is True assert result.limit == 999999 - + def test_blacklist_blocks_request(self, ip_limiter, mock_redis): """Test blacklisted IP is blocked""" mock_redis.sismember.side_effect = [False, True] # Not whitelisted, is blacklisted - - result = ip_limiter.check_ip_rate_limit("10.0.0.1") - + + result = ip_limiter.check_ip_rate_limit('10.0.0.1') + assert result.allowed is False - + def test_reputation_reduces_limits(self, ip_limiter, mock_redis): """Test low reputation reduces rate limits""" mock_redis.sismember.return_value = False - mock_redis.get.side_effect = ["30", "0", "0"] # Low reputation, no requests yet - - result = ip_limiter.check_ip_rate_limit("10.0.0.1") - + mock_redis.get.side_effect = ['30', '0', '0'] # Low reputation, no requests yet + + result = ip_limiter.check_ip_rate_limit('10.0.0.1') + # Limits should be reduced due to low reputation assert result.allowed is True - + def test_reputation_score_update(self, ip_limiter, mock_redis): """Test reputation score update""" - mock_redis.get.return_value = "100" - - new_score = ip_limiter.update_reputation_score("192.168.1.1", -10) - + mock_redis.get.return_value = '100' + + new_score = ip_limiter.update_reputation_score('192.168.1.1', -10) + assert new_score == 90 mock_redis.setex.assert_called_once() @@ -274,70 +251,65 @@ class TestIPRateLimiter: # INTEGRATION TESTS # ============================================================================ + class TestRateLimitIntegration: """Integration tests for rate limiting system""" - + @pytest.mark.asyncio async def test_concurrent_requests(self, rate_limiter, sample_rule, mock_redis): """Test concurrent requests handling""" - mock_redis.get.return_value = "0" - + mock_redis.get.return_value = '0' + # Simulate 10 concurrent requests tasks = [] for i in range(10): task = asyncio.create_task( - asyncio.to_thread( - rate_limiter.check_hybrid, - sample_rule, - f"user_{i}" - ) + asyncio.to_thread(rate_limiter.check_hybrid, sample_rule, f'user_{i}') ) tasks.append(task) - + results = await asyncio.gather(*tasks) - + # All should be allowed (different users) assert all(r.allowed for r in results) - + def test_burst_handling_sequence(self, rate_limiter, sample_rule, mock_redis): """Test burst handling in sequence""" # Simulate reaching normal limit, then using burst mock_redis.get.side_effect = [ - "99", "100", "0", # Normal flow - "100", "5" # Burst flow + '99', + '100', + '0', # Normal flow + '100', + '5', # Burst flow ] - + # First request: within normal limit - result1 = rate_limiter.check_hybrid(sample_rule, "test_user") + result1 = rate_limiter.check_hybrid(sample_rule, 'test_user') assert result1.allowed is True - + # Second request: at normal limit, use burst - result2 = rate_limiter.check_hybrid(sample_rule, "test_user") + rate_limiter.check_hybrid(sample_rule, 'test_user') # Should attempt burst tokens - + def test_quota_and_rate_limit_interaction(self, quota_tracker, rate_limiter, mock_redis): """Test interaction between quota and rate limits""" - mock_redis.get.return_value = "50" - + mock_redis.get.return_value = '50' + # Check rate limit rate_result = rate_limiter._check_sliding_window( RateLimitRule( - rule_id="test", + rule_id='test', rule_type=RuleType.PER_USER, time_window=TimeWindow.MINUTE, - limit=100 + limit=100, ), - "test_user" + 'test_user', ) - + # Check quota - quota_result = quota_tracker.check_quota( - "test_user", - QuotaType.REQUESTS, - 10000, - "month" - ) - + quota_result = quota_tracker.check_quota('test_user', QuotaType.REQUESTS, 10000, 'month') + # Both should allow assert rate_result.allowed is True assert quota_result.is_exhausted is False @@ -347,45 +319,43 @@ class TestRateLimitIntegration: # LOAD TESTS # ============================================================================ + class TestRateLimitLoad: """Load tests for rate limiting system""" - + def test_high_volume_requests(self, rate_limiter, sample_rule, mock_redis): """Test handling high volume of requests""" - mock_redis.get.return_value = "0" - + mock_redis.get.return_value = '0' + start_time = time.time() - + # Simulate 1000 requests for i in range(1000): - result = rate_limiter.check_hybrid(sample_rule, f"user_{i % 100}") - + rate_limiter.check_hybrid(sample_rule, f'user_{i % 100}') + elapsed = time.time() - start_time - + # Should complete in reasonable time (< 1 second for 1000 requests) assert elapsed < 1.0 - + def test_distributed_scenario(self, rate_limiter, mock_redis): """Test distributed rate limiting scenario""" # Simulate multiple servers checking same user rules = [ RateLimitRule( - rule_id=f"rule_{i}", + rule_id=f'rule_{i}', rule_type=RuleType.PER_USER, time_window=TimeWindow.MINUTE, - limit=100 + limit=100, ) for i in range(5) ] - - mock_redis.get.return_value = "50" - + + mock_redis.get.return_value = '50' + # All servers should get consistent results - results = [ - rate_limiter.check_hybrid(rule, "test_user") - for rule in rules - ] - + results = [rate_limiter.check_hybrid(rule, 'test_user') for rule in rules] + # All should have same limit assert all(r.limit == 100 for r in results) @@ -394,38 +364,39 @@ class TestRateLimitLoad: # BURST HANDLING TESTS # ============================================================================ + class TestBurstHandling: """Specific tests for burst handling""" - + def test_burst_allows_spike(self, rate_limiter, sample_rule, mock_redis): """Test burst allows temporary spike""" # Normal limit reached - mock_redis.get.side_effect = ["100", "0"] - - result = rate_limiter._use_burst_tokens(sample_rule, "test_user", Mock()) - + mock_redis.get.side_effect = ['100', '0'] + + result = rate_limiter._use_burst_tokens(sample_rule, 'test_user', Mock()) + assert result.allowed is True - + def test_burst_refills_over_time(self, rate_limiter, sample_rule, mock_redis): """Test burst tokens refill over time""" # Burst used, then check again after time window - mock_redis.get.side_effect = ["20", "0"] # Burst exhausted, then refilled - + mock_redis.get.side_effect = ['20', '0'] # Burst exhausted, then refilled + # First check: exhausted - result1 = rate_limiter._use_burst_tokens(sample_rule, "test_user", Mock()) - + rate_limiter._use_burst_tokens(sample_rule, 'test_user', Mock()) + # Simulate time passing (would reset in Redis) - mock_redis.get.side_effect = ["0"] - + mock_redis.get.side_effect = ['0'] + # Second check: refilled - result2 = rate_limiter._use_burst_tokens(sample_rule, "test_user", Mock()) - + rate_limiter._use_burst_tokens(sample_rule, 'test_user', Mock()) + def test_burst_tracking(self, rate_limiter, sample_rule, mock_redis): """Test burst usage is tracked separately""" - mock_redis.get.return_value = "5" - - result = rate_limiter._use_burst_tokens(sample_rule, "test_user", Mock()) - + mock_redis.get.return_value = '5' + + rate_limiter._use_burst_tokens(sample_rule, 'test_user', Mock()) + # Should track burst usage mock_redis.incr.assert_called() @@ -434,39 +405,33 @@ class TestBurstHandling: # PERFORMANCE TESTS # ============================================================================ + class TestPerformance: """Performance optimization tests""" - + def test_redis_pipeline_usage(self, rate_limiter, mock_redis): """Test Redis pipeline is used for batch operations""" rule = RateLimitRule( - rule_id="test", - rule_type=RuleType.PER_USER, - time_window=TimeWindow.MINUTE, - limit=100 + rule_id='test', rule_type=RuleType.PER_USER, time_window=TimeWindow.MINUTE, limit=100 ) - - mock_redis.get.return_value = "50" - - rate_limiter.check_hybrid(rule, "test_user") - + + mock_redis.get.return_value = '50' + + rate_limiter.check_hybrid(rule, 'test_user') + # Pipeline should be used for atomic operations # (Implementation dependent) - + def test_counter_increment_efficiency(self, quota_tracker, mock_redis): """Test counter increment is efficient""" start_time = time.time() - + # Increment 100 times for i in range(100): - quota_tracker.increment_quota( - f"user_{i}", - QuotaType.REQUESTS, - "month" - ) - + quota_tracker.increment_quota(f'user_{i}', QuotaType.REQUESTS, 'month') + elapsed = time.time() - start_time - + # Should be fast (< 0.1 seconds) assert elapsed < 0.1 @@ -475,33 +440,34 @@ class TestPerformance: # ERROR HANDLING TESTS # ============================================================================ + class TestErrorHandling: """Test error handling and graceful degradation""" - + def test_redis_connection_failure(self, rate_limiter, sample_rule): """Test graceful handling of Redis connection failure""" redis = Mock() - redis.get.side_effect = Exception("Connection failed") - + redis.get.side_effect = Exception('Connection failed') + limiter = RateLimiter(redis_client=redis) - + # Should not crash, should allow by default (fail open) - result = limiter.check_hybrid(sample_rule, "test_user") - + result = limiter.check_hybrid(sample_rule, 'test_user') + # Graceful degradation: allow request assert result is not None - + def test_invalid_rule_parameters(self, rate_limiter): """Test handling of invalid rule parameters""" invalid_rule = RateLimitRule( - rule_id="invalid", + rule_id='invalid', rule_type=RuleType.PER_USER, time_window=TimeWindow.MINUTE, - limit=0 # Invalid limit + limit=0, # Invalid limit ) - + # Should handle gracefully - result = rate_limiter.check_hybrid(invalid_rule, "test_user") + result = rate_limiter.check_hybrid(invalid_rule, 'test_user') assert result is not None @@ -509,5 +475,5 @@ class TestErrorHandling: # RUN TESTS # ============================================================================ -if __name__ == "__main__": - pytest.main([__file__, "-v", "--tb=short"]) +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/backend-services/tests/test_real_compression_benchmarks.py b/backend-services/tests/test_real_compression_benchmarks.py index 3b2d151..41cc0eb 100644 --- a/backend-services/tests/test_real_compression_benchmarks.py +++ b/backend-services/tests/test_real_compression_benchmarks.py @@ -8,13 +8,14 @@ This test creates realistic API payloads and measures: 4. Realistic throughput estimates """ -import pytest +import asyncio import gzip -import json import io +import json import os import time -import asyncio + +import pytest @pytest.mark.asyncio @@ -23,7 +24,7 @@ async def test_realistic_rest_api_response_compression(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r = await client.post('/platform/authorization', json=login_payload) @@ -40,7 +41,7 @@ async def test_realistic_rest_api_response_compression(client): 'api_servers': [ 'https://api-primary.example.com', 'https://api-secondary.example.com', - 'https://api-backup.example.com' + 'https://api-backup.example.com', ], 'api_type': 'REST', 'api_allowed_retry_count': 3, @@ -49,7 +50,7 @@ async def test_realistic_rest_api_response_compression(client): 'api_credits_enabled': False, 'api_rate_limit_enabled': True, 'api_rate_limit_requests': 1000, - 'api_rate_limit_window': 60 + 'api_rate_limit_window': 60, } r = await client.post('/platform/api', json=api_payload) @@ -74,24 +75,26 @@ async def test_realistic_rest_api_response_compression(client): results[level] = { 'compressed_size': compressed_size, 'ratio': ratio, - 'time_ms': compression_time + 'time_ms': compression_time, } - print(f"\n{'='*70}") - print(f"REALISTIC REST API RESPONSE COMPRESSION BENCHMARK") - print(f"{'='*70}") - print(f"Uncompressed size: {uncompressed_size:,} bytes") - print(f"\nCompression Results:") + print(f'\n{"=" * 70}') + print('REALISTIC REST API RESPONSE COMPRESSION BENCHMARK') + print(f'{"=" * 70}') + print(f'Uncompressed size: {uncompressed_size:,} bytes') + print('\nCompression Results:') for level, result in results.items(): - print(f" Level {level}: {result['compressed_size']:,} bytes " - f"({result['ratio']:.1f}% reduction) " - f"in {result['time_ms']:.3f}ms") + print( + f' Level {level}: {result["compressed_size"]:,} bytes ' + f'({result["ratio"]:.1f}% reduction) ' + f'in {result["time_ms"]:.3f}ms' + ) # Cleanup await client.delete(f'/platform/api/{api_name}/v1') # Assertions - assert results[6]['ratio'] > 0, "Should achieve some compression" + assert results[6]['ratio'] > 0, 'Should achieve some compression' @pytest.mark.asyncio @@ -100,7 +103,7 @@ async def test_typical_api_gateway_request_flow(client): # Login login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } # Measure login response @@ -127,20 +130,22 @@ async def test_typical_api_gateway_request_flow(client): gz.write(health_json.encode('utf-8')) health_compressed = len(compressed_buffer.getvalue()) - print(f"\n{'='*70}") - print(f"TYPICAL API GATEWAY REQUEST FLOW") - print(f"{'='*70}") - print(f"\nLogin (POST /platform/authorization):") - print(f" Uncompressed: {login_uncompressed:,} bytes") - print(f" Compressed: {login_compressed:,} bytes") - print(f" Ratio: {(1 - login_compressed/login_uncompressed)*100:.1f}% reduction") - print(f"\nHealth Check (GET /api/health):") - print(f" Uncompressed: {health_uncompressed:,} bytes") - print(f" Compressed: {health_compressed:,} bytes") + print(f'\n{"=" * 70}') + print('TYPICAL API GATEWAY REQUEST FLOW') + print(f'{"=" * 70}') + print('\nLogin (POST /platform/authorization):') + print(f' Uncompressed: {login_uncompressed:,} bytes') + print(f' Compressed: {login_compressed:,} bytes') + print(f' Ratio: {(1 - login_compressed / login_uncompressed) * 100:.1f}% reduction') + print('\nHealth Check (GET /api/health):') + print(f' Uncompressed: {health_uncompressed:,} bytes') + print(f' Compressed: {health_compressed:,} bytes') if health_uncompressed > 500: - print(f" Ratio: {(1 - health_compressed/health_uncompressed)*100:.1f}% reduction") + print( + f' Ratio: {(1 - health_compressed / health_uncompressed) * 100:.1f}% reduction' + ) else: - print(f" Note: Below 500 byte minimum - not compressed in production") + print(' Note: Below 500 byte minimum - not compressed in production') @pytest.mark.asyncio @@ -148,7 +153,7 @@ async def test_large_list_response_compression(client): """Test compression on large list responses (multiple APIs)""" login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r = await client.post('/platform/authorization', json=login_payload) @@ -166,11 +171,11 @@ async def test_large_list_response_compression(client): 'api_allowed_groups': ['ALL', 'TEAM_A', 'TEAM_B'], 'api_servers': [ f'https://api-{i}-primary.example.com', - f'https://api-{i}-secondary.example.com' + f'https://api-{i}-secondary.example.com', ], 'api_type': 'REST', 'api_allowed_retry_count': 3, - 'active': True + 'active': True, } r = await client.post('/platform/api', json=api_payload) if r.status_code in (200, 201): @@ -193,10 +198,10 @@ async def test_large_list_response_compression(client): 'https://api-prod-1.example.com', 'https://api-prod-2.example.com', 'https://api-prod-3.example.com', - 'https://api-dr.example.com' + 'https://api-dr.example.com', ], 'api_type': 'REST', - 'active': True + 'active': True, } r = await client.post('/platform/api', json=api_payload) @@ -209,14 +214,14 @@ async def test_large_list_response_compression(client): gz.write(json_str.encode('utf-8')) compressed = len(compressed_buffer.getvalue()) - ratio = (1 - compressed/uncompressed) * 100 + ratio = (1 - compressed / uncompressed) * 100 - print(f"\n{'='*70}") - print(f"LARGE API CONFIGURATION RESPONSE") - print(f"{'='*70}") - print(f"Uncompressed: {uncompressed:,} bytes") - print(f"Compressed: {compressed:,} bytes") - print(f"Ratio: {ratio:.1f}% reduction") + print(f'\n{"=" * 70}') + print('LARGE API CONFIGURATION RESPONSE') + print(f'{"=" * 70}') + print(f'Uncompressed: {uncompressed:,} bytes') + print(f'Compressed: {compressed:,} bytes') + print(f'Ratio: {ratio:.1f}% reduction') api_names.append(api_payload['api_name']) @@ -236,14 +241,11 @@ async def test_worst_case_already_compressed_data(client): # Simulate a JWT token (random-looking base64) import base64 + random_bytes = os.urandom(256) token_like = base64.b64encode(random_bytes).decode('utf-8') - response_data = { - 'status': 'success', - 'token': token_like, - 'expires_in': 3600 - } + response_data = {'status': 'success', 'token': token_like, 'expires_in': 3600} json_str = json.dumps(response_data, separators=(',', ':')) uncompressed = len(json_str.encode('utf-8')) @@ -253,15 +255,15 @@ async def test_worst_case_already_compressed_data(client): gz.write(json_str.encode('utf-8')) compressed = len(compressed_buffer.getvalue()) - ratio = (1 - compressed/uncompressed) * 100 + ratio = (1 - compressed / uncompressed) * 100 - print(f"\n{'='*70}") - print(f"WORST CASE: RANDOM DATA (JWT-like tokens)") - print(f"{'='*70}") - print(f"Uncompressed: {uncompressed:,} bytes") - print(f"Compressed: {compressed:,} bytes") - print(f"Ratio: {ratio:.1f}% reduction") - print(f"\nNote: Even random data achieves some compression due to JSON structure") + print(f'\n{"=" * 70}') + print('WORST CASE: RANDOM DATA (JWT-like tokens)') + print(f'{"=" * 70}') + print(f'Uncompressed: {uncompressed:,} bytes') + print(f'Compressed: {compressed:,} bytes') + print(f'Ratio: {ratio:.1f}% reduction') + print('\nNote: Even random data achieves some compression due to JSON structure') @pytest.mark.asyncio @@ -282,27 +284,22 @@ async def test_compression_cpu_overhead_estimate(client): 'created_at': '2025-01-15T12:00:00Z', 'updated_at': '2025-01-18T15:30:00Z', 'views': 1234, - 'likes': 567 - } + 'likes': 567, + }, } for i in range(50) # 50 products ], - 'pagination': { - 'page': 1, - 'per_page': 50, - 'total': 500, - 'total_pages': 10 - } + 'pagination': {'page': 1, 'per_page': 50, 'total': 500, 'total_pages': 10}, } json_str = json.dumps(large_response, separators=(',', ':')) data = json_str.encode('utf-8') - print(f"\n{'='*70}") - print(f"COMPRESSION CPU OVERHEAD BENCHMARK") - print(f"{'='*70}") - print(f"Test payload: {len(data):,} bytes") - print(f"\nCompression performance:") + print(f'\n{"=" * 70}') + print('COMPRESSION CPU OVERHEAD BENCHMARK') + print(f'{"=" * 70}') + print(f'Test payload: {len(data):,} bytes') + print('\nCompression performance:') for level in [1, 4, 6, 9]: # Warm-up @@ -324,7 +321,7 @@ async def test_compression_cpu_overhead_estimate(client): elapsed = time.perf_counter() - start avg_time_ms = (elapsed / iterations) * 1000 avg_size = total_compressed // iterations - ratio = (1 - avg_size/len(data)) * 100 + ratio = (1 - avg_size / len(data)) * 100 # Estimate RPS capacity (assuming 50ms total request time) # Compression adds overhead, reducing available CPU time @@ -332,11 +329,11 @@ async def test_compression_cpu_overhead_estimate(client): with_compression_time = base_request_time + avg_time_ms rps_impact = (avg_time_ms / with_compression_time) * 100 - print(f"\n Level {level}:") - print(f" Time: {avg_time_ms:.3f} ms/request") - print(f" Compressed size: {avg_size:,} bytes ({ratio:.1f}% reduction)") - print(f" CPU overhead: {rps_impact:.1f}% of total request time") - print(f" Throughput: ~{1000/avg_time_ms:.0f} compressions/sec (single core)") + print(f'\n Level {level}:') + print(f' Time: {avg_time_ms:.3f} ms/request') + print(f' Compressed size: {avg_size:,} bytes ({ratio:.1f}% reduction)') + print(f' CPU overhead: {rps_impact:.1f}% of total request time') + print(f' Throughput: ~{1000 / avg_time_ms:.0f} compressions/sec (single core)') pytestmark = [pytest.mark.benchmark] diff --git a/backend-services/tests/test_redis_token_revocation_ha.py b/backend-services/tests/test_redis_token_revocation_ha.py index 7be1bdd..4ad1465 100644 --- a/backend-services/tests/test_redis_token_revocation_ha.py +++ b/backend-services/tests/test_redis_token_revocation_ha.py @@ -7,9 +7,11 @@ Simulates multi-node scenario: - Token validation on "Node B" (different process) should fail """ -import pytest import os +import pytest + + @pytest.mark.asyncio async def test_redis_token_revocation_shared_across_processes(monkeypatch, authed_client): """Test that token revocation via Redis is visible across simulated nodes. @@ -22,6 +24,7 @@ async def test_redis_token_revocation_shared_across_processes(monkeypatch, authe monkeypatch.setenv('MEM_OR_EXTERNAL', 'REDIS') from utils import auth_blacklist + auth_blacklist._redis_client = None auth_blacklist._redis_enabled = False auth_blacklist._init_redis_if_possible() @@ -31,7 +34,10 @@ async def test_redis_token_revocation_shared_across_processes(monkeypatch, authe login_response = await authed_client.post( '/platform/authorization', - json={'email': os.environ.get('DOORMAN_ADMIN_EMAIL'), 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD')} + json={ + 'email': os.environ.get('DOORMAN_ADMIN_EMAIL'), + 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD'), + }, ) assert login_response.status_code == 200 token_data = login_response.json() @@ -39,11 +45,8 @@ async def test_redis_token_revocation_shared_across_processes(monkeypatch, authe assert access_token is not None from jose import jwt - payload = jwt.decode( - access_token, - os.environ.get('JWT_SECRET_KEY'), - algorithms=['HS256'] - ) + + payload = jwt.decode(access_token, os.environ.get('JWT_SECRET_KEY'), algorithms=['HS256']) jti = payload.get('jti') username = payload.get('sub') exp = payload.get('exp') @@ -52,6 +55,7 @@ async def test_redis_token_revocation_shared_across_processes(monkeypatch, authe assert username is not None import time + ttl = max(1, int(exp - time.time())) if exp else 3600 auth_blacklist.add_revoked_jti(username, jti, ttl) @@ -65,12 +69,14 @@ async def test_redis_token_revocation_shared_across_processes(monkeypatch, authe if auth_blacklist._redis_client: auth_blacklist._redis_client.delete(auth_blacklist._revoked_jti_key(username, jti)) + @pytest.mark.asyncio async def test_redis_revoke_all_for_user_shared_across_processes(monkeypatch): """Test that user-level revocation via Redis is visible across nodes.""" monkeypatch.setenv('MEM_OR_EXTERNAL', 'REDIS') from utils import auth_blacklist + auth_blacklist._redis_client = None auth_blacklist._redis_enabled = False auth_blacklist._init_redis_if_possible() @@ -93,14 +99,16 @@ async def test_redis_revoke_all_for_user_shared_across_processes(monkeypatch): is_revoked_after_cleanup = auth_blacklist.is_user_revoked(test_username) assert is_revoked_after_cleanup is False + @pytest.mark.asyncio async def test_redis_token_revocation_ttl_expiry(monkeypatch): """Test that revoked tokens auto-expire in Redis based on TTL.""" monkeypatch.setenv('MEM_OR_EXTERNAL', 'REDIS') - from utils import auth_blacklist import time + from utils import auth_blacklist + auth_blacklist._redis_client = None auth_blacklist._redis_enabled = False auth_blacklist._init_redis_if_possible() @@ -119,6 +127,7 @@ async def test_redis_token_revocation_ttl_expiry(monkeypatch): assert auth_blacklist.is_jti_revoked(test_username, test_jti) is False + @pytest.mark.asyncio async def test_memory_fallback_when_redis_unavailable(monkeypatch): """Test that system falls back to in-memory revocation when Redis is unavailable.""" @@ -138,4 +147,3 @@ async def test_memory_fallback_when_redis_unavailable(monkeypatch): auth_blacklist.add_revoked_jti(test_username, test_jti, ttl_seconds=60) assert auth_blacklist.is_jti_revoked(test_username, test_jti) is True - diff --git a/backend-services/tests/test_request_id_and_logging_redaction.py b/backend-services/tests/test_request_id_and_logging_redaction.py index 77daa2f..5b3554f 100644 --- a/backend-services/tests/test_request_id_and_logging_redaction.py +++ b/backend-services/tests/test_request_id_and_logging_redaction.py @@ -1,19 +1,23 @@ -import pytest import logging from io import StringIO +import pytest + + @pytest.mark.asyncio async def test_request_id_middleware_injects_header_when_missing(authed_client): r = await authed_client.get('/api/status') assert r.status_code == 200 assert r.headers.get('X-Request-ID') + @pytest.mark.asyncio async def test_request_id_middleware_preserves_existing_header(authed_client): r = await authed_client.get('/api/status', headers={'X-Request-ID': 'req-123'}) assert r.status_code == 200 assert r.headers.get('X-Request-ID') == 'req-123' + def _capture_logs(logger_name: str, message: str) -> str: logger = logging.getLogger(logger_name) stream = StringIO() @@ -26,18 +30,20 @@ def _capture_logs(logger_name: str, message: str) -> str: logger.removeHandler(handler) return stream.getvalue() + def test_logging_redacts_authorization_headers(): msg = 'Authorization: Bearer secret-token' out = _capture_logs('doorman.gateway', msg) assert 'Authorization: [REDACTED]' in out + def test_logging_redacts_access_refresh_tokens(): msg = 'access_token="abc123" refresh_token="def456"' out = _capture_logs('doorman.gateway', msg) assert 'access_token' in out and '[REDACTED]' in out + def test_logging_redacts_cookie_values(): msg = 'cookie: sessionid=abcdef; csrftoken=xyz' out = _capture_logs('doorman.gateway', msg) assert 'cookie: [REDACTED]' in out - diff --git a/backend-services/tests/test_request_id_propagation.py b/backend-services/tests/test_request_id_propagation.py index a575231..d3160ab 100644 --- a/backend-services/tests/test_request_id_propagation.py +++ b/backend-services/tests/test_request_id_propagation.py @@ -1,13 +1,16 @@ import httpx import pytest + @pytest.mark.asyncio async def test_request_id_propagates_to_upstream_and_response(monkeypatch, authed_client): captured = {'xrid': None} def handler(req: httpx.Request) -> httpx.Response: captured['xrid'] = req.headers.get('X-Request-ID') - return httpx.Response(200, json={'ok': True}, headers={'X-Upstream-Request-ID': captured['xrid'] or ''}) + return httpx.Response( + 200, json={'ok': True}, headers={'X-Upstream-Request-ID': captured['xrid'] or ''} + ) transport = httpx.MockTransport(handler) mock_client = httpx.AsyncClient(transport=transport) @@ -17,7 +20,9 @@ async def test_request_id_propagates_to_upstream_and_response(monkeypatch, authe async def _get_client(): return mock_client - monkeypatch.setattr(gateway_service.GatewayService, 'get_http_client', classmethod(lambda cls: mock_client)) + monkeypatch.setattr( + gateway_service.GatewayService, 'get_http_client', classmethod(lambda cls: mock_client) + ) api_name, api_version = 'ridtest', 'v1' payload = { @@ -34,16 +39,22 @@ async def test_request_id_propagates_to_upstream_and_response(monkeypatch, authe r = await authed_client.post('/platform/api', json=payload) assert r.status_code in (200, 201), r.text - r2 = await authed_client.post('/platform/endpoint', json={ - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_method': 'GET', - 'endpoint_uri': '/echo', - 'endpoint_description': 'echo' - }) + r2 = await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_method': 'GET', + 'endpoint_uri': '/echo', + 'endpoint_description': 'echo', + }, + ) assert r2.status_code in (200, 201), r2.text - sub = await authed_client.post('/platform/subscription/subscribe', json={'username': 'admin', 'api_name': api_name, 'api_version': api_version}) + sub = await authed_client.post( + '/platform/subscription/subscribe', + json={'username': 'admin', 'api_name': api_name, 'api_version': api_version}, + ) assert sub.status_code in (200, 201), sub.text resp = await authed_client.get(f'/api/rest/{api_name}/{api_version}/echo') diff --git a/backend-services/tests/test_response_compression.py b/backend-services/tests/test_response_compression.py index d39662c..d6c3892 100644 --- a/backend-services/tests/test_response_compression.py +++ b/backend-services/tests/test_response_compression.py @@ -10,21 +10,17 @@ Verifies: - Compression level affects size and performance """ -import pytest -import gzip import json import os -from httpx import AsyncClient + +import pytest @pytest.mark.asyncio async def test_compression_enabled_for_json_response(client): """Verify JSON responses are compressed when client accepts gzip""" # Request with Accept-Encoding: gzip header - r = await client.get( - '/api/health', - headers={'Accept-Encoding': 'gzip'} - ) + r = await client.get('/api/health', headers={'Accept-Encoding': 'gzip'}) assert r.status_code == 200 # httpx automatically decompresses responses, but the middleware @@ -38,15 +34,12 @@ async def test_compression_reduces_response_size(client): # Get a response without compression r_uncompressed = await client.get( '/api/health', - headers={'Accept-Encoding': 'identity'} # No compression + headers={'Accept-Encoding': 'identity'}, # No compression ) assert r_uncompressed.status_code == 200 # Get same response with compression - r_compressed = await client.get( - '/api/health', - headers={'Accept-Encoding': 'gzip'} - ) + r_compressed = await client.get('/api/health', headers={'Accept-Encoding': 'gzip'}) assert r_compressed.status_code == 200 # Both should have same decompressed content @@ -62,17 +55,14 @@ async def test_compression_with_large_json_list(client): # First authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) assert r_auth.status_code == 200 # Get list of APIs (potentially large response) - r = await client.get( - '/platform/api', - headers={'Accept-Encoding': 'gzip'} - ) + r = await client.get('/platform/api', headers={'Accept-Encoding': 'gzip'}) assert r.status_code == 200 # Should have content-encoding header if response is large enough @@ -89,10 +79,7 @@ async def test_compression_with_large_json_list(client): async def test_small_response_not_compressed(client): """Verify small responses below minimum_size are not compressed""" # Health endpoint returns small response - r = await client.get( - '/api/health', - headers={'Accept-Encoding': 'gzip'} - ) + r = await client.get('/api/health', headers={'Accept-Encoding': 'gzip'}) assert r.status_code == 200 response_size = len(r.content) @@ -100,7 +87,7 @@ async def test_small_response_not_compressed(client): # If response is smaller than minimum_size (500 bytes default), # it should not be compressed if response_size < 500: - headers_lower = {k.lower(): v for k, v in r.headers.items()} + {k.lower(): v for k, v in r.headers.items()} # May or may not have content-encoding based on actual size # This is expected behavior - small responses aren't worth compressing @@ -111,7 +98,7 @@ async def test_compression_with_different_content_types(client): # Authenticate first login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) @@ -123,10 +110,7 @@ async def test_compression_with_different_content_types(client): ] for endpoint, expected_content_type in endpoints: - r = await client.get( - endpoint, - headers={'Accept-Encoding': 'gzip'} - ) + r = await client.get(endpoint, headers={'Accept-Encoding': 'gzip'}) # Should succeed assert r.status_code == 200 @@ -139,10 +123,7 @@ async def test_compression_with_different_content_types(client): @pytest.mark.asyncio async def test_no_compression_when_not_requested(client): """Verify compression is not applied when client doesn't accept it""" - r = await client.get( - '/api/health', - headers={'Accept-Encoding': 'identity'} - ) + r = await client.get('/api/health', headers={'Accept-Encoding': 'identity'}) assert r.status_code == 200 # Should not have gzip encoding @@ -157,23 +138,17 @@ async def test_compression_preserves_response_body(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) assert r_auth.status_code == 200 # Get response without compression - r1 = await client.get( - '/platform/user', - headers={'Accept-Encoding': 'identity'} - ) + r1 = await client.get('/platform/user', headers={'Accept-Encoding': 'identity'}) # Get response with compression - r2 = await client.get( - '/platform/user', - headers={'Accept-Encoding': 'gzip'} - ) + r2 = await client.get('/platform/user', headers={'Accept-Encoding': 'gzip'}) # Both should succeed assert r1.status_code == 200 @@ -189,13 +164,11 @@ async def test_compression_with_post_request(client): # Login with compression login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r = await client.post( - '/platform/authorization', - json=login_payload, - headers={'Accept-Encoding': 'gzip'} + '/platform/authorization', json=login_payload, headers={'Accept-Encoding': 'gzip'} ) assert r.status_code == 200 @@ -216,10 +189,7 @@ async def test_compression_works_with_errors(client): except Exception: pass - r = await client.get( - '/platform/user', - headers={'Accept-Encoding': 'gzip'} - ) + r = await client.get('/platform/user', headers={'Accept-Encoding': 'gzip'}) # Should be unauthorized assert r.status_code in (401, 403) @@ -236,10 +206,7 @@ async def test_compression_works_with_errors(client): @pytest.mark.asyncio async def test_compression_with_cache_headers(client): """Verify compression doesn't interfere with cache headers""" - r = await client.get( - '/api/health', - headers={'Accept-Encoding': 'gzip'} - ) + r = await client.get('/api/health', headers={'Accept-Encoding': 'gzip'}) assert r.status_code == 200 # Response should still have normal headers @@ -265,7 +232,7 @@ async def test_compression_with_large_payload(client): # Authenticate login_payload = { 'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'), - 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars') + 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'), } r_auth = await client.post('/platform/authorization', json=login_payload) @@ -275,10 +242,7 @@ async def test_compression_with_large_payload(client): # (This is mainly to test that compression handles large payloads) # Get list of APIs (can be large) - r = await client.get( - '/platform/api', - headers={'Accept-Encoding': 'gzip'} - ) + r = await client.get('/platform/api', headers={'Accept-Encoding': 'gzip'}) # Should succeed regardless of size assert r.status_code == 200 diff --git a/backend-services/tests/test_response_compression_edges.py b/backend-services/tests/test_response_compression_edges.py new file mode 100644 index 0000000..357e0a3 --- /dev/null +++ b/backend-services/tests/test_response_compression_edges.py @@ -0,0 +1,10 @@ +import pytest + + +@pytest.mark.asyncio +async def test_large_response_is_compressed(authed_client): + # Export all config typically returns a payload above compression threshold. + r = await authed_client.get('/platform/config/export/all') + assert r.status_code == 200 + ce = (r.headers.get('content-encoding') or '').lower() + assert ce == 'gzip' diff --git a/backend-services/tests/test_response_envelope_and_headers.py b/backend-services/tests/test_response_envelope_and_headers.py index 5557b0c..55e69e2 100644 --- a/backend-services/tests/test_response_envelope_and_headers.py +++ b/backend-services/tests/test_response_envelope_and_headers.py @@ -1,28 +1,37 @@ import pytest + @pytest.mark.asyncio async def test_rest_loose_envelope_returns_raw_message(monkeypatch, authed_client): name, ver = 'envrest', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://up'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'api_public': True, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'GET', - 'endpoint_uri': '/e', - 'endpoint_description': 'e' - }) - import services.gateway_service as gs + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'api_public': True, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/e', + 'endpoint_description': 'e', + }, + ) from tests.test_gateway_routing_limits import _FakeAsyncClient + + import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) monkeypatch.delenv('STRICT_RESPONSE_ENVELOPE', raising=False) r = await authed_client.get(f'/api/rest/{name}/{ver}/e') @@ -30,30 +39,39 @@ async def test_rest_loose_envelope_returns_raw_message(monkeypatch, authed_clien body = r.json() assert 'status_code' not in body and body.get('method') == 'GET' + @pytest.mark.asyncio async def test_rest_strict_envelope_wraps_message(monkeypatch, authed_client): monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true') name, ver = 'envrest2', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://up'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'api_public': True, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'GET', - 'endpoint_uri': '/e', - 'endpoint_description': 'e' - }) - import services.gateway_service as gs + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'api_public': True, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/e', + 'endpoint_description': 'e', + }, + ) from tests.test_gateway_routing_limits import _FakeAsyncClient + + import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/e') assert r.status_code == 200 @@ -62,6 +80,7 @@ async def test_rest_strict_envelope_wraps_message(monkeypatch, authed_client): assert isinstance(body.get('response'), dict) assert body['response'].get('method') == 'GET' + async def _setup_graphql(client, name='envgql', ver='v1'): payload = { 'api_name': name, @@ -74,160 +93,228 @@ async def _setup_graphql(client, name='envgql', ver='v1'): 'api_allowed_retry_count': 0, } await client.post('/platform/api', json=payload) - await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/graphql', - 'endpoint_description': 'gql', - }) + await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/graphql', + 'endpoint_description': 'gql', + }, + ) from conftest import subscribe_self + await subscribe_self(client, name, ver) return name, ver + @pytest.mark.asyncio async def test_graphql_strict_and_loose_envelopes(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'envgql', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://gql.up'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'api_public': True, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/graphql', - 'endpoint_description': 'gql' - }) + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://gql.up'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'api_public': True, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/graphql', + 'endpoint_description': 'gql', + }, + ) class FakeHTTPResp: def __init__(self, payload): self._p = payload + def json(self): return self._p + class FakeHTTPX: def __init__(self, *args, **kwargs): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def post(self, url, json=None, headers=None): return FakeHTTPResp({'data': {'pong': True}}) + # Force HTTPX path class Dummy: pass + monkeypatch.setattr(gs, 'Client', Dummy) monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX) monkeypatch.delenv('STRICT_RESPONSE_ENVELOPE', raising=False) - r1 = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ q }', 'variables': {}}) + r1 = await authed_client.post( + f'/api/graphql/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'query': '{ q }', 'variables': {}}, + ) assert r1.status_code == 200 assert 'status_code' not in r1.json() monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true') - r2 = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ q }', 'variables': {}}) + r2 = await authed_client.post( + f'/api/graphql/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'query': '{ q }', 'variables': {}}, + ) assert r2.status_code == 200 assert r2.json().get('status_code') == 200 and isinstance(r2.json().get('response'), dict) + @pytest.mark.asyncio async def test_grpc_strict_and_loose_envelopes(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'envgrpc', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': 'g', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['grpc://127.0.0.1:9'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/grpc', - 'endpoint_description': 'grpc' - }) + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'g', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['grpc://127.0.0.1:9'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/grpc', + 'endpoint_description': 'grpc', + }, + ) from conftest import subscribe_self + await subscribe_self(authed_client, name, ver) - async def fake_grpc_gateway(username, request, request_id, start_time, path, api_name=None, url=None, retry=0): + async def fake_grpc_gateway( + username, request, request_id, start_time, path, api_name=None, url=None, retry=0 + ): from models.response_model import ResponseModel - return ResponseModel(status_code=200, response_headers={'request_id': request_id}, response={'ok': True}).dict() + + return ResponseModel( + status_code=200, response_headers={'request_id': request_id}, response={'ok': True} + ).dict() + monkeypatch.setattr(gs.GatewayService, 'grpc_gateway', staticmethod(fake_grpc_gateway)) monkeypatch.delenv('STRICT_RESPONSE_ENVELOPE', raising=False) - r1 = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}) + r1 = await authed_client.post( + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, + ) assert r1.status_code == 200 assert r1.json() == {'ok': True} monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true') - r2 = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}) + r2 = await authed_client.post( + f'/api/grpc/{name}', + headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, + json={'method': 'Svc.M', 'message': {}}, + ) assert r2.status_code == 200 assert r2.json().get('status_code') == 200 and r2.json().get('response', {}).get('ok') is True + @pytest.mark.asyncio async def test_header_normalization_sets_x_request_id_from_request_id(monkeypatch, authed_client): name, ver = 'hdrnorm', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://up'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'api_public': True, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'GET', - 'endpoint_uri': '/h', - 'endpoint_description': 'h' - }) - import services.gateway_service as gs + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'api_public': True, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/h', + 'endpoint_description': 'h', + }, + ) from tests.test_gateway_routing_limits import _FakeAsyncClient + + import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/h') assert r.status_code == 200 assert r.headers.get('X-Request-ID') + @pytest.mark.asyncio async def test_header_normalization_preserves_existing_x_request_id(monkeypatch, authed_client): name, ver = 'hdrnorm2', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://up'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - 'api_public': True, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'GET', - 'endpoint_uri': '/h', - 'endpoint_description': 'h' - }) - import services.gateway_service as gs + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + 'api_public': True, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/h', + 'endpoint_description': 'h', + }, + ) from tests.test_gateway_routing_limits import _FakeAsyncClient + + import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.get(f'/api/rest/{name}/{ver}/h', headers={'X-Request-ID': 'my-req-id'}) assert r.status_code == 200 diff --git a/backend-services/tests/test_rest_authorization_field_swap.py b/backend-services/tests/test_rest_authorization_field_swap.py index 10637e6..d3e19a9 100644 --- a/backend-services/tests/test_rest_authorization_field_swap.py +++ b/backend-services/tests/test_rest_authorization_field_swap.py @@ -1,5 +1,6 @@ import pytest + class _Resp: def __init__(self, status_code=200, json_body=None, headers=None): self.status_code = status_code @@ -9,17 +10,22 @@ class _Resp: base_headers.update(headers) self.headers = base_headers self.text = '' + def json(self): return self._json_body + def _mk_client_capture(seen): class _Client: def __init__(self, timeout=None, limits=None, http2=False): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -37,24 +43,48 @@ def _mk_client_capture(seen): return await self.put(url, **kwargs) else: return _Resp(405) + async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs): - payload = {'method': 'POST', 'url': url, 'params': dict(params or {}), 'body': json, 'headers': headers or {}} - seen.append({'url': url, 'params': dict(params or {}), 'headers': dict(headers or {}), 'json': json}) + payload = { + 'method': 'POST', + 'url': url, + 'params': dict(params or {}), + 'body': json, + 'headers': headers or {}, + } + seen.append( + { + 'url': url, + 'params': dict(params or {}), + 'headers': dict(headers or {}), + 'json': json, + } + ) return _Resp(200, json_body=payload, headers={'X-Upstream': 'yes'}) + async def get(self, url, params=None, headers=None, **kwargs): - payload = {'method': 'GET', 'url': url, 'params': dict(params or {}), 'headers': headers or {}} + payload = { + 'method': 'GET', + 'url': url, + 'params': dict(params or {}), + 'headers': headers or {}, + } seen.append({'url': url, 'params': dict(params or {}), 'headers': dict(headers or {})}) return _Resp(200, json_body=payload, headers={'X-Upstream': 'yes'}) + async def put(self, url, **kwargs): payload = {'method': 'PUT', 'url': url, 'params': {}, 'headers': {}} seen.append({'url': url, 'params': {}, 'headers': {}}) return _Resp(200, json_body=payload, headers={'X-Upstream': 'yes'}) + async def delete(self, url, **kwargs): payload = {'method': 'DELETE', 'url': url, 'params': {}, 'headers': {}} seen.append({'url': url, 'params': {}, 'headers': {}}) return _Resp(200, json_body=payload, headers={'X-Upstream': 'yes'}) + return _Client + async def _setup_api(client, name, ver, swap_header, allowed_headers=None): payload = { 'api_name': name, @@ -71,28 +101,35 @@ async def _setup_api(client, name, ver, swap_header, allowed_headers=None): payload['api_allowed_headers'] = allowed_headers r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'GET', - 'endpoint_uri': '/p', - 'endpoint_description': 'p' - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/p', + 'endpoint_description': 'p', + }, + ) assert r2.status_code in (200, 201) from conftest import subscribe_self + await subscribe_self(client, name, ver) + @pytest.mark.asyncio async def test_auth_swap_injects_authorization_from_custom_header(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'authswap1', 'v1' swap_from = 'x-token' - await _setup_api(authed_client, name, ver, swap_from, allowed_headers=['authorization', swap_from]) + await _setup_api( + authed_client, name, ver, swap_from, allowed_headers=['authorization', swap_from] + ) seen = [] monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_client_capture(seen)) r = await authed_client.get( - f'/api/rest/{name}/{ver}/p', - headers={swap_from: 'Bearer backend-token'} + f'/api/rest/{name}/{ver}/p', headers={swap_from: 'Bearer backend-token'} ) assert r.status_code == 200 assert len(seen) == 1 @@ -100,33 +137,41 @@ async def test_auth_swap_injects_authorization_from_custom_header(monkeypatch, a auth_val = forwarded.get('Authorization') or forwarded.get('authorization') assert auth_val == 'Bearer backend-token' + @pytest.mark.asyncio async def test_auth_swap_missing_source_header_no_crash(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'authswap2', 'v1' swap_from = 'X-Backend-Auth' - await _setup_api(authed_client, name, ver, swap_from, allowed_headers=['Content-Type', swap_from]) + await _setup_api( + authed_client, name, ver, swap_from, allowed_headers=['Content-Type', swap_from] + ) seen = [] monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_client_capture(seen)) - r = await authed_client.get( - f'/api/rest/{name}/{ver}/p', - headers={} - ) + r = await authed_client.get(f'/api/rest/{name}/{ver}/p', headers={}) assert r.status_code == 200 fwd = (r.json() or {}).get('headers') or {} assert not (('Authorization' in fwd) or ('authorization' in fwd)) + @pytest.mark.asyncio async def test_auth_swap_with_empty_value_does_not_override(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'authswap3', 'v1' swap_from = 'X-Backend-Auth' - await _setup_api(authed_client, name, ver, swap_from, allowed_headers=['Content-Type', swap_from, 'Authorization']) + await _setup_api( + authed_client, + name, + ver, + swap_from, + allowed_headers=['Content-Type', swap_from, 'Authorization'], + ) seen = [] monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_client_capture(seen)) r = await authed_client.get( - f'/api/rest/{name}/{ver}/p', - headers={swap_from: '', 'Authorization': 'Bearer existing'} + f'/api/rest/{name}/{ver}/p', headers={swap_from: '', 'Authorization': 'Bearer existing'} ) assert r.status_code == 200 fwd = (r.json() or {}).get('headers') or {} diff --git a/backend-services/tests/test_rest_gateway_retries.py b/backend-services/tests/test_rest_gateway_retries.py index a75fb12..07dbcf1 100644 --- a/backend-services/tests/test_rest_gateway_retries.py +++ b/backend-services/tests/test_rest_gateway_retries.py @@ -1,5 +1,6 @@ import pytest + class _Resp: def __init__(self, status_code=200, body=b'{"ok":true}', headers=None): self.status_code = status_code @@ -12,8 +13,10 @@ class _Resp: def json(self): import json + return json.loads(self.text) + def _mk_retry_client(sequence, seen): """Factory for a fake AsyncClient that returns statuses from `sequence`. Records each call's (url, headers, params) into `seen` list. @@ -50,7 +53,14 @@ def _mk_retry_client(sequence, seen): return _Resp(405) async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs): - seen.append({'url': url, 'params': dict(params or {}), 'headers': dict(headers or {}), 'json': json}) + seen.append( + { + 'url': url, + 'params': dict(params or {}), + 'headers': dict(headers or {}), + 'json': json, + } + ) idx = min(counter['i'], len(sequence) - 1) code = sequence[idx] counter['i'] = counter['i'] + 1 @@ -79,6 +89,7 @@ def _mk_retry_client(sequence, seen): return _Client + async def _setup_api(client, name, ver, retry_count=0, allowed_headers=None): payload = { 'api_name': name, @@ -94,20 +105,26 @@ async def _setup_api(client, name, ver, retry_count=0, allowed_headers=None): payload['api_allowed_headers'] = allowed_headers r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/p', - 'endpoint_description': 'p' - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/p', + 'endpoint_description': 'p', + }, + ) assert r2.status_code in (200, 201) from conftest import subscribe_self + await subscribe_self(client, name, ver) + @pytest.mark.asyncio async def test_rest_retry_on_500_then_success(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'retry500', 'v1' await _setup_api(authed_client, name, ver, retry_count=2) seen = [] @@ -116,9 +133,11 @@ async def test_rest_retry_on_500_then_success(monkeypatch, authed_client): assert r.status_code == 200 assert len(seen) == 2 + @pytest.mark.asyncio async def test_rest_retry_on_502_then_success(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'retry502', 'v1' await _setup_api(authed_client, name, ver, retry_count=2) seen = [] @@ -127,9 +146,11 @@ async def test_rest_retry_on_502_then_success(monkeypatch, authed_client): assert r.status_code == 200 assert len(seen) == 2 + @pytest.mark.asyncio async def test_rest_retry_on_503_then_success(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'retry503', 'v1' await _setup_api(authed_client, name, ver, retry_count=2) seen = [] @@ -138,9 +159,11 @@ async def test_rest_retry_on_503_then_success(monkeypatch, authed_client): assert r.status_code == 200 assert len(seen) == 2 + @pytest.mark.asyncio async def test_rest_retry_on_504_then_success(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'retry504', 'v1' await _setup_api(authed_client, name, ver, retry_count=2) seen = [] @@ -149,9 +172,11 @@ async def test_rest_retry_on_504_then_success(monkeypatch, authed_client): assert r.status_code == 200 assert len(seen) == 2 + @pytest.mark.asyncio async def test_rest_no_retry_when_retry_count_zero(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'retry0', 'v1' await _setup_api(authed_client, name, ver, retry_count=0) seen = [] @@ -160,9 +185,11 @@ async def test_rest_no_retry_when_retry_count_zero(monkeypatch, authed_client): assert r.status_code == 500 assert len(seen) == 1 + @pytest.mark.asyncio async def test_rest_retry_stops_after_limit(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'retryLimit', 'v1' await _setup_api(authed_client, name, ver, retry_count=1) seen = [] @@ -171,9 +198,11 @@ async def test_rest_retry_stops_after_limit(monkeypatch, authed_client): assert r.status_code == 500 assert len(seen) == 2 + @pytest.mark.asyncio async def test_rest_retry_preserves_headers_and_params(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'retryHdr', 'v1' await _setup_api(authed_client, name, ver, retry_count=1, allowed_headers=['X-Custom']) seen = [] @@ -181,11 +210,13 @@ async def test_rest_retry_preserves_headers_and_params(monkeypatch, authed_clien r = await authed_client.post( f'/api/rest/{name}/{ver}/p?foo=bar', headers={'X-Custom': 'abc', 'Content-Type': 'application/json'}, - json={'a': 1} + json={'a': 1}, ) assert r.status_code == 200 assert len(seen) == 2 assert all(call['params'].get('foo') == 'bar' for call in seen) + def _hdr(call): return call['headers'].get('X-Custom') or call['headers'].get('x-custom') + assert all(_hdr(call) == 'abc' for call in seen) diff --git a/backend-services/tests/test_rest_header_and_response_parsing.py b/backend-services/tests/test_rest_header_and_response_parsing.py index 18cf795..a1945a2 100644 --- a/backend-services/tests/test_rest_header_and_response_parsing.py +++ b/backend-services/tests/test_rest_header_and_response_parsing.py @@ -1,5 +1,6 @@ import pytest + class _Resp: def __init__(self, status_code=200, body=b'{"ok":true}', headers=None): self.status_code = status_code @@ -22,16 +23,21 @@ class _Resp: def json(self): import json + return json.loads(self.text) + def _mk_client_capture(seen, resp_status=200, resp_headers=None, resp_body=b'{"ok":true}'): class _Client: def __init__(self, timeout=None, limits=None, http2=False): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -49,20 +55,33 @@ def _mk_client_capture(seen, resp_status=200, resp_headers=None, resp_body=b'{"o return await self.put(url, **kwargs) else: return _Resp(405) + async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs): - seen.append({'url': url, 'params': dict(params or {}), 'headers': dict(headers or {}), 'json': json}) + seen.append( + { + 'url': url, + 'params': dict(params or {}), + 'headers': dict(headers or {}), + 'json': json, + } + ) return _Resp(resp_status, body=resp_body, headers=resp_headers) + async def get(self, url, **kwargs): seen.append({'url': url, 'params': {}, 'headers': {}}) return _Resp(resp_status, body=resp_body, headers=resp_headers) + async def put(self, url, **kwargs): seen.append({'url': url, 'params': {}, 'headers': {}}) return _Resp(resp_status, body=resp_body, headers=resp_headers) + async def delete(self, url, **kwargs): seen.append({'url': url, 'params': {}, 'headers': {}}) return _Resp(resp_status, body=resp_body, headers=resp_headers) + return _Client + async def _setup_api(client, name, ver, allowed_headers=None): payload = { 'api_name': name, @@ -78,20 +97,28 @@ async def _setup_api(client, name, ver, allowed_headers=None): payload['api_allowed_headers'] = allowed_headers r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/p', - 'endpoint_description': 'p' - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/p', + 'endpoint_description': 'p', + }, + ) assert r2.status_code in (200, 201) from conftest import subscribe_self + await subscribe_self(client, name, ver) + @pytest.mark.asyncio -async def test_header_allowlist_forwards_only_allowed_headers_case_insensitive(monkeypatch, authed_client): +async def test_header_allowlist_forwards_only_allowed_headers_case_insensitive( + monkeypatch, authed_client +): import services.gateway_service as gs + name, ver = 'hdrallow', 'v1' await _setup_api(authed_client, name, ver, allowed_headers=['X-Custom', 'Content-Type']) seen = [] @@ -99,7 +126,7 @@ async def test_header_allowlist_forwards_only_allowed_headers_case_insensitive(m r = await authed_client.post( f'/api/rest/{name}/{ver}/p?foo=bar', headers={'x-custom': 'abc', 'X-Blocked': 'nope', 'Content-Type': 'application/json'}, - json={'a': 1} + json={'a': 1}, ) assert r.status_code == 200 assert len(seen) == 1 @@ -108,9 +135,11 @@ async def test_header_allowlist_forwards_only_allowed_headers_case_insensitive(m assert keys_lower.get('x-custom') == 'abc' assert 'x-blocked' not in keys_lower + @pytest.mark.asyncio async def test_header_block_non_allowlisted_headers(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'hdrblock', 'v1' await _setup_api(authed_client, name, ver, allowed_headers=['Content-Type']) seen = [] @@ -118,59 +147,75 @@ async def test_header_block_non_allowlisted_headers(monkeypatch, authed_client): r = await authed_client.post( f'/api/rest/{name}/{ver}/p', headers={'X-NotAllowed': '123', 'Content-Type': 'application/json'}, - json={'a': 1} + json={'a': 1}, ) assert r.status_code == 200 forwarded = seen[0]['headers'] assert 'X-NotAllowed' not in forwarded and 'x-notallowed' not in {k.lower() for k in forwarded} + def test_response_parse_application_json(): import services.gateway_service as gs + body = b'{"x": 1}' resp = _Resp(headers={'Content-Type': 'application/json'}, body=body) out = gs.GatewayService.parse_response(resp) assert isinstance(out, dict) and out.get('x') == 1 + def test_response_parse_text_plain_fallback(): import services.gateway_service as gs + body = b'hello world' resp = _Resp(headers={'Content-Type': 'text/plain'}, body=body) out = gs.GatewayService.parse_response(resp) assert out == body + def test_response_parse_application_xml(): import services.gateway_service as gs + body = b'1' resp = _Resp(headers={'Content-Type': 'application/xml'}, body=body) out = gs.GatewayService.parse_response(resp) from xml.etree.ElementTree import Element + assert isinstance(out, Element) and out.tag == 'root' + def test_response_parse_malformed_json_as_text(): import services.gateway_service as gs + body = b'{"x": 1' resp = _Resp(headers={'Content-Type': 'text/plain'}, body=body) out = gs.GatewayService.parse_response(resp) assert out == body + def test_response_binary_passthrough_no_decode(): import services.gateway_service as gs - binary = b'\x00\xFF\x10\x80' + + binary = b'\x00\xff\x10\x80' resp = _Resp(headers={'Content-Type': 'application/octet-stream'}, body=binary) out = gs.GatewayService.parse_response(resp) assert out == binary + def test_response_malformed_json_with_application_json_raises(): import services.gateway_service as gs + body = b'{"x": 1' resp = _Resp(headers={'Content-Type': 'application/json'}, body=body) import pytest + with pytest.raises(Exception): gs.GatewayService.parse_response(resp) + @pytest.mark.asyncio async def test_rest_gateway_returns_500_on_malformed_json_upstream(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'jsonfail', 'v1' await _setup_api(authed_client, name, ver) @@ -182,14 +227,22 @@ async def test_rest_gateway_returns_500_on_malformed_json_upstream(monkeypatch, self.headers = {'Content-Type': 'application/json'} self.content = bad_body self.text = bad_body.decode('utf-8', errors='ignore') + def json(self): import json + return json.loads(self.text) class _Client2: - def __init__(self, *a, **k): pass - async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc, tb): return False + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -207,16 +260,28 @@ async def test_rest_gateway_returns_500_on_malformed_json_upstream(monkeypatch, return await self.put(url, **kwargs) else: return _Resp2() - async def get(self, url, params=None, headers=None, **kwargs): return _Resp2() - async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs): return _Resp2() - async def head(self, url, params=None, headers=None, **kwargs): return _Resp2() - async def put(self, url, **kwargs): return _Resp2() - async def delete(self, url, **kwargs): return _Resp2() + + async def get(self, url, params=None, headers=None, **kwargs): + return _Resp2() + + async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs): + return _Resp2() + + async def head(self, url, params=None, headers=None, **kwargs): + return _Resp2() + + async def put(self, url, **kwargs): + return _Resp2() + + async def delete(self, url, **kwargs): + return _Resp2() monkeypatch.setattr(gs.httpx, 'AsyncClient', _Client2) - r = await authed_client.post(f'/api/rest/{name}/{ver}/p', headers={'Content-Type': 'application/json'}, json={'k': 'v'}) + r = await authed_client.post( + f'/api/rest/{name}/{ver}/p', headers={'Content-Type': 'application/json'}, json={'k': 'v'} + ) assert r.status_code == 500 body = r.json() payload = body.get('response', body) - assert (payload.get('error_code') or payload.get('error_message')) + assert payload.get('error_code') or payload.get('error_message') diff --git a/backend-services/tests/test_rest_methods_and_405.py b/backend-services/tests/test_rest_methods_and_405.py index 99634d4..2c3092c 100644 --- a/backend-services/tests/test_rest_methods_and_405.py +++ b/backend-services/tests/test_rest_methods_and_405.py @@ -1,5 +1,6 @@ import pytest + class _FakeHTTPResponse: def __init__(self, status_code=200, json_body=None, text_body=None, headers=None): self.status_code = status_code @@ -12,10 +13,12 @@ class _FakeHTTPResponse: def json(self): import json as _json + if self._json_body is None: return _json.loads(self.text or '{}') return self._json_body + class _FakeAsyncClient: def __init__(self, *args, **kwargs): pass @@ -45,63 +48,114 @@ class _FakeAsyncClient: return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'}) async def get(self, url, params=None, headers=None, **kwargs): - return _FakeHTTPResponse(200, json_body={'method': 'GET', 'url': url, 'params': dict(params or {}), 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + return _FakeHTTPResponse( + 200, + json_body={ + 'method': 'GET', + 'url': url, + 'params': dict(params or {}), + 'headers': headers or {}, + }, + headers={'X-Upstream': 'yes'}, + ) async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs): - body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) - return _FakeHTTPResponse(200, json_body={'method': 'POST', 'url': url, 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + body = ( + json + if json is not None + else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + ) + return _FakeHTTPResponse( + 200, + json_body={'method': 'POST', 'url': url, 'body': body, 'headers': headers or {}}, + headers={'X-Upstream': 'yes'}, + ) async def put(self, url, json=None, params=None, headers=None, content=None, **kwargs): - body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) - return _FakeHTTPResponse(200, json_body={'method': 'PUT', 'url': url, 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + body = ( + json + if json is not None + else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + ) + return _FakeHTTPResponse( + 200, + json_body={'method': 'PUT', 'url': url, 'body': body, 'headers': headers or {}}, + headers={'X-Upstream': 'yes'}, + ) async def delete(self, url, json=None, params=None, headers=None, content=None, **kwargs): - return _FakeHTTPResponse(200, json_body={'method': 'DELETE', 'url': url, 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + return _FakeHTTPResponse( + 200, + json_body={'method': 'DELETE', 'url': url, 'headers': headers or {}}, + headers={'X-Upstream': 'yes'}, + ) async def patch(self, url, json=None, params=None, headers=None, content=None, **kwargs): - body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) - return _FakeHTTPResponse(200, json_body={'method': 'PATCH', 'url': url, 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'}) + body = ( + json + if json is not None + else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content) + ) + return _FakeHTTPResponse( + 200, + json_body={'method': 'PATCH', 'url': url, 'body': body, 'headers': headers or {}}, + headers={'X-Upstream': 'yes'}, + ) async def head(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse(200, json_body=None, headers={'X-Upstream': 'yes'}) + async def _setup_api(client, name, ver, endpoint_method='GET', endpoint_uri='/p'): - r = await client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://up.methods'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) + r = await client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up.methods'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': endpoint_method, - 'endpoint_uri': endpoint_uri, - 'endpoint_description': endpoint_method.lower(), - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': endpoint_method, + 'endpoint_uri': endpoint_uri, + 'endpoint_description': endpoint_method.lower(), + }, + ) assert r2.status_code in (200, 201) rme = await client.get('/platform/user/me') - username = (rme.json().get('username') if rme.status_code == 200 else 'admin') - sr = await client.post('/platform/subscription/subscribe', json={'username': username, 'api_name': name, 'api_version': ver}) + username = rme.json().get('username') if rme.status_code == 200 else 'admin' + sr = await client.post( + '/platform/subscription/subscribe', + json={'username': username, 'api_name': name, 'api_version': ver}, + ) assert sr.status_code in (200, 201) + @pytest.mark.asyncio async def test_rest_head_supported_when_upstream_allows(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'headok', 'v1' await _setup_api(authed_client, name, ver, endpoint_method='GET', endpoint_uri='/p') monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.request('HEAD', f'/api/rest/{name}/{ver}/p') assert r.status_code == 200 + @pytest.mark.asyncio async def test_rest_patch_supported_when_registered(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'patchok', 'v1' await _setup_api(authed_client, name, ver, endpoint_method='PATCH', endpoint_uri='/edit') monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) @@ -110,24 +164,29 @@ async def test_rest_patch_supported_when_registered(monkeypatch, authed_client): j = r.json().get('response', r.json()) assert j.get('method') == 'PATCH' + @pytest.mark.asyncio async def test_rest_options_unregistered_endpoint_returns_405(monkeypatch, authed_client): monkeypatch.setenv('STRICT_OPTIONS_405', 'true') name, ver = 'optunreg', 'v1' - r = await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://up.methods'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) + r = await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up.methods'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) assert r.status_code in (200, 201) resp = await authed_client.options(f'/api/rest/{name}/{ver}/not-made') assert resp.status_code == 405 + @pytest.mark.asyncio async def test_rest_unsupported_method_returns_405(authed_client): name, ver = 'unsup', 'v1' diff --git a/backend-services/tests/test_rest_preflight_positive.py b/backend-services/tests/test_rest_preflight_positive.py new file mode 100644 index 0000000..065ce67 --- /dev/null +++ b/backend-services/tests/test_rest_preflight_positive.py @@ -0,0 +1,54 @@ +import pytest + + +@pytest.mark.asyncio +async def test_rest_preflight_positive_allows(authed_client): + name, ver = 'restpos', 'v1' + c = await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'REST preflight positive', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up.invalid'], + 'api_type': 'REST', + 'api_cors_allow_origins': ['http://ok.example'], + 'api_cors_allow_methods': ['GET'], + 'api_cors_allow_headers': ['Content-Type'], + 'api_allowed_retry_count': 0, + }, + ) + assert c.status_code in (200, 201) + ce = await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/p', + 'endpoint_description': 'p', + }, + ) + assert ce.status_code in (200, 201) + + r = await authed_client.options( + f'/api/rest/{name}/{ver}/p', + headers={ + 'Origin': 'http://ok.example', + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'Content-Type', + }, + ) + assert r.status_code == 204 + acao = r.headers.get('Access-Control-Allow-Origin') or r.headers.get( + 'access-control-allow-origin' + ) + assert acao == 'http://ok.example' + ach = ( + r.headers.get('Access-Control-Allow-Headers') + or r.headers.get('access-control-allow-headers') + or '' + ) + assert 'Content-Type' in ach diff --git a/backend-services/tests/test_role_admin_edge_cases.py b/backend-services/tests/test_role_admin_edge_cases.py new file mode 100644 index 0000000..c7855a8 --- /dev/null +++ b/backend-services/tests/test_role_admin_edge_cases.py @@ -0,0 +1,45 @@ +import time + +import pytest +from httpx import AsyncClient + + +async def _login(email: str, password: str) -> AsyncClient: + from doorman import doorman + + c = AsyncClient(app=doorman, base_url='http://testserver') + r = await c.post('/platform/authorization', json={'email': email, 'password': password}) + assert r.status_code == 200 + body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} + token = body.get('access_token') + if token: + c.cookies.set('access_token_cookie', token, domain='testserver', path='/') + return c + + +@pytest.mark.asyncio +async def test_non_admin_cannot_create_admin_role(authed_client): + # Create a user with manage_roles but not admin + rname = f'mgr_roles_{int(time.time())}' + cr = await authed_client.post('/platform/role', json={'role_name': rname, 'manage_roles': True}) + assert cr.status_code in (200, 201) + uname = f'role_mgr_{int(time.time())}' + pwd = 'RoleMgrStrongPass1!!' + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': f'{uname}@example.com', + 'password': pwd, + 'role': rname, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu.status_code in (200, 201) + mgr = await _login(f'{uname}@example.com', pwd) + # Attempt to create 'admin' role (should be forbidden for non-admin) + admin_create = await mgr.post( + '/platform/role', json={'role_name': 'admin', 'manage_roles': True} + ) + assert admin_create.status_code == 403 diff --git a/backend-services/tests/test_role_update.py b/backend-services/tests/test_role_update.py index 982c23f..aa7adb1 100644 --- a/backend-services/tests/test_role_update.py +++ b/backend-services/tests/test_role_update.py @@ -2,35 +2,35 @@ """Quick test to validate role model""" from models.update_role_model import UpdateRoleModel -import json # Test data with new permissions test_data = { - "role_name": "admin", - "role_description": "Administrator role", - "manage_users": True, - "manage_apis": True, - "manage_endpoints": True, - "manage_groups": True, - "manage_roles": True, - "manage_routings": True, - "manage_gateway": True, - "manage_subscriptions": True, - "manage_security": True, - "manage_tiers": True, - "manage_rate_limits": True, - "manage_credits": True, - "manage_auth": True, - "view_analytics": True, - "view_logs": True, - "export_logs": True + 'role_name': 'admin', + 'role_description': 'Administrator role', + 'manage_users': True, + 'manage_apis': True, + 'manage_endpoints': True, + 'manage_groups': True, + 'manage_roles': True, + 'manage_routings': True, + 'manage_gateway': True, + 'manage_subscriptions': True, + 'manage_security': True, + 'manage_tiers': True, + 'manage_rate_limits': True, + 'manage_credits': True, + 'manage_auth': True, + 'view_analytics': True, + 'view_logs': True, + 'export_logs': True, } try: model = UpdateRoleModel(**test_data) - print("✅ Model validation successful!") - print(f"Model: {model.model_dump()}") + print('✅ Model validation successful!') + print(f'Model: {model.model_dump()}') except Exception as e: - print(f"❌ Model validation failed: {e}") + print(f'❌ Model validation failed: {e}') import traceback + traceback.print_exc() diff --git a/backend-services/tests/test_roles_groups_subscriptions.py b/backend-services/tests/test_roles_groups_subscriptions.py index 9318e4d..afc6608 100644 --- a/backend-services/tests/test_roles_groups_subscriptions.py +++ b/backend-services/tests/test_roles_groups_subscriptions.py @@ -1,8 +1,8 @@ import pytest + @pytest.mark.asyncio async def test_roles_crud(authed_client): - r = await authed_client.post( '/platform/role', json={ @@ -35,9 +35,9 @@ async def test_roles_crud(authed_client): d = await authed_client.delete('/platform/role/qa') assert d.status_code == 200 + @pytest.mark.asyncio async def test_groups_crud(authed_client): - cg = await authed_client.post( '/platform/group', json={'group_name': 'qa-group', 'group_description': 'QA', 'api_access': []}, @@ -58,9 +58,9 @@ async def test_groups_crud(authed_client): dg = await authed_client.delete('/platform/group/qa-group') assert dg.status_code == 200 + @pytest.mark.asyncio async def test_subscriptions_flow(authed_client): - api_payload = { 'api_name': 'orders', 'api_version': 'v1', @@ -92,9 +92,9 @@ async def test_subscriptions_flow(authed_client): ) assert us.status_code in (200, 400) + @pytest.mark.asyncio async def test_token_defs_and_deduction_on_gateway(monkeypatch, authed_client): - credit_group = 'ai-group' cd = await authed_client.post( '/platform/credit', @@ -103,7 +103,13 @@ async def test_token_defs_and_deduction_on_gateway(monkeypatch, authed_client): 'api_key': 'sk-test-123', 'api_key_header': 'x-api-key', 'credit_tiers': [ - {'tier_name': 'basic', 'credits': 100, 'input_limit': 150, 'output_limit': 150, 'reset_frequency': 'monthly'} + { + 'tier_name': 'basic', + 'credits': 100, + 'input_limit': 150, + 'output_limit': 150, + 'reset_frequency': 'monthly', + } ], }, ) @@ -144,12 +150,10 @@ async def test_token_defs_and_deduction_on_gateway(monkeypatch, authed_client): assert s.status_code in (200, 201) uc = await authed_client.post( - f'/platform/credit/admin', + '/platform/credit/admin', json={ 'username': 'admin', - 'users_credits': { - credit_group: {'tier_name': 'basic', 'available_credits': 2} - }, + 'users_credits': {credit_group: {'tier_name': 'basic', 'available_credits': 2}}, }, ) assert uc.status_code in (200, 201), uc.text @@ -158,7 +162,9 @@ async def test_token_defs_and_deduction_on_gateway(monkeypatch, authed_client): r = await authed_client.get('/platform/credit/admin') assert r.status_code == 200, r.text body = r.json() - users_credits = body.get('users_credits') or body.get('response', {}).get('users_credits', {}) + users_credits = body.get('users_credits') or body.get('response', {}).get( + 'users_credits', {} + ) return int(users_credits.get(credit_group, {}).get('available_credits', 0)) import services.gateway_service as gs @@ -177,6 +183,7 @@ async def test_token_defs_and_deduction_on_gateway(monkeypatch, authed_client): class _FakeAsyncClient: def __init__(self, timeout=None, limits=None, http2=False): self._timeout = timeout + async def __aenter__(self): return self diff --git a/backend-services/tests/test_routing_crud_permissions.py b/backend-services/tests/test_routing_crud_permissions.py new file mode 100644 index 0000000..51eb7d2 --- /dev/null +++ b/backend-services/tests/test_routing_crud_permissions.py @@ -0,0 +1,81 @@ +import time + +import pytest +from httpx import AsyncClient + + +async def _login(email: str, password: str) -> AsyncClient: + from doorman import doorman + + c = AsyncClient(app=doorman, base_url='http://testserver') + r = await c.post('/platform/authorization', json={'email': email, 'password': password}) + assert r.status_code == 200 + body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} + token = body.get('access_token') + if token: + c.cookies.set('access_token_cookie', token, domain='testserver', path='/') + return c + + +@pytest.mark.asyncio +async def test_routing_crud_requires_manage_routings(authed_client): + client_key = f'client_{int(time.time())}' + # Limited user + uname = f'route_limited_{int(time.time())}' + pwd = 'RouteLimitStrong1!!' + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': f'{uname}@example.com', + 'password': pwd, + 'role': 'user', + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu.status_code in (200, 201) + limited = await _login(f'{uname}@example.com', pwd) + # Create denied + cr = await limited.post( + '/platform/routing', + json={'client_key': client_key, 'routing_name': 'r1', 'routing_servers': ['http://up']}, + ) + assert cr.status_code == 403 + + # Grant manage_routings + rname = f'route_mgr_{int(time.time())}' + rr = await authed_client.post( + '/platform/role', json={'role_name': rname, 'manage_routings': True} + ) + assert rr.status_code in (200, 201) + uname2 = f'route_mgr_user_{int(time.time())}' + cu2 = await authed_client.post( + '/platform/user', + json={ + 'username': uname2, + 'email': f'{uname2}@example.com', + 'password': pwd, + 'role': rname, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu2.status_code in (200, 201) + mgr = await _login(f'{uname2}@example.com', pwd) + + # Create allowed + ok = await mgr.post( + '/platform/routing', + json={'client_key': client_key, 'routing_name': 'r1', 'routing_servers': ['http://up']}, + ) + assert ok.status_code in (200, 201) + # Get allowed + gl = await mgr.get('/platform/routing/all') + assert gl.status_code == 200 + # Update allowed + up = await mgr.put(f'/platform/routing/{client_key}', json={'routing_description': 'updated'}) + assert up.status_code == 200 + # Delete allowed + de = await mgr.delete(f'/platform/routing/{client_key}') + assert de.status_code == 200 diff --git a/backend-services/tests/test_routing_precedence_and_round_robin.py b/backend-services/tests/test_routing_precedence_and_round_robin.py index 6531940..98c47c7 100644 --- a/backend-services/tests/test_routing_precedence_and_round_robin.py +++ b/backend-services/tests/test_routing_precedence_and_round_robin.py @@ -1,31 +1,38 @@ import pytest + @pytest.mark.asyncio async def test_routing_endpoint_servers_take_precedence_over_api_servers(authed_client): - from utils.database import api_collection, endpoint_collection - from utils.doorman_cache_util import doorman_cache from utils import routing_util + from utils.database import api_collection + from utils.doorman_cache_util import doorman_cache name, ver = 'route1', 'v1' - r = await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://api1', 'http://api2'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) + r = await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://api1', 'http://api2'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) assert r.status_code in (200, 201) - r2 = await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/echo', - 'endpoint_description': 'echo', - 'endpoint_servers': ['http://ep1', 'http://ep2'] - }) + r2 = await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/echo', + 'endpoint_description': 'echo', + 'endpoint_servers': ['http://ep1', 'http://ep2'], + }, + ) assert r2.status_code in (200, 201) api = api_collection.find_one({'api_name': name, 'api_version': ver}) @@ -36,39 +43,44 @@ async def test_routing_endpoint_servers_take_precedence_over_api_servers(authed_ picked = await routing_util.pick_upstream_server(api, 'POST', '/echo', client_key=None) assert picked == 'http://ep1' + @pytest.mark.asyncio async def test_routing_client_specific_routing_over_endpoint_and_api(authed_client): + from utils import routing_util from utils.database import api_collection, routing_collection from utils.doorman_cache_util import doorman_cache - from utils import routing_util name, ver = 'route2', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://api1', 'http://api2'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/echo', - 'endpoint_description': 'echo', - 'endpoint_servers': ['http://ep1', 'http://ep2'] - }) + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://api1', 'http://api2'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/echo', + 'endpoint_description': 'echo', + 'endpoint_servers': ['http://ep1', 'http://ep2'], + }, + ) api = api_collection.find_one({'api_name': name, 'api_version': ver}) api.pop('_id', None) - routing_collection.insert_one({ - 'client_key': 'ck1', - 'routing_servers': ['http://r1', 'http://r2'], - 'server_index': 0, - }) + routing_collection.insert_one( + {'client_key': 'ck1', 'routing_servers': ['http://r1', 'http://r2'], 'server_index': 0} + ) doorman_cache.clear_cache('client_routing_cache') doorman_cache.clear_cache('endpoint_server_cache') @@ -77,30 +89,37 @@ async def test_routing_client_specific_routing_over_endpoint_and_api(authed_clie assert s1 == 'http://r1' assert s2 == 'http://r2' + @pytest.mark.asyncio async def test_routing_round_robin_api_servers_rotates(authed_client): + from utils import routing_util from utils.database import api_collection from utils.doorman_cache_util import doorman_cache - from utils import routing_util name, ver = 'route3', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://a1', 'http://a2'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'GET', - 'endpoint_uri': '/status', - 'endpoint_description': 'status' - }) + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://a1', 'http://a2'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/status', + 'endpoint_description': 'status', + }, + ) api = api_collection.find_one({'api_name': name, 'api_version': ver}) api.pop('_id', None) @@ -110,31 +129,38 @@ async def test_routing_round_robin_api_servers_rotates(authed_client): s3 = await routing_util.pick_upstream_server(api, 'GET', '/status', client_key=None) assert [s1, s2, s3] == ['http://a1', 'http://a2', 'http://a1'] + @pytest.mark.asyncio async def test_routing_round_robin_endpoint_servers_rotates(authed_client): + from utils import routing_util from utils.database import api_collection from utils.doorman_cache_util import doorman_cache - from utils import routing_util name, ver = 'route4', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://a1', 'http://a2'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/echo', - 'endpoint_description': 'echo', - 'endpoint_servers': ['http://e1', 'http://e2'] - }) + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://a1', 'http://a2'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/echo', + 'endpoint_description': 'echo', + 'endpoint_servers': ['http://e1', 'http://e2'], + }, + ) api = api_collection.find_one({'api_name': name, 'api_version': ver}) api.pop('_id', None) @@ -144,30 +170,37 @@ async def test_routing_round_robin_endpoint_servers_rotates(authed_client): s3 = await routing_util.pick_upstream_server(api, 'POST', '/echo', client_key=None) assert [s1, s2, s3] == ['http://e1', 'http://e2', 'http://e1'] + @pytest.mark.asyncio async def test_routing_round_robin_index_persists_in_cache(authed_client): + from utils import routing_util from utils.database import api_collection from utils.doorman_cache_util import doorman_cache - from utils import routing_util name, ver = 'route5', 'v1' - await authed_client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://a1', 'http://a2', 'http://a3'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) - await authed_client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'GET', - 'endpoint_uri': '/status', - 'endpoint_description': 'status' - }) + await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://a1', 'http://a2', 'http://a3'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) + await authed_client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'GET', + 'endpoint_uri': '/status', + 'endpoint_description': 'status', + }, + ) api = api_collection.find_one({'api_name': name, 'api_version': ver}) api.pop('_id', None) @@ -181,4 +214,3 @@ async def test_routing_round_robin_index_persists_in_cache(authed_client): assert s3 == 'http://a3' idx_after = doorman_cache.get_cache('endpoint_server_cache', api['api_id']) assert idx_after == 0 - diff --git a/backend-services/tests/test_security.py b/backend-services/tests/test_security.py index 70b2345..6771b03 100644 --- a/backend-services/tests/test_security.py +++ b/backend-services/tests/test_security.py @@ -1,25 +1,24 @@ -import os import pytest from jose import jwt + @pytest.mark.asyncio async def test_jwt_tamper_rejected(client): - bad_token = jwt.encode({'sub': 'admin'}, 'wrong-secret', algorithm='HS256') client.cookies.set('access_token_cookie', bad_token) r = await client.get('/platform/user/me') assert r.status_code in (401, 500) + @pytest.mark.asyncio async def test_csrf_required_when_https(monkeypatch, authed_client): - monkeypatch.setenv('HTTPS_ONLY', 'true') r = await authed_client.get('/platform/user/me') assert r.status_code in (401, 500) + @pytest.mark.asyncio async def test_header_injection_is_sanitized(monkeypatch, authed_client): - api_name, version = 'hdr', 'v1' c = await authed_client.post( '/platform/api', @@ -54,6 +53,7 @@ async def test_header_injection_is_sanitized(monkeypatch, authed_client): assert s.status_code in (200, 201) import services.gateway_service as gs + captured = {} class _FakeHTTPResponse: @@ -63,14 +63,17 @@ async def test_header_injection_is_sanitized(monkeypatch, authed_client): self.headers = {'Content-Type': 'application/json'} self.content = b'{}' self.text = '{}' + def json(self): return self._json_body class _FakeAsyncClient: async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -88,61 +91,82 @@ async def test_header_injection_is_sanitized(monkeypatch, authed_client): return await self.put(url, **kwargs) else: return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'}) + async def get(self, url, params=None, headers=None, **kwargs): captured['headers'] = headers or {} return _FakeHTTPResponse(200) + async def post(self, url, **kwargs): return _FakeHTTPResponse(200) + async def put(self, url, **kwargs): return _FakeHTTPResponse(200) + async def delete(self, url, **kwargs): return _FakeHTTPResponse(200) monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) import routes.gateway_routes as gr - async def _no_limit(req): return None - async def _pass_sub(req): return {'sub': 'admin'} - async def _pass_group(req, full_path: str = None, user_to_subscribe=None): return {'sub': 'admin'} + + async def _no_limit(req): + return None + + async def _pass_sub(req): + return {'sub': 'admin'} + + async def _pass_group(req, full_path: str = None, user_to_subscribe=None): + return {'sub': 'admin'} + monkeypatch.setattr(gr, 'limit_and_throttle', _no_limit) monkeypatch.setattr(gr, 'subscription_required', _pass_sub) monkeypatch.setattr(gr, 'group_required', _pass_group) inj_value = 'abc\r\nInjected: 1' - r = await authed_client.get( - '/api/rest/hdr/v1/x', - headers={'X-Allowed': inj_value}, - ) + r = await authed_client.get('/api/rest/hdr/v1/x', headers={'X-Allowed': inj_value}) assert r.status_code in (200, 500) forwarded = captured.get('headers', {}).get('X-Allowed', '') assert '\r' not in forwarded and '\n' not in forwarded assert '<' not in forwarded and '>' not in forwarded + @pytest.mark.asyncio async def test_rate_limit_enforced(monkeypatch, authed_client): - from utils.database import user_collection + user_collection.update_one( {'username': 'admin'}, - {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second', - 'throttle_duration': 1, 'throttle_duration_type': 'second', - 'throttle_queue_limit': 1, 'throttle_wait_duration': 0.0, 'throttle_wait_duration_type': 'second'}} + { + '$set': { + 'rate_limit_duration': 1, + 'rate_limit_duration_type': 'second', + 'throttle_duration': 1, + 'throttle_duration_type': 'second', + 'throttle_queue_limit': 1, + 'throttle_wait_duration': 0.0, + 'throttle_wait_duration_type': 'second', + } + }, ) class _FakeRedis: def __init__(self): self.store = {} + async def incr(self, key): self.store[key] = self.store.get(key, 0) + 1 return self.store[key] + async def expire(self, key, ttl): return True from doorman import doorman as app + app.state.redis = _FakeRedis() from utils.doorman_cache_util import doorman_cache + try: doorman_cache.delete_cache('user_cache', 'admin') except Exception: @@ -180,6 +204,7 @@ async def test_rate_limit_enforced(monkeypatch, authed_client): assert s.status_code in (200, 201) import services.gateway_service as gs + class _FakeHTTPResponse: def __init__(self, status_code=200, json_body=None): self.status_code = status_code @@ -187,11 +212,17 @@ async def test_rate_limit_enforced(monkeypatch, authed_client): self.headers = {'Content-Type': 'application/json'} self.content = b'{}' self.text = '{}' + def json(self): return self._json_body + class _FakeAsyncClient: - async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc, tb): return False + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -209,15 +240,29 @@ async def test_rate_limit_enforced(monkeypatch, authed_client): return await self.put(url, **kwargs) else: return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'}) - async def get(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse(200) - async def post(self, url, **kwargs): return _FakeHTTPResponse(200) - async def put(self, url, **kwargs): return _FakeHTTPResponse(200) - async def delete(self, url, **kwargs): return _FakeHTTPResponse(200) + + async def get(self, url, params=None, headers=None, **kwargs): + return _FakeHTTPResponse(200) + + async def post(self, url, **kwargs): + return _FakeHTTPResponse(200) + + async def put(self, url, **kwargs): + return _FakeHTTPResponse(200) + + async def delete(self, url, **kwargs): + return _FakeHTTPResponse(200) + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) import routes.gateway_routes as gr - async def _pass_sub(req): return {'sub': 'admin'} - async def _pass_group(req, full_path: str = None, user_to_subscribe=None): return {'sub': 'admin'} + + async def _pass_sub(req): + return {'sub': 'admin'} + + async def _pass_group(req, full_path: str = None, user_to_subscribe=None): + return {'sub': 'admin'} + monkeypatch.setattr(gr, 'subscription_required', _pass_sub) monkeypatch.setattr(gr, 'group_required', _pass_group) diff --git a/backend-services/tests/test_security_and_metrics.py b/backend-services/tests/test_security_and_metrics.py index 5f3bfaf..2dc9e90 100644 --- a/backend-services/tests/test_security_and_metrics.py +++ b/backend-services/tests/test_security_and_metrics.py @@ -1,11 +1,10 @@ import os -import json -import time + import pytest + @pytest.mark.asyncio async def test_security_headers_and_hsts(monkeypatch, client): - r = await client.get('/platform/monitor/liveness') assert r.status_code == 200 assert r.headers.get('X-Content-Type-Options') == 'nosniff' @@ -18,13 +17,17 @@ async def test_security_headers_and_hsts(monkeypatch, client): assert r.status_code == 200 assert 'Strict-Transport-Security' in r.headers + @pytest.mark.asyncio async def test_body_size_limit_returns_413(monkeypatch, client): monkeypatch.setenv('MAX_BODY_SIZE_BYTES', '10') payload = 'x' * 100 - r = await client.post('/platform/authorization', content=payload, headers={'Content-Type': 'text/plain'}) + r = await client.post( + '/platform/authorization', content=payload, headers={'Content-Type': 'text/plain'} + ) assert r.status_code == 413 + @pytest.mark.asyncio async def test_strict_response_envelope(monkeypatch, authed_client): monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true') @@ -36,10 +39,11 @@ async def test_strict_response_envelope(monkeypatch, authed_client): assert isinstance(data, dict) assert 'status_code' in data and data['status_code'] == 200 + @pytest.mark.asyncio async def test_metrics_recording_snapshot(authed_client): - from conftest import create_api, create_endpoint, subscribe_self + name, ver = 'metapi', 'v1' await create_api(authed_client, name, ver) await create_endpoint(authed_client, name, ver, 'GET', '/status') @@ -56,19 +60,22 @@ async def test_metrics_recording_snapshot(authed_client): assert isinstance(series, list) assert body.get('response', {}).get('total_requests') or body.get('total_requests') >= 1 + @pytest.mark.asyncio async def test_cors_strict_allows_localhost(monkeypatch, client): - monkeypatch.setenv('CORS_STRICT', 'true') monkeypatch.setenv('ALLOWED_ORIGINS', '*') monkeypatch.setenv('ALLOW_CREDENTIALS', 'true') r = await client.get('/platform/monitor/liveness', headers={'Origin': 'http://localhost:3000'}) - assert r.headers.get('access-control-allow-origin') in ('http://localhost:3000', 'http://localhost') + assert r.headers.get('access-control-allow-origin') in ( + 'http://localhost:3000', + 'http://localhost', + ) + @pytest.mark.asyncio async def test_csp_header_default_and_override(monkeypatch, client): - monkeypatch.delenv('CONTENT_SECURITY_POLICY', raising=False) r = await client.get('/platform/monitor/liveness') assert r.status_code == 200 @@ -79,9 +86,9 @@ async def test_csp_header_default_and_override(monkeypatch, client): r2 = await client.get('/platform/monitor/liveness') assert r2.headers.get('Content-Security-Policy') == "default-src 'self'" + @pytest.mark.asyncio async def test_request_id_header_generation_and_echo(client): - r = await client.get('/platform/monitor/liveness') assert r.status_code == 200 assert r.headers.get('X-Request-ID') @@ -92,9 +99,9 @@ async def test_request_id_header_generation_and_echo(client): assert r2.headers.get('X-Request-ID') == incoming assert r2.headers.get('request_id') == incoming + @pytest.mark.asyncio async def test_memory_dump_and_restore(tmp_path, monkeypatch): - monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM') monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'unit-test-secret') @@ -102,8 +109,12 @@ async def test_memory_dump_and_restore(tmp_path, monkeypatch): dump_dir.mkdir(parents=True, exist_ok=True) monkeypatch.setenv('MEM_DUMP_PATH', str(dump_dir / 'memory_dump.bin')) - from utils.memory_dump_util import dump_memory_to_file, find_latest_dump_path, restore_memory_from_file from utils.database import database + from utils.memory_dump_util import ( + dump_memory_to_file, + find_latest_dump_path, + restore_memory_from_file, + ) database.db.users.insert_one({'username': 'tmp', 'email': 't@t.t', 'password': 'x'}) path = dump_memory_to_file(None) @@ -113,5 +124,5 @@ async def test_memory_dump_and_restore(tmp_path, monkeypatch): database.db.users._docs.clear() assert database.db.users.count_documents({}) == 0 - info = restore_memory_from_file(latest) + restore_memory_from_file(latest) assert database.db.users.count_documents({}) >= 1 diff --git a/backend-services/tests/test_security_permissions.py b/backend-services/tests/test_security_permissions.py index f78d124..85797af 100644 --- a/backend-services/tests/test_security_permissions.py +++ b/backend-services/tests/test_security_permissions.py @@ -1,8 +1,8 @@ import pytest + @pytest.mark.asyncio async def test_security_settings_requires_permission(authed_client): - r = await authed_client.put('/platform/role/admin', json={'manage_security': False}) assert r.status_code in (200, 201) @@ -17,4 +17,3 @@ async def test_security_settings_requires_permission(authed_client): gs2 = await authed_client.get('/platform/security/settings') assert gs2.status_code == 200 - diff --git a/backend-services/tests/test_security_settings_permissions.py b/backend-services/tests/test_security_settings_permissions.py new file mode 100644 index 0000000..4c7cdc4 --- /dev/null +++ b/backend-services/tests/test_security_settings_permissions.py @@ -0,0 +1,65 @@ +import time + +import pytest + + +@pytest.mark.asyncio +async def test_security_settings_get_put_permissions(authed_client): + # Limited user: 403 on get/put + uname = f'sec_limited_{int(time.time())}' + pwd = 'SecLimitStrongPass1!!' + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': f'{uname}@example.com', + 'password': pwd, + 'role': 'user', + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu.status_code in (200, 201) + + from httpx import AsyncClient + + from doorman import doorman + + limited = AsyncClient(app=doorman, base_url='http://testserver') + r = await limited.post( + '/platform/authorization', json={'email': f'{uname}@example.com', 'password': pwd} + ) + assert r.status_code == 200 + get403 = await limited.get('/platform/security/settings') + assert get403.status_code == 403 + put403 = await limited.put('/platform/security/settings', json={'trust_x_forwarded_for': True}) + assert put403.status_code == 403 + + # Role manage_security: 200 on get/put + rname = f'sec_mgr_{int(time.time())}' + cr = await authed_client.post( + '/platform/role', json={'role_name': rname, 'manage_security': True} + ) + assert cr.status_code in (200, 201) + uname2 = f'sec_mgr_user_{int(time.time())}' + cu2 = await authed_client.post( + '/platform/user', + json={ + 'username': uname2, + 'email': f'{uname2}@example.com', + 'password': pwd, + 'role': rname, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu2.status_code in (200, 201) + mgr = AsyncClient(app=doorman, base_url='http://testserver') + r2 = await mgr.post( + '/platform/authorization', json={'email': f'{uname2}@example.com', 'password': pwd} + ) + assert r2.status_code == 200 + g = await mgr.get('/platform/security/settings') + assert g.status_code == 200 + u = await mgr.put('/platform/security/settings', json={'trust_x_forwarded_for': True}) + assert u.status_code == 200 diff --git a/backend-services/tests/test_security_settings_persistence.py b/backend-services/tests/test_security_settings_persistence.py index 4c2bcb1..884a838 100644 --- a/backend-services/tests/test_security_settings_persistence.py +++ b/backend-services/tests/test_security_settings_persistence.py @@ -1,15 +1,17 @@ -import json import asyncio +import json + import pytest + @pytest.mark.asyncio async def test_load_settings_from_file_in_memory_mode(tmp_path, monkeypatch): monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM') from utils import security_settings_util as ssu + settings_path = tmp_path / 'sec_settings.json' monkeypatch.setattr(ssu, 'SETTINGS_FILE', str(settings_path), raising=False) - from utils.security_settings_util import load_settings, get_cached_settings - from utils.security_settings_util import _get_collection + from utils.security_settings_util import _get_collection, get_cached_settings, load_settings coll = _get_collection() try: @@ -41,6 +43,7 @@ async def test_load_settings_from_file_in_memory_mode(tmp_path, monkeypatch): assert cached.get('ip_whitelist') == ['203.0.113.1'] assert cached.get('allow_localhost_bypass') is True + @pytest.mark.asyncio async def test_save_settings_writes_file_and_autosave_triggers_dump(tmp_path, monkeypatch): monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM') @@ -50,20 +53,24 @@ async def test_save_settings_writes_file_and_autosave_triggers_dump(tmp_path, mo monkeypatch.setattr(ssu, 'SETTINGS_FILE', str(sec_file), raising=False) calls = {'count': 0, 'last_path': None} + def _fake_dump(path_hint): calls['count'] += 1 calls['last_path'] = path_hint return str(tmp_path / 'dump.bin') + monkeypatch.setattr(ssu, 'dump_memory_to_file', _fake_dump) await ssu.start_auto_save_task() prev_task = getattr(ssu, '_AUTO_TASK', None) - result = await ssu.save_settings({ - 'enable_auto_save': True, - 'auto_save_frequency_seconds': 90, - 'dump_path': str(tmp_path / 'wanted_dump.bin'), - }) + result = await ssu.save_settings( + { + 'enable_auto_save': True, + 'auto_save_frequency_seconds': 90, + 'dump_path': str(tmp_path / 'wanted_dump.bin'), + } + ) await asyncio.sleep(0) diff --git a/backend-services/tests/test_soap_gateway_content_types.py b/backend-services/tests/test_soap_gateway_content_types.py index 731626f..f3e1092 100644 --- a/backend-services/tests/test_soap_gateway_content_types.py +++ b/backend-services/tests/test_soap_gateway_content_types.py @@ -1,5 +1,6 @@ import pytest + class _FakeXMLResponse: def __init__(self, status_code=200, text='', headers=None): self.status_code = status_code @@ -10,14 +11,18 @@ class _FakeXMLResponse: self.headers = base self.content = self.text.encode('utf-8') + def _mk_xml_client(captured): class _FakeXMLClient: def __init__(self, *args, **kwargs): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -35,45 +40,62 @@ def _mk_xml_client(captured): return await self.put(url, **kwargs) else: return _FakeXMLResponse(405, 'Method not allowed') + async def get(self, url, **kwargs): return _FakeXMLResponse(200, '', {'X-Upstream': 'yes', 'Content-Type': 'text/xml'}) + async def post(self, url, content=None, params=None, headers=None, **kwargs): captured.append({'url': url, 'headers': dict(headers or {}), 'content': content}) return _FakeXMLResponse(200, '', {'X-Upstream': 'yes', 'Content-Type': 'text/xml'}) + async def put(self, url, **kwargs): return _FakeXMLResponse(200, '', {'X-Upstream': 'yes', 'Content-Type': 'text/xml'}) + async def delete(self, url, **kwargs): return _FakeXMLResponse(200, '', {'X-Upstream': 'yes', 'Content-Type': 'text/xml'}) + return _FakeXMLClient + async def _setup_api(client, name, ver): - r = await client.post('/platform/api', json={ - 'api_name': name, - 'api_version': ver, - 'api_description': f'{name} {ver}', - 'api_allowed_roles': ['admin'], - 'api_allowed_groups': ['ALL'], - 'api_servers': ['http://soap.up'], - 'api_type': 'REST', - 'api_allowed_retry_count': 0, - }) + r = await client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': f'{name} {ver}', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://soap.up'], + 'api_type': 'REST', + 'api_allowed_retry_count': 0, + }, + ) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/call', - 'endpoint_description': 'soap call', - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/call', + 'endpoint_description': 'soap call', + }, + ) assert r2.status_code in (200, 201) rme = await client.get('/platform/user/me') - username = (rme.json().get('username') if rme.status_code == 200 else 'admin') - rs = await client.post('/platform/subscription/subscribe', json={'username': username, 'api_name': name, 'api_version': ver}) + username = rme.json().get('username') if rme.status_code == 200 else 'admin' + rs = await client.post( + '/platform/subscription/subscribe', + json={'username': username, 'api_name': name, 'api_version': ver}, + ) assert rs.status_code in (200, 201) + @pytest.mark.asyncio async def test_soap_incoming_application_xml_sets_text_xml_outgoing(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'soapct1', 'v1' await _setup_api(authed_client, name, ver) captured = [] @@ -89,26 +111,28 @@ async def test_soap_incoming_application_xml_sets_text_xml_outgoing(monkeypatch, h = captured[0]['headers'] assert h.get('Content-Type') == 'text/xml; charset=utf-8' + @pytest.mark.asyncio async def test_soap_incoming_text_xml_passes_through(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'soapct2', 'v1' await _setup_api(authed_client, name, ver) captured = [] monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_xml_client(captured)) envelope = '' r = await authed_client.post( - f'/api/soap/{name}/{ver}/call', - headers={'Content-Type': 'text/xml'}, - content=envelope, + f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'text/xml'}, content=envelope ) assert r.status_code == 200 h = captured[0]['headers'] assert h.get('Content-Type') == 'text/xml' + @pytest.mark.asyncio async def test_soap_incoming_application_soap_xml_passes_through(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'soapct3', 'v1' await _setup_api(authed_client, name, ver) captured = [] @@ -123,9 +147,11 @@ async def test_soap_incoming_application_soap_xml_passes_through(monkeypatch, au h = captured[0]['headers'] assert h.get('Content-Type') == 'application/soap+xml' + @pytest.mark.asyncio async def test_soap_adds_default_soapaction_when_missing(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'soapct4', 'v1' await _setup_api(authed_client, name, ver) captured = [] @@ -140,9 +166,11 @@ async def test_soap_adds_default_soapaction_when_missing(monkeypatch, authed_cli h = captured[0]['headers'] assert 'SOAPAction' in h and h['SOAPAction'] == '""' + @pytest.mark.asyncio async def test_soap_parses_xml_response_success(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'soapct5', 'v1' await _setup_api(authed_client, name, ver) captured = [] diff --git a/backend-services/tests/test_soap_gateway_retries.py b/backend-services/tests/test_soap_gateway_retries.py index 0cd1d6e..b7d33f8 100644 --- a/backend-services/tests/test_soap_gateway_retries.py +++ b/backend-services/tests/test_soap_gateway_retries.py @@ -1,5 +1,6 @@ import pytest + class _Resp: def __init__(self, status_code=200, body='', headers=None): self.status_code = status_code @@ -10,16 +11,20 @@ class _Resp: self.headers = base self.content = (self.text or '').encode('utf-8') + def _mk_retry_xml_client(sequence, seen): counter = {'i': 0} class _Client: def __init__(self, timeout=None, limits=None, http2=False): pass + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): return False + async def request(self, method, url, **kwargs): """Generic request method used by http_client.request_with_resilience""" method = method.upper() @@ -37,32 +42,45 @@ def _mk_retry_xml_client(sequence, seen): return await self.put(url, **kwargs) else: return _Resp(405) + async def post(self, url, content=None, params=None, headers=None, **kwargs): - seen.append({'url': url, 'params': dict(params or {}), 'headers': dict(headers or {}), 'content': content}) + seen.append( + { + 'url': url, + 'params': dict(params or {}), + 'headers': dict(headers or {}), + 'content': content, + } + ) idx = min(counter['i'], len(sequence) - 1) code = sequence[idx] counter['i'] = counter['i'] + 1 return _Resp(code) + async def get(self, url, **kwargs): seen.append({'url': url, 'params': {}, 'headers': {}}) idx = min(counter['i'], len(sequence) - 1) code = sequence[idx] counter['i'] = counter['i'] + 1 return _Resp(code) + async def put(self, url, **kwargs): seen.append({'url': url, 'params': {}, 'headers': {}}) idx = min(counter['i'], len(sequence) - 1) code = sequence[idx] counter['i'] = counter['i'] + 1 return _Resp(code) + async def delete(self, url, **kwargs): seen.append({'url': url, 'params': {}, 'headers': {}}) idx = min(counter['i'], len(sequence) - 1) code = sequence[idx] counter['i'] = counter['i'] + 1 return _Resp(code) + return _Client + async def _setup_soap(client, name, ver, retry_count=0): payload = { 'api_name': name, @@ -76,79 +94,102 @@ async def _setup_soap(client, name, ver, retry_count=0): } r = await client.post('/platform/api', json=payload) assert r.status_code in (200, 201) - r2 = await client.post('/platform/endpoint', json={ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': 'POST', - 'endpoint_uri': '/call', - 'endpoint_description': 'soap call', - }) + r2 = await client.post( + '/platform/endpoint', + json={ + 'api_name': name, + 'api_version': ver, + 'endpoint_method': 'POST', + 'endpoint_uri': '/call', + 'endpoint_description': 'soap call', + }, + ) assert r2.status_code in (200, 201) from conftest import subscribe_self + await subscribe_self(client, name, ver) + @pytest.mark.asyncio async def test_soap_retry_on_500_then_success(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'soapretry500', 'v1' await _setup_soap(authed_client, name, ver, retry_count=2) seen = [] monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_retry_xml_client([500, 200], seen)) r = await authed_client.post( - f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'application/xml'}, content='' + f'/api/soap/{name}/{ver}/call', + headers={'Content-Type': 'application/xml'}, + content='', ) assert r.status_code == 200 assert len(seen) == 2 + @pytest.mark.asyncio async def test_soap_retry_on_502_then_success(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'soapretry502', 'v1' await _setup_soap(authed_client, name, ver, retry_count=2) seen = [] monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_retry_xml_client([502, 200], seen)) r = await authed_client.post( - f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'application/xml'}, content='' + f'/api/soap/{name}/{ver}/call', + headers={'Content-Type': 'application/xml'}, + content='', ) assert r.status_code == 200 assert len(seen) == 2 + @pytest.mark.asyncio async def test_soap_retry_on_503_then_success(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'soapretry503', 'v1' await _setup_soap(authed_client, name, ver, retry_count=2) seen = [] monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_retry_xml_client([503, 200], seen)) r = await authed_client.post( - f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'application/xml'}, content='' + f'/api/soap/{name}/{ver}/call', + headers={'Content-Type': 'application/xml'}, + content='', ) assert r.status_code == 200 assert len(seen) == 2 + @pytest.mark.asyncio async def test_soap_retry_on_504_then_success(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'soapretry504', 'v1' await _setup_soap(authed_client, name, ver, retry_count=2) seen = [] monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_retry_xml_client([504, 200], seen)) r = await authed_client.post( - f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'application/xml'}, content='' + f'/api/soap/{name}/{ver}/call', + headers={'Content-Type': 'application/xml'}, + content='', ) assert r.status_code == 200 assert len(seen) == 2 + @pytest.mark.asyncio async def test_soap_no_retry_when_retry_count_zero(monkeypatch, authed_client): import services.gateway_service as gs + name, ver = 'soapretry0', 'v1' await _setup_soap(authed_client, name, ver, retry_count=0) seen = [] monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_retry_xml_client([500, 200], seen)) r = await authed_client.post( - f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'application/xml'}, content='' + f'/api/soap/{name}/{ver}/call', + headers={'Content-Type': 'application/xml'}, + content='', ) assert r.status_code == 500 assert len(seen) == 1 - diff --git a/backend-services/tests/test_soap_validation_no_wsdl.py b/backend-services/tests/test_soap_validation_no_wsdl.py index 81a7035..65f9669 100644 --- a/backend-services/tests/test_soap_validation_no_wsdl.py +++ b/backend-services/tests/test_soap_validation_no_wsdl.py @@ -1,6 +1,7 @@ import pytest from fastapi import HTTPException + @pytest.mark.asyncio async def test_soap_structural_validation_passes_without_wsdl(): from utils.database import endpoint_validation_collection @@ -8,38 +9,32 @@ async def test_soap_structural_validation_passes_without_wsdl(): endpoint_id = 'soap-ep-struct-1' endpoint_validation_collection.delete_one({'endpoint_id': endpoint_id}) - endpoint_validation_collection.insert_one({ - 'endpoint_id': endpoint_id, - 'validation_enabled': True, - 'validation_schema': { - 'username': { - 'required': True, - 'type': 'string', - 'min': 3, - 'max': 50, - }, - 'email': { - 'required': True, - 'type': 'string', - 'format': 'email', + endpoint_validation_collection.insert_one( + { + 'endpoint_id': endpoint_id, + 'validation_enabled': True, + 'validation_schema': { + 'username': {'required': True, 'type': 'string', 'min': 3, 'max': 50}, + 'email': {'required': True, 'type': 'string', 'format': 'email'}, }, } - }) + ) envelope = ( "" "" - " " - " " - " alice" - " alice@example.com" - " " - " " - "" + ' ' + ' ' + ' alice' + ' alice@example.com' + ' ' + ' ' + '' ) await validation_util.validate_soap_request(endpoint_id, envelope) + @pytest.mark.asyncio async def test_soap_structural_validation_fails_without_wsdl(): from utils.database import endpoint_validation_collection @@ -47,30 +42,25 @@ async def test_soap_structural_validation_fails_without_wsdl(): endpoint_id = 'soap-ep-struct-2' endpoint_validation_collection.delete_one({'endpoint_id': endpoint_id}) - endpoint_validation_collection.insert_one({ - 'endpoint_id': endpoint_id, - 'validation_enabled': True, - 'validation_schema': { - 'username': { - 'required': True, - 'type': 'string', - 'min': 3, - } + endpoint_validation_collection.insert_one( + { + 'endpoint_id': endpoint_id, + 'validation_enabled': True, + 'validation_schema': {'username': {'required': True, 'type': 'string', 'min': 3}}, } - }) + ) bad_envelope = ( "" "" - " " - " " - " no-user@example.com" - " " - " " - "" + ' ' + ' ' + ' no-user@example.com' + ' ' + ' ' + '' ) with pytest.raises(HTTPException) as ex: await validation_util.validate_soap_request(endpoint_id, bad_envelope) assert ex.value.status_code == 400 - diff --git a/backend-services/tests/test_subscription_flows.py b/backend-services/tests/test_subscription_flows.py new file mode 100644 index 0000000..667526c --- /dev/null +++ b/backend-services/tests/test_subscription_flows.py @@ -0,0 +1,38 @@ +import time + +import pytest + + +@pytest.mark.asyncio +async def test_subscription_lifecycle(authed_client): + name, ver = f'subapi_{int(time.time())}', 'v1' + # Create API and endpoint + c = await authed_client.post( + '/platform/api', + json={ + 'api_name': name, + 'api_version': ver, + 'api_description': 'sub api', + 'api_allowed_roles': ['admin'], + 'api_allowed_groups': ['ALL'], + 'api_servers': ['http://up.invalid'], + 'api_type': 'REST', + 'active': True, + }, + ) + assert c.status_code in (200, 201) + # Subscribe current user (admin) + s = await authed_client.post( + '/platform/subscription/subscribe', + json={'api_name': name, 'api_version': ver, 'username': 'admin'}, + ) + assert s.status_code in (200, 201) + # List current user subscriptions + ls = await authed_client.get('/platform/subscription/subscriptions') + assert ls.status_code == 200 + # Unsubscribe + u = await authed_client.post( + '/platform/subscription/unsubscribe', + json={'api_name': name, 'api_version': ver, 'username': 'admin'}, + ) + assert u.status_code == 200 diff --git a/backend-services/tests/test_subscription_required_parsing.py b/backend-services/tests/test_subscription_required_parsing.py index abaab0d..99f14db 100644 --- a/backend-services/tests/test_subscription_required_parsing.py +++ b/backend-services/tests/test_subscription_required_parsing.py @@ -1,7 +1,9 @@ import pytest + def _make_request(path: str, headers: dict | None = None): from starlette.requests import Request + hdrs = [] for k, v in (headers or {}).items(): hdrs.append((k.lower().encode('latin-1'), str(v).encode('latin-1'))) @@ -17,65 +19,108 @@ def _make_request(path: str, headers: dict | None = None): } return Request(scope) + @pytest.mark.asyncio async def test_subscription_required_rest_path_parsing(monkeypatch): import utils.subscription_util as su async def fake_auth(req): return {'sub': 'alice'} + monkeypatch.setattr(su, 'auth_required', fake_auth) - monkeypatch.setattr(su.doorman_cache, 'get_cache', lambda name, key: {'apis': ['svc1/v1']} if (name, key) == ('user_subscription_cache', 'alice') else None) + monkeypatch.setattr( + su.doorman_cache, + 'get_cache', + lambda name, key: {'apis': ['svc1/v1']} + if (name, key) == ('user_subscription_cache', 'alice') + else None, + ) req = _make_request('/api/rest/svc1/v1/resource') payload = await su.subscription_required(req) assert payload.get('sub') == 'alice' + @pytest.mark.asyncio async def test_subscription_required_soap_path_parsing(monkeypatch): import utils.subscription_util as su + async def fake_auth(req): return {'sub': 'alice'} + monkeypatch.setattr(su, 'auth_required', fake_auth) - monkeypatch.setattr(su.doorman_cache, 'get_cache', lambda name, key: {'apis': ['svc2/v2']} if (name, key) == ('user_subscription_cache', 'alice') else None) + monkeypatch.setattr( + su.doorman_cache, + 'get_cache', + lambda name, key: {'apis': ['svc2/v2']} + if (name, key) == ('user_subscription_cache', 'alice') + else None, + ) req = _make_request('/api/soap/svc2/v2/do') payload = await su.subscription_required(req) assert payload.get('sub') == 'alice' + @pytest.mark.asyncio async def test_subscription_required_graphql_header_parsing(monkeypatch): import utils.subscription_util as su + async def fake_auth(req): return {'sub': 'alice'} + monkeypatch.setattr(su, 'auth_required', fake_auth) - monkeypatch.setattr(su.doorman_cache, 'get_cache', lambda name, key: {'apis': ['svc3/v3']} if (name, key) == ('user_subscription_cache', 'alice') else None) + monkeypatch.setattr( + su.doorman_cache, + 'get_cache', + lambda name, key: {'apis': ['svc3/v3']} + if (name, key) == ('user_subscription_cache', 'alice') + else None, + ) req = _make_request('/api/graphql/svc3', headers={'X-API-Version': 'v3'}) payload = await su.subscription_required(req) assert payload.get('sub') == 'alice' + @pytest.mark.asyncio async def test_subscription_required_grpc_path_parsing(monkeypatch): import utils.subscription_util as su + async def fake_auth(req): return {'sub': 'alice'} + monkeypatch.setattr(su, 'auth_required', fake_auth) - monkeypatch.setattr(su.doorman_cache, 'get_cache', lambda name, key: {'apis': ['svc4/v4']} if (name, key) == ('user_subscription_cache', 'alice') else None) + monkeypatch.setattr( + su.doorman_cache, + 'get_cache', + lambda name, key: {'apis': ['svc4/v4']} + if (name, key) == ('user_subscription_cache', 'alice') + else None, + ) req = _make_request('/api/grpc/svc4', headers={'X-API-Version': 'v4'}) payload = await su.subscription_required(req) assert payload.get('sub') == 'alice' + @pytest.mark.asyncio async def test_subscription_required_unknown_prefix_fallback(monkeypatch): import utils.subscription_util as su + async def fake_auth(req): return {'sub': 'alice'} + monkeypatch.setenv('PYTHONASYNCIODEBUG', '0') monkeypatch.setattr(su, 'auth_required', fake_auth) - monkeypatch.setattr(su.doorman_cache, 'get_cache', lambda name, key: {'apis': ['svc5/v5']} if (name, key) == ('user_subscription_cache', 'alice') else None) + monkeypatch.setattr( + su.doorman_cache, + 'get_cache', + lambda name, key: {'apis': ['svc5/v5']} + if (name, key) == ('user_subscription_cache', 'alice') + else None, + ) req = _make_request('/api/other/svc5/v5/op') payload = await su.subscription_required(req) assert payload.get('sub') == 'alice' - diff --git a/backend-services/tests/test_subscription_routes_extended.py b/backend-services/tests/test_subscription_routes_extended.py index 1a1b1f3..7fbc709 100644 --- a/backend-services/tests/test_subscription_routes_extended.py +++ b/backend-services/tests/test_subscription_routes_extended.py @@ -1,8 +1,8 @@ import pytest + @pytest.mark.asyncio async def test_subscriptions_happy_and_invalid_payload(authed_client): - c = await authed_client.post( '/platform/api', json={ @@ -35,4 +35,3 @@ async def test_subscriptions_happy_and_invalid_payload(authed_client): bad = await authed_client.post('/platform/subscription/subscribe', json={'username': 'admin'}) assert bad.status_code in (400, 422) - diff --git a/backend-services/tests/test_tools_chaos_toggles.py b/backend-services/tests/test_tools_chaos_toggles.py new file mode 100644 index 0000000..cd9a4a0 --- /dev/null +++ b/backend-services/tests/test_tools_chaos_toggles.py @@ -0,0 +1,34 @@ +import pytest + + +@pytest.mark.asyncio +async def test_tools_chaos_toggle_and_stats(authed_client): + # Baseline stats + st0 = await authed_client.get('/platform/tools/chaos/stats') + assert st0.status_code == 200 + + # Enable redis outage (response reflects enabled state) + r1 = await authed_client.post( + '/platform/tools/chaos/toggle', json={'backend': 'redis', 'enabled': True} + ) + assert r1.status_code == 200 + en = r1.json().get('response', r1.json()) + assert en.get('enabled') is True + + # Immediately disable to avoid auth failures on subsequent calls + # Disable using internal util to avoid auth during outage + from utils import chaos_util as _cu + + _cu.enable('redis', False) + + # Stats should reflect disabled + st2 = await authed_client.get('/platform/tools/chaos/stats') + assert st2.status_code == 200 + body2 = st2.json().get('response', st2.json()) + assert body2.get('redis_outage') is False + + # Invalid backend -> 400 + bad = await authed_client.post( + '/platform/tools/chaos/toggle', json={'backend': 'notabackend', 'enabled': True} + ) + assert bad.status_code == 400 diff --git a/backend-services/tests/test_tools_cors_checker.py b/backend-services/tests/test_tools_cors_checker.py index b9dee02..ac87d25 100644 --- a/backend-services/tests/test_tools_cors_checker.py +++ b/backend-services/tests/test_tools_cors_checker.py @@ -1,8 +1,8 @@ import pytest + @pytest.mark.asyncio async def test_cors_checker_allows_matching_origin(monkeypatch, authed_client): - monkeypatch.setenv('ALLOWED_ORIGINS', 'http://localhost:3000') monkeypatch.setenv('ALLOW_METHODS', 'GET,POST') monkeypatch.setenv('ALLOW_HEADERS', 'Content-Type,X-CSRF-Token') @@ -13,14 +13,18 @@ async def test_cors_checker_allows_matching_origin(monkeypatch, authed_client): 'origin': 'http://localhost:3000', 'method': 'GET', 'request_headers': ['Content-Type'], - 'with_credentials': True + 'with_credentials': True, } r = await authed_client.post('/platform/tools/cors/check', json=body) assert r.status_code == 200, r.text data = r.json() assert data.get('preflight', {}).get('allowed') is True assert data.get('actual', {}).get('allowed') is True - assert data.get('preflight', {}).get('response_headers', {}).get('Access-Control-Allow-Origin') == body['origin'] + assert ( + data.get('preflight', {}).get('response_headers', {}).get('Access-Control-Allow-Origin') + == body['origin'] + ) + @pytest.mark.asyncio async def test_cors_checker_denies_disallowed_header(monkeypatch, authed_client): @@ -34,7 +38,7 @@ async def test_cors_checker_denies_disallowed_header(monkeypatch, authed_client) 'origin': 'http://localhost:3000', 'method': 'GET', 'request_headers': ['X-Custom-Header'], - 'with_credentials': True + 'with_credentials': True, } r = await authed_client.post('/platform/tools/cors/check', json=body) assert r.status_code == 200, r.text diff --git a/backend-services/tests/test_tools_cors_checker_edge_cases.py b/backend-services/tests/test_tools_cors_checker_edge_cases.py index 5bfb5c8..e111260 100644 --- a/backend-services/tests/test_tools_cors_checker_edge_cases.py +++ b/backend-services/tests/test_tools_cors_checker_edge_cases.py @@ -1,8 +1,10 @@ import pytest + async def _allow_tools(client): await client.put('/platform/user/admin', json={'manage_security': True}) + @pytest.mark.asyncio async def test_tools_cors_checker_allows_when_method_and_headers_match(monkeypatch, authed_client): await _allow_tools(authed_client) @@ -10,12 +12,18 @@ async def test_tools_cors_checker_allows_when_method_and_headers_match(monkeypat monkeypatch.setenv('ALLOW_METHODS', 'GET,POST') monkeypatch.setenv('ALLOW_HEADERS', 'Content-Type,X-CSRF-Token') monkeypatch.setenv('ALLOW_CREDENTIALS', 'true') - body = {'origin': 'http://ok.example', 'method': 'GET', 'request_headers': ['X-CSRF-Token'], 'with_credentials': True} + body = { + 'origin': 'http://ok.example', + 'method': 'GET', + 'request_headers': ['X-CSRF-Token'], + 'with_credentials': True, + } r = await authed_client.post('/platform/tools/cors/check', json=body) assert r.status_code == 200 data = r.json() assert data.get('preflight', {}).get('allowed') is True + @pytest.mark.asyncio async def test_tools_cors_checker_denies_when_method_not_allowed(monkeypatch, authed_client): await _allow_tools(authed_client) @@ -28,6 +36,7 @@ async def test_tools_cors_checker_denies_when_method_not_allowed(monkeypatch, au assert data.get('preflight', {}).get('allowed') is False assert data.get('preflight', {}).get('method_allowed') is False + @pytest.mark.asyncio async def test_tools_cors_checker_denies_when_headers_not_allowed(monkeypatch, authed_client): await _allow_tools(authed_client) @@ -41,6 +50,7 @@ async def test_tools_cors_checker_denies_when_headers_not_allowed(monkeypatch, a assert data.get('preflight', {}).get('allowed') is False assert 'X-CSRF-Token' in (data.get('preflight', {}).get('not_allowed_headers') or []) + @pytest.mark.asyncio async def test_tools_cors_checker_credentials_and_wildcard_interaction(monkeypatch, authed_client): await _allow_tools(authed_client) @@ -53,4 +63,3 @@ async def test_tools_cors_checker_credentials_and_wildcard_interaction(monkeypat data = r.json() assert data.get('preflight', {}).get('allow_origin') is True assert any('Wildcard origins' in n for n in data.get('notes') or []) - diff --git a/backend-services/tests/test_tools_cors_checker_env_edges.py b/backend-services/tests/test_tools_cors_checker_env_edges.py new file mode 100644 index 0000000..a8ac49e --- /dev/null +++ b/backend-services/tests/test_tools_cors_checker_env_edges.py @@ -0,0 +1,51 @@ +import pytest + + +@pytest.mark.asyncio +async def test_tools_cors_checker_strict_with_credentials_blocks_wildcard( + monkeypatch, authed_client +): + monkeypatch.setenv('ALLOWED_ORIGINS', '*') + monkeypatch.setenv('ALLOW_CREDENTIALS', 'true') + monkeypatch.setenv('CORS_STRICT', 'true') + + r = await authed_client.post( + '/platform/tools/cors/check', + json={ + 'origin': 'http://evil.example', + 'with_credentials': True, + 'method': 'GET', + 'request_headers': ['Content-Type'], + }, + ) + assert r.status_code == 200 + body = r.json().get('response', r.json()) + pre = body.get('preflight', {}) + hdrs = pre.get('response_headers', {}) + assert hdrs.get('Access-Control-Allow-Origin') in (None, '') + assert hdrs.get('Access-Control-Allow-Credentials') == 'true' + + +@pytest.mark.asyncio +async def test_tools_cors_checker_non_strict_wildcard_with_credentials_allows( + monkeypatch, authed_client +): + monkeypatch.setenv('ALLOWED_ORIGINS', '*') + monkeypatch.setenv('ALLOW_CREDENTIALS', 'true') + monkeypatch.setenv('CORS_STRICT', 'false') + + r = await authed_client.post( + '/platform/tools/cors/check', + json={ + 'origin': 'http://ok.example', + 'with_credentials': True, + 'method': 'GET', + 'request_headers': ['Content-Type'], + }, + ) + assert r.status_code == 200 + body = r.json().get('response', r.json()) + pre = body.get('preflight', {}) + hdrs = pre.get('response_headers', {}) + assert hdrs.get('Access-Control-Allow-Origin') == 'http://ok.example' + assert hdrs.get('Access-Control-Allow-Credentials') == 'true' diff --git a/backend-services/tests/test_tools_cors_checker_permissions.py b/backend-services/tests/test_tools_cors_checker_permissions.py new file mode 100644 index 0000000..c4ef991 --- /dev/null +++ b/backend-services/tests/test_tools_cors_checker_permissions.py @@ -0,0 +1,66 @@ +import time + +import pytest +from httpx import AsyncClient + + +async def _login(email: str, password: str) -> AsyncClient: + from doorman import doorman + + c = AsyncClient(app=doorman, base_url='http://testserver') + r = await c.post('/platform/authorization', json={'email': email, 'password': password}) + assert r.status_code == 200, r.text + body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {} + token = body.get('access_token') + if token: + c.cookies.set('access_token_cookie', token, domain='testserver', path='/') + return c + + +@pytest.mark.asyncio +async def test_tools_cors_checker_requires_manage_security(authed_client): + uname = f'sec_check_{int(time.time())}' + pwd = 'SecCheckStrongPass!!' + # No manage_security + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': f'{uname}@example.com', + 'password': pwd, + 'role': 'user', + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu.status_code in (200, 201) + client = await _login(f'{uname}@example.com', pwd) + r = await client.post( + '/platform/tools/cors/check', json={'origin': 'http://x', 'method': 'GET'} + ) + assert r.status_code == 403 + + # Grant manage_security via role and new user + role = f'sec_mgr_{int(time.time())}' + cr = await authed_client.post( + '/platform/role', json={'role_name': role, 'manage_security': True} + ) + assert cr.status_code in (200, 201) + uname2 = f'sec_check2_{int(time.time())}' + cu2 = await authed_client.post( + '/platform/user', + json={ + 'username': uname2, + 'email': f'{uname2}@example.com', + 'password': pwd, + 'role': role, + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu2.status_code in (200, 201) + client2 = await _login(f'{uname2}@example.com', pwd) + r2 = await client2.post( + '/platform/tools/cors/check', json={'origin': 'http://x', 'method': 'GET'} + ) + assert r2.status_code == 200 diff --git a/backend-services/tests/test_user_endpoints.py b/backend-services/tests/test_user_endpoints.py index 99bea55..3781e5a 100644 --- a/backend-services/tests/test_user_endpoints.py +++ b/backend-services/tests/test_user_endpoints.py @@ -1,8 +1,8 @@ import pytest + @pytest.mark.asyncio async def test_user_me_and_crud(authed_client): - me = await authed_client.get('/platform/user/me') assert me.status_code == 200 assert me.json().get('username') == 'admin' @@ -25,8 +25,7 @@ async def test_user_me_and_crud(authed_client): assert uu.status_code == 200 up = await authed_client.put( - '/platform/user/testuser1/update-password', - json={'new_password': 'ThisIsANewPwd!456'}, + '/platform/user/testuser1/update-password', json={'new_password': 'ThisIsANewPwd!456'} ) assert up.status_code == 200 diff --git a/backend-services/tests/test_user_permissions_negative.py b/backend-services/tests/test_user_permissions_negative.py index 28fcb46..0118df1 100644 --- a/backend-services/tests/test_user_permissions_negative.py +++ b/backend-services/tests/test_user_permissions_negative.py @@ -1,32 +1,30 @@ import pytest + @pytest.mark.asyncio async def test_update_other_user_denied_without_permission(authed_client): - await authed_client.post( - '/platform/role', - json={'role_name': 'user', 'role_description': 'Standard user'}, + '/platform/role', json={'role_name': 'user', 'role_description': 'Standard user'} ) cu = await authed_client.post( '/platform/user', - json={'username': 'qa_user', 'email': 'qa@doorman.dev', 'password': 'QaPass123_ValidLen!!', 'role': 'user'}, + json={ + 'username': 'qa_user', + 'email': 'qa@doorman.dev', + 'password': 'QaPass123_ValidLen!!', + 'role': 'user', + }, ) assert cu.status_code in (200, 201), cu.text r = await authed_client.put('/platform/role/admin', json={'manage_users': False}) assert r.status_code in (200, 201) - up = await authed_client.put( - '/platform/user/qa_user', - json={'email': 'qa2@doorman.dev'}, - ) + up = await authed_client.put('/platform/user/qa_user', json={'email': 'qa2@doorman.dev'}) assert up.status_code == 403 r2 = await authed_client.put('/platform/role/admin', json={'manage_users': True}) assert r2.status_code in (200, 201) - up2 = await authed_client.put( - '/platform/user/qa_user', - json={'email': 'qa3@doorman.dev'}, - ) + up2 = await authed_client.put('/platform/user/qa_user', json={'email': 'qa3@doorman.dev'}) assert up2.status_code in (200, 201) diff --git a/backend-services/tests/test_validation_audit.py b/backend-services/tests/test_validation_audit.py index 68fb945..4d37c01 100644 --- a/backend-services/tests/test_validation_audit.py +++ b/backend-services/tests/test_validation_audit.py @@ -4,6 +4,7 @@ import pytest from utils.database import endpoint_collection, endpoint_validation_collection + def _mk_endpoint(api_name: str, api_version: str, method: str, uri: str) -> dict: eid = str(uuid.uuid4()) doc = { @@ -19,6 +20,7 @@ def _mk_endpoint(api_name: str, api_version: str, method: str, uri: str) -> dict endpoint_collection.insert_one(doc) return doc + def _run_audit() -> list[str]: failures: list[str] = [] for vdoc in endpoint_validation_collection.find({'validation_enabled': True}): @@ -29,59 +31,50 @@ def _run_audit() -> list[str]: continue schema = vdoc.get('validation_schema') if not isinstance(schema, dict) or not schema: - failures.append(f'Enabled validation missing schema for endpoint {ep.get("endpoint_method")} {ep.get("api_name")}/{ep.get("api_version")} {ep.get("endpoint_uri")} (id={eid})') + failures.append( + f'Enabled validation missing schema for endpoint {ep.get("endpoint_method")} {ep.get("api_name")}/{ep.get("api_version")} {ep.get("endpoint_uri")} (id={eid})' + ) return failures + @pytest.mark.asyncio async def test_validator_activation_audit_passes(): e_rest = _mk_endpoint('customers', 'v1', 'POST', '/create') e_graphql = _mk_endpoint('graphqlsvc', 'v1', 'POST', '/graphql') e_grpc = _mk_endpoint('grpcsvc', 'v1', 'POST', '/grpc') - e_soap = _mk_endpoint('soapsvc', 'v1', 'POST', '/soap') + _mk_endpoint('soapsvc', 'v1', 'POST', '/soap') - endpoint_validation_collection.insert_one({ - 'endpoint_id': e_rest['endpoint_id'], - 'validation_enabled': True, - 'validation_schema': { - 'payload.name': { - 'required': True, - 'type': 'string', - 'min': 1 - } + endpoint_validation_collection.insert_one( + { + 'endpoint_id': e_rest['endpoint_id'], + 'validation_enabled': True, + 'validation_schema': {'payload.name': {'required': True, 'type': 'string', 'min': 1}}, } - }) - endpoint_validation_collection.insert_one({ - 'endpoint_id': e_graphql['endpoint_id'], - 'validation_enabled': True, - 'validation_schema': { - 'input.query': { - 'required': True, - 'type': 'string', - 'min': 1 - } + ) + endpoint_validation_collection.insert_one( + { + 'endpoint_id': e_graphql['endpoint_id'], + 'validation_enabled': True, + 'validation_schema': {'input.query': {'required': True, 'type': 'string', 'min': 1}}, } - }) - endpoint_validation_collection.insert_one({ - 'endpoint_id': e_grpc['endpoint_id'], - 'validation_enabled': True, - 'validation_schema': { - 'message.name': { - 'required': True, - 'type': 'string', - 'min': 1 - } + ) + endpoint_validation_collection.insert_one( + { + 'endpoint_id': e_grpc['endpoint_id'], + 'validation_enabled': True, + 'validation_schema': {'message.name': {'required': True, 'type': 'string', 'min': 1}}, } - }) + ) failures = _run_audit() assert not failures, '\n'.join(failures) + @pytest.mark.asyncio async def test_validator_activation_audit_detects_missing_schema(): e = _mk_endpoint('soapsvc2', 'v1', 'POST', '/soap') - endpoint_validation_collection.insert_one({ - 'endpoint_id': e['endpoint_id'], - 'validation_enabled': True, - }) + endpoint_validation_collection.insert_one( + {'endpoint_id': e['endpoint_id'], 'validation_enabled': True} + ) failures = _run_audit() assert failures and any('missing schema' in f for f in failures) diff --git a/backend-services/tests/test_validation_nested_and_schema_errors.py b/backend-services/tests/test_validation_nested_and_schema_errors.py index 7b432bd..638a314 100644 --- a/backend-services/tests/test_validation_nested_and_schema_errors.py +++ b/backend-services/tests/test_validation_nested_and_schema_errors.py @@ -1,9 +1,10 @@ import pytest - from tests.test_gateway_routing_limits import _FakeAsyncClient + async def _setup(client, api='vtest', ver='v1', method='POST', uri='/data'): from conftest import create_api, create_endpoint, subscribe_self + await create_api(client, api, ver) await create_endpoint(client, api, ver, method, uri) await subscribe_self(client, api, ver) @@ -12,11 +13,13 @@ async def _setup(client, api='vtest', ver='v1', method='POST', uri='/data'): eid = g.json().get('endpoint_id') or g.json().get('response', {}).get('endpoint_id') return api, ver, uri, eid + async def _apply_schema(client, eid, schema_dict): payload = {'endpoint_id': eid, 'validation_enabled': True, 'validation_schema': schema_dict} r = await client.post('/platform/endpoint/endpoint/validation', json=payload) assert r.status_code in (200, 201, 400) + @pytest.mark.asyncio async def test_validation_nested_object_paths_valid(monkeypatch, authed_client): api, ver, uri, eid = await _setup(authed_client, api='vnest', uri='/ok1') @@ -25,105 +28,142 @@ async def test_validation_nested_object_paths_valid(monkeypatch, authed_client): 'user': { 'required': True, 'type': 'object', - 'nested_schema': { - 'name': {'required': True, 'type': 'string', 'min': 2} - } + 'nested_schema': {'name': {'required': True, 'type': 'string', 'min': 2}}, } } } await _apply_schema(authed_client, eid, schema) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'user': {'name': 'John'}}) assert r.status_code == 200 + @pytest.mark.asyncio async def test_validation_array_items_valid(monkeypatch, authed_client): api, ver, uri, eid = await _setup(authed_client, api='varr1', uri='/ok2') schema = { 'validation_schema': { - 'tags': {'required': True, 'type': 'array', 'min': 1, 'array_items': {'type': 'string', 'min': 2, 'required': True}} + 'tags': { + 'required': True, + 'type': 'array', + 'min': 1, + 'array_items': {'type': 'string', 'min': 2, 'required': True}, + } } } await _apply_schema(authed_client, eid, schema) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'tags': ['ab', 'cd']}) assert r.status_code == 200 + @pytest.mark.asyncio async def test_validation_array_items_invalid_type(monkeypatch, authed_client): api, ver, uri, eid = await _setup(authed_client, api='varr2', uri='/bad1') schema = { 'validation_schema': { - 'tags': {'required': True, 'type': 'array', 'array_items': {'type': 'string', 'required': True}} + 'tags': { + 'required': True, + 'type': 'array', + 'array_items': {'type': 'string', 'required': True}, + } } } await _apply_schema(authed_client, eid, schema) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'tags': [1, 2]}) assert r.status_code == 400 + @pytest.mark.asyncio async def test_validation_required_field_missing(monkeypatch, authed_client): api, ver, uri, eid = await _setup(authed_client, api='vreq', uri='/bad2') schema = {'validation_schema': {'profile.age': {'required': True, 'type': 'number'}}} await _apply_schema(authed_client, eid, schema) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'profile': {}}) assert r.status_code == 400 + @pytest.mark.asyncio async def test_validation_enum_restrictions(monkeypatch, authed_client): api, ver, uri, eid = await _setup(authed_client, api='venum', uri='/bad3') - schema = {'validation_schema': {'status': {'required': True, 'type': 'string', 'enum': ['NEW', 'OPEN']}}} + schema = { + 'validation_schema': { + 'status': {'required': True, 'type': 'string', 'enum': ['NEW', 'OPEN']} + } + } await _apply_schema(authed_client, eid, schema) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) bad = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'status': 'CLOSED'}) assert bad.status_code == 400 ok = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'status': 'OPEN'}) assert ok.status_code == 200 + @pytest.mark.asyncio async def test_validation_custom_validator_success(monkeypatch, authed_client): api, ver, uri, eid = await _setup(authed_client, api='vcust1', uri='/ok3') - from utils.validation_util import validation_util, ValidationError + from utils.validation_util import ValidationError, validation_util + def is_upper(value, vdef): if not isinstance(value, str) or not value.isupper(): raise ValidationError('Not upper', 'code') + validation_util.register_custom_validator('isUpper', is_upper) - schema = {'validation_schema': {'code': {'required': True, 'type': 'string', 'custom_validator': 'isUpper'}}} + schema = { + 'validation_schema': { + 'code': {'required': True, 'type': 'string', 'custom_validator': 'isUpper'} + } + } await _apply_schema(authed_client, eid, schema) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'code': 'ABC'}) assert r.status_code == 200 + @pytest.mark.asyncio async def test_validation_custom_validator_failure(monkeypatch, authed_client): api, ver, uri, eid = await _setup(authed_client, api='vcust2', uri='/bad4') - from utils.validation_util import validation_util, ValidationError + from utils.validation_util import ValidationError, validation_util + def is_upper(value, vdef): if not isinstance(value, str) or not value.isupper(): raise ValidationError('Not upper', 'code') + validation_util.register_custom_validator('isUpper2', is_upper) - schema = {'validation_schema': {'code': {'required': True, 'type': 'string', 'custom_validator': 'isUpper2'}}} + schema = { + 'validation_schema': { + 'code': {'required': True, 'type': 'string', 'custom_validator': 'isUpper2'} + } + } await _apply_schema(authed_client, eid, schema) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'code': 'Abc'}) assert r.status_code == 400 + @pytest.mark.asyncio async def test_validation_invalid_field_path_raises_schema_error(monkeypatch, authed_client): api, ver, uri, eid = await _setup(authed_client, api='vbadpath', uri='/bad5') schema = {'validation_schema': {'user..name': {'required': True, 'type': 'string'}}} await _apply_schema(authed_client, eid, schema) import services.gateway_service as gs + monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient) r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'user': {'name': 'ok'}}) assert r.status_code == 400 - diff --git a/backend-services/tests/test_vault_routes_permissions.py b/backend-services/tests/test_vault_routes_permissions.py new file mode 100644 index 0000000..008cd99 --- /dev/null +++ b/backend-services/tests/test_vault_routes_permissions.py @@ -0,0 +1,36 @@ +import time + +import pytest + + +@pytest.mark.asyncio +async def test_vault_routes_permissions(authed_client): + # Limited user + uname = f'vault_limited_{int(time.time())}' + pwd = 'VaultLimitStrongPass1!!' + cu = await authed_client.post( + '/platform/user', + json={ + 'username': uname, + 'email': f'{uname}@example.com', + 'password': pwd, + 'role': 'user', + 'groups': ['ALL'], + 'ui_access': True, + }, + ) + assert cu.status_code in (200, 201) + + from httpx import AsyncClient + + from doorman import doorman + + limited = AsyncClient(app=doorman, base_url='http://testserver') + r = await limited.post( + '/platform/authorization', json={'email': f'{uname}@example.com', 'password': pwd} + ) + assert r.status_code == 200 + # Vault operations require manage_security + sk = await limited.post('/platform/vault', json={'key_name': 'k', 'value': 'v'}) + # Without manage_security and possibly without VAULT_KEY, expect 403 or 400/404/500 depending on env + assert sk.status_code in (403, 400, 404, 500) diff --git a/backend-services/utils/analytics_aggregator.py b/backend-services/utils/analytics_aggregator.py index 5e9b036..7d31883 100644 --- a/backend-services/utils/analytics_aggregator.py +++ b/backend-services/utils/analytics_aggregator.py @@ -6,144 +6,141 @@ for efficient historical queries without scanning all minute-level data. """ from __future__ import annotations + import time from collections import defaultdict, deque -from typing import Dict, List, Optional, Deque -from dataclasses import dataclass from models.analytics_models import ( - AggregationLevel, AggregatedMetrics, + AggregationLevel, + EnhancedMinuteBucket, PercentileMetrics, - EnhancedMinuteBucket ) class AnalyticsAggregator: """ Multi-level time-series aggregator. - + Aggregates minute-level buckets into: - 5-minute buckets (for 24-hour views) - Hourly buckets (for 7-day views) - Daily buckets (for 30-day+ views) - + Retention policy: - Minute-level: 24 hours - 5-minute level: 7 days - Hourly level: 30 days - Daily level: 90 days """ - + def __init__(self): - self.five_minute_buckets: Deque[AggregatedMetrics] = deque(maxlen=2016) # 7 days * 288 buckets/day - self.hourly_buckets: Deque[AggregatedMetrics] = deque(maxlen=720) # 30 days * 24 hours - self.daily_buckets: Deque[AggregatedMetrics] = deque(maxlen=90) # 90 days - + self.five_minute_buckets: deque[AggregatedMetrics] = deque( + maxlen=2016 + ) # 7 days * 288 buckets/day + self.hourly_buckets: deque[AggregatedMetrics] = deque(maxlen=720) # 30 days * 24 hours + self.daily_buckets: deque[AggregatedMetrics] = deque(maxlen=90) # 90 days + self._last_5min_aggregation = 0 self._last_hourly_aggregation = 0 self._last_daily_aggregation = 0 - + @staticmethod def _floor_timestamp(ts: int, seconds: int) -> int: """Floor timestamp to nearest interval.""" return (ts // seconds) * seconds - - def aggregate_to_5minute(self, minute_buckets: List[EnhancedMinuteBucket]) -> None: + + def aggregate_to_5minute(self, minute_buckets: list[EnhancedMinuteBucket]) -> None: """ Aggregate minute-level buckets into 5-minute buckets. - + Should be called every 5 minutes with the last 5 minutes of data. """ if not minute_buckets: return - + # Group by 5-minute intervals - five_min_groups: Dict[int, List[EnhancedMinuteBucket]] = defaultdict(list) + five_min_groups: dict[int, list[EnhancedMinuteBucket]] = defaultdict(list) for bucket in minute_buckets: five_min_start = self._floor_timestamp(bucket.start_ts, 300) # 300 seconds = 5 minutes five_min_groups[five_min_start].append(bucket) - + # Create aggregated buckets for five_min_start, buckets in five_min_groups.items(): agg = self._aggregate_buckets( buckets, start_ts=five_min_start, end_ts=five_min_start + 300, - level=AggregationLevel.FIVE_MINUTE + level=AggregationLevel.FIVE_MINUTE, ) self.five_minute_buckets.append(agg) - + self._last_5min_aggregation = int(time.time()) - - def aggregate_to_hourly(self, five_minute_buckets: Optional[List[AggregatedMetrics]] = None) -> None: + + def aggregate_to_hourly( + self, five_minute_buckets: list[AggregatedMetrics] | None = None + ) -> None: """ Aggregate 5-minute buckets into hourly buckets. - + Should be called every hour. """ # Use provided buckets or last 12 from deque (1 hour = 12 * 5-minute buckets) if five_minute_buckets is None: five_minute_buckets = list(self.five_minute_buckets)[-12:] - + if not five_minute_buckets: return - + # Group by hour - hourly_groups: Dict[int, List[AggregatedMetrics]] = defaultdict(list) + hourly_groups: dict[int, list[AggregatedMetrics]] = defaultdict(list) for bucket in five_minute_buckets: hour_start = self._floor_timestamp(bucket.start_ts, 3600) # 3600 seconds = 1 hour hourly_groups[hour_start].append(bucket) - + # Create aggregated buckets for hour_start, buckets in hourly_groups.items(): agg = self._aggregate_aggregated_buckets( - buckets, - start_ts=hour_start, - end_ts=hour_start + 3600, - level=AggregationLevel.HOUR + buckets, start_ts=hour_start, end_ts=hour_start + 3600, level=AggregationLevel.HOUR ) self.hourly_buckets.append(agg) - + self._last_hourly_aggregation = int(time.time()) - - def aggregate_to_daily(self, hourly_buckets: Optional[List[AggregatedMetrics]] = None) -> None: + + def aggregate_to_daily(self, hourly_buckets: list[AggregatedMetrics] | None = None) -> None: """ Aggregate hourly buckets into daily buckets. - + Should be called once per day. """ # Use provided buckets or last 24 from deque (1 day = 24 hourly buckets) if hourly_buckets is None: hourly_buckets = list(self.hourly_buckets)[-24:] - + if not hourly_buckets: return - + # Group by day - daily_groups: Dict[int, List[AggregatedMetrics]] = defaultdict(list) + daily_groups: dict[int, list[AggregatedMetrics]] = defaultdict(list) for bucket in hourly_buckets: day_start = self._floor_timestamp(bucket.start_ts, 86400) # 86400 seconds = 1 day daily_groups[day_start].append(bucket) - + # Create aggregated buckets for day_start, buckets in daily_groups.items(): agg = self._aggregate_aggregated_buckets( - buckets, - start_ts=day_start, - end_ts=day_start + 86400, - level=AggregationLevel.DAY + buckets, start_ts=day_start, end_ts=day_start + 86400, level=AggregationLevel.DAY ) self.daily_buckets.append(agg) - + self._last_daily_aggregation = int(time.time()) - + def _aggregate_buckets( self, - buckets: List[EnhancedMinuteBucket], + buckets: list[EnhancedMinuteBucket], start_ts: int, end_ts: int, - level: AggregationLevel + level: AggregationLevel, ) -> AggregatedMetrics: """Aggregate minute-level buckets into a single aggregated bucket.""" total_count = sum(b.count for b in buckets) @@ -151,31 +148,31 @@ class AnalyticsAggregator: total_ms = sum(b.total_ms for b in buckets) total_bytes_in = sum(b.bytes_in for b in buckets) total_bytes_out = sum(b.bytes_out for b in buckets) - + # Merge status counts - status_counts: Dict[int, int] = defaultdict(int) + status_counts: dict[int, int] = defaultdict(int) for bucket in buckets: for status, count in bucket.status_counts.items(): status_counts[status] += count - + # Merge API counts - api_counts: Dict[str, int] = defaultdict(int) + api_counts: dict[str, int] = defaultdict(int) for bucket in buckets: for api, count in bucket.api_counts.items(): api_counts[api] += count - + # Collect all latencies for percentile calculation - all_latencies: List[float] = [] + all_latencies: list[float] = [] for bucket in buckets: all_latencies.extend(list(bucket.latencies)) - + percentiles = PercentileMetrics.calculate(all_latencies) if all_latencies else None - + # Count unique users across all buckets unique_users = set() for bucket in buckets: unique_users.update(bucket.unique_users) - + return AggregatedMetrics( start_ts=start_ts, end_ts=end_ts, @@ -188,15 +185,11 @@ class AnalyticsAggregator: unique_users=len(unique_users), status_counts=dict(status_counts), api_counts=dict(api_counts), - percentiles=percentiles + percentiles=percentiles, ) - + def _aggregate_aggregated_buckets( - self, - buckets: List[AggregatedMetrics], - start_ts: int, - end_ts: int, - level: AggregationLevel + self, buckets: list[AggregatedMetrics], start_ts: int, end_ts: int, level: AggregationLevel ) -> AggregatedMetrics: """Aggregate already-aggregated buckets (5-min → hourly, hourly → daily).""" total_count = sum(b.count for b in buckets) @@ -204,26 +197,26 @@ class AnalyticsAggregator: total_ms = sum(b.total_ms for b in buckets) total_bytes_in = sum(b.bytes_in for b in buckets) total_bytes_out = sum(b.bytes_out for b in buckets) - + # Merge status counts - status_counts: Dict[int, int] = defaultdict(int) + status_counts: dict[int, int] = defaultdict(int) for bucket in buckets: for status, count in bucket.status_counts.items(): status_counts[status] += count - + # Merge API counts - api_counts: Dict[str, int] = defaultdict(int) + api_counts: dict[str, int] = defaultdict(int) for bucket in buckets: for api, count in bucket.api_counts.items(): api_counts[api] += count - + # For percentiles, we'll use weighted average (not perfect, but acceptable) # Ideally, we'd re-calculate from raw latencies, but those aren't stored in aggregated buckets weighted_percentiles = self._weighted_average_percentiles(buckets) - + # Unique users: sum (may overcount, but acceptable for aggregated data) unique_users = sum(b.unique_users for b in buckets) - + return AggregatedMetrics( start_ts=start_ts, end_ts=end_ts, @@ -236,28 +229,30 @@ class AnalyticsAggregator: unique_users=unique_users, status_counts=dict(status_counts), api_counts=dict(api_counts), - percentiles=weighted_percentiles + percentiles=weighted_percentiles, ) - - def _weighted_average_percentiles(self, buckets: List[AggregatedMetrics]) -> Optional[PercentileMetrics]: + + def _weighted_average_percentiles( + self, buckets: list[AggregatedMetrics] + ) -> PercentileMetrics | None: """Calculate weighted average of percentiles from aggregated buckets.""" if not buckets: return None - + total_count = sum(b.count for b in buckets if b.count > 0) if total_count == 0: return None - + # Weighted average of each percentile p50_sum = sum(b.percentiles.p50 * b.count for b in buckets if b.percentiles and b.count > 0) p75_sum = sum(b.percentiles.p75 * b.count for b in buckets if b.percentiles and b.count > 0) p90_sum = sum(b.percentiles.p90 * b.count for b in buckets if b.percentiles and b.count > 0) p95_sum = sum(b.percentiles.p95 * b.count for b in buckets if b.percentiles and b.count > 0) p99_sum = sum(b.percentiles.p99 * b.count for b in buckets if b.percentiles and b.count > 0) - + min_val = min(b.percentiles.min for b in buckets if b.percentiles) max_val = max(b.percentiles.max for b in buckets if b.percentiles) - + return PercentileMetrics( p50=p50_sum / total_count, p75=p75_sum / total_count, @@ -265,25 +260,22 @@ class AnalyticsAggregator: p95=p95_sum / total_count, p99=p99_sum / total_count, min=min_val, - max=max_val + max=max_val, ) - + def get_buckets_for_range( - self, - start_ts: int, - end_ts: int, - preferred_level: Optional[AggregationLevel] = None - ) -> List[AggregatedMetrics]: + self, start_ts: int, end_ts: int, preferred_level: AggregationLevel | None = None + ) -> list[AggregatedMetrics]: """ Get the most appropriate aggregation level for a time range. - + Automatically selects: - 5-minute buckets for ranges < 24 hours - Hourly buckets for ranges < 7 days - Daily buckets for ranges >= 7 days """ range_seconds = end_ts - start_ts - + # Determine best aggregation level if preferred_level: level = preferred_level @@ -293,7 +285,7 @@ class AnalyticsAggregator: level = AggregationLevel.HOUR else: level = AggregationLevel.DAY - + # Select appropriate bucket collection if level == AggregationLevel.FIVE_MINUTE: buckets = list(self.five_minute_buckets) @@ -301,21 +293,21 @@ class AnalyticsAggregator: buckets = list(self.hourly_buckets) else: buckets = list(self.daily_buckets) - + # Filter to time range return [b for b in buckets if b.start_ts >= start_ts and b.end_ts <= end_ts] - - def should_aggregate(self) -> Dict[str, bool]: + + def should_aggregate(self) -> dict[str, bool]: """Check if any aggregation jobs should run.""" now = int(time.time()) - + return { '5minute': (now - self._last_5min_aggregation) >= 300, # Every 5 minutes 'hourly': (now - self._last_hourly_aggregation) >= 3600, # Every hour - 'daily': (now - self._last_daily_aggregation) >= 86400 # Every day + 'daily': (now - self._last_daily_aggregation) >= 86400, # Every day } - - def to_dict(self) -> Dict: + + def to_dict(self) -> dict: """Serialize aggregator state for persistence.""" return { 'five_minute_buckets': [b.to_dict() for b in self.five_minute_buckets], @@ -323,10 +315,10 @@ class AnalyticsAggregator: 'daily_buckets': [b.to_dict() for b in self.daily_buckets], 'last_5min_aggregation': self._last_5min_aggregation, 'last_hourly_aggregation': self._last_hourly_aggregation, - 'last_daily_aggregation': self._last_daily_aggregation + 'last_daily_aggregation': self._last_daily_aggregation, } - - def load_dict(self, data: Dict) -> None: + + def load_dict(self, data: dict) -> None: """Load aggregator state from persistence.""" # Note: This is a simplified version. Full implementation would # reconstruct AggregatedMetrics objects from dictionaries diff --git a/backend-services/utils/analytics_scheduler.py b/backend-services/utils/analytics_scheduler.py index d5d89f9..2a4d7d1 100644 --- a/backend-services/utils/analytics_scheduler.py +++ b/backend-services/utils/analytics_scheduler.py @@ -12,10 +12,9 @@ Runs periodic tasks to: import asyncio import logging import time -from typing import Optional -from utils.enhanced_metrics_util import enhanced_metrics_store from utils.analytics_aggregator import analytics_aggregator +from utils.enhanced_metrics_util import enhanced_metrics_store logger = logging.getLogger('doorman.analytics') @@ -23,7 +22,7 @@ logger = logging.getLogger('doorman.analytics') class AnalyticsScheduler: """ Background task scheduler for analytics aggregation. - + Runs aggregation jobs at appropriate intervals: - 5-minute aggregation: Every 5 minutes - Hourly aggregation: Every hour @@ -31,31 +30,31 @@ class AnalyticsScheduler: - Persistence: Every 5 minutes - Cleanup: Once per day """ - + def __init__(self): self.running = False - self._task: Optional[asyncio.Task] = None + self._task: asyncio.Task | None = None self._last_5min = 0 self._last_hourly = 0 self._last_daily = 0 self._last_persist = 0 self._last_cleanup = 0 - + async def start(self): """Start the scheduler.""" if self.running: - logger.warning("Analytics scheduler already running") + logger.warning('Analytics scheduler already running') return - + self.running = True self._task = asyncio.create_task(self._run_loop()) - logger.info("Analytics scheduler started") - + logger.info('Analytics scheduler started') + async def stop(self): """Stop the scheduler.""" if not self.running: return - + self.running = False if self._task: self._task.cancel() @@ -63,9 +62,9 @@ class AnalyticsScheduler: await self._task except asyncio.CancelledError: pass - - logger.info("Analytics scheduler stopped") - + + logger.info('Analytics scheduler stopped') + async def _run_loop(self): """Main scheduler loop.""" while self.running: @@ -76,136 +75,139 @@ class AnalyticsScheduler: except asyncio.CancelledError: break except Exception as e: - logger.error(f"Error in analytics scheduler: {str(e)}", exc_info=True) + logger.error(f'Error in analytics scheduler: {str(e)}', exc_info=True) await asyncio.sleep(60) - + async def _check_and_run_jobs(self): """Check if any jobs should run and execute them.""" now = int(time.time()) - + # 5-minute aggregation (every 5 minutes) if now - self._last_5min >= 300: await self._run_5minute_aggregation() self._last_5min = now - + # Hourly aggregation (every hour) if now - self._last_hourly >= 3600: await self._run_hourly_aggregation() self._last_hourly = now - + # Daily aggregation (once per day) if now - self._last_daily >= 86400: await self._run_daily_aggregation() self._last_daily = now - + # Persist metrics (every 5 minutes) if now - self._last_persist >= 300: await self._persist_metrics() self._last_persist = now - + # Cleanup old data (once per day) if now - self._last_cleanup >= 86400: await self._cleanup_old_data() self._last_cleanup = now - + async def _run_5minute_aggregation(self): """Aggregate last 5 minutes of data into 5-minute buckets.""" try: - logger.info("Running 5-minute aggregation") + logger.info('Running 5-minute aggregation') start_time = time.time() - + # Get last 5 minutes of buckets minute_buckets = list(enhanced_metrics_store._buckets)[-5:] - + if minute_buckets: analytics_aggregator.aggregate_to_5minute(minute_buckets) - + duration_ms = (time.time() - start_time) * 1000 - logger.info(f"5-minute aggregation completed in {duration_ms:.2f}ms") + logger.info(f'5-minute aggregation completed in {duration_ms:.2f}ms') else: - logger.debug("No buckets to aggregate (5-minute)") - + logger.debug('No buckets to aggregate (5-minute)') + except Exception as e: - logger.error(f"Failed to run 5-minute aggregation: {str(e)}", exc_info=True) - + logger.error(f'Failed to run 5-minute aggregation: {str(e)}', exc_info=True) + async def _run_hourly_aggregation(self): """Aggregate last hour of 5-minute data into hourly buckets.""" try: - logger.info("Running hourly aggregation") + logger.info('Running hourly aggregation') start_time = time.time() - + analytics_aggregator.aggregate_to_hourly() - + duration_ms = (time.time() - start_time) * 1000 - logger.info(f"Hourly aggregation completed in {duration_ms:.2f}ms") - + logger.info(f'Hourly aggregation completed in {duration_ms:.2f}ms') + except Exception as e: - logger.error(f"Failed to run hourly aggregation: {str(e)}", exc_info=True) - + logger.error(f'Failed to run hourly aggregation: {str(e)}', exc_info=True) + async def _run_daily_aggregation(self): """Aggregate last day of hourly data into daily buckets.""" try: - logger.info("Running daily aggregation") + logger.info('Running daily aggregation') start_time = time.time() - + analytics_aggregator.aggregate_to_daily() - + duration_ms = (time.time() - start_time) * 1000 - logger.info(f"Daily aggregation completed in {duration_ms:.2f}ms") - + logger.info(f'Daily aggregation completed in {duration_ms:.2f}ms') + except Exception as e: - logger.error(f"Failed to run daily aggregation: {str(e)}", exc_info=True) - + logger.error(f'Failed to run daily aggregation: {str(e)}', exc_info=True) + async def _persist_metrics(self): """Save metrics to disk for persistence.""" try: - logger.debug("Persisting metrics to disk") + logger.debug('Persisting metrics to disk') start_time = time.time() - + # Save minute-level metrics enhanced_metrics_store.save_to_file('platform-logs/enhanced_metrics.json') - + # Save aggregated metrics import json import os - + aggregated_data = analytics_aggregator.to_dict() path = 'platform-logs/aggregated_metrics.json' - + os.makedirs(os.path.dirname(path), exist_ok=True) tmp = path + '.tmp' with open(tmp, 'w', encoding='utf-8') as f: json.dump(aggregated_data, f) os.replace(tmp, path) - + duration_ms = (time.time() - start_time) * 1000 - logger.debug(f"Metrics persisted in {duration_ms:.2f}ms") - + logger.debug(f'Metrics persisted in {duration_ms:.2f}ms') + except Exception as e: - logger.error(f"Failed to persist metrics: {str(e)}", exc_info=True) - + logger.error(f'Failed to persist metrics: {str(e)}', exc_info=True) + async def _cleanup_old_data(self): """Remove data beyond retention policy.""" try: - logger.info("Running data cleanup") + logger.info('Running data cleanup') start_time = time.time() - + now = int(time.time()) - + # Minute-level: Keep only last 24 hours cutoff_minute = now - (24 * 3600) - while enhanced_metrics_store._buckets and enhanced_metrics_store._buckets[0].start_ts < cutoff_minute: + while ( + enhanced_metrics_store._buckets + and enhanced_metrics_store._buckets[0].start_ts < cutoff_minute + ): enhanced_metrics_store._buckets.popleft() - + # 5-minute: Keep only last 7 days (handled by deque maxlen) # Hourly: Keep only last 30 days (handled by deque maxlen) # Daily: Keep only last 90 days (handled by deque maxlen) - + duration_ms = (time.time() - start_time) * 1000 - logger.info(f"Data cleanup completed in {duration_ms:.2f}ms") - + logger.info(f'Data cleanup completed in {duration_ms:.2f}ms') + except Exception as e: - logger.error(f"Failed to cleanup old data: {str(e)}", exc_info=True) + logger.error(f'Failed to cleanup old data: {str(e)}', exc_info=True) # Global scheduler instance diff --git a/backend-services/utils/api_resolution_util.py b/backend-services/utils/api_resolution_util.py index 562852d..abe31be 100644 --- a/backend-services/utils/api_resolution_util.py +++ b/backend-services/utils/api_resolution_util.py @@ -5,12 +5,14 @@ Reduces duplicate code for GraphQL/gRPC API name/version parsing. """ import re -from typing import Tuple, Optional -from fastapi import Request, HTTPException -from utils.doorman_cache_util import doorman_cache -from utils import api_util -def parse_graphql_grpc_path(path: str, request: Request) -> Tuple[str, str, str]: +from fastapi import HTTPException, Request + +from utils import api_util +from utils.doorman_cache_util import doorman_cache + + +def parse_graphql_grpc_path(path: str, request: Request) -> tuple[str, str, str]: """Parse GraphQL/gRPC path to extract API name and version. Args: @@ -38,7 +40,8 @@ def parse_graphql_grpc_path(path: str, request: Request) -> Tuple[str, str, str] return api_name, api_version, api_path -async def resolve_api(api_name: str, api_version: str) -> Optional[dict]: + +async def resolve_api(api_name: str, api_version: str) -> dict | None: """Resolve API from cache or database. Args: @@ -52,7 +55,10 @@ async def resolve_api(api_name: str, api_version: str) -> Optional[dict]: api_key = doorman_cache.get_cache('api_id_cache', api_path) return await api_util.get_api(api_key, api_path) -async def resolve_api_from_request(path: str, request: Request) -> Tuple[Optional[dict], str, str, str]: + +async def resolve_api_from_request( + path: str, request: Request +) -> tuple[dict | None, str, str, str]: """Parse path, extract API name/version, and resolve API in one call. Args: diff --git a/backend-services/utils/api_util.py b/backend-services/utils/api_util.py index aab2c7a..5f55f25 100644 --- a/backend-services/utils/api_util.py +++ b/backend-services/utils/api_util.py @@ -1,10 +1,9 @@ -from typing import Optional, Dict - -from utils.doorman_cache_util import doorman_cache +from utils.async_db import db_find_list, db_find_one from utils.database_async import api_collection, endpoint_collection -from utils.async_db import db_find_one, db_find_list +from utils.doorman_cache_util import doorman_cache -async def get_api(api_key: Optional[str], api_name_version: str) -> Optional[Dict]: + +async def get_api(api_key: str | None, api_name_version: str) -> dict | None: """Get API document by key or name/version. Args: @@ -25,7 +24,8 @@ async def get_api(api_key: Optional[str], api_name_version: str) -> Optional[Dic doorman_cache.set_cache('api_id_cache', api_name_version, api_key) return api -async def get_api_endpoints(api_id: str) -> Optional[list]: + +async def get_api_endpoints(api_id: str) -> list | None: """Get list of endpoints for an API. Args: @@ -40,13 +40,14 @@ async def get_api_endpoints(api_id: str) -> Optional[list]: if not endpoints_list: return None endpoints = [ - f"{endpoint.get('endpoint_method')}{endpoint.get('endpoint_uri')}" + f'{endpoint.get("endpoint_method")}{endpoint.get("endpoint_uri")}' for endpoint in endpoints_list ] doorman_cache.set_cache('api_endpoint_cache', api_id, endpoints) return endpoints -async def get_endpoint(api: Dict, method: str, endpoint_uri: str) -> Optional[Dict]: + +async def get_endpoint(api: dict, method: str, endpoint_uri: str) -> dict | None: """Return the endpoint document for a given API, method, and uri. Uses the same cache key pattern as EndpointService to avoid duplicate queries. @@ -57,12 +58,15 @@ async def get_endpoint(api: Dict, method: str, endpoint_uri: str) -> Optional[Di endpoint = doorman_cache.get_cache('endpoint_cache', cache_key) if endpoint: return endpoint - doc = await db_find_one(endpoint_collection, { - 'api_name': api_name, - 'api_version': api_version, - 'endpoint_uri': endpoint_uri, - 'endpoint_method': method - }) + doc = await db_find_one( + endpoint_collection, + { + 'api_name': api_name, + 'api_version': api_version, + 'endpoint_uri': endpoint_uri, + 'endpoint_method': method, + }, + ) if not doc: return None doc.pop('_id', None) diff --git a/backend-services/utils/async_db.py b/backend-services/utils/async_db.py index b4f40f0..835cb3a 100644 --- a/backend-services/utils/async_db.py +++ b/backend-services/utils/async_db.py @@ -9,34 +9,39 @@ from __future__ import annotations import asyncio import inspect -from typing import Any, Dict, List, Optional +from typing import Any -async def db_find_one(collection: Any, query: Dict[str, Any]) -> Optional[Dict[str, Any]]: - fn = getattr(collection, 'find_one') + +async def db_find_one(collection: Any, query: dict[str, Any]) -> dict[str, Any] | None: + fn = collection.find_one if inspect.iscoroutinefunction(fn): return await fn(query) return await asyncio.to_thread(fn, query) -async def db_insert_one(collection: Any, doc: Dict[str, Any]) -> Any: - fn = getattr(collection, 'insert_one') + +async def db_insert_one(collection: Any, doc: dict[str, Any]) -> Any: + fn = collection.insert_one if inspect.iscoroutinefunction(fn): return await fn(doc) return await asyncio.to_thread(fn, doc) -async def db_update_one(collection: Any, query: Dict[str, Any], update: Dict[str, Any]) -> Any: - fn = getattr(collection, 'update_one') + +async def db_update_one(collection: Any, query: dict[str, Any], update: dict[str, Any]) -> Any: + fn = collection.update_one if inspect.iscoroutinefunction(fn): return await fn(query, update) return await asyncio.to_thread(fn, query, update) -async def db_delete_one(collection: Any, query: Dict[str, Any]) -> Any: - fn = getattr(collection, 'delete_one') + +async def db_delete_one(collection: Any, query: dict[str, Any]) -> Any: + fn = collection.delete_one if inspect.iscoroutinefunction(fn): return await fn(query) return await asyncio.to_thread(fn, query) -async def db_find_list(collection: Any, query: Dict[str, Any]) -> List[Dict[str, Any]]: - find = getattr(collection, 'find') + +async def db_find_list(collection: Any, query: dict[str, Any]) -> list[dict[str, Any]]: + find = collection.find cursor = find(query) to_list = getattr(cursor, 'to_list', None) if callable(to_list): @@ -44,4 +49,3 @@ async def db_find_list(collection: Any, query: Dict[str, Any]) -> List[Dict[str, return await to_list(length=None) return await asyncio.to_thread(to_list, None) return await asyncio.to_thread(lambda: list(cursor)) - diff --git a/backend-services/utils/audit_util.py b/backend-services/utils/audit_util.py index be1155b..b721bf8 100644 --- a/backend-services/utils/audit_util.py +++ b/backend-services/utils/audit_util.py @@ -1,38 +1,72 @@ -import logging import json +import logging import re _logger = logging.getLogger('doorman.audit') SENSITIVE_KEYS = { - 'password', 'passwd', 'pwd', - 'token', 'access_token', 'refresh_token', 'bearer_token', 'auth_token', - 'authorization', 'auth', 'bearer', - - 'api_key', 'apikey', 'api-key', - 'user_api_key', 'user-api-key', - 'secret', 'client_secret', 'client-secret', 'api_secret', 'api-secret', - 'private_key', 'private-key', 'privatekey', - - 'session', 'session_id', 'session-id', 'sessionid', - 'csrf_token', 'csrf-token', 'csrftoken', - 'x-csrf-token', 'xsrf_token', 'xsrf-token', - - 'cookie', 'set-cookie', 'set_cookie', - 'access_token_cookie', 'refresh_token_cookie', - - 'connection_string', 'connection-string', 'connectionstring', - 'database_password', 'db_password', 'db_passwd', - 'mongo_password', 'redis_password', - - 'id_token', 'id-token', - 'jwt', 'jwt_token', - 'oauth_token', 'oauth-token', - 'code_verifier', 'code-verifier', - - 'encryption_key', 'encryption-key', - 'signing_key', 'signing-key', - 'key', 'private', 'secret_key', + 'password', + 'passwd', + 'pwd', + 'token', + 'access_token', + 'refresh_token', + 'bearer_token', + 'auth_token', + 'authorization', + 'auth', + 'bearer', + 'api_key', + 'apikey', + 'api-key', + 'user_api_key', + 'user-api-key', + 'secret', + 'client_secret', + 'client-secret', + 'api_secret', + 'api-secret', + 'private_key', + 'private-key', + 'privatekey', + 'session', + 'session_id', + 'session-id', + 'sessionid', + 'csrf_token', + 'csrf-token', + 'csrftoken', + 'x-csrf-token', + 'xsrf_token', + 'xsrf-token', + 'cookie', + 'set-cookie', + 'set_cookie', + 'access_token_cookie', + 'refresh_token_cookie', + 'connection_string', + 'connection-string', + 'connectionstring', + 'database_password', + 'db_password', + 'db_passwd', + 'mongo_password', + 'redis_password', + 'id_token', + 'id-token', + 'jwt', + 'jwt_token', + 'oauth_token', + 'oauth-token', + 'code_verifier', + 'code-verifier', + 'encryption_key', + 'encryption-key', + 'signing_key', + 'signing-key', + 'key', + 'private', + 'secret_key', } SENSITIVE_VALUE_PATTERNS = [ @@ -44,14 +78,18 @@ SENSITIVE_VALUE_PATTERNS = [ re.compile(r'^-----BEGIN[A-Z\s]+PRIVATE KEY-----', re.DOTALL), ] + def _is_sensitive_key(key: str) -> bool: """Check if a key name indicates sensitive data.""" try: lk = str(key).lower().replace('-', '_') - return lk in SENSITIVE_KEYS or any(s in lk for s in ['password', 'secret', 'token', 'key', 'auth']) + return lk in SENSITIVE_KEYS or any( + s in lk for s in ['password', 'secret', 'token', 'key', 'auth'] + ) except Exception: return False + def _is_sensitive_value(value) -> bool: """Check if a value looks like sensitive data (even if key isn't obviously sensitive).""" try: @@ -61,6 +99,7 @@ def _is_sensitive_value(value) -> bool: except Exception: return False + def _sanitize(obj): """Recursively sanitize objects to redact sensitive data. @@ -88,7 +127,10 @@ def _sanitize(obj): except Exception: return None -def audit(request=None, actor=None, action=None, target=None, status=None, details=None, request_id=None): + +def audit( + request=None, actor=None, action=None, target=None, status=None, details=None, request_id=None +): event = { 'actor': actor, 'action': action, @@ -104,6 +146,4 @@ def audit(request=None, actor=None, action=None, target=None, status=None, detai event['request_id'] = request_id _logger.info(json.dumps(event, separators=(',', ':'))) except Exception: - pass - diff --git a/backend-services/utils/auth_blacklist.py b/backend-services/utils/auth_blacklist.py index 2e309ab..82c1969 100644 --- a/backend-services/utils/auth_blacklist.py +++ b/backend-services/utils/auth_blacklist.py @@ -44,11 +44,10 @@ see limit_throttle_util.py which uses the async Redis client (app.state.redis). - doorman.py app_lifespan() for production Redis requirement enforcement """ -from datetime import datetime, timedelta import heapq import os -from typing import Optional import time +from datetime import datetime, timedelta try: from utils.database import database, revocations_collection @@ -67,6 +66,7 @@ revoked_all_users = set() _redis_client = None _redis_enabled = False + def _init_redis_if_possible(): global _redis_client, _redis_enabled if _redis_client is not None: @@ -84,7 +84,9 @@ def _init_redis_if_possible(): host = os.getenv('REDIS_HOST', 'localhost') port = int(os.getenv('REDIS_PORT', 6379)) db = int(os.getenv('REDIS_DB', 0)) - pool = redis.ConnectionPool(host=host, port=port, db=db, decode_responses=True, max_connections=100) + pool = redis.ConnectionPool( + host=host, port=port, db=db, decode_responses=True, max_connections=100 + ) _redis_client = redis.StrictRedis(connection_pool=pool) try: _redis_client.ping() @@ -96,23 +98,36 @@ def _init_redis_if_possible(): _redis_client = None _redis_enabled = False + def _revoked_jti_key(username: str, jti: str) -> str: return f'jwt:revoked:{username}:{jti}' + def _revoke_all_key(username: str) -> str: return f'jwt:revoke_all:{username}' + def revoke_all_for_user(username: str): """Mark all tokens for a user as revoked (durable if Redis is enabled).""" _init_redis_if_possible() try: - if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None: + if ( + database is not None + and getattr(database, 'memory_only', False) + and revocations_collection is not None + ): try: - existing = revocations_collection.find_one({'type': 'revoke_all', 'username': username}) + existing = revocations_collection.find_one( + {'type': 'revoke_all', 'username': username} + ) if existing: - revocations_collection.update_one({'_id': existing.get('_id')}, {'$set': {'revoke_all': True}}) + revocations_collection.update_one( + {'_id': existing.get('_id')}, {'$set': {'revoke_all': True}} + ) else: - revocations_collection.insert_one({'type': 'revoke_all', 'username': username, 'revoke_all': True}) + revocations_collection.insert_one( + {'type': 'revoke_all', 'username': username, 'revoke_all': True} + ) except Exception: revoked_all_users.add(username) return @@ -123,11 +138,16 @@ def revoke_all_for_user(username: str): except Exception: revoked_all_users.add(username) + def unrevoke_all_for_user(username: str): """Clear 'revoke all' for a user (durable if Redis is enabled).""" _init_redis_if_possible() try: - if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None: + if ( + database is not None + and getattr(database, 'memory_only', False) + and revocations_collection is not None + ): try: revocations_collection.delete_one({'type': 'revoke_all', 'username': username}) except Exception: @@ -140,11 +160,16 @@ def unrevoke_all_for_user(username: str): except Exception: revoked_all_users.discard(username) + def is_user_revoked(username: str) -> bool: """Return True if user is under 'revoke all' (durable check if Redis enabled).""" _init_redis_if_possible() try: - if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None: + if ( + database is not None + and getattr(database, 'memory_only', False) + and revocations_collection is not None + ): try: doc = revocations_collection.find_one({'type': 'revoke_all', 'username': username}) return bool(doc and doc.get('revoke_all')) @@ -156,6 +181,7 @@ def is_user_revoked(username: str) -> bool: except Exception: return username in revoked_all_users + class TimedHeap: def __init__(self, purge_after=timedelta(hours=1)): self.heap = [] @@ -182,7 +208,8 @@ class TimedHeap: return self.heap[0][1] return None -def add_revoked_jti(username: str, jti: str, ttl_seconds: Optional[int] = None): + +def add_revoked_jti(username: str, jti: str, ttl_seconds: int | None = None): """Add a specific JTI to the revocation list. - If Redis is enabled, store key with TTL so it auto-expires. @@ -192,14 +219,26 @@ def add_revoked_jti(username: str, jti: str, ttl_seconds: Optional[int] = None): return _init_redis_if_possible() try: - if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None: + if ( + database is not None + and getattr(database, 'memory_only', False) + and revocations_collection is not None + ): try: - exp = int(time.time()) + (max(1, int(ttl_seconds)) if ttl_seconds is not None else 3600) - existing = revocations_collection.find_one({'type': 'jti', 'username': username, 'jti': jti}) + exp = int(time.time()) + ( + max(1, int(ttl_seconds)) if ttl_seconds is not None else 3600 + ) + existing = revocations_collection.find_one( + {'type': 'jti', 'username': username, 'jti': jti} + ) if existing: - revocations_collection.update_one({'_id': existing.get('_id')}, {'$set': {'expires_at': exp}}) + revocations_collection.update_one( + {'_id': existing.get('_id')}, {'$set': {'expires_at': exp}} + ) else: - revocations_collection.insert_one({'type': 'jti', 'username': username, 'jti': jti, 'expires_at': exp}) + revocations_collection.insert_one( + {'type': 'jti', 'username': username, 'jti': jti, 'expires_at': exp} + ) return except Exception: pass @@ -215,15 +254,22 @@ def add_revoked_jti(username: str, jti: str, ttl_seconds: Optional[int] = None): jwt_blacklist[username] = th th.push(jti) + def is_jti_revoked(username: str, jti: str) -> bool: """Check whether a specific JTI is revoked (durable if Redis enabled).""" if not username or not jti: return False _init_redis_if_possible() try: - if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None: + if ( + database is not None + and getattr(database, 'memory_only', False) + and revocations_collection is not None + ): try: - doc = revocations_collection.find_one({'type': 'jti', 'username': username, 'jti': jti}) + doc = revocations_collection.find_one( + {'type': 'jti', 'username': username, 'jti': jti} + ) if not doc: pass else: @@ -248,13 +294,18 @@ def is_jti_revoked(username: str, jti: str) -> bool: return True return False + async def purge_expired_tokens(): """No-op when Redis-backed; purge DB/in-memory when memory-only.""" _init_redis_if_possible() if _redis_enabled: return try: - if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None: + if ( + database is not None + and getattr(database, 'memory_only', False) + and revocations_collection is not None + ): now = int(time.time()) to_delete = [] for d in revocations_collection.find({'type': 'jti'}): diff --git a/backend-services/utils/auth_util.py b/backend-services/utils/auth_util.py index 0bfc06a..8fd8ac6 100644 --- a/backend-services/utils/auth_util.py +++ b/backend-services/utils/auth_util.py @@ -7,32 +7,32 @@ See https://github.com/pypeople-dev/doorman for more information from datetime import datetime, timedelta try: - from datetime import UTC except Exception: - from datetime import timezone as _timezone - UTC = _timezone.utc + UTC = UTC +import asyncio +import logging import os import uuid + from fastapi import HTTPException, Request -from jose import jwt, JWTError -import asyncio +from jose import JWTError, jwt -from utils.auth_blacklist import is_user_revoked, is_jti_revoked -from utils.database import user_collection, role_collection +from utils.auth_blacklist import is_jti_revoked, is_user_revoked +from utils.database import role_collection, user_collection from utils.doorman_cache_util import doorman_cache -import logging - logger = logging.getLogger('doorman.gateway') SECRET_KEY = os.getenv('JWT_SECRET_KEY') ALGORITHM = 'HS256' + def is_jwt_configured() -> bool: """Return True if a JWT secret key is configured.""" return bool(os.getenv('JWT_SECRET_KEY')) + def _read_int_env(name: str, default: int) -> int: try: raw = os.getenv(name) @@ -47,22 +47,39 @@ def _read_int_env(name: str, default: int) -> int: logger.warning(f'Invalid value for {name}; using default {default}') return default + def _normalize_unit(unit: str) -> str: u = (unit or '').strip().lower() mapping = { - 's': 'seconds', 'sec': 'seconds', 'second': 'seconds', 'seconds': 'seconds', - 'm': 'minutes', 'min': 'minutes', 'minute': 'minutes', 'minutes': 'minutes', - 'h': 'hours', 'hr': 'hours', 'hour': 'hours', 'hours': 'hours', - 'd': 'days', 'day': 'days', 'days': 'days', - 'w': 'weeks', 'wk': 'weeks', 'week': 'weeks', 'weeks': 'weeks', + 's': 'seconds', + 'sec': 'seconds', + 'second': 'seconds', + 'seconds': 'seconds', + 'm': 'minutes', + 'min': 'minutes', + 'minute': 'minutes', + 'minutes': 'minutes', + 'h': 'hours', + 'hr': 'hours', + 'hour': 'hours', + 'hours': 'hours', + 'd': 'days', + 'day': 'days', + 'days': 'days', + 'w': 'weeks', + 'wk': 'weeks', + 'week': 'weeks', + 'weeks': 'weeks', } return mapping.get(u, 'minutes') -def _expiry_from_env(value_key: str, unit_key: str, default_value: int, default_unit: str) -> timedelta: + +def _expiry_from_env( + value_key: str, unit_key: str, default_value: int, default_unit: str +) -> timedelta: value = _read_int_env(value_key, default_value) unit = _normalize_unit(os.getenv(unit_key, default_unit)) try: - return timedelta(**{unit: value}) except Exception: logger.warning( @@ -70,6 +87,7 @@ def _expiry_from_env(value_key: str, unit_key: str, default_value: int, default_ ) return timedelta(**{_normalize_unit(default_unit): default_value}) + async def validate_csrf_double_submit(header_token: str, cookie_token: str) -> bool: try: if not header_token or not cookie_token: @@ -78,6 +96,7 @@ async def validate_csrf_double_submit(header_token: str, cookie_token: str) -> b except Exception: return False + async def auth_required(request: Request) -> dict: """Validate JWT token and CSRF for HTTPS @@ -96,10 +115,7 @@ async def auth_required(request: Request) -> dict: raise HTTPException(status_code=401, detail='Invalid CSRF token') try: payload = jwt.decode( - token, - SECRET_KEY, - algorithms=[ALGORITHM], - options={'verify_signature': True} + token, SECRET_KEY, algorithms=[ALGORITHM], options={'verify_signature': True} ) username = payload.get('sub') jti = payload.get('jti') @@ -112,8 +128,10 @@ async def auth_required(request: Request) -> dict: user = await asyncio.to_thread(user_collection.find_one, {'username': username}) if not user: raise HTTPException(status_code=404, detail='User not found') - if user.get('_id'): del user['_id'] - if user.get('password'): del user['password'] + if user.get('_id'): + del user['_id'] + if user.get('password'): + del user['password'] doorman_cache.set_cache('user_cache', username, user) if not user: raise HTTPException(status_code=404, detail='User not found') @@ -127,6 +145,7 @@ async def auth_required(request: Request) -> dict: logger.error(f'Unexpected error in auth_required: {str(e)}') raise HTTPException(status_code=401, detail='Unauthorized') + def create_access_token(data: dict, refresh: bool = False) -> str: """Create a JWT access token with user permissions. @@ -153,8 +172,10 @@ def create_access_token(data: dict, refresh: bool = False) -> str: if not user: user = user_collection.find_one({'username': username}) if user: - if user.get('_id'): del user['_id'] - if user.get('password'): del user['password'] + if user.get('_id'): + del user['_id'] + if user.get('password'): + del user['password'] doorman_cache.set_cache('user_cache', username, user) if not user: @@ -168,7 +189,8 @@ def create_access_token(data: dict, refresh: bool = False) -> str: if not role: role = role_collection.find_one({'role_name': role_name}) if role: - if role.get('_id'): del role['_id'] + if role.get('_id'): + del role['_id'] doorman_cache.set_cache('role_cache', role_name, role) accesses = { @@ -186,11 +208,9 @@ def create_access_token(data: dict, refresh: bool = False) -> str: 'view_logs': role.get('view_logs', False) if role else False, } - to_encode.update({ - 'exp': datetime.now(UTC) + expire, - 'jti': str(uuid.uuid4()), - 'accesses': accesses - }) + to_encode.update( + {'exp': datetime.now(UTC) + expire, 'jti': str(uuid.uuid4()), 'accesses': accesses} + ) logger.info(f'Creating token for user {username} with accesses: {accesses}') encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) diff --git a/backend-services/utils/bandwidth_util.py b/backend-services/utils/bandwidth_util.py index 4ce563c..accd6c0 100644 --- a/backend-services/utils/bandwidth_util.py +++ b/backend-services/utils/bandwidth_util.py @@ -1,13 +1,14 @@ from __future__ import annotations -from fastapi import Request, HTTPException import time -from typing import Optional -from utils.doorman_cache_util import doorman_cache +from fastapi import HTTPException, Request + from utils.database import user_collection +from utils.doorman_cache_util import doorman_cache -def _window_to_seconds(win: Optional[str]) -> int: + +def _window_to_seconds(win: str | None) -> int: mapping = { 'second': 1, 'minute': 60, @@ -21,14 +22,16 @@ def _window_to_seconds(win: Optional[str]) -> int: w = win.lower().rstrip('s') return mapping.get(w, 86400) -def _bucket_key(username: str, window: str, now: Optional[int] = None) -> tuple[str, int]: + +def _bucket_key(username: str, window: str, now: int | None = None) -> tuple[str, int]: sec = _window_to_seconds(window) now = now or int(time.time()) bucket = (now // sec) * sec key = f'bandwidth_usage:{username}:{sec}:{bucket}' return key, sec -def _get_user(username: str) -> Optional[dict]: + +def _get_user(username: str) -> dict | None: user = doorman_cache.get_cache('user_cache', username) if not user: user = user_collection.find_one({'username': username}) @@ -36,10 +39,12 @@ def _get_user(username: str) -> Optional[dict]: del user['_id'] return user + def _get_client(): return doorman_cache.cache if getattr(doorman_cache, 'is_redis', False) else None -def get_current_usage(username: str, window: Optional[str]) -> int: + +def get_current_usage(username: str, window: str | None) -> int: win = window or 'day' key, ttl = _bucket_key(username, win) client = _get_client() @@ -55,7 +60,8 @@ def get_current_usage(username: str, window: Optional[str]) -> int: except Exception: return 0 -def add_usage(username: str, delta_bytes: int, window: Optional[str]) -> None: + +def add_usage(username: str, delta_bytes: int, window: str | None) -> None: if not delta_bytes: return win = window or 'day' @@ -75,7 +81,8 @@ def add_usage(username: str, delta_bytes: int, window: Optional[str]) -> None: except Exception: pass -async def enforce_pre_request_limit(request: Request, username: Optional[str]) -> None: + +async def enforce_pre_request_limit(request: Request, username: str | None) -> None: if not username: return user = _get_user(username) diff --git a/backend-services/utils/cache_manager_util.py b/backend-services/utils/cache_manager_util.py index 1e256c6..155216a 100644 --- a/backend-services/utils/cache_manager_util.py +++ b/backend-services/utils/cache_manager_util.py @@ -4,17 +4,18 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/pypeople-dev/doorman for more information """ -from fastapi import FastAPI +import os + from aiocache import Cache, caches from aiocache.decorators import cached from dotenv import load_dotenv -import os +from fastapi import FastAPI load_dotenv() + class CacheManager: def __init__(self): - redis_host = os.getenv('REDIS_HOST') redis_port = os.getenv('REDIS_PORT') redis_db = os.getenv('REDIS_DB') @@ -29,26 +30,17 @@ class CacheManager: 'endpoint': redis_host, 'port': port, 'db': db, - 'timeout': 300 + 'timeout': 300, } } except Exception: - self.cache_backend = Cache.MEMORY self.cache_config = { - 'default': { - 'cache': 'aiocache.SimpleMemoryCache', - 'timeout': 300 - } + 'default': {'cache': 'aiocache.SimpleMemoryCache', 'timeout': 300} } else: self.cache_backend = Cache.MEMORY - self.cache_config = { - 'default': { - 'cache': 'aiocache.SimpleMemoryCache', - 'timeout': 300 - } - } + self.cache_config = {'default': {'cache': 'aiocache.SimpleMemoryCache', 'timeout': 300}} caches.set_config(self.cache_config) def init_app(self, app: FastAPI): @@ -58,4 +50,5 @@ class CacheManager: def cached(self, ttl=300, key=None): return cached(ttl=ttl, key=key, cache=self.cache_backend) + cache_manager = CacheManager() diff --git a/backend-services/utils/chaos_util.py b/backend-services/utils/chaos_util.py index 0467f2c..cbcefa9 100644 --- a/backend-services/utils/chaos_util.py +++ b/backend-services/utils/chaos_util.py @@ -1,16 +1,12 @@ -import threading -import time import logging +import threading -_state = { - 'redis_outage': False, - 'mongo_outage': False, - 'error_budget_burn': 0, -} +_state = {'redis_outage': False, 'mongo_outage': False, 'error_budget_burn': 0} _lock = threading.RLock() _logger = logging.getLogger('doorman.chaos') + def enable(backend: str, on: bool): with _lock: key = _key_for(backend) @@ -18,12 +14,14 @@ def enable(backend: str, on: bool): _state[key] = bool(on) _logger.warning(f'chaos: {backend} outage set to {on}') + def enable_for(backend: str, duration_ms: int): enable(backend, True) t = threading.Timer(duration_ms / 1000.0, lambda: enable(backend, False)) t.daemon = True t.start() + def _key_for(backend: str): b = (backend or '').strip().lower() if b == 'redis': @@ -32,6 +30,7 @@ def _key_for(backend: str): return 'mongo_outage' return None + def should_fail(backend: str) -> bool: key = _key_for(backend) if not key: @@ -39,12 +38,15 @@ def should_fail(backend: str) -> bool: with _lock: return bool(_state.get(key)) + def burn_error_budget(backend: str): with _lock: _state['error_budget_burn'] += 1 - _logger.warning(f'chaos: error_budget_burn+1 backend={backend} total={_state["error_budget_burn"]}') + _logger.warning( + f'chaos: error_budget_burn+1 backend={backend} total={_state["error_budget_burn"]}' + ) + def stats() -> dict: with _lock: return dict(_state) - diff --git a/backend-services/utils/constants.py b/backend-services/utils/constants.py index 0f167a9..72fdb8a 100644 --- a/backend-services/utils/constants.py +++ b/backend-services/utils/constants.py @@ -1,6 +1,7 @@ class Headers: REQUEST_ID = 'request_id' + class Defaults: PAGE = 1 PAGE_SIZE = 10 @@ -9,6 +10,7 @@ class Defaults: MAX_MULTIPART_SIZE_BYTES_ENV = 'MAX_MULTIPART_SIZE_BYTES' MAX_MULTIPART_SIZE_BYTES_DEFAULT = 5_242_880 + class Roles: MANAGE_USERS = 'manage_users' MANAGE_APIS = 'manage_apis' @@ -18,6 +20,7 @@ class Roles: EXPORT_LOGS = 'export_logs' MANAGE_ROLES = 'manage_roles' + class ErrorCodes: UNEXPECTED = 'GTW999' HTTP_EXCEPTION = 'GTW998' @@ -29,6 +32,7 @@ class ErrorCodes: REQUEST_FILE_TYPE = 'REQ003' PAGE_SIZE = 'PAG001' + class Messages: UNEXPECTED = 'An unexpected error occurred' FILE_TOO_LARGE = 'Uploaded file too large' diff --git a/backend-services/utils/correlation_util.py b/backend-services/utils/correlation_util.py index 901e7e9..3c56c27 100644 --- a/backend-services/utils/correlation_util.py +++ b/backend-services/utils/correlation_util.py @@ -4,27 +4,29 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from contextvars import ContextVar -from typing import Optional -import uuid import logging +import uuid +from contextvars import ContextVar -correlation_id: ContextVar[Optional[str]] = ContextVar('correlation_id', default=None) +correlation_id: ContextVar[str | None] = ContextVar('correlation_id', default=None) logger = logging.getLogger('doorman.gateway') -def get_correlation_id() -> Optional[str]: + +def get_correlation_id() -> str | None: """ Get the current correlation ID from context. """ return correlation_id.get() + def set_correlation_id(value: str) -> None: """ Set the correlation ID in the current context. """ correlation_id.set(value) + def ensure_correlation_id() -> str: """ Get existing correlation ID or generate a new one. @@ -35,16 +37,18 @@ def ensure_correlation_id() -> str: correlation_id.set(cid) return cid + def log_with_correlation(level: str, message: str, **kwargs) -> None: """ Log a message with the correlation ID automatically prepended. """ cid = get_correlation_id() or 'no-correlation-id' - log_message = f"{cid} | {message}" + log_message = f'{cid} | {message}' log_func = getattr(logger, level.lower(), logger.info) log_func(log_message, **kwargs) -async def run_with_correlation(coro, correlation_id_value: Optional[str] = None): + +async def run_with_correlation(coro, correlation_id_value: str | None = None): """ Run an async coroutine with a correlation ID. """ @@ -56,10 +60,12 @@ async def run_with_correlation(coro, correlation_id_value: Optional[str] = None) finally: pass + class CorrelationContext: """ Context manager for setting correlation ID in a scope. """ + def __init__(self, correlation_id_value: str): self.correlation_id_value = correlation_id_value self.token = None @@ -72,7 +78,10 @@ class CorrelationContext: if self.token: correlation_id.reset(self.token) -async def run_async_with_correlation(func, *args, correlation_id_value: Optional[str] = None, **kwargs): + +async def run_async_with_correlation( + func, *args, correlation_id_value: str | None = None, **kwargs +): """ Run an async function with correlation ID. """ diff --git a/backend-services/utils/credit_util.py b/backend-services/utils/credit_util.py index 1c259e6..d194b57 100644 --- a/backend-services/utils/credit_util.py +++ b/backend-services/utils/credit_util.py @@ -1,7 +1,9 @@ -from utils.database_async import user_credit_collection, credit_def_collection +from datetime import UTC, datetime + from utils.async_db import db_find_one, db_update_one +from utils.database_async import credit_def_collection, user_credit_collection from utils.encryption_util import decrypt_value -from datetime import datetime, timezone + async def deduct_credit(api_credit_group, username): if not api_credit_group: @@ -14,9 +16,14 @@ async def deduct_credit(api_credit_group, username): if not info or info.get('available_credits', 0) <= 0: return False available_credits = info.get('available_credits', 0) - 1 - await db_update_one(user_credit_collection, {'username': username}, {'$set': {f'users_credits.{api_credit_group}.available_credits': available_credits}}) + await db_update_one( + user_credit_collection, + {'username': username}, + {'$set': {f'users_credits.{api_credit_group}.available_credits': available_credits}}, + ) return True + async def get_user_api_key(api_credit_group, username): if not api_credit_group: return None @@ -29,6 +36,7 @@ async def get_user_api_key(api_credit_group, username): dec = decrypt_value(enc) return dec if dec is not None else enc + async def get_credit_api_header(api_credit_group): """ Get credit API header and key, supporting rotation. @@ -61,7 +69,9 @@ async def get_credit_api_header(api_credit_group): if api_key_new_encrypted and rotation_expires: if isinstance(rotation_expires, str): try: - rotation_expires_dt = datetime.fromisoformat(rotation_expires.replace('Z', '+00:00')) + rotation_expires_dt = datetime.fromisoformat( + rotation_expires.replace('Z', '+00:00') + ) except Exception: rotation_expires_dt = None elif isinstance(rotation_expires, datetime): @@ -69,7 +79,7 @@ async def get_credit_api_header(api_credit_group): else: rotation_expires_dt = None - now = datetime.now(timezone.utc) + now = datetime.now(UTC) if rotation_expires_dt and now < rotation_expires_dt: api_key_new = decrypt_value(api_key_new_encrypted) api_key_new = api_key_new if api_key_new is not None else api_key_new_encrypted diff --git a/backend-services/utils/database.py b/backend-services/utils/database.py index 10a1eb7..f495223 100644 --- a/backend-services/utils/database.py +++ b/backend-services/utils/database.py @@ -4,24 +4,22 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/pypeople-dev/doorman for more information """ -from pymongo import MongoClient, IndexModel, ASCENDING -from dotenv import load_dotenv -import os -import uuid import copy -import json -import threading -import secrets -import string as _string import logging +import os +import threading +import uuid -from utils import password_util -from utils import chaos_util +from dotenv import load_dotenv +from pymongo import ASCENDING, IndexModel, MongoClient + +from utils import chaos_util, password_util load_dotenv() logger = logging.getLogger('doorman.gateway') + def _build_admin_seed_doc(email: str, pwd_hash: str) -> dict: """Canonical admin bootstrap document used for both memory and Mongo modes. @@ -47,9 +45,9 @@ def _build_admin_seed_doc(email: str, pwd_hash: str) -> dict: 'active': True, } + class Database: def __init__(self): - mem_flag = os.getenv('MEM_OR_EXTERNAL') if mem_flag is None: mem_flag = os.getenv('MEM_OR_REDIS', 'MEM') @@ -75,17 +73,15 @@ class Database: self.db_existed = True if len(host_list) > 1 and replica_set_name: - connection_uri = f"mongodb://{mongo_user}:{mongo_pass}@{','.join(host_list)}/doorman?replicaSet={replica_set_name}" + connection_uri = f'mongodb://{mongo_user}:{mongo_pass}@{",".join(host_list)}/doorman?replicaSet={replica_set_name}' else: - connection_uri = f"mongodb://{mongo_user}:{mongo_pass}@{','.join(host_list)}/doorman" + connection_uri = f'mongodb://{mongo_user}:{mongo_pass}@{",".join(host_list)}/doorman' self.client = MongoClient( - connection_uri, - serverSelectionTimeoutMS=5000, - maxPoolSize=100, - minPoolSize=5 + connection_uri, serverSelectionTimeoutMS=5000, maxPoolSize=100, minPoolSize=5 ) self.db = self.client.get_database() + def initialize_collections(self): if self.memory_only: # Resolve admin seed credentials consistently across modes (no auto-generation) @@ -93,7 +89,9 @@ class Database: email = os.getenv('DOORMAN_ADMIN_EMAIL') or 'admin@doorman.dev' pwd = os.getenv('DOORMAN_ADMIN_PASSWORD') if not pwd: - raise RuntimeError('DOORMAN_ADMIN_PASSWORD is required for admin initialization') + raise RuntimeError( + 'DOORMAN_ADMIN_PASSWORD is required for admin initialization' + ) return email, password_util.hash_password(pwd) users = self.db.users @@ -101,37 +99,43 @@ class Database: groups = self.db.groups if not roles.find_one({'role_name': 'admin'}): - roles.insert_one({ - 'role_name': 'admin', - 'role_description': 'Administrator role', - 'manage_users': True, - 'manage_apis': True, - 'manage_endpoints': True, - 'manage_groups': True, - 'manage_roles': True, - 'manage_routings': True, - 'manage_gateway': True, - 'manage_subscriptions': True, - 'manage_credits': True, - 'manage_auth': True, - 'manage_security': True, - 'view_logs': True, - 'export_logs': True, - 'ui_access': True - }) + roles.insert_one( + { + 'role_name': 'admin', + 'role_description': 'Administrator role', + 'manage_users': True, + 'manage_apis': True, + 'manage_endpoints': True, + 'manage_groups': True, + 'manage_roles': True, + 'manage_routings': True, + 'manage_gateway': True, + 'manage_subscriptions': True, + 'manage_credits': True, + 'manage_auth': True, + 'manage_security': True, + 'view_logs': True, + 'export_logs': True, + 'ui_access': True, + } + ) if not groups.find_one({'group_name': 'admin'}): - groups.insert_one({ - 'group_name': 'admin', - 'group_description': 'Administrator group with full access', - 'api_access': [] - }) + groups.insert_one( + { + 'group_name': 'admin', + 'group_description': 'Administrator group with full access', + 'api_access': [], + } + ) if not groups.find_one({'group_name': 'ALL'}): - groups.insert_one({ - 'group_name': 'ALL', - 'group_description': 'Default group with access to all APIs', - 'api_access': [] - }) + groups.insert_one( + { + 'group_name': 'ALL', + 'group_description': 'Default group with access to all APIs', + 'api_access': [], + } + ) if not users.find_one({'username': 'admin'}): _email, _pwd_hash = _admin_seed_creds() @@ -150,16 +154,24 @@ class Database: try: env_pwd = os.getenv('DOORMAN_ADMIN_PASSWORD') if env_pwd: - users.update_one({'username': 'admin'}, {'$set': {'password': password_util.hash_password(env_pwd)}}) + users.update_one( + {'username': 'admin'}, + {'$set': {'password': password_util.hash_password(env_pwd)}}, + ) except Exception: pass try: from datetime import datetime + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) env_logs = os.getenv('LOGS_DIR') - logs_dir = os.path.abspath(env_logs) if env_logs else os.path.join(base_dir, 'platform-logs') + logs_dir = ( + os.path.abspath(env_logs) + if env_logs + else os.path.join(base_dir, 'platform-logs') + ) os.makedirs(logs_dir, exist_ok=True) log_path = os.path.join(logs_dir, 'doorman.log') now = datetime.now() @@ -169,7 +181,7 @@ class Database: ('orders', '/orders/v1/status'), ('weather', '/weather/v1/status'), ] - for api_name, ep in samples: + for _api_name, ep in samples: ts = now.strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] rid = str(uuid.uuid4()) msg = f'{rid} | Username: admin | From: 127.0.0.1:54321 | Endpoint: GET {ep} | Total time: 42ms' @@ -180,7 +192,23 @@ class Database: pass logger.info('Memory-only mode: Core data initialized (admin user/role/groups)') return - collections = ['users', 'apis', 'endpoints', 'groups', 'roles', 'subscriptions', 'routings', 'credit_defs', 'user_credits', 'endpoint_validations', 'settings', 'revocations', 'vault_entries', 'tiers', 'user_tier_assignments'] + collections = [ + 'users', + 'apis', + 'endpoints', + 'groups', + 'roles', + 'subscriptions', + 'routings', + 'credit_defs', + 'user_credits', + 'endpoint_validations', + 'settings', + 'revocations', + 'vault_entries', + 'tiers', + 'user_tier_assignments', + ] for collection in collections: if collection not in self.db.list_collection_names(): self.db_existed = False @@ -193,8 +221,11 @@ class Database: email = os.getenv('DOORMAN_ADMIN_EMAIL') or 'admin@doorman.dev' pwd = os.getenv('DOORMAN_ADMIN_PASSWORD') if not pwd: - raise RuntimeError('DOORMAN_ADMIN_PASSWORD is required for admin initialization') + raise RuntimeError( + 'DOORMAN_ADMIN_PASSWORD is required for admin initialization' + ) return email, password_util.hash_password(pwd) + _email, _pwd_hash = _admin_seed_creds_mongo() self.db.users.insert_one(_build_admin_seed_doc(_email, _pwd_hash)) try: @@ -210,90 +241,109 @@ class Database: if env_pwd: self.db.users.update_one( {'username': 'admin'}, - {'$set': {'password': password_util.hash_password(env_pwd)}} + {'$set': {'password': password_util.hash_password(env_pwd)}}, ) logger.warning('Admin user lacked password; set from DOORMAN_ADMIN_PASSWORD') else: - raise RuntimeError('Admin user missing password and DOORMAN_ADMIN_PASSWORD not set') + raise RuntimeError( + 'Admin user missing password and DOORMAN_ADMIN_PASSWORD not set' + ) except Exception: pass if not self.db.roles.find_one({'role_name': 'admin'}): - self.db.roles.insert_one({ - 'role_name': 'admin', - 'role_description': 'Administrator role', - 'manage_users': True, - 'manage_apis': True, - 'manage_endpoints': True, - 'manage_groups': True, - 'manage_roles': True, - 'manage_routings': True, - 'manage_gateway': True, - 'manage_subscriptions': True, - 'manage_credits': True, - 'manage_auth': True, - 'view_logs': True, - 'export_logs': True, - 'manage_security': True - }) + self.db.roles.insert_one( + { + 'role_name': 'admin', + 'role_description': 'Administrator role', + 'manage_users': True, + 'manage_apis': True, + 'manage_endpoints': True, + 'manage_groups': True, + 'manage_roles': True, + 'manage_routings': True, + 'manage_gateway': True, + 'manage_subscriptions': True, + 'manage_credits': True, + 'manage_auth': True, + 'view_logs': True, + 'export_logs': True, + 'manage_security': True, + } + ) if not self.db.groups.find_one({'group_name': 'admin'}): - self.db.groups.insert_one({ - 'group_name': 'admin', - 'group_description': 'Administrator group with full access', - 'api_access': [] - }) + self.db.groups.insert_one( + { + 'group_name': 'admin', + 'group_description': 'Administrator group with full access', + 'api_access': [], + } + ) if not self.db.groups.find_one({'group_name': 'ALL'}): - self.db.groups.insert_one({ - 'group_name': 'ALL', - 'group_description': 'Default group with access to all APIs', - 'api_access': [] - }) + self.db.groups.insert_one( + { + 'group_name': 'ALL', + 'group_description': 'Default group with access to all APIs', + 'api_access': [], + } + ) def create_indexes(self): if self.memory_only: logger.debug('Memory-only mode: Skipping MongoDB index creation') return - self.db.apis.create_indexes([ - IndexModel([('api_id', ASCENDING)], unique=True), - IndexModel([('name', ASCENDING), ('version', ASCENDING)]) - ]) - self.db.endpoints.create_indexes([ - IndexModel([('endpoint_method', ASCENDING), ('api_name', ASCENDING), ('api_version', ASCENDING), ('endpoint_uri', ASCENDING)], unique=True), - ]) - self.db.users.create_indexes([ - IndexModel([('username', ASCENDING)], unique=True), - IndexModel([('email', ASCENDING)], unique=True) - ]) - self.db.groups.create_indexes([ - IndexModel([('group_name', ASCENDING)], unique=True) - ]) - self.db.roles.create_indexes([ - IndexModel([('role_name', ASCENDING)], unique=True) - ]) - self.db.subscriptions.create_indexes([ - IndexModel([('username', ASCENDING)], unique=True) - ]) - self.db.routings.create_indexes([ - IndexModel([('client_key', ASCENDING)], unique=True) - ]) - self.db.credit_defs.create_indexes([ - IndexModel([('api_credit_group', ASCENDING)], unique=True), - IndexModel([('username', ASCENDING)], unique=True) - ]) - self.db.endpoint_validations.create_indexes([ - IndexModel([('endpoint_id', ASCENDING)], unique=True) - ]) - self.db.vault_entries.create_indexes([ - IndexModel([('username', ASCENDING), ('key_name', ASCENDING)], unique=True), - IndexModel([('username', ASCENDING)]) - ]) - self.db.tiers.create_indexes([ - IndexModel([('tier_id', ASCENDING)], unique=True), - IndexModel([('name', ASCENDING)]) - ]) - self.db.user_tier_assignments.create_indexes([ - IndexModel([('user_id', ASCENDING)], unique=True), - IndexModel([('tier_id', ASCENDING)]) - ]) + self.db.apis.create_indexes( + [ + IndexModel([('api_id', ASCENDING)], unique=True), + IndexModel([('name', ASCENDING), ('version', ASCENDING)]), + ] + ) + self.db.endpoints.create_indexes( + [ + IndexModel( + [ + ('endpoint_method', ASCENDING), + ('api_name', ASCENDING), + ('api_version', ASCENDING), + ('endpoint_uri', ASCENDING), + ], + unique=True, + ) + ] + ) + self.db.users.create_indexes( + [ + IndexModel([('username', ASCENDING)], unique=True), + IndexModel([('email', ASCENDING)], unique=True), + ] + ) + self.db.groups.create_indexes([IndexModel([('group_name', ASCENDING)], unique=True)]) + self.db.roles.create_indexes([IndexModel([('role_name', ASCENDING)], unique=True)]) + self.db.subscriptions.create_indexes([IndexModel([('username', ASCENDING)], unique=True)]) + self.db.routings.create_indexes([IndexModel([('client_key', ASCENDING)], unique=True)]) + self.db.credit_defs.create_indexes( + [ + IndexModel([('api_credit_group', ASCENDING)], unique=True), + IndexModel([('username', ASCENDING)], unique=True), + ] + ) + self.db.endpoint_validations.create_indexes( + [IndexModel([('endpoint_id', ASCENDING)], unique=True)] + ) + self.db.vault_entries.create_indexes( + [ + IndexModel([('username', ASCENDING), ('key_name', ASCENDING)], unique=True), + IndexModel([('username', ASCENDING)]), + ] + ) + self.db.tiers.create_indexes( + [IndexModel([('tier_id', ASCENDING)], unique=True), IndexModel([('name', ASCENDING)])] + ) + self.db.user_tier_assignments.create_indexes( + [ + IndexModel([('user_id', ASCENDING)], unique=True), + IndexModel([('tier_id', ASCENDING)]), + ] + ) def is_memory_only(self) -> bool: return self.memory_only @@ -303,27 +353,30 @@ class Database: 'mode': 'memory_only' if self.memory_only else 'mongodb', 'mongodb_connected': not self.memory_only and self.client is not None, 'collections_available': not self.memory_only, - 'cache_backend': os.getenv('MEM_OR_EXTERNAL', os.getenv('MEM_OR_REDIS', 'REDIS')) + 'cache_backend': os.getenv('MEM_OR_EXTERNAL', os.getenv('MEM_OR_REDIS', 'REDIS')), } + class InMemoryInsertResult: def __init__(self, inserted_id): self.acknowledged = True self.inserted_id = inserted_id + class InMemoryUpdateResult: def __init__(self, modified_count): self.acknowledged = True self.modified_count = modified_count + class InMemoryDeleteResult: def __init__(self, deleted_count): self.acknowledged = True self.deleted_count = deleted_count + class InMemoryCursor: def __init__(self, docs): - self._docs = [copy.deepcopy(d) for d in docs] self._index = 0 @@ -362,11 +415,11 @@ class InMemoryCursor: if length is None: return data try: - return data[: int(length)] except Exception: return data + class InMemoryCollection: def __init__(self, name): self.name = name @@ -453,7 +506,6 @@ class InMemoryCollection: if set_data: for k, v in set_data.items(): - if isinstance(k, str) and '.' in k: parts = k.split('.') cur = updated @@ -545,49 +597,48 @@ class InMemoryCollection: return None def create_indexes(self, *args, **kwargs): - return [] class AsyncInMemoryCollection: """Async wrapper around InMemoryCollection for async/await compatibility""" - + def __init__(self, sync_collection): self._sync = sync_collection self.name = sync_collection.name - + async def find_one(self, query=None): """Async find_one""" return self._sync.find_one(query) - + def find(self, query=None): """Returns cursor (sync method, but cursor supports async iteration)""" return self._sync.find(query) - + async def insert_one(self, doc): """Async insert_one""" return self._sync.insert_one(doc) - + async def update_one(self, query, update): """Async update_one""" return self._sync.update_one(query, update) - + async def delete_one(self, query): """Async delete_one""" return self._sync.delete_one(query) - + async def count_documents(self, query=None): """Async count_documents""" return self._sync.count_documents(query) - + async def replace_one(self, query, replacement): """Async replace_one""" return self._sync.replace_one(query, replacement) - + async def find_one_and_update(self, query, update, return_document=False): """Async find_one_and_update""" return self._sync.find_one_and_update(query, update, return_document) - + def create_indexes(self, *args, **kwargs): return self._sync.create_indexes(*args, **kwargs) @@ -595,8 +646,7 @@ class AsyncInMemoryCollection: class InMemoryDB: def __init__(self, async_mode=False): self._async_mode = async_mode - CollectionClass = AsyncInMemoryCollection if async_mode else InMemoryCollection - + # Create base sync collections self._sync_users = InMemoryCollection('users') self._sync_apis = InMemoryCollection('apis') @@ -614,7 +664,7 @@ class InMemoryDB: self._sync_tiers = InMemoryCollection('tiers') self._sync_user_tier_assignments = InMemoryCollection('user_tier_assignments') self._sync_rate_limit_rules = InMemoryCollection('rate_limit_rules') - + # Expose as async or sync based on mode if async_mode: self.users = AsyncInMemoryCollection(self._sync_users) @@ -653,10 +703,22 @@ class InMemoryDB: def list_collection_names(self): return [ - 'users', 'apis', 'endpoints', 'groups', 'roles', - 'subscriptions', 'routings', 'credit_defs', 'user_credits', - 'endpoint_validations', 'settings', 'revocations', 'vault_entries', - 'tiers', 'user_tier_assignments', 'rate_limit_rules' + 'users', + 'apis', + 'endpoints', + 'groups', + 'roles', + 'subscriptions', + 'routings', + 'credit_defs', + 'user_credits', + 'endpoint_validations', + 'settings', + 'revocations', + 'vault_entries', + 'tiers', + 'user_tier_assignments', + 'rate_limit_rules', ] def create_collection(self, name): @@ -709,11 +771,11 @@ class InMemoryDB: load_coll(self._sync_tiers, data.get('tiers', [])) load_coll(self._sync_user_tier_assignments, data.get('user_tier_assignments', [])) + database = Database() database.initialize_collections() database.create_indexes() if database.memory_only: - db = database.db mongodb_client = None api_collection = db.apis @@ -750,6 +812,7 @@ else: except Exception: vault_entries_collection = None + def close_database_connections(): """ Close all database connections for graceful shutdown. @@ -758,6 +821,6 @@ def close_database_connections(): try: if mongodb_client: mongodb_client.close() - logger.info("MongoDB connections closed") + logger.info('MongoDB connections closed') except Exception as e: - logger.warning(f"Error closing MongoDB connections: {e}") + logger.warning(f'Error closing MongoDB connections: {e}') diff --git a/backend-services/utils/database_async.py b/backend-services/utils/database_async.py index b023e19..fefaa78 100644 --- a/backend-services/utils/database_async.py +++ b/backend-services/utils/database_async.py @@ -10,19 +10,19 @@ try: from motor.motor_asyncio import AsyncIOMotorClient except Exception: AsyncIOMotorClient = None -from dotenv import load_dotenv -import os -import asyncio -from typing import Optional import logging +import os + +from dotenv import load_dotenv -from utils.database import InMemoryDB, InMemoryCollection from utils import password_util +from utils.database import InMemoryDB load_dotenv() logger = logging.getLogger('doorman.gateway') + class AsyncDatabase: """Async database wrapper that supports both Motor (MongoDB) and in-memory modes.""" @@ -37,6 +37,7 @@ class AsyncDatabase: # async-compatible collection interfaces (wrapping the same sync collections). try: from utils.database import database as _sync_db + self.client = None self.db_existed = getattr(_sync_db, 'db_existed', False) @@ -46,6 +47,7 @@ class AsyncDatabase: self._sync = sync_db # Wrap collections with async facade using the same underlying storage from utils.database import AsyncInMemoryCollection as _AIC + self.users = _AIC(sync_db._sync_users) self.apis = _AIC(sync_db._sync_apis) self.endpoints = _AIC(sync_db._sync_endpoints) @@ -95,17 +97,16 @@ class AsyncDatabase: self.db_existed = True if len(host_list) > 1 and replica_set_name: - connection_uri = f"mongodb://{mongo_user}:{mongo_pass}@{','.join(host_list)}/doorman?replicaSet={replica_set_name}" + connection_uri = f'mongodb://{mongo_user}:{mongo_pass}@{",".join(host_list)}/doorman?replicaSet={replica_set_name}' else: - connection_uri = f"mongodb://{mongo_user}:{mongo_pass}@{','.join(host_list)}/doorman" + connection_uri = f'mongodb://{mongo_user}:{mongo_pass}@{",".join(host_list)}/doorman' if AsyncIOMotorClient is None: - raise RuntimeError('motor is required for async MongoDB mode; install motor or set MEM_OR_EXTERNAL=MEM') + raise RuntimeError( + 'motor is required for async MongoDB mode; install motor or set MEM_OR_EXTERNAL=MEM' + ) self.client = AsyncIOMotorClient( - connection_uri, - serverSelectionTimeoutMS=5000, - maxPoolSize=100, - minPoolSize=5 + connection_uri, serverSelectionTimeoutMS=5000, maxPoolSize=100, minPoolSize=5 ) self.db = self.client.get_database() @@ -113,13 +114,24 @@ class AsyncDatabase: """Initialize collections and default data.""" if self.memory_only: from utils.database import database + database.initialize_collections() return collections = [ - 'users', 'apis', 'endpoints', 'groups', 'roles', 'subscriptions', - 'routings', 'credit_defs', 'user_credits', 'endpoint_validations', - 'settings', 'revocations', 'vault_entries' + 'users', + 'apis', + 'endpoints', + 'groups', + 'roles', + 'subscriptions', + 'routings', + 'credit_defs', + 'user_credits', + 'endpoint_validations', + 'settings', + 'revocations', + 'vault_entries', ] existing_collections = await self.db.list_collection_names() @@ -136,82 +148,91 @@ class AsyncDatabase: email = os.getenv('DOORMAN_ADMIN_EMAIL') or 'admin@doorman.dev' pwd = os.getenv('DOORMAN_ADMIN_PASSWORD') if not pwd: - raise RuntimeError('DOORMAN_ADMIN_PASSWORD is required for admin initialization') + raise RuntimeError( + 'DOORMAN_ADMIN_PASSWORD is required for admin initialization' + ) pwd_hash = password_util.hash_password(pwd) - await self.db.users.insert_one({ - 'username': 'admin', - 'email': email, - 'password': pwd_hash, - 'role': 'admin', - 'groups': ['ALL', 'admin'], - 'rate_limit_duration': 1, - 'rate_limit_duration_type': 'second', - 'throttle_duration': 1, - 'throttle_duration_type': 'second', - 'throttle_wait_duration': 0, - 'throttle_wait_duration_type': 'second', - 'custom_attributes': {'custom_key': 'custom_value'}, - 'active': True, - 'throttle_queue_limit': 1, - 'throttle_enabled': None, - 'ui_access': True - }) + await self.db.users.insert_one( + { + 'username': 'admin', + 'email': email, + 'password': pwd_hash, + 'role': 'admin', + 'groups': ['ALL', 'admin'], + 'rate_limit_duration': 1, + 'rate_limit_duration_type': 'second', + 'throttle_duration': 1, + 'throttle_duration_type': 'second', + 'throttle_wait_duration': 0, + 'throttle_wait_duration_type': 'second', + 'custom_attributes': {'custom_key': 'custom_value'}, + 'active': True, + 'throttle_queue_limit': 1, + 'throttle_enabled': None, + 'ui_access': True, + } + ) try: adm = await self.db.users.find_one({'username': 'admin'}) if adm and adm.get('ui_access') is not True: - await self.db.users.update_one( - {'username': 'admin'}, - {'$set': {'ui_access': True}} - ) + await self.db.users.update_one({'username': 'admin'}, {'$set': {'ui_access': True}}) if adm and not adm.get('password'): env_pwd = os.getenv('DOORMAN_ADMIN_PASSWORD') if env_pwd: await self.db.users.update_one( {'username': 'admin'}, - {'$set': {'password': password_util.hash_password(env_pwd)}} + {'$set': {'password': password_util.hash_password(env_pwd)}}, ) logger.warning('Admin user lacked password; set from DOORMAN_ADMIN_PASSWORD') else: - raise RuntimeError('Admin user missing password and DOORMAN_ADMIN_PASSWORD not set') + raise RuntimeError( + 'Admin user missing password and DOORMAN_ADMIN_PASSWORD not set' + ) except Exception: pass admin_role = await self.db.roles.find_one({'role_name': 'admin'}) if not admin_role: - await self.db.roles.insert_one({ - 'role_name': 'admin', - 'role_description': 'Administrator role', - 'manage_users': True, - 'manage_apis': True, - 'manage_endpoints': True, - 'manage_groups': True, - 'manage_roles': True, - 'manage_routings': True, - 'manage_gateway': True, - 'manage_subscriptions': True, - 'manage_credits': True, - 'manage_auth': True, - 'view_logs': True, - 'export_logs': True, - 'manage_security': True - }) + await self.db.roles.insert_one( + { + 'role_name': 'admin', + 'role_description': 'Administrator role', + 'manage_users': True, + 'manage_apis': True, + 'manage_endpoints': True, + 'manage_groups': True, + 'manage_roles': True, + 'manage_routings': True, + 'manage_gateway': True, + 'manage_subscriptions': True, + 'manage_credits': True, + 'manage_auth': True, + 'view_logs': True, + 'export_logs': True, + 'manage_security': True, + } + ) admin_group = await self.db.groups.find_one({'group_name': 'admin'}) if not admin_group: - await self.db.groups.insert_one({ - 'group_name': 'admin', - 'group_description': 'Administrator group with full access', - 'api_access': [] - }) + await self.db.groups.insert_one( + { + 'group_name': 'admin', + 'group_description': 'Administrator group with full access', + 'api_access': [], + } + ) all_group = await self.db.groups.find_one({'group_name': 'ALL'}) if not all_group: - await self.db.groups.insert_one({ - 'group_name': 'ALL', - 'group_description': 'Default group with access to all APIs', - 'api_access': [] - }) + await self.db.groups.insert_one( + { + 'group_name': 'ALL', + 'group_description': 'Default group with access to all APIs', + 'api_access': [], + } + ) async def create_indexes(self): """Create database indexes for performance.""" @@ -219,56 +240,65 @@ class AsyncDatabase: logger.debug('Async Memory-only mode: Skipping MongoDB index creation') return - from pymongo import IndexModel, ASCENDING + from pymongo import ASCENDING, IndexModel - await self.db.apis.create_indexes([ - IndexModel([('api_id', ASCENDING)], unique=True), - IndexModel([('name', ASCENDING), ('version', ASCENDING)]) - ]) + await self.db.apis.create_indexes( + [ + IndexModel([('api_id', ASCENDING)], unique=True), + IndexModel([('name', ASCENDING), ('version', ASCENDING)]), + ] + ) - await self.db.endpoints.create_indexes([ - IndexModel([ - ('endpoint_method', ASCENDING), - ('api_name', ASCENDING), - ('api_version', ASCENDING), - ('endpoint_uri', ASCENDING) - ], unique=True), - ]) + await self.db.endpoints.create_indexes( + [ + IndexModel( + [ + ('endpoint_method', ASCENDING), + ('api_name', ASCENDING), + ('api_version', ASCENDING), + ('endpoint_uri', ASCENDING), + ], + unique=True, + ) + ] + ) - await self.db.users.create_indexes([ - IndexModel([('username', ASCENDING)], unique=True), - IndexModel([('email', ASCENDING)], unique=True) - ]) + await self.db.users.create_indexes( + [ + IndexModel([('username', ASCENDING)], unique=True), + IndexModel([('email', ASCENDING)], unique=True), + ] + ) - await self.db.groups.create_indexes([ - IndexModel([('group_name', ASCENDING)], unique=True) - ]) + await self.db.groups.create_indexes([IndexModel([('group_name', ASCENDING)], unique=True)]) - await self.db.roles.create_indexes([ - IndexModel([('role_name', ASCENDING)], unique=True) - ]) + await self.db.roles.create_indexes([IndexModel([('role_name', ASCENDING)], unique=True)]) - await self.db.subscriptions.create_indexes([ - IndexModel([('username', ASCENDING)], unique=True) - ]) + await self.db.subscriptions.create_indexes( + [IndexModel([('username', ASCENDING)], unique=True)] + ) - await self.db.routings.create_indexes([ - IndexModel([('client_key', ASCENDING)], unique=True) - ]) + await self.db.routings.create_indexes( + [IndexModel([('client_key', ASCENDING)], unique=True)] + ) - await self.db.credit_defs.create_indexes([ - IndexModel([('api_credit_group', ASCENDING)], unique=True), - IndexModel([('username', ASCENDING)], unique=True) - ]) + await self.db.credit_defs.create_indexes( + [ + IndexModel([('api_credit_group', ASCENDING)], unique=True), + IndexModel([('username', ASCENDING)], unique=True), + ] + ) - await self.db.endpoint_validations.create_indexes([ - IndexModel([('endpoint_id', ASCENDING)], unique=True) - ]) + await self.db.endpoint_validations.create_indexes( + [IndexModel([('endpoint_id', ASCENDING)], unique=True)] + ) - await self.db.vault_entries.create_indexes([ - IndexModel([('username', ASCENDING), ('key_name', ASCENDING)], unique=True), - IndexModel([('username', ASCENDING)]) - ]) + await self.db.vault_entries.create_indexes( + [ + IndexModel([('username', ASCENDING), ('key_name', ASCENDING)], unique=True), + IndexModel([('username', ASCENDING)]), + ] + ) def is_memory_only(self) -> bool: """Check if running in memory-only mode.""" @@ -280,14 +310,15 @@ class AsyncDatabase: 'mode': 'memory_only' if self.memory_only else 'mongodb', 'mongodb_connected': not self.memory_only and self.client is not None, 'collections_available': not self.memory_only, - 'cache_backend': os.getenv('MEM_OR_EXTERNAL', os.getenv('MEM_OR_REDIS', 'REDIS')) + 'cache_backend': os.getenv('MEM_OR_EXTERNAL', os.getenv('MEM_OR_REDIS', 'REDIS')), } async def close(self): """Close database connections gracefully.""" if self.client: self.client.close() - logger.info("Async MongoDB connections closed") + logger.info('Async MongoDB connections closed') + async_database = AsyncDatabase() @@ -332,6 +363,7 @@ else: except Exception: vault_entries_collection = None + async def close_async_database_connections(): """Close all async database connections for graceful shutdown.""" await async_database.close() diff --git a/backend-services/utils/demo_seed_util.py b/backend-services/utils/demo_seed_util.py index c2ce2c0..dfc7b7e 100644 --- a/backend-services/utils/demo_seed_util.py +++ b/backend-services/utils/demo_seed_util.py @@ -1,39 +1,73 @@ from __future__ import annotations + import os -import uuid import random import string +import uuid from datetime import datetime, timedelta from utils import password_util from utils.database import ( - api_collection, + credit_def_collection, endpoint_collection, group_collection, role_collection, subscriptions_collection, - credit_def_collection, - user_credit_collection, user_collection, + user_credit_collection, ) -from utils.metrics_util import metrics_store, MinuteBucket from utils.encryption_util import encrypt_value +from utils.metrics_util import MinuteBucket, metrics_store + def _rand_choice(seq): return random.choice(seq) + def _rand_word(min_len=4, max_len=10) -> str: length = random.randint(min_len, max_len) return ''.join(random.choices(string.ascii_lowercase, k=length)) + def _rand_name() -> str: - firsts = ['alex','casey','morgan','sam','taylor','riley','jamie','jordan','drew','quinn','kyle','parker','blake','devon'] - lasts = ['lee','kim','patel','garcia','nguyen','williams','brown','davis','miller','wilson','moore','taylor','thomas'] + firsts = [ + 'alex', + 'casey', + 'morgan', + 'sam', + 'taylor', + 'riley', + 'jamie', + 'jordan', + 'drew', + 'quinn', + 'kyle', + 'parker', + 'blake', + 'devon', + ] + lasts = [ + 'lee', + 'kim', + 'patel', + 'garcia', + 'nguyen', + 'williams', + 'brown', + 'davis', + 'miller', + 'wilson', + 'moore', + 'taylor', + 'thomas', + ] return f'{_rand_choice(firsts)}.{_rand_choice(lasts)}' + def _rand_domain() -> str: - return _rand_choice(['example.com','acme.io','contoso.net','demo.dev']) + return _rand_choice(['example.com', 'acme.io', 'contoso.net', 'demo.dev']) + def _rand_password() -> str: upp = _rand_choice(string.ascii_uppercase) @@ -44,11 +78,23 @@ def _rand_password() -> str: raw = upp + low + dig + spc + tail return ''.join(random.sample(raw, len(raw))) + def ensure_roles() -> list[str]: - roles = [('developer', dict(manage_apis=True, manage_endpoints=True, manage_subscriptions=True, manage_credits=True, view_logs=True)), - ('analyst', dict(view_logs=True, export_logs=True)), - ('viewer', dict(view_logs=True)), - ('ops', dict(manage_gateway=True, view_logs=True, export_logs=True, manage_security=True))] + roles = [ + ( + 'developer', + dict( + manage_apis=True, + manage_endpoints=True, + manage_subscriptions=True, + manage_credits=True, + view_logs=True, + ), + ), + ('analyst', dict(view_logs=True, export_logs=True)), + ('viewer', dict(view_logs=True)), + ('ops', dict(manage_gateway=True, view_logs=True, export_logs=True, manage_security=True)), + ] created = [] for role_name, extra in roles: if not role_collection.find_one({'role_name': role_name}): @@ -58,58 +104,103 @@ def ensure_roles() -> list[str]: created.append(role_name) return ['admin', *created] + def seed_groups(n: int, api_keys: list[str]) -> list[str]: names = [] for i in range(n): - gname = f'team-{_rand_word(3,6)}-{i}' + gname = f'team-{_rand_word(3, 6)}-{i}' if group_collection.find_one({'group_name': gname}): names.append(gname) continue - access = sorted(set(random.sample(api_keys, k=min(len(api_keys), random.randint(1, max(1, len(api_keys)//3))))) ) if api_keys else [] - group_collection.insert_one({'group_name': gname, 'group_description': f'Auto group {gname}', 'api_access': access}) + access = ( + sorted( + set( + random.sample( + api_keys, + k=min(len(api_keys), random.randint(1, max(1, len(api_keys) // 3))), + ) + ) + ) + if api_keys + else [] + ) + group_collection.insert_one( + {'group_name': gname, 'group_description': f'Auto group {gname}', 'api_access': access} + ) names.append(gname) for base in ('ALL', 'admin'): if not group_collection.find_one({'group_name': base}): - group_collection.insert_one({'group_name': base, 'group_description': f'{base} group', 'api_access': []}) + group_collection.insert_one( + {'group_name': base, 'group_description': f'{base} group', 'api_access': []} + ) if base not in names: names.append(base) return names + def seed_users(n: int, roles: list[str], groups: list[str]) -> list[str]: usernames = [] for i in range(n): uname = f'{_rand_name()}_{i}' - email = f"{uname.replace('.', '_')}@{_rand_domain()}" + email = f'{uname.replace(".", "_")}@{_rand_domain()}' if user_collection.find_one({'username': uname}): usernames.append(uname) continue hashed = password_util.hash_password(_rand_password()) - ugrps = sorted(set(random.sample(groups, k=min(len(groups), random.randint(1, min(3, max(1, len(groups)))))))) + ugrps = sorted( + set( + random.sample( + groups, k=min(len(groups), random.randint(1, min(3, max(1, len(groups))))) + ) + ) + ) role = _rand_choice(roles) - user_collection.insert_one({ - 'username': uname, - 'email': email, - 'password': hashed, - 'role': role, - 'groups': ugrps, - 'rate_limit_duration': random.randint(100, 10000), - 'rate_limit_duration_type': _rand_choice(['minute','hour','day']), - 'throttle_duration': random.randint(1000, 100000), - 'throttle_duration_type': _rand_choice(['second','minute']), - 'throttle_wait_duration': random.randint(100, 10000), - 'throttle_wait_duration_type': _rand_choice(['seconds','minutes']), - 'custom_attributes': {'dept': _rand_choice(['sales','eng','support','ops'])}, - 'active': True, - 'ui_access': _rand_choice([True, False]) - }) + user_collection.insert_one( + { + 'username': uname, + 'email': email, + 'password': hashed, + 'role': role, + 'groups': ugrps, + 'rate_limit_duration': random.randint(100, 10000), + 'rate_limit_duration_type': _rand_choice(['minute', 'hour', 'day']), + 'throttle_duration': random.randint(1000, 100000), + 'throttle_duration_type': _rand_choice(['second', 'minute']), + 'throttle_wait_duration': random.randint(100, 10000), + 'throttle_wait_duration_type': _rand_choice(['seconds', 'minutes']), + 'custom_attributes': {'dept': _rand_choice(['sales', 'eng', 'support', 'ops'])}, + 'active': True, + 'ui_access': _rand_choice([True, False]), + } + ) usernames.append(uname) return usernames -def seed_apis(n: int, roles: list[str], groups: list[str]) -> list[tuple[str,str]]: + +def seed_apis(n: int, roles: list[str], groups: list[str]) -> list[tuple[str, str]]: pairs = [] - for i in range(n): - name = _rand_choice(['customers','orders','billing','weather','news','crypto','search','inventory','shipping','payments','alerts','metrics','recommendations']) + f'-{_rand_word(3,6)}' - ver = _rand_choice(['v1','v2','v3']) + for _i in range(n): + name = ( + _rand_choice( + [ + 'customers', + 'orders', + 'billing', + 'weather', + 'news', + 'crypto', + 'search', + 'inventory', + 'shipping', + 'payments', + 'alerts', + 'metrics', + 'recommendations', + ] + ) + + f'-{_rand_word(3, 6)}' + ) + ver = _rand_choice(['v1', 'v2', 'v3']) if api_collection.find_one({'api_name': name, 'api_version': ver}): pairs.append((name, ver)) continue @@ -118,9 +209,17 @@ def seed_apis(n: int, roles: list[str], groups: list[str]) -> list[tuple[str,str 'api_name': name, 'api_version': ver, 'api_description': f'Auto API {name}/{ver}', - 'api_allowed_roles': sorted(set(random.sample(roles, k=min(len(roles), random.randint(1, min(3, len(roles))))))), - 'api_allowed_groups': sorted(set(random.sample(groups, k=min(len(groups), random.randint(1, min(5, len(groups))))))), - 'api_servers': [f'http://localhost:{8000+random.randint(0,999)}'], + 'api_allowed_roles': sorted( + set(random.sample(roles, k=min(len(roles), random.randint(1, min(3, len(roles)))))) + ), + 'api_allowed_groups': sorted( + set( + random.sample( + groups, k=min(len(groups), random.randint(1, min(5, len(groups)))) + ) + ) + ), + 'api_servers': [f'http://localhost:{8000 + random.randint(0, 999)}'], 'api_type': 'REST', 'api_allowed_retry_count': random.randint(0, 3), 'api_id': api_id, @@ -130,10 +229,22 @@ def seed_apis(n: int, roles: list[str], groups: list[str]) -> list[tuple[str,str pairs.append((name, ver)) return pairs -def seed_endpoints(apis: list[tuple[str,str]], per_api: int) -> None: - methods = ['GET','POST','PUT','DELETE','PATCH'] - bases = ['/status','/health','/items','/items/{id}','/search','/reports','/export','/metrics','/list','/detail/{id}'] - for (name, ver) in apis: + +def seed_endpoints(apis: list[tuple[str, str]], per_api: int) -> None: + methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'] + bases = [ + '/status', + '/health', + '/items', + '/items/{id}', + '/search', + '/reports', + '/export', + '/metrics', + '/list', + '/detail/{id}', + ] + for name, ver in apis: created = set() for _ in range(per_api): m = _rand_choice(methods) @@ -142,24 +253,49 @@ def seed_endpoints(apis: list[tuple[str,str]], per_api: int) -> None: if key in created: continue created.add(key) - if endpoint_collection.find_one({'api_name': name, 'api_version': ver, 'endpoint_method': m, 'endpoint_uri': u}): + if endpoint_collection.find_one( + {'api_name': name, 'api_version': ver, 'endpoint_method': m, 'endpoint_uri': u} + ): continue - endpoint_collection.insert_one({ - 'api_name': name, - 'api_version': ver, - 'endpoint_method': m, - 'endpoint_uri': u, - 'endpoint_description': f'{m} {u} for {name}', - 'api_id': api_collection.find_one({'api_name': name, 'api_version': ver}).get('api_id'), - 'endpoint_id': str(uuid.uuid4()), - }) + endpoint_collection.insert_one( + { + 'api_name': name, + 'api_version': ver, + 'endpoint_method': m, + 'endpoint_uri': u, + 'endpoint_description': f'{m} {u} for {name}', + 'api_id': api_collection.find_one({'api_name': name, 'api_version': ver}).get( + 'api_id' + ), + 'endpoint_id': str(uuid.uuid4()), + } + ) + def seed_credits() -> list[str]: - groups = ['ai-basic','ai-pro','maps-basic','maps-pro','news-tier','weather-tier'] + groups = ['ai-basic', 'ai-pro', 'maps-basic', 'maps-pro', 'news-tier', 'weather-tier'] tiers_catalog = [ - {'tier_name': 'basic', 'credits': 100, 'input_limit': 100, 'output_limit': 100, 'reset_frequency': 'monthly'}, - {'tier_name': 'pro', 'credits': 1000, 'input_limit': 500, 'output_limit': 500, 'reset_frequency': 'monthly'}, - {'tier_name': 'enterprise', 'credits': 10000, 'input_limit': 2000, 'output_limit': 2000, 'reset_frequency': 'monthly'}, + { + 'tier_name': 'basic', + 'credits': 100, + 'input_limit': 100, + 'output_limit': 100, + 'reset_frequency': 'monthly', + }, + { + 'tier_name': 'pro', + 'credits': 1000, + 'input_limit': 500, + 'output_limit': 500, + 'reset_frequency': 'monthly', + }, + { + 'tier_name': 'enterprise', + 'credits': 10000, + 'input_limit': 2000, + 'output_limit': 2000, + 'reset_frequency': 'monthly', + }, ] created = [] for g in groups: @@ -167,64 +303,84 @@ def seed_credits() -> list[str]: created.append(g) continue tiers = random.sample(tiers_catalog, k=random.randint(1, 3)) - credit_def_collection.insert_one({ - 'api_credit_group': g, - 'api_key': encrypt_value(uuid.uuid4().hex), - 'api_key_header': _rand_choice(['x-api-key','authorization','x-token']), - 'credit_tiers': tiers, - }) + credit_def_collection.insert_one( + { + 'api_credit_group': g, + 'api_key': encrypt_value(uuid.uuid4().hex), + 'api_key_header': _rand_choice(['x-api-key', 'authorization', 'x-token']), + 'credit_tiers': tiers, + } + ) created.append(g) return created + def seed_user_credits(usernames: list[str], credit_groups: list[str]) -> None: - pick_users = random.sample(usernames, k=min(len(usernames), max(1, len(usernames)//2))) if usernames else [] + pick_users = ( + random.sample(usernames, k=min(len(usernames), max(1, len(usernames) // 2))) + if usernames + else [] + ) for u in pick_users: users_credits = {} for g in random.sample(credit_groups, k=random.randint(1, min(3, len(credit_groups)))): users_credits[g] = { - 'tier_name': _rand_choice(['basic','pro','enterprise']), + 'tier_name': _rand_choice(['basic', 'pro', 'enterprise']), 'available_credits': random.randint(10, 10000), - 'reset_date': (datetime.utcnow() + timedelta(days=random.randint(1, 30))).strftime('%Y-%m-%d'), + 'reset_date': (datetime.utcnow() + timedelta(days=random.randint(1, 30))).strftime( + '%Y-%m-%d' + ), 'user_api_key': encrypt_value(uuid.uuid4().hex), } existing = user_credit_collection.find_one({'username': u}) if existing: - user_credit_collection.update_one({'username': u}, {'$set': {'users_credits': users_credits}}) + user_credit_collection.update_one( + {'username': u}, {'$set': {'users_credits': users_credits}} + ) else: user_credit_collection.insert_one({'username': u, 'users_credits': users_credits}) -def seed_subscriptions(usernames: list[str], apis: list[tuple[str,str]]) -> None: + +def seed_subscriptions(usernames: list[str], apis: list[tuple[str, str]]) -> None: api_keys = [f'{a}/{v}' for a, v in apis] for u in usernames: - subs = sorted(set(random.sample(api_keys, k=random.randint(1, min(5, len(api_keys))))) ) if api_keys else [] + subs = ( + sorted(set(random.sample(api_keys, k=random.randint(1, min(5, len(api_keys)))))) + if api_keys + else [] + ) existing = subscriptions_collection.find_one({'username': u}) if existing: subscriptions_collection.update_one({'username': u}, {'$set': {'apis': subs}}) else: subscriptions_collection.insert_one({'username': u, 'apis': subs}) -def seed_logs(n: int, usernames: list[str], apis: list[tuple[str,str]]) -> None: + +def seed_logs(n: int, usernames: list[str], apis: list[tuple[str, str]]) -> None: base_dir = os.path.dirname(os.path.abspath(__file__)) base_dir = os.path.abspath(os.path.join(base_dir, '..')) logs_dir = os.path.join(base_dir, 'logs') os.makedirs(logs_dir, exist_ok=True) log_path = os.path.join(logs_dir, 'doorman.log') - methods = ['GET','POST','PUT','DELETE','PATCH'] - uris = ['/status','/list','/items','/items/123','/search?q=test','/export','/metrics'] + methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'] + uris = ['/status', '/list', '/items', '/items/123', '/search?q=test', '/export', '/metrics'] now = datetime.now() with open(log_path, 'a', encoding='utf-8') as lf: for _ in range(n): - api = _rand_choice(apis) if apis else ('demo','v1') + api = _rand_choice(apis) if apis else ('demo', 'v1') method = _rand_choice(methods) uri = _rand_choice(uris) - ts = (now - timedelta(seconds=random.randint(0, 3600))).strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] + ts = (now - timedelta(seconds=random.randint(0, 3600))).strftime( + '%Y-%m-%d %H:%M:%S,%f' + )[:-3] rid = str(uuid.uuid4()) user = _rand_choice(usernames) if usernames else 'admin' port = random.randint(10000, 65000) - msg = f'{rid} | Username: {user} | From: 127.0.0.1:{port} | Endpoint: {method} /{api[0]}/{api[1]}{uri} | Total time: {random.randint(5,500)}ms' + msg = f'{rid} | Username: {user} | From: 127.0.0.1:{port} | Endpoint: {method} /{api[0]}/{api[1]}{uri} | Total time: {random.randint(5, 500)}ms' lf.write(f'{ts} - doorman.gateway - INFO - {msg}\n') -def seed_protos(n: int, apis: list[tuple[str,str]]) -> None: + +def seed_protos(n: int, apis: list[tuple[str, str]]) -> None: base_dir = os.path.dirname(os.path.abspath(__file__)) base_dir = os.path.abspath(os.path.join(base_dir, '..')) proto_dir = os.path.join(base_dir, 'proto') @@ -235,7 +391,7 @@ def seed_protos(n: int, apis: list[tuple[str,str]]) -> None: for name, ver in picked: key = f'{name}_{ver}'.replace('-', '_') svc = ''.join([p.capitalize() for p in name.split('-')]) - content = f'''syntax = "proto3"; + content = f"""syntax = "proto3"; package {key}; @@ -251,12 +407,13 @@ message StatusReply {{ string status = 1; string message = 2; }} -''' +""" path = os.path.join(proto_dir, f'{key}.proto') with open(path, 'w', encoding='utf-8') as f: f.write(content) -def seed_metrics(usernames: list[str], apis: list[tuple[str,str]], minutes: int = 400) -> None: + +def seed_metrics(usernames: list[str], apis: list[tuple[str, str]], minutes: int = 400) -> None: now = datetime.utcnow() for i in range(minutes, 0, -1): minute_start = int(((now - timedelta(minutes=i)).timestamp()) // 60) * 60 @@ -264,7 +421,7 @@ def seed_metrics(usernames: list[str], apis: list[tuple[str,str]], minutes: int count = random.randint(0, 50) for _ in range(count): dur = random.uniform(10, 400) - status = _rand_choice([200,200,200,201,204,400,401,403,404,500]) + status = _rand_choice([200, 200, 200, 201, 204, 400, 401, 403, 404, 500]) b.add(dur, status) metrics_store.total_requests += 1 metrics_store.total_ms += dur @@ -276,11 +433,12 @@ def seed_metrics(usernames: list[str], apis: list[tuple[str,str]], minutes: int metrics_store.api_counts[f'rest:{_rand_choice(apis)[0]}'] += 1 metrics_store._buckets.append(b) + def run_seed(users=30, apis=12, endpoints=5, groups=6, protos=5, logs=1000, seed=None): if seed is not None: random.seed(seed) roles = ensure_roles() - api_pairs = seed_apis(apis, roles, ['ALL','admin']) + api_pairs = seed_apis(apis, roles, ['ALL', 'admin']) group_names = seed_groups(groups, [f'{a}/{v}' for a, v in api_pairs]) usernames = seed_users(users, roles, group_names) seed_endpoints(api_pairs, endpoints) diff --git a/backend-services/utils/doorman_cache_async.py b/backend-services/utils/doorman_cache_async.py index e128edd..b23c320 100644 --- a/backend-services/utils/doorman_cache_async.py +++ b/backend-services/utils/doorman_cache_async.py @@ -6,16 +6,18 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/pypeople-dev/doorman for more information """ -import redis.asyncio as aioredis import json -import os -from typing import Dict, Any, Optional import logging +import os +from typing import Any + +import redis.asyncio as aioredis from utils.doorman_cache_util import MemoryCache logger = logging.getLogger('doorman.gateway') + class AsyncDoormanCacheManager: """Async cache manager supporting both Redis (async) and in-memory modes.""" @@ -52,7 +54,7 @@ class AsyncDoormanCacheManager: 'endpoint_server_cache': 'endpoint_server_cache:', 'client_routing_cache': 'client_routing_cache:', 'token_def_cache': 'token_def_cache:', - 'credit_def_cache': 'credit_def_cache:' + 'credit_def_cache': 'credit_def_cache:', } self.default_ttls = { @@ -70,7 +72,7 @@ class AsyncDoormanCacheManager: 'endpoint_server_cache': 86400, 'client_routing_cache': 86400, 'token_def_cache': 86400, - 'credit_def_cache': 86400 + 'credit_def_cache': 86400, } def _to_json_serializable(self, value): @@ -99,6 +101,7 @@ class AsyncDoormanCacheManager: if self._init_lock: import asyncio + while self._init_lock: await asyncio.sleep(0.01) return @@ -114,7 +117,7 @@ class AsyncDoormanCacheManager: port=redis_port, db=redis_db, decode_responses=True, - max_connections=100 + max_connections=100, ) self.cache = aioredis.Redis(connection_pool=self._redis_pool) @@ -148,7 +151,7 @@ class AsyncDoormanCacheManager: else: self.cache.setex(cache_key, ttl, payload) - async def get_cache(self, cache_name: str, key: str) -> Optional[Any]: + async def get_cache(self, cache_name: str, key: str) -> Any | None: """Get cache value (async).""" if self.is_redis: await self._ensure_redis_connection() @@ -200,13 +203,13 @@ class AsyncDoormanCacheManager: for cache_name in self.prefixes.keys(): await self.clear_cache(cache_name) - async def get_cache_info(self) -> Dict[str, Any]: + async def get_cache_info(self) -> dict[str, Any]: """Get cache information (async).""" info = { 'type': self.cache_type, 'is_redis': self.is_redis, 'prefixes': list(self.prefixes.keys()), - 'default_ttl': self.default_ttls + 'default_ttl': self.default_ttls, } if not self.is_redis and hasattr(self.cache, 'get_cache_stats'): @@ -257,6 +260,7 @@ class AsyncDoormanCacheManager: """ try: import inspect + if inspect.iscoroutine(operation): result = await operation else: @@ -268,7 +272,7 @@ class AsyncDoormanCacheManager: await self.delete_cache(cache_name, key) return result - except Exception as e: + except Exception: await self.delete_cache(cache_name, key) raise @@ -278,10 +282,12 @@ class AsyncDoormanCacheManager: await self.cache.close() if self._redis_pool: await self._redis_pool.disconnect() - logger.info("Async Redis connections closed") + logger.info('Async Redis connections closed') + async_doorman_cache = AsyncDoormanCacheManager() + async def close_async_cache_connections(): """Close all async cache connections for graceful shutdown.""" await async_doorman_cache.close() diff --git a/backend-services/utils/doorman_cache_util.py b/backend-services/utils/doorman_cache_util.py index 7499e0b..689aecc 100644 --- a/backend-services/utils/doorman_cache_util.py +++ b/backend-services/utils/doorman_cache_util.py @@ -4,18 +4,21 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/pypeople-dev/doorman for more information """ -import redis +import asyncio import json +import logging import os import threading -from typing import Dict, Any, Optional -import asyncio -import logging +from typing import Any + +import redis + from utils import chaos_util + class MemoryCache: def __init__(self, maxsize: int = 10000): - self._cache: Dict[str, Dict[str, Any]] = {} + self._cache: dict[str, dict[str, Any]] = {} self._lock = threading.RLock() self._maxsize = maxsize self._access_order = [] @@ -29,16 +32,13 @@ class MemoryCache: lru_key = self._access_order.pop(0) self._cache.pop(lru_key, None) - self._cache[key] = { - 'value': value, - 'expires_at': self._get_current_time() + ttl - } + self._cache[key] = {'value': value, 'expires_at': self._get_current_time() + ttl} if key in self._access_order: self._access_order.remove(key) self._access_order.append(key) - def get(self, key: str) -> Optional[str]: + def get(self, key: str) -> str | None: with self._lock: if key in self._cache: cache_entry = self._cache[key] @@ -70,42 +70,45 @@ class MemoryCache: def _get_current_time(self) -> int: import time + return int(time.time()) - def get_cache_stats(self) -> Dict[str, Any]: + def get_cache_stats(self) -> dict[str, Any]: with self._lock: current_time = self._get_current_time() total_entries = len(self._cache) - expired_entries = sum(1 for entry in self._cache.values() - if current_time >= entry['expires_at']) + expired_entries = sum( + 1 for entry in self._cache.values() if current_time >= entry['expires_at'] + ) active_entries = total_entries - expired_entries return { 'total_entries': total_entries, 'active_entries': active_entries, 'expired_entries': expired_entries, 'maxsize': self._maxsize, - 'usage_percent': (total_entries / self._maxsize * 100) if self._maxsize > 0 else 0 + 'usage_percent': (total_entries / self._maxsize * 100) if self._maxsize > 0 else 0, } def _cleanup_expired(self): current_time = self._get_current_time() expired_keys = [ - key for key, entry in self._cache.items() - if current_time >= entry['expires_at'] + key for key, entry in self._cache.items() if current_time >= entry['expires_at'] ] for key in expired_keys: del self._cache[key] if key in self._access_order: self._access_order.remove(key) if expired_keys: - logging.getLogger('doorman.cache').info(f'Cleaned up {len(expired_keys)} expired cache entries') + logging.getLogger('doorman.cache').info( + f'Cleaned up {len(expired_keys)} expired cache entries' + ) def stop_auto_save(self): return + class DoormanCacheManager: def __init__(self): - cache_flag = os.getenv('MEM_OR_EXTERNAL') if cache_flag is None: cache_flag = os.getenv('MEM_OR_REDIS', 'MEM') @@ -124,12 +127,14 @@ class DoormanCacheManager: port=redis_port, db=redis_db, decode_responses=True, - max_connections=100 + max_connections=100, ) self.cache = redis.StrictRedis(connection_pool=pool) self.is_redis = True except Exception as e: - logging.getLogger('doorman.cache').warning(f'Redis connection failed, falling back to memory cache: {e}') + logging.getLogger('doorman.cache').warning( + f'Redis connection failed, falling back to memory cache: {e}' + ) maxsize = int(os.getenv('CACHE_MAX_SIZE', 10000)) self.cache = MemoryCache(maxsize=maxsize) self.is_redis = False @@ -150,7 +155,7 @@ class DoormanCacheManager: 'endpoint_server_cache': 'endpoint_server_cache:', 'client_routing_cache': 'client_routing_cache:', 'token_def_cache': 'token_def_cache:', - 'credit_def_cache': 'credit_def_cache:' + 'credit_def_cache': 'credit_def_cache:', } self.default_ttls = { 'api_cache': 86400, @@ -167,7 +172,7 @@ class DoormanCacheManager: 'endpoint_server_cache': 86400, 'client_routing_cache': 86400, 'token_def_cache': 86400, - 'credit_def_cache': 86400 + 'credit_def_cache': 86400, } def _get_key(self, cache_name, key): @@ -255,7 +260,7 @@ class DoormanCacheManager: 'type': self.cache_type, 'is_redis': self.is_redis, 'prefixes': list(self.prefixes.keys()), - 'default_ttl': self.default_ttls + 'default_ttl': self.default_ttls, } if not self.is_redis and hasattr(self.cache, 'get_cache_stats'): info['memory_stats'] = self.cache.get_cache_stats() @@ -267,7 +272,6 @@ class DoormanCacheManager: self.cache._cleanup_expired() def force_save_cache(self): - return def stop_cache_persistence(self): @@ -317,8 +321,9 @@ class DoormanCacheManager: elif hasattr(result, 'deleted_count') and result.deleted_count > 0: self.delete_cache(cache_name, key) return result - except Exception as e: + except Exception: self.delete_cache(cache_name, key) raise + doorman_cache = DoormanCacheManager() diff --git a/backend-services/utils/encryption_util.py b/backend-services/utils/encryption_util.py index a11150e..6d1a781 100644 --- a/backend-services/utils/encryption_util.py +++ b/backend-services/utils/encryption_util.py @@ -1,10 +1,11 @@ -from typing import Optional -import os import base64 import hashlib +import os + from cryptography.fernet import Fernet -def _get_cipher() -> Optional[Fernet]: + +def _get_cipher() -> Fernet | None: """Return a Fernet cipher derived from TOKEN_ENCRYPTION_KEY or MEM_ENCRYPTION_KEY. If neither is set, returns None (plaintext compatibility mode). """ @@ -12,16 +13,15 @@ def _get_cipher() -> Optional[Fernet]: if not key: return None try: - Fernet(key) fkey = key except Exception: - digest = hashlib.sha256(key.encode('utf-8')).digest() fkey = base64.urlsafe_b64encode(digest) return Fernet(fkey) -def encrypt_value(value: Optional[str]) -> Optional[str]: + +def encrypt_value(value: str | None) -> str | None: if value is None: return None cipher = _get_cipher() @@ -30,7 +30,8 @@ def encrypt_value(value: Optional[str]) -> Optional[str]: token = cipher.encrypt(value.encode('utf-8')).decode('utf-8') return f'enc:{token}' -def decrypt_value(value: Optional[str]) -> Optional[str]: + +def decrypt_value(value: str | None) -> str | None: if value is None: return None if not isinstance(value, str): @@ -39,11 +40,9 @@ def decrypt_value(value: Optional[str]) -> Optional[str]: return value cipher = _get_cipher() if not cipher: - return None try: raw = value[4:] return cipher.decrypt(raw.encode('utf-8')).decode('utf-8') except Exception: return None - diff --git a/backend-services/utils/enhanced_metrics_util.py b/backend-services/utils/enhanced_metrics_util.py index f123ff5..b645de9 100644 --- a/backend-services/utils/enhanced_metrics_util.py +++ b/backend-services/utils/enhanced_metrics_util.py @@ -10,31 +10,26 @@ Extends the existing metrics_util.py with: """ from __future__ import annotations -import time -import os -from collections import defaultdict, deque -from typing import Dict, List, Optional, Deque -from models.analytics_models import ( - EnhancedMinuteBucket, - PercentileMetrics, - EndpointMetrics, - AnalyticsSnapshot -) +import os +import time +from collections import defaultdict, deque + +from models.analytics_models import AnalyticsSnapshot, EnhancedMinuteBucket, PercentileMetrics from utils.analytics_aggregator import analytics_aggregator class EnhancedMetricsStore: """ Enhanced version of MetricsStore with analytics capabilities. - + Backward compatible with existing metrics_util.py while adding: - Per-endpoint performance tracking - Full percentile calculations - Unique user counting - Automatic aggregation to 5-min/hourly/daily buckets """ - + def __init__(self, max_minutes: int = 60 * 24): # 24 hours of minute-level data # Global counters (backward compatible) self.total_requests: int = 0 @@ -43,60 +38,60 @@ class EnhancedMetricsStore: self.total_bytes_out: int = 0 self.total_upstream_timeouts: int = 0 self.total_retries: int = 0 - self.status_counts: Dict[int, int] = defaultdict(int) - self.username_counts: Dict[str, int] = defaultdict(int) - self.api_counts: Dict[str, int] = defaultdict(int) - + self.status_counts: dict[int, int] = defaultdict(int) + self.username_counts: dict[str, int] = defaultdict(int) + self.api_counts: dict[str, int] = defaultdict(int) + # Enhanced: Use EnhancedMinuteBucket instead of MinuteBucket - self._buckets: Deque[EnhancedMinuteBucket] = deque() + self._buckets: deque[EnhancedMinuteBucket] = deque() self._max_minutes = max_minutes - + # Track last aggregation times self._last_aggregation_check = 0 - + @staticmethod def _minute_floor(ts: float) -> int: """Floor timestamp to nearest minute.""" return int(ts // 60) * 60 - + def _ensure_bucket(self, minute_start: int) -> EnhancedMinuteBucket: """Get or create bucket for the given minute.""" if self._buckets and self._buckets[-1].start_ts == minute_start: return self._buckets[-1] - + # Create new enhanced bucket mb = EnhancedMinuteBucket(start_ts=minute_start) self._buckets.append(mb) - + # Maintain max size while len(self._buckets) > self._max_minutes: - old_bucket = self._buckets.popleft() + self._buckets.popleft() # Trigger aggregation for old buckets self._maybe_aggregate() - + return mb - + def record( self, status: int, duration_ms: float, - username: Optional[str] = None, - api_key: Optional[str] = None, - endpoint_uri: Optional[str] = None, - method: Optional[str] = None, + username: str | None = None, + api_key: str | None = None, + endpoint_uri: str | None = None, + method: str | None = None, bytes_in: int = 0, - bytes_out: int = 0 + bytes_out: int = 0, ) -> None: """ Record a request with enhanced tracking. - + Backward compatible with existing record() calls while supporting new parameters for per-endpoint tracking. """ now = time.time() minute_start = self._minute_floor(now) bucket = self._ensure_bucket(minute_start) - + # Record with enhanced tracking bucket.add_request( ms=duration_ms, @@ -106,9 +101,9 @@ class EnhancedMetricsStore: endpoint_uri=endpoint_uri, method=method, bytes_in=bytes_in, - bytes_out=bytes_out + bytes_out=bytes_out, ) - + # Update global counters (backward compatible) self.total_requests += 1 self.total_ms += duration_ms @@ -117,17 +112,17 @@ class EnhancedMetricsStore: self.total_bytes_out += int(bytes_out or 0) except Exception: pass - + self.status_counts[status] += 1 if username: self.username_counts[username] += 1 if api_key: self.api_counts[api_key] += 1 - + # Check if aggregation should run self._maybe_aggregate() - - def record_retry(self, api_key: Optional[str] = None) -> None: + + def record_retry(self, api_key: str | None = None) -> None: """Record a retry event.""" now = time.time() minute_start = self._minute_floor(now) @@ -137,8 +132,8 @@ class EnhancedMetricsStore: self.total_retries += 1 except Exception: pass - - def record_upstream_timeout(self, api_key: Optional[str] = None) -> None: + + def record_upstream_timeout(self, api_key: str | None = None) -> None: """Record an upstream timeout event.""" now = time.time() minute_start = self._minute_floor(now) @@ -148,48 +143,45 @@ class EnhancedMetricsStore: self.total_upstream_timeouts += 1 except Exception: pass - + def _maybe_aggregate(self) -> None: """ Check if aggregation should run and trigger if needed. - + Runs aggregation jobs based on time elapsed: - 5-minute aggregation: Every 5 minutes - Hourly aggregation: Every hour - Daily aggregation: Once per day """ now = int(time.time()) - + # Only check every minute to avoid overhead if now - self._last_aggregation_check < 60: return - + self._last_aggregation_check = now - + # Check what aggregations should run should_run = analytics_aggregator.should_aggregate() - + if should_run.get('5minute'): # Get last 5 minutes of buckets minute_buckets = list(self._buckets)[-5:] if minute_buckets: analytics_aggregator.aggregate_to_5minute(minute_buckets) - + if should_run.get('hourly'): analytics_aggregator.aggregate_to_hourly() - + if should_run.get('daily'): analytics_aggregator.aggregate_to_daily() - + def get_snapshot( - self, - start_ts: int, - end_ts: int, - granularity: str = 'auto' + self, start_ts: int, end_ts: int, granularity: str = 'auto' ) -> AnalyticsSnapshot: """ Get analytics snapshot for a time range. - + Automatically selects best aggregation level based on range. """ # Determine which buckets to use @@ -207,7 +199,7 @@ class EnhancedMetricsStore: buckets = [b for b in self._buckets if start_ts <= b.start_ts <= end_ts] else: buckets = analytics_aggregator.get_buckets_for_range(start_ts, end_ts) - + if not buckets: # Return empty snapshot return AnalyticsSnapshot( @@ -225,98 +217,107 @@ class EnhancedMetricsStore: top_apis=[], top_users=[], top_endpoints=[], - status_distribution={} + status_distribution={}, ) - + # Aggregate data from buckets total_requests = sum(b.count for b in buckets) total_errors = sum(b.error_count for b in buckets) total_ms = sum(b.total_ms for b in buckets) total_bytes_in = sum(b.bytes_in for b in buckets) total_bytes_out = sum(b.bytes_out for b in buckets) - + # Collect all latencies for percentile calculation - all_latencies: List[float] = [] + all_latencies: list[float] = [] for bucket in buckets: all_latencies.extend(list(bucket.latencies)) - - percentiles = PercentileMetrics.calculate(all_latencies) if all_latencies else PercentileMetrics() - + + percentiles = ( + PercentileMetrics.calculate(all_latencies) if all_latencies else PercentileMetrics() + ) + # Count unique users unique_users = set() for bucket in buckets: unique_users.update(bucket.unique_users) - + # Aggregate status counts - status_distribution: Dict[str, int] = defaultdict(int) + status_distribution: dict[str, int] = defaultdict(int) for bucket in buckets: for status, count in bucket.status_counts.items(): status_distribution[str(status)] += count - + # Aggregate API counts - api_counts: Dict[str, int] = defaultdict(int) + api_counts: dict[str, int] = defaultdict(int) for bucket in buckets: for api, count in bucket.api_counts.items(): api_counts[api] += count - + # Aggregate user counts - user_counts: Dict[str, int] = defaultdict(int) + user_counts: dict[str, int] = defaultdict(int) for bucket in buckets: for user, count in bucket.user_counts.items(): user_counts[user] += count - + # Aggregate endpoint metrics - endpoint_metrics: Dict[str, Dict] = defaultdict(lambda: { - 'count': 0, - 'error_count': 0, - 'total_ms': 0.0, - 'latencies': [] - }) + endpoint_metrics: dict[str, dict] = defaultdict( + lambda: {'count': 0, 'error_count': 0, 'total_ms': 0.0, 'latencies': []} + ) for bucket in buckets: for endpoint_key, ep_metrics in bucket.endpoint_metrics.items(): endpoint_metrics[endpoint_key]['count'] += ep_metrics.count endpoint_metrics[endpoint_key]['error_count'] += ep_metrics.error_count endpoint_metrics[endpoint_key]['total_ms'] += ep_metrics.total_ms endpoint_metrics[endpoint_key]['latencies'].extend(list(ep_metrics.latencies)) - + # Build top endpoints list top_endpoints = [] for endpoint_key, metrics in endpoint_metrics.items(): method, uri = endpoint_key.split(':', 1) avg_ms = metrics['total_ms'] / metrics['count'] if metrics['count'] > 0 else 0.0 - ep_percentiles = PercentileMetrics.calculate(metrics['latencies']) if metrics['latencies'] else PercentileMetrics() - - top_endpoints.append({ - 'endpoint_uri': uri, - 'method': method, - 'count': metrics['count'], - 'error_count': metrics['error_count'], - 'error_rate': metrics['error_count'] / metrics['count'] if metrics['count'] > 0 else 0.0, - 'avg_ms': avg_ms, - 'percentiles': ep_percentiles.to_dict() - }) - + ep_percentiles = ( + PercentileMetrics.calculate(metrics['latencies']) + if metrics['latencies'] + else PercentileMetrics() + ) + + top_endpoints.append( + { + 'endpoint_uri': uri, + 'method': method, + 'count': metrics['count'], + 'error_count': metrics['error_count'], + 'error_rate': metrics['error_count'] / metrics['count'] + if metrics['count'] > 0 + else 0.0, + 'avg_ms': avg_ms, + 'percentiles': ep_percentiles.to_dict(), + } + ) + # Sort by count (most used) top_endpoints.sort(key=lambda x: x['count'], reverse=True) - + # Build time-series data series = [] for bucket in buckets: avg_ms = bucket.total_ms / bucket.count if bucket.count > 0 else 0.0 bucket_percentiles = bucket.get_percentiles() - - series.append({ - 'timestamp': bucket.start_ts, - 'count': bucket.count, - 'error_count': bucket.error_count, - 'error_rate': bucket.error_count / bucket.count if bucket.count > 0 else 0.0, - 'avg_ms': avg_ms, - 'percentiles': bucket_percentiles.to_dict(), - 'bytes_in': bucket.bytes_in, - 'bytes_out': bucket.bytes_out, - 'unique_users': bucket.get_unique_user_count() - }) - + + series.append( + { + 'timestamp': bucket.start_ts, + 'count': bucket.count, + 'error_count': bucket.error_count, + 'error_rate': bucket.error_count / bucket.count if bucket.count > 0 else 0.0, + 'avg_ms': avg_ms, + 'percentiles': bucket_percentiles.to_dict(), + 'bytes_in': bucket.bytes_in, + 'bytes_out': bucket.bytes_out, + 'unique_users': bucket.get_unique_user_count(), + } + ) + # Create snapshot return AnalyticsSnapshot( start_ts=start_ts, @@ -333,32 +334,30 @@ class EnhancedMetricsStore: top_apis=sorted(api_counts.items(), key=lambda x: x[1], reverse=True)[:10], top_users=sorted(user_counts.items(), key=lambda x: x[1], reverse=True)[:10], top_endpoints=top_endpoints[:10], - status_distribution=dict(status_distribution) + status_distribution=dict(status_distribution), ) - - def snapshot(self, range_key: str, group: str = 'minute', sort: str = 'asc') -> Dict: + + def snapshot(self, range_key: str, group: str = 'minute', sort: str = 'asc') -> dict: """ Backward compatible snapshot method for existing monitor endpoints. """ - range_to_minutes = { - '1h': 60, - '24h': 60 * 24, - '7d': 60 * 24 * 7, - '30d': 60 * 24 * 30, - } + range_to_minutes = {'1h': 60, '24h': 60 * 24, '7d': 60 * 24 * 7, '30d': 60 * 24 * 30} minutes = range_to_minutes.get(range_key, 60 * 24) - buckets: List[EnhancedMinuteBucket] = list(self._buckets)[-minutes:] + buckets: list[EnhancedMinuteBucket] = list(self._buckets)[-minutes:] series = [] - + if group == 'day': from collections import defaultdict - day_map: Dict[int, Dict[str, float]] = defaultdict(lambda: { - 'count': 0, - 'error_count': 0, - 'total_ms': 0.0, - 'bytes_in': 0, - 'bytes_out': 0, - }) + + day_map: dict[int, dict[str, float]] = defaultdict( + lambda: { + 'count': 0, + 'error_count': 0, + 'total_ms': 0.0, + 'bytes_in': 0, + 'bytes_out': 0, + } + ) for b in buckets: day_ts = int((b.start_ts // 86400) * 86400) d = day_map[day_ts] @@ -369,38 +368,44 @@ class EnhancedMetricsStore: d['bytes_out'] += b.bytes_out for day_ts, d in day_map.items(): avg_ms = (d['total_ms'] / d['count']) if d['count'] else 0.0 - series.append({ - 'timestamp': day_ts, - 'count': int(d['count']), - 'error_count': int(d['error_count']), - 'avg_ms': avg_ms, - 'bytes_in': int(d['bytes_in']), - 'bytes_out': int(d['bytes_out']), - 'error_rate': (int(d['error_count']) / int(d['count'])) if d['count'] else 0.0, - }) + series.append( + { + 'timestamp': day_ts, + 'count': int(d['count']), + 'error_count': int(d['error_count']), + 'avg_ms': avg_ms, + 'bytes_in': int(d['bytes_in']), + 'bytes_out': int(d['bytes_out']), + 'error_rate': (int(d['error_count']) / int(d['count'])) + if d['count'] + else 0.0, + } + ) else: for b in buckets: avg_ms = (b.total_ms / b.count) if b.count else 0.0 percentiles = b.get_percentiles() - series.append({ - 'timestamp': b.start_ts, - 'count': b.count, - 'error_count': b.error_count, - 'avg_ms': avg_ms, - 'p95_ms': percentiles.p95, - 'bytes_in': b.bytes_in, - 'bytes_out': b.bytes_out, - 'error_rate': (b.error_count / b.count) if b.count else 0.0, - 'upstream_timeouts': b.upstream_timeouts, - 'retries': b.retries, - }) - - reverse = (str(sort).lower() == 'desc') + series.append( + { + 'timestamp': b.start_ts, + 'count': b.count, + 'error_count': b.error_count, + 'avg_ms': avg_ms, + 'p95_ms': percentiles.p95, + 'bytes_in': b.bytes_in, + 'bytes_out': b.bytes_out, + 'error_rate': (b.error_count / b.count) if b.count else 0.0, + 'upstream_timeouts': b.upstream_timeouts, + 'retries': b.retries, + } + ) + + reverse = str(sort).lower() == 'desc' try: series.sort(key=lambda x: x.get('timestamp', 0), reverse=reverse) except Exception: pass - + total = self.total_requests avg_total_ms = (self.total_ms / total) if total else 0.0 status = {str(k): v for k, v in self.status_counts.items()} @@ -413,33 +418,36 @@ class EnhancedMetricsStore: 'total_retries': self.total_retries, 'status_counts': status, 'series': series, - 'top_users': sorted(self.username_counts.items(), key=lambda kv: kv[1], reverse=True)[:10], + 'top_users': sorted(self.username_counts.items(), key=lambda kv: kv[1], reverse=True)[ + :10 + ], 'top_apis': sorted(self.api_counts.items(), key=lambda kv: kv[1], reverse=True)[:10], } - + def save_to_file(self, path: str) -> None: """Save metrics to file for persistence.""" try: os.makedirs(os.path.dirname(path), exist_ok=True) except Exception: pass - + try: import json + tmp = path + '.tmp' data = { 'total_requests': self.total_requests, 'total_ms': self.total_ms, 'total_bytes_in': self.total_bytes_in, 'total_bytes_out': self.total_bytes_out, - 'buckets': [b.to_dict() for b in list(self._buckets)] + 'buckets': [b.to_dict() for b in list(self._buckets)], } with open(tmp, 'w', encoding='utf-8') as f: json.dump(data, f) os.replace(tmp, path) except Exception: pass - + def load_from_file(self, path: str) -> None: """Load metrics from file.""" # Note: Simplified version - full implementation would reconstruct diff --git a/backend-services/utils/error_codes.py b/backend-services/utils/error_codes.py index 5a40ecf..8c4e0c0 100644 --- a/backend-services/utils/error_codes.py +++ b/backend-services/utils/error_codes.py @@ -22,6 +22,7 @@ Usage: ).dict() """ + class ErrorCode: """ Centralized error code constants. @@ -217,7 +218,9 @@ class ErrorCode: RATE_LIMIT_EXCEEDED = 'RATE_LIMIT_EXCEEDED' + # Alias for backward compatibility class ErrorCodes(ErrorCode): """Deprecated: Use ErrorCode instead.""" + pass diff --git a/backend-services/utils/error_util.py b/backend-services/utils/error_util.py index d84efb0..c7bfb96 100644 --- a/backend-services/utils/error_util.py +++ b/backend-services/utils/error_util.py @@ -2,18 +2,20 @@ Standardized error response utilities """ +from typing import Any + from models.response_model import ResponseModel from utils.response_util import process_response -from typing import Optional, Dict, Any + def create_error_response( status_code: int, error_code: str, error_message: str, - request_id: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - api_type: str = 'rest' -) -> Dict[str, Any]: + request_id: str | None = None, + data: dict[str, Any] | None = None, + api_type: str = 'rest', +) -> dict[str, Any]: """ Create a standardized error response using ResponseModel. """ @@ -25,17 +27,18 @@ def create_error_response( response_headers=response_headers, error_code=error_code, error_message=error_message, - response=data + response=data, ) return process_response(response_model.dict(), api_type) + def success_response( status_code: int = 200, - message: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None, - api_type: str = 'rest' -) -> Dict[str, Any]: + message: str | None = None, + data: dict[str, Any] | None = None, + request_id: str | None = None, + api_type: str = 'rest', +) -> dict[str, Any]: """ Create a standardized success response using ResponseModel. """ @@ -43,9 +46,6 @@ def success_response( if request_id: response_headers['request_id'] = request_id response_model = ResponseModel( - status_code=status_code, - response_headers=response_headers, - message=message, - response=data + status_code=status_code, response_headers=response_headers, message=message, response=data ) return process_response(response_model.dict(), api_type) diff --git a/backend-services/utils/gateway_utils.py b/backend-services/utils/gateway_utils.py index 7be23ce..92c28ac 100644 --- a/backend-services/utils/gateway_utils.py +++ b/backend-services/utils/gateway_utils.py @@ -1,7 +1,7 @@ -import re -from typing import Dict, List -from fastapi import Request import logging +import re + +from fastapi import Request _logger = logging.getLogger('doorman.gateway') @@ -17,6 +17,7 @@ SENSITIVE_HEADERS = { 'csrf-token', } + def sanitize_headers(value: str): """Sanitize header values to prevent injection attacks. @@ -37,6 +38,7 @@ def sanitize_headers(value: str): except Exception: return '' + def redact_sensitive_header(header_name: str, header_value: str) -> str: """Redact sensitive header values for logging purposes. @@ -69,7 +71,8 @@ def redact_sensitive_header(header_name: str, header_value: str) -> str: except Exception: return '[REDACTION_ERROR]' -def log_headers_safely(request: Request, allowed_headers: List[str] = None, redact: bool = True): + +def log_headers_safely(request: Request, allowed_headers: list[str] = None, redact: bool = True): """Log request headers safely with redaction. Args: @@ -100,12 +103,13 @@ def log_headers_safely(request: Request, allowed_headers: List[str] = None, reda headers_to_log[key] = sanitized if headers_to_log: - _logger.debug(f"Request headers: {headers_to_log}") + _logger.debug(f'Request headers: {headers_to_log}') except Exception as e: - _logger.debug(f"Failed to log headers safely: {e}") + _logger.debug(f'Failed to log headers safely: {e}') -async def get_headers(request: Request, allowed_headers: List[str]): + +async def get_headers(request: Request, allowed_headers: list[str]): """Extract and sanitize allowed headers from request. This function is used for forwarding headers to upstream services. diff --git a/backend-services/utils/geo_lookup.py b/backend-services/utils/geo_lookup.py index 96ed6f9..6967c0c 100644 --- a/backend-services/utils/geo_lookup.py +++ b/backend-services/utils/geo_lookup.py @@ -6,9 +6,7 @@ Uses MaxMind GeoLite2 database (free) or can be extended to use commercial servi """ import logging -from typing import Optional, Dict, List, Set from dataclasses import dataclass -from datetime import datetime from utils.redis_client import RedisClient, get_redis_client @@ -18,45 +16,46 @@ logger = logging.getLogger(__name__) @dataclass class GeoLocation: """Geographic location information""" + ip: str - country_code: Optional[str] = None - country_name: Optional[str] = None - region: Optional[str] = None - city: Optional[str] = None - latitude: Optional[float] = None - longitude: Optional[float] = None - timezone: Optional[str] = None + country_code: str | None = None + country_name: str | None = None + region: str | None = None + city: str | None = None + latitude: float | None = None + longitude: float | None = None + timezone: str | None = None class GeoLookup: """ Geographic IP lookup and country-based rate limiting - + Note: This is a simplified implementation. In production, integrate with: - MaxMind GeoLite2 (free): https://dev.maxmind.com/geoip/geolite2-free-geolocation-data - MaxMind GeoIP2 (paid): More accurate - IP2Location: Alternative service """ - - def __init__(self, redis_client: Optional[RedisClient] = None): + + def __init__(self, redis_client: RedisClient | None = None): """Initialize geo lookup""" self.redis = redis_client or get_redis_client() - + # Cache TTL for geo lookups (24 hours) self.cache_ttl = 86400 - + def lookup_ip(self, ip: str) -> GeoLocation: """ Lookup geographic information for IP - + In production, this would use MaxMind GeoIP2 or similar service. For now, returns cached data or placeholder. """ try: # Check cache first - cache_key = f"geo:cache:{ip}" + cache_key = f'geo:cache:{ip}' cached = self.redis.hgetall(cache_key) - + if cached: return GeoLocation( ip=ip, @@ -66,9 +65,9 @@ class GeoLookup: city=cached.get('city'), latitude=float(cached['latitude']) if cached.get('latitude') else None, longitude=float(cached['longitude']) if cached.get('longitude') else None, - timezone=cached.get('timezone') + timezone=cached.get('timezone'), ) - + # TODO: Integrate with MaxMind GeoIP2 # Example integration: # import geoip2.database @@ -79,23 +78,23 @@ class GeoLookup: # city = response.city.name # latitude = response.location.latitude # longitude = response.location.longitude - + # For now, return placeholder with unknown location geo = GeoLocation(ip=ip, country_code='UNKNOWN') - + # Cache the result self._cache_geo_data(ip, geo) - + return geo - + except Exception as e: - logger.error(f"Error looking up IP geolocation: {e}") + logger.error(f'Error looking up IP geolocation: {e}') return GeoLocation(ip=ip) - + def _cache_geo_data(self, ip: str, geo: GeoLocation) -> None: """Cache geo data in Redis""" try: - cache_key = f"geo:cache:{ip}" + cache_key = f'geo:cache:{ip}' data = { 'country_code': geo.country_code or '', 'country_name': geo.country_name or '', @@ -103,197 +102,197 @@ class GeoLookup: 'city': geo.city or '', 'latitude': str(geo.latitude) if geo.latitude else '', 'longitude': str(geo.longitude) if geo.longitude else '', - 'timezone': geo.timezone or '' + 'timezone': geo.timezone or '', } self.redis.hmset(cache_key, data) self.redis.expire(cache_key, self.cache_ttl) except Exception as e: - logger.error(f"Error caching geo data: {e}") - + logger.error(f'Error caching geo data: {e}') + def is_country_blocked(self, country_code: str) -> bool: """Check if country is blocked""" try: return bool(self.redis.sismember('geo:blocked_countries', country_code)) except Exception as e: - logger.error(f"Error checking blocked country: {e}") + logger.error(f'Error checking blocked country: {e}') return False - + def is_country_allowed(self, country_code: str) -> bool: """ Check if country is in allowlist - + If allowlist is empty, all countries are allowed. If allowlist has entries, only those countries are allowed. """ try: # Get allowlist allowed = self.redis.smembers('geo:allowed_countries') - + # If no allowlist, all countries allowed if not allowed: return True - + # Check if country in allowlist return country_code in allowed except Exception as e: - logger.error(f"Error checking allowed country: {e}") + logger.error(f'Error checking allowed country: {e}') return True - + def block_country(self, country_code: str) -> bool: """Add country to blocklist""" try: self.redis.sadd('geo:blocked_countries', country_code) - logger.info(f"Blocked country: {country_code}") + logger.info(f'Blocked country: {country_code}') return True except Exception as e: - logger.error(f"Error blocking country: {e}") + logger.error(f'Error blocking country: {e}') return False - + def unblock_country(self, country_code: str) -> bool: """Remove country from blocklist""" try: self.redis.srem('geo:blocked_countries', country_code) - logger.info(f"Unblocked country: {country_code}") + logger.info(f'Unblocked country: {country_code}') return True except Exception as e: - logger.error(f"Error unblocking country: {e}") + logger.error(f'Error unblocking country: {e}') return False - + def add_to_allowlist(self, country_code: str) -> bool: """Add country to allowlist""" try: self.redis.sadd('geo:allowed_countries', country_code) - logger.info(f"Added country to allowlist: {country_code}") + logger.info(f'Added country to allowlist: {country_code}') return True except Exception as e: - logger.error(f"Error adding to allowlist: {e}") + logger.error(f'Error adding to allowlist: {e}') return False - + def remove_from_allowlist(self, country_code: str) -> bool: """Remove country from allowlist""" try: self.redis.srem('geo:allowed_countries', country_code) - logger.info(f"Removed country from allowlist: {country_code}") + logger.info(f'Removed country from allowlist: {country_code}') return True except Exception as e: - logger.error(f"Error removing from allowlist: {e}") + logger.error(f'Error removing from allowlist: {e}') return False - - def get_country_rate_limit(self, country_code: str) -> Optional[int]: + + def get_country_rate_limit(self, country_code: str) -> int | None: """ Get custom rate limit for country - + Returns None if no custom limit set. """ try: - limit_key = f"geo:rate_limit:{country_code}" + limit_key = f'geo:rate_limit:{country_code}' limit = self.redis.get(limit_key) return int(limit) if limit else None except Exception as e: - logger.error(f"Error getting country rate limit: {e}") + logger.error(f'Error getting country rate limit: {e}') return None - + def set_country_rate_limit(self, country_code: str, limit: int) -> bool: """Set custom rate limit for country""" try: - limit_key = f"geo:rate_limit:{country_code}" + limit_key = f'geo:rate_limit:{country_code}' self.redis.set(limit_key, str(limit)) - logger.info(f"Set rate limit for {country_code}: {limit}") + logger.info(f'Set rate limit for {country_code}: {limit}') return True except Exception as e: - logger.error(f"Error setting country rate limit: {e}") + logger.error(f'Error setting country rate limit: {e}') return False - - def check_geographic_access(self, ip: str) -> tuple[bool, Optional[str]]: + + def check_geographic_access(self, ip: str) -> tuple[bool, str | None]: """ Check if IP's geographic location is allowed - + Returns: (allowed, reason) tuple """ try: geo = self.lookup_ip(ip) - + if not geo.country_code or geo.country_code == 'UNKNOWN': # Allow unknown locations by default return True, None - + # Check blocklist if self.is_country_blocked(geo.country_code): - return False, f"Country {geo.country_code} is blocked" - + return False, f'Country {geo.country_code} is blocked' + # Check allowlist if not self.is_country_allowed(geo.country_code): - return False, f"Country {geo.country_code} is not in allowlist" - + return False, f'Country {geo.country_code} is not in allowlist' + return True, None - + except Exception as e: - logger.error(f"Error checking geographic access: {e}") + logger.error(f'Error checking geographic access: {e}') # Allow by default on error return True, None - + def track_country_request(self, country_code: str) -> None: """Track request from country for analytics""" try: - counter_key = f"geo:requests:{country_code}" + counter_key = f'geo:requests:{country_code}' self.redis.incr(counter_key) self.redis.expire(counter_key, 86400) # 24 hour window except Exception as e: - logger.error(f"Error tracking country request: {e}") - - def get_geographic_distribution(self) -> List[tuple]: + logger.error(f'Error tracking country request: {e}') + + def get_geographic_distribution(self) -> list[tuple]: """ Get request distribution by country - + Returns: List of (country_code, request_count) tuples """ try: - pattern = "geo:requests:*" + pattern = 'geo:requests:*' keys = [] - + cursor = 0 while True: cursor, batch = self.redis.scan(cursor, match=pattern, count=100) keys.extend(batch) if cursor == 0: break - + # Get counts for each country country_counts = [] for key in keys: country_code = key.replace('geo:requests:', '') count = int(self.redis.get(key) or 0) country_counts.append((country_code, count)) - + # Sort by count country_counts.sort(key=lambda x: x[1], reverse=True) return country_counts - + except Exception as e: - logger.error(f"Error getting geographic distribution: {e}") + logger.error(f'Error getting geographic distribution: {e}') return [] - - def get_blocked_countries(self) -> Set[str]: + + def get_blocked_countries(self) -> set[str]: """Get list of blocked countries""" try: return self.redis.smembers('geo:blocked_countries') except Exception as e: - logger.error(f"Error getting blocked countries: {e}") + logger.error(f'Error getting blocked countries: {e}') return set() - - def get_allowed_countries(self) -> Set[str]: + + def get_allowed_countries(self) -> set[str]: """Get list of allowed countries""" try: return self.redis.smembers('geo:allowed_countries') except Exception as e: - logger.error(f"Error getting allowed countries: {e}") + logger.error(f'Error getting allowed countries: {e}') return set() # Global instance -_geo_lookup: Optional[GeoLookup] = None +_geo_lookup: GeoLookup | None = None def get_geo_lookup() -> GeoLookup: diff --git a/backend-services/utils/group_util.py b/backend-services/utils/group_util.py index 66e70bd..602da96 100644 --- a/backend-services/utils/group_util.py +++ b/backend-services/utils/group_util.py @@ -5,17 +5,19 @@ See https://github.com/pypeople-dev/doorman for more information """ import logging + from fastapi import HTTPException, Request -from utils.doorman_cache_util import doorman_cache from services.user_service import UserService -from utils.database_async import api_collection from utils.async_db import db_find_one from utils.auth_util import auth_required +from utils.database_async import api_collection +from utils.doorman_cache_util import doorman_cache logger = logging.getLogger('doorman.gateway') -async def group_required(request: Request = None, full_path: str = None, user_to_subscribe = None): + +async def group_required(request: Request = None, full_path: str = None, user_to_subscribe=None): try: payload = await auth_required(request) username = payload.get('sub') @@ -37,7 +39,7 @@ async def group_required(request: Request = None, full_path: str = None, user_to prefix = '/api/grpc/' if request: postfix = '/' + request.headers.get('X-API-Version', 'v0') - path = full_path[len(prefix):] if full_path.startswith(prefix) else full_path + path = full_path[len(prefix) :] if full_path.startswith(prefix) else full_path api_and_version = '/'.join(path.split('/')[:2]) + postfix if not api_and_version or '/' not in api_and_version: raise HTTPException(status_code=404, detail='Invalid API path format') @@ -46,11 +48,15 @@ async def group_required(request: Request = None, full_path: str = None, user_to user = await UserService.get_user_by_username_helper(user_to_subscribe) else: user = await UserService.get_user_by_username_helper(username) - api = doorman_cache.get_cache('api_cache', api_and_version) or await db_find_one(api_collection, {'api_name': api_name, 'api_version': api_version}) + api = doorman_cache.get_cache('api_cache', api_and_version) or await db_find_one( + api_collection, {'api_name': api_name, 'api_version': api_version} + ) if not api: raise HTTPException(status_code=404, detail='API not found') if not set(user.get('groups') or []).intersection(api.get('api_allowed_groups') or []): - raise HTTPException(status_code=401, detail='You do not have the correct group for this') + raise HTTPException( + status_code=401, detail='You do not have the correct group for this' + ) except HTTPException as e: raise HTTPException(status_code=e.status_code, detail=e.detail) return payload diff --git a/backend-services/utils/health_check_util.py b/backend-services/utils/health_check_util.py index f314a16..4a104e1 100644 --- a/backend-services/utils/health_check_util.py +++ b/backend-services/utils/health_check_util.py @@ -4,23 +4,24 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/pypeople-dev/doorman for more information """ -import psutil -import time import logging -from datetime import timedelta -from redis.asyncio import Redis import os +import time +from datetime import timedelta -from utils.database import mongodb_client, database +import psutil +from redis.asyncio import Redis + +from utils.database import database, mongodb_client from utils.doorman_cache_util import doorman_cache logger = logging.getLogger('doorman.gateway') START_TIME = time.time() + async def check_mongodb(): try: - if database.memory_only: return True if mongodb_client is None: @@ -31,14 +32,14 @@ async def check_mongodb(): logger.error(f'MongoDB health check failed: {str(e)}') return False + async def check_redis(): try: - if not getattr(doorman_cache, 'is_redis', False): return True redis = Redis.from_url( f'redis://{os.getenv("REDIS_HOST")}:{os.getenv("REDIS_PORT")}/{os.getenv("REDIS_DB")}', - decode_responses=True + decode_responses=True, ) await redis.ping() @@ -47,6 +48,7 @@ async def check_redis(): logger.error(f'Redis health check failed: {str(e)}') return False + def get_memory_usage(): try: process = psutil.Process(os.getpid()) @@ -58,6 +60,7 @@ def get_memory_usage(): logger.error(f'Memory usage check failed: {str(e)}') return 'unknown' + def get_active_connections(): try: process = psutil.Process(os.getpid()) @@ -67,6 +70,7 @@ def get_active_connections(): logger.error(f'Active connections check failed: {str(e)}') return 0 + def get_uptime(): try: uptime_seconds = time.time() - START_TIME diff --git a/backend-services/utils/hot_reload_config.py b/backend-services/utils/hot_reload_config.py index d8c7ba1..c43bb82 100644 --- a/backend-services/utils/hot_reload_config.py +++ b/backend-services/utils/hot_reload_config.py @@ -18,16 +18,19 @@ Usage: hot_config.reload() """ -import os import json -import yaml import logging +import os import threading -from typing import Any, Dict, Callable, Optional +from collections.abc import Callable from pathlib import Path +from typing import Any + +import yaml logger = logging.getLogger('doorman.gateway') + class HotReloadConfig: """ Thread-safe configuration manager with hot reload support. @@ -39,10 +42,10 @@ class HotReloadConfig: - Callbacks for configuration changes """ - def __init__(self, config_file: Optional[str] = None): + def __init__(self, config_file: str | None = None): self._lock = threading.RLock() - self._config: Dict[str, Any] = {} - self._callbacks: Dict[str, list] = {} + self._config: dict[str, Any] = {} + self._callbacks: dict[str, list] = {} self._config_file = config_file or os.getenv('DOORMAN_CONFIG_FILE') self._load_initial_config() @@ -64,7 +67,7 @@ class HotReloadConfig: """Load configuration from YAML or JSON file""" path = Path(filepath) - with open(filepath, 'r') as f: + with open(filepath) as f: if path.suffix in ['.yaml', '.yml']: file_config = yaml.safe_load(f) or {} elif path.suffix == '.json': @@ -80,29 +83,22 @@ class HotReloadConfig: 'LOG_LEVEL', 'LOG_FORMAT', 'LOG_FILE', - 'GATEWAY_TIMEOUT', 'UPSTREAM_TIMEOUT', 'CONNECTION_TIMEOUT', - 'RATE_LIMIT_ENABLED', 'RATE_LIMIT_REQUESTS', 'RATE_LIMIT_WINDOW', - 'CACHE_TTL', 'CACHE_MAX_SIZE', - 'CIRCUIT_BREAKER_ENABLED', 'CIRCUIT_BREAKER_THRESHOLD', 'CIRCUIT_BREAKER_TIMEOUT', - 'RETRY_ENABLED', 'RETRY_MAX_ATTEMPTS', 'RETRY_BACKOFF', - 'METRICS_ENABLED', 'METRICS_INTERVAL', - 'FEATURE_REQUEST_REPLAY', 'FEATURE_AB_TESTING', 'FEATURE_COST_ANALYTICS', @@ -249,7 +245,7 @@ class HotReloadConfig: logger.info('Configuration reload complete') - def dump(self) -> Dict[str, Any]: + def dump(self) -> dict[str, Any]: """Dump current configuration (for debugging)""" with self._lock: config = self._config.copy() @@ -259,10 +255,12 @@ class HotReloadConfig: config[key] = self._parse_value(env_value) return config + hot_config = HotReloadConfig() + # Convenience functions for common config patterns -def get_timeout_config() -> Dict[str, int]: +def get_timeout_config() -> dict[str, int]: """Get all timeout configurations""" return { 'gateway_timeout': hot_config.get_int('GATEWAY_TIMEOUT', 30), @@ -270,7 +268,8 @@ def get_timeout_config() -> Dict[str, int]: 'connection_timeout': hot_config.get_int('CONNECTION_TIMEOUT', 10), } -def get_rate_limit_config() -> Dict[str, Any]: + +def get_rate_limit_config() -> dict[str, Any]: """Get rate limiting configuration""" return { 'enabled': hot_config.get_bool('RATE_LIMIT_ENABLED', True), @@ -278,14 +277,16 @@ def get_rate_limit_config() -> Dict[str, Any]: 'window': hot_config.get_int('RATE_LIMIT_WINDOW', 60), } -def get_cache_config() -> Dict[str, Any]: + +def get_cache_config() -> dict[str, Any]: """Get cache configuration""" return { 'ttl': hot_config.get_int('CACHE_TTL', 300), 'max_size': hot_config.get_int('CACHE_MAX_SIZE', 1000), } -def get_circuit_breaker_config() -> Dict[str, Any]: + +def get_circuit_breaker_config() -> dict[str, Any]: """Get circuit breaker configuration""" return { 'enabled': hot_config.get_bool('CIRCUIT_BREAKER_ENABLED', True), @@ -293,7 +294,8 @@ def get_circuit_breaker_config() -> Dict[str, Any]: 'timeout': hot_config.get_int('CIRCUIT_BREAKER_TIMEOUT', 60), } -def get_retry_config() -> Dict[str, Any]: + +def get_retry_config() -> dict[str, Any]: """Get retry configuration""" return { 'enabled': hot_config.get_bool('RETRY_ENABLED', True), diff --git a/backend-services/utils/http_client.py b/backend-services/utils/http_client.py index f3985d3..1d1919f 100644 --- a/backend-services/utils/http_client.py +++ b/backend-services/utils/http_client.py @@ -15,30 +15,34 @@ Usage: from __future__ import annotations import asyncio +import logging import os import random import time from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any import httpx -import logging + from utils.metrics_util import metrics_store logger = logging.getLogger('doorman.gateway') + class CircuitOpenError(Exception): pass + @dataclass class _BreakerState: failures: int = 0 opened_at: float = 0.0 state: str = 'closed' + class _CircuitManager: def __init__(self) -> None: - self._states: Dict[str, _BreakerState] = {} + self._states: dict[str, _BreakerState] = {} def get(self, key: str) -> _BreakerState: st = self._states.get(key) @@ -75,9 +79,11 @@ class _CircuitManager: st.state = 'open' st.opened_at = self.now() + circuit_manager = _CircuitManager() -def _build_timeout(api_config: Optional[dict]) -> httpx.Timeout: + +def _build_timeout(api_config: dict | None) -> httpx.Timeout: # Per-API overrides if present on document; otherwise env defaults def _f(key: str, env_key: str, default: float) -> float: try: @@ -93,28 +99,31 @@ def _build_timeout(api_config: Optional[dict]) -> httpx.Timeout: pool = _f('api_pool_timeout', 'HTTP_TIMEOUT', 30.0) return httpx.Timeout(connect=connect, read=read, write=write, pool=pool) + def _should_retry_status(status: int) -> bool: return status in (500, 502, 503, 504) + def _backoff_delay(attempt: int) -> float: base = float(os.getenv('HTTP_RETRY_BASE_DELAY', 0.25)) cap = float(os.getenv('HTTP_RETRY_MAX_DELAY', 2.0)) delay = min(cap, base * (2 ** max(0, attempt - 1))) return random.uniform(0, delay) + async def request_with_resilience( client: httpx.AsyncClient, method: str, url: str, *, api_key: str, - headers: Optional[Dict[str, str]] = None, - params: Optional[Dict[str, Any]] = None, + headers: dict[str, str] | None = None, + params: dict[str, Any] | None = None, data: Any = None, json: Any = None, content: Any = None, retries: int = 0, - api_config: Optional[dict] = None, + api_config: dict | None = None, ) -> httpx.Response: """Perform an HTTP request with retries, backoff, and circuit breaker. @@ -132,8 +141,8 @@ async def request_with_resilience( if enabled: circuit_manager.check(api_key, open_seconds) - last_exc: Optional[BaseException] = None - response: Optional[httpx.Response] = None + last_exc: BaseException | None = None + response: httpx.Response | None = None for attempt in range(1, attempts + 1): if attempt > 1: try: @@ -143,13 +152,18 @@ async def request_with_resilience( await asyncio.sleep(_backoff_delay(attempt)) try: try: - requester = getattr(client, 'request') + requester = client.request except Exception: requester = None if requester is not None: response = await requester( - method.upper(), url, - headers=headers, params=params, data=data, json=json, content=content, + method.upper(), + url, + headers=headers, + params=params, + data=data, + json=json, + content=content, timeout=timeout, ) else: @@ -165,10 +179,7 @@ async def request_with_resilience( kwargs['json'] = json elif data is not None: kwargs['json'] = data - response = await meth( - url, - **kwargs, - ) + response = await meth(url, **kwargs) if _should_retry_status(response.status_code) and attempt < attempts: if enabled: diff --git a/backend-services/utils/ip_policy_util.py b/backend-services/utils/ip_policy_util.py index 404b222..e64bf50 100644 --- a/backend-services/utils/ip_policy_util.py +++ b/backend-services/utils/ip_policy_util.py @@ -1,13 +1,14 @@ from __future__ import annotations -from fastapi import Request, HTTPException -from typing import Optional, List - -from utils.security_settings_util import get_cached_settings -from utils.audit_util import audit import os -def _get_client_ip(request: Request, trust_xff: bool) -> Optional[str]: +from fastapi import HTTPException, Request + +from utils.audit_util import audit +from utils.security_settings_util import get_cached_settings + + +def _get_client_ip(request: Request, trust_xff: bool) -> str | None: """Determine client IP with optional proxy trust. When `trust_xff` is True, this prefers headers supplied by trusted proxies. @@ -28,7 +29,14 @@ def _get_client_ip(request: Request, trust_xff: bool) -> Optional[str]: return _ip_in_list(src_ip, trusted) if src_ip else False if trust_xff and _from_trusted_proxy(): - for header in ('x-forwarded-for', 'X-Forwarded-For', 'x-real-ip', 'X-Real-IP', 'cf-connecting-ip', 'CF-Connecting-IP'): + for header in ( + 'x-forwarded-for', + 'X-Forwarded-For', + 'x-real-ip', + 'X-Real-IP', + 'cf-connecting-ip', + 'CF-Connecting-IP', + ): val = request.headers.get(header) if val: ip = val.split(',')[0].strip() @@ -38,11 +46,13 @@ def _get_client_ip(request: Request, trust_xff: bool) -> Optional[str]: except Exception: return request.client.host if request.client else None -def _ip_in_list(ip: str, patterns: List[str]) -> bool: + +def _ip_in_list(ip: str, patterns: list[str]) -> bool: try: import ipaddress + ip_obj = ipaddress.ip_address(ip) - for pat in (patterns or []): + for pat in patterns or []: p = (pat or '').strip() if not p: continue @@ -60,17 +70,20 @@ def _ip_in_list(ip: str, patterns: List[str]) -> bool: except Exception: return False -def _is_loopback(ip: Optional[str]) -> bool: + +def _is_loopback(ip: str | None) -> bool: try: if not ip: return False if ip in ('testserver', 'localhost'): return True import ipaddress + return ipaddress.ip_address(ip).is_loopback except Exception: return False + def enforce_api_ip_policy(request: Request, api: dict): """ Enforce per-API IP policy. @@ -81,16 +94,36 @@ def enforce_api_ip_policy(request: Request, api: dict): """ try: settings = get_cached_settings() - trust_xff = bool(api.get('api_trust_x_forwarded_for')) if api.get('api_trust_x_forwarded_for') is not None else bool(settings.get('trust_x_forwarded_for')) + trust_xff = ( + bool(api.get('api_trust_x_forwarded_for')) + if api.get('api_trust_x_forwarded_for') is not None + else bool(settings.get('trust_x_forwarded_for')) + ) client_ip = _get_client_ip(request, trust_xff) if not client_ip: return try: settings = get_cached_settings() env_flag = os.getenv('LOCAL_HOST_IP_BYPASS') - allow_local = (env_flag.lower() == 'true') if isinstance(env_flag, str) and env_flag.strip() != '' else bool(settings.get('allow_localhost_bypass')) + allow_local = ( + (env_flag.lower() == 'true') + if isinstance(env_flag, str) and env_flag.strip() != '' + else bool(settings.get('allow_localhost_bypass')) + ) direct_ip = getattr(getattr(request, 'client', None), 'host', None) - has_forward = any(request.headers.get(h) for h in ('x-forwarded-for','X-Forwarded-For','x-real-ip','X-Real-IP','cf-connecting-ip','CF-Connecting-IP','forwarded','Forwarded')) + has_forward = any( + request.headers.get(h) + for h in ( + 'x-forwarded-for', + 'X-Forwarded-For', + 'x-real-ip', + 'X-Real-IP', + 'cf-connecting-ip', + 'CF-Connecting-IP', + 'forwarded', + 'Forwarded', + ) + ) if allow_local and direct_ip and _is_loopback(direct_ip) and not has_forward: return except Exception: @@ -100,14 +133,28 @@ def enforce_api_ip_policy(request: Request, api: dict): bl = api.get('api_ip_blacklist') or [] if bl and _ip_in_list(client_ip, bl): try: - audit(request, actor=None, action='ip.api_deny', target=str(api.get('api_id') or api.get('api_name') or 'unknown_api'), status='blocked', details={'reason': 'blacklisted', 'effective_ip': client_ip}) + audit( + request, + actor=None, + action='ip.api_deny', + target=str(api.get('api_id') or api.get('api_name') or 'unknown_api'), + status='blocked', + details={'reason': 'blacklisted', 'effective_ip': client_ip}, + ) except Exception: pass raise HTTPException(status_code=403, detail='API011') if mode == 'whitelist': if not wl or not _ip_in_list(client_ip, wl): try: - audit(request, actor=None, action='ip.api_deny', target=str(api.get('api_id') or api.get('api_name') or 'unknown_api'), status='blocked', details={'reason': 'not_in_whitelist', 'effective_ip': client_ip}) + audit( + request, + actor=None, + action='ip.api_deny', + target=str(api.get('api_id') or api.get('api_name') or 'unknown_api'), + status='blocked', + details={'reason': 'not_in_whitelist', 'effective_ip': client_ip}, + ) except Exception: pass raise HTTPException(status_code=403, detail='API010') diff --git a/backend-services/utils/ip_rate_limiter.py b/backend-services/utils/ip_rate_limiter.py index 06ebe7c..c21afae 100644 --- a/backend-services/utils/ip_rate_limiter.py +++ b/backend-services/utils/ip_rate_limiter.py @@ -6,14 +6,14 @@ and IP reputation scoring. """ import logging -from typing import Optional, List, Set -from datetime import datetime, timedelta +import time from dataclasses import dataclass +from datetime import datetime + from fastapi import Request -import time -from utils.redis_client import RedisClient, get_redis_client from utils.rate_limiter import RateLimiter, RateLimitResult +from utils.redis_client import RedisClient, get_redis_client logger = logging.getLogger(__name__) @@ -21,14 +21,15 @@ logger = logging.getLogger(__name__) @dataclass class IPInfo: """Information about an IP address""" + ip: str is_whitelisted: bool = False is_blacklisted: bool = False reputation_score: int = 100 # 0-100, lower is worse request_count: int = 0 - last_seen: Optional[datetime] = None - countries: Set[str] = None - + last_seen: datetime | None = None + countries: set[str] = None + def __post_init__(self): if self.countries is None: self.countries = set() @@ -38,24 +39,24 @@ class IPRateLimiter: """ IP-based rate limiting with whitelist/blacklist and reputation scoring """ - - def __init__(self, redis_client: Optional[RedisClient] = None): + + def __init__(self, redis_client: RedisClient | None = None): """Initialize IP rate limiter""" self.redis = redis_client or get_redis_client() self.rate_limiter = RateLimiter(redis_client) - + # Default IP rate limits (can be overridden per IP) self.default_ip_limit_per_minute = 60 self.default_ip_limit_per_hour = 1000 - + # Reputation thresholds self.suspicious_threshold = 50 # Below this is suspicious self.ban_threshold = 20 # Below this gets banned - + def extract_client_ip(self, request: Request) -> str: """ Extract client IP from request, handling proxy headers - + Priority order: 1. X-Forwarded-For (first IP in chain) 2. X-Real-IP @@ -66,239 +67,229 @@ class IPRateLimiter: if forwarded_for: # Take the first IP in the chain (original client) return forwarded_for.split(',')[0].strip() - + # Check X-Real-IP real_ip = request.headers.get('X-Real-IP') if real_ip: return real_ip.strip() - + # Fallback to direct connection IP if request.client and request.client.host: return request.client.host - + return 'unknown' - + def is_whitelisted(self, ip: str) -> bool: """Check if IP is whitelisted""" try: return bool(self.redis.sismember('ip:whitelist', ip)) except Exception as e: - logger.error(f"Error checking whitelist: {e}") + logger.error(f'Error checking whitelist: {e}') return False - + def is_blacklisted(self, ip: str) -> bool: """Check if IP is blacklisted""" try: return bool(self.redis.sismember('ip:blacklist', ip)) except Exception as e: - logger.error(f"Error checking blacklist: {e}") + logger.error(f'Error checking blacklist: {e}') return False - + def add_to_whitelist(self, ip: str) -> bool: """Add IP to whitelist""" try: self.redis.sadd('ip:whitelist', ip) - logger.info(f"Added IP to whitelist: {ip}") + logger.info(f'Added IP to whitelist: {ip}') return True except Exception as e: - logger.error(f"Error adding to whitelist: {e}") + logger.error(f'Error adding to whitelist: {e}') return False - - def add_to_blacklist(self, ip: str, duration_seconds: Optional[int] = None) -> bool: + + def add_to_blacklist(self, ip: str, duration_seconds: int | None = None) -> bool: """ Add IP to blacklist - + Args: ip: IP address to blacklist duration_seconds: Optional duration for temporary ban """ try: self.redis.sadd('ip:blacklist', ip) - + if duration_seconds: # Set expiration for temporary ban - ban_key = f"ip:ban:{ip}" + ban_key = f'ip:ban:{ip}' self.redis.setex(ban_key, duration_seconds, '1') - logger.info(f"Temporarily banned IP for {duration_seconds}s: {ip}") + logger.info(f'Temporarily banned IP for {duration_seconds}s: {ip}') else: - logger.info(f"Permanently blacklisted IP: {ip}") - + logger.info(f'Permanently blacklisted IP: {ip}') + return True except Exception as e: - logger.error(f"Error adding to blacklist: {e}") + logger.error(f'Error adding to blacklist: {e}') return False - + def remove_from_whitelist(self, ip: str) -> bool: """Remove IP from whitelist""" try: self.redis.srem('ip:whitelist', ip) - logger.info(f"Removed IP from whitelist: {ip}") + logger.info(f'Removed IP from whitelist: {ip}') return True except Exception as e: - logger.error(f"Error removing from whitelist: {e}") + logger.error(f'Error removing from whitelist: {e}') return False - + def remove_from_blacklist(self, ip: str) -> bool: """Remove IP from blacklist""" try: self.redis.srem('ip:blacklist', ip) - ban_key = f"ip:ban:{ip}" + ban_key = f'ip:ban:{ip}' self.redis.delete(ban_key) - logger.info(f"Removed IP from blacklist: {ip}") + logger.info(f'Removed IP from blacklist: {ip}') return True except Exception as e: - logger.error(f"Error removing from blacklist: {e}") + logger.error(f'Error removing from blacklist: {e}') return False - + def get_reputation_score(self, ip: str) -> int: """ Get reputation score for IP (0-100) - + Lower scores indicate worse reputation. """ try: - score_key = f"ip:reputation:{ip}" + score_key = f'ip:reputation:{ip}' score = self.redis.get(score_key) return int(score) if score else 100 except Exception as e: - logger.error(f"Error getting reputation score: {e}") + logger.error(f'Error getting reputation score: {e}') return 100 - + def update_reputation_score(self, ip: str, delta: int) -> int: """ Update reputation score for IP - + Args: ip: IP address delta: Change in score (positive or negative) - + Returns: New reputation score """ try: - score_key = f"ip:reputation:{ip}" + score_key = f'ip:reputation:{ip}' current_score = self.get_reputation_score(ip) new_score = max(0, min(100, current_score + delta)) - + self.redis.setex(score_key, 86400 * 7, str(new_score)) # 7 day TTL - + # Auto-ban if score too low if new_score <= self.ban_threshold: self.add_to_blacklist(ip, duration_seconds=3600) # 1 hour ban - logger.warning(f"Auto-banned IP due to low reputation: {ip} (score: {new_score})") - + logger.warning(f'Auto-banned IP due to low reputation: {ip} (score: {new_score})') + return new_score except Exception as e: - logger.error(f"Error updating reputation score: {e}") + logger.error(f'Error updating reputation score: {e}') return 100 - + def track_request(self, ip: str) -> None: """Track request from IP for analytics""" try: # Increment request counter - counter_key = f"ip:requests:{ip}" + counter_key = f'ip:requests:{ip}' self.redis.incr(counter_key) self.redis.expire(counter_key, 86400) # 24 hour window - + # Update last seen - last_seen_key = f"ip:last_seen:{ip}" + last_seen_key = f'ip:last_seen:{ip}' self.redis.set(last_seen_key, datetime.now().isoformat()) self.redis.expire(last_seen_key, 86400 * 7) # 7 days - + except Exception as e: - logger.error(f"Error tracking request: {e}") - + logger.error(f'Error tracking request: {e}') + def get_ip_info(self, ip: str) -> IPInfo: """Get comprehensive information about an IP""" try: - request_count_key = f"ip:requests:{ip}" + request_count_key = f'ip:requests:{ip}' request_count = int(self.redis.get(request_count_key) or 0) - - last_seen_key = f"ip:last_seen:{ip}" + + last_seen_key = f'ip:last_seen:{ip}' last_seen_str = self.redis.get(last_seen_key) last_seen = datetime.fromisoformat(last_seen_str) if last_seen_str else None - + return IPInfo( ip=ip, is_whitelisted=self.is_whitelisted(ip), is_blacklisted=self.is_blacklisted(ip), reputation_score=self.get_reputation_score(ip), request_count=request_count, - last_seen=last_seen + last_seen=last_seen, ) except Exception as e: - logger.error(f"Error getting IP info: {e}") + logger.error(f'Error getting IP info: {e}') return IPInfo(ip=ip) - + def check_ip_rate_limit( - self, - ip: str, - limit_per_minute: Optional[int] = None, - limit_per_hour: Optional[int] = None + self, ip: str, limit_per_minute: int | None = None, limit_per_hour: int | None = None ) -> RateLimitResult: """ Check rate limit for specific IP - + Args: ip: IP address limit_per_minute: Override default per-minute limit limit_per_hour: Override default per-hour limit - + Returns: RateLimitResult indicating if request is allowed """ # Check whitelist (always allow) if self.is_whitelisted(ip): - return RateLimitResult(allowed=True, limit=999999, remaining=999999, reset_at=int(time.time()) + 60) - + return RateLimitResult( + allowed=True, limit=999999, remaining=999999, reset_at=int(time.time()) + 60 + ) + # Check blacklist (always deny) if self.is_blacklisted(ip): return RateLimitResult( - allowed=False, - limit=0, - remaining=0, - reset_at=int(time.time()) + 60 + allowed=False, limit=0, remaining=0, reset_at=int(time.time()) + 60 ) - + # Check reputation score reputation = self.get_reputation_score(ip) if reputation < self.suspicious_threshold: # Reduce limits for suspicious IPs limit_per_minute = int((limit_per_minute or self.default_ip_limit_per_minute) * 0.5) limit_per_hour = int((limit_per_hour or self.default_ip_limit_per_hour) * 0.5) - logger.warning(f"Reduced limits for suspicious IP: {ip} (reputation: {reputation})") - + logger.warning(f'Reduced limits for suspicious IP: {ip} (reputation: {reputation})') + # Use defaults if not specified limit_per_minute = limit_per_minute or self.default_ip_limit_per_minute limit_per_hour = limit_per_hour or self.default_ip_limit_per_hour - + # Check per-minute limit - minute_key = f"ip:limit:minute:{ip}" + minute_key = f'ip:limit:minute:{ip}' minute_count = int(self.redis.get(minute_key) or 0) - + if minute_count >= limit_per_minute: # Decrease reputation for rate limit violations self.update_reputation_score(ip, -5) return RateLimitResult( - allowed=False, - limit=limit_per_minute, - remaining=0, - reset_at=int(time.time()) + 60 + allowed=False, limit=limit_per_minute, remaining=0, reset_at=int(time.time()) + 60 ) - + # Check per-hour limit - hour_key = f"ip:limit:hour:{ip}" + hour_key = f'ip:limit:hour:{ip}' hour_count = int(self.redis.get(hour_key) or 0) - + if hour_count >= limit_per_hour: self.update_reputation_score(ip, -5) return RateLimitResult( - allowed=False, - limit=limit_per_hour, - remaining=0, - reset_at=int(time.time()) + 3600 + allowed=False, limit=limit_per_hour, remaining=0, reset_at=int(time.time()) + 3600 ) - + # Increment counters pipe = self.redis.pipeline() pipe.incr(minute_key) @@ -306,83 +297,83 @@ class IPRateLimiter: pipe.incr(hour_key) pipe.expire(hour_key, 3600) pipe.execute() - + # Track request self.track_request(ip) - + return RateLimitResult( allowed=True, limit=limit_per_minute, remaining=limit_per_minute - minute_count - 1, - reset_at=int(time.time()) + 60 + reset_at=int(time.time()) + 60, ) - - def get_top_ips(self, limit: int = 10) -> List[tuple]: + + def get_top_ips(self, limit: int = 10) -> list[tuple]: """ Get top IPs by request volume - + Returns: List of (ip, request_count) tuples """ try: # Scan for all IP request counters - pattern = "ip:requests:*" + pattern = 'ip:requests:*' keys = [] - + cursor = 0 while True: cursor, batch = self.redis.scan(cursor, match=pattern, count=100) keys.extend(batch) if cursor == 0: break - + # Get counts for each IP ip_counts = [] for key in keys[:1000]: # Limit to prevent overwhelming ip = key.replace('ip:requests:', '') count = int(self.redis.get(key) or 0) ip_counts.append((ip, count)) - + # Sort by count and return top N ip_counts.sort(key=lambda x: x[1], reverse=True) return ip_counts[:limit] - + except Exception as e: - logger.error(f"Error getting top IPs: {e}") + logger.error(f'Error getting top IPs: {e}') return [] - + def detect_suspicious_activity(self, ip: str) -> bool: """ Detect suspicious activity patterns - + Returns: True if suspicious activity detected """ try: # Check request rate - minute_key = f"ip:limit:minute:{ip}" + minute_key = f'ip:limit:minute:{ip}' minute_count = int(self.redis.get(minute_key) or 0) - + # Suspicious if > 80% of limit if minute_count > self.default_ip_limit_per_minute * 0.8: - logger.warning(f"Suspicious activity detected for IP: {ip} (high request rate)") + logger.warning(f'Suspicious activity detected for IP: {ip} (high request rate)') self.update_reputation_score(ip, -10) return True - + # Check reputation reputation = self.get_reputation_score(ip) if reputation < self.suspicious_threshold: return True - + return False - + except Exception as e: - logger.error(f"Error detecting suspicious activity: {e}") + logger.error(f'Error detecting suspicious activity: {e}') return False # Global instance -_ip_rate_limiter: Optional[IPRateLimiter] = None +_ip_rate_limiter: IPRateLimiter | None = None def get_ip_rate_limiter() -> IPRateLimiter: diff --git a/backend-services/utils/limit_throttle_util.py b/backend-services/utils/limit_throttle_util.py index 6a77a87..b0202fa 100644 --- a/backend-services/utils/limit_throttle_util.py +++ b/backend-services/utils/limit_throttle_util.py @@ -1,18 +1,19 @@ -from fastapi import Request, HTTPException import asyncio -import time import logging import os +import time +from fastapi import HTTPException, Request + +from utils.async_db import db_find_one from utils.auth_util import auth_required from utils.database_async import user_collection -from utils.async_db import db_find_one -import asyncio from utils.doorman_cache_util import doorman_cache from utils.ip_policy_util import _get_client_ip logger = logging.getLogger('doorman.gateway') + class InMemoryWindowCounter: """Simple in-memory counter with TTL semantics to mimic required Redis ops. @@ -36,6 +37,7 @@ class InMemoryWindowCounter: See: doorman.py app_lifespan() for multi-worker validation """ + def __init__(self): self._store = {} @@ -45,7 +47,6 @@ class InMemoryWindowCounter: if entry and entry['expires_at'] > now: entry['count'] += 1 else: - entry = {'count': 1, 'expires_at': now + 1} self._store[key] = entry return entry['count'] @@ -57,8 +58,10 @@ class InMemoryWindowCounter: entry['expires_at'] = now + int(ttl_seconds) self._store[key] = entry + _fallback_counter = InMemoryWindowCounter() + def duration_to_seconds(duration: str) -> int: mapping = { 'second': 1, @@ -67,7 +70,7 @@ def duration_to_seconds(duration: str) -> int: 'day': 86400, 'week': 604800, 'month': 2592000, - 'year': 31536000 + 'year': 31536000, } if not duration: return 60 @@ -75,13 +78,14 @@ def duration_to_seconds(duration: str) -> int: duration = duration[:-1] return mapping.get(duration.lower(), 60) + async def limit_and_throttle(request: Request): """Enforce user-level rate limiting and throttling. - + **Rate Limiting Hierarchy:** 1. Tier-based limits (checked by TierRateLimitMiddleware first) 2. User-specific overrides (checked here) - + This function provides user-specific rate/throttle settings that override or supplement tier-based limits. The TierRateLimitMiddleware runs first and enforces tier limits, then this function applies user-specific rules. @@ -163,7 +167,9 @@ async def limit_and_throttle(request: Request): await _fallback_counter.expire(throttle_key, throttle_window) try: if os.getenv('DOORMAN_TEST_MODE', 'false').lower() == 'true': - logger.info(f'[throttle] key={throttle_key} count={throttle_count} qlimit={int(user.get("throttle_queue_limit") or 10)} window={throttle_window}s') + logger.info( + f'[throttle] key={throttle_key} count={throttle_count} qlimit={int(user.get("throttle_queue_limit") or 10)} window={throttle_window}s' + ) except Exception: pass throttle_queue_limit = int(user.get('throttle_queue_limit') or 10) @@ -181,13 +187,18 @@ async def limit_and_throttle(request: Request): throttle_wait *= duration_to_seconds(throttle_wait_duration) dynamic_wait = throttle_wait * (throttle_count - throttle_limit) try: - import sys as _sys, os as _os - if _os.getenv('DOORMAN_TEST_MODE', 'false').lower() == 'true' and _sys.version_info >= (3, 13): + import os as _os + import sys as _sys + + if _os.getenv( + 'DOORMAN_TEST_MODE', 'false' + ).lower() == 'true' and _sys.version_info >= (3, 13): dynamic_wait = max(dynamic_wait, 0.2) except Exception: pass await asyncio.sleep(dynamic_wait) + def reset_counters(): """Reset in-memory rate/throttle counters (used by tests and cache clears). Has no effect when a real Redis client is configured. @@ -197,6 +208,7 @@ def reset_counters(): except Exception: pass + async def limit_by_ip(request: Request, limit: int = 10, window: int = 60): """IP-based rate limiting for endpoints that don't require authentication. @@ -232,12 +244,7 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60): try: if os.getenv('LOGIN_IP_RATE_DISABLED', 'false').lower() == 'true': now = int(time.time()) - return { - 'limit': limit, - 'remaining': limit, - 'reset': now + window, - 'window': window - } + return {'limit': limit, 'remaining': limit, 'reset': now + window, 'window': window} client_ip = _get_client_ip(request, trust_xff=True) if not client_ip: logger.warning('Unable to determine client IP for rate limiting, allowing request') @@ -245,7 +252,7 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60): 'limit': limit, 'remaining': limit, 'reset': int(time.time()) + window, - 'window': window + 'window': window, } now = int(time.time()) @@ -273,7 +280,7 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60): 'limit': limit, 'remaining': remaining, 'reset': reset_time, - 'window': window + 'window': window, } if count > limit: @@ -283,14 +290,14 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60): detail={ 'error_code': 'IP_RATE_LIMIT', 'message': f'Too many requests from your IP address. Please wait {retry_after} seconds before trying again. Limit: {limit} requests per {window} seconds.', - 'retry_after': retry_after + 'retry_after': retry_after, }, headers={ 'Retry-After': str(retry_after), 'X-RateLimit-Limit': str(limit), 'X-RateLimit-Remaining': '0', - 'X-RateLimit-Reset': str(reset_time) - } + 'X-RateLimit-Reset': str(reset_time), + }, ) if count > (limit * 0.8): @@ -306,5 +313,5 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60): 'limit': limit, 'remaining': limit, 'reset': int(time.time()) + window, - 'window': window + 'window': window, } diff --git a/backend-services/utils/memory_dump_util.py b/backend-services/utils/memory_dump_util.py index c870285..0f33731 100644 --- a/backend-services/utils/memory_dump_util.py +++ b/backend-services/utils/memory_dump_util.py @@ -2,30 +2,28 @@ Utilities to dump and restore in-memory database state with encryption. """ -import os -from pathlib import Path -import json import base64 -from typing import Optional, Any -from datetime import datetime, timezone -from cryptography.hazmat.primitives.kdf.hkdf import HKDF +import json +import os +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.kdf.hkdf import HKDF from .database import database _PROJECT_ROOT = Path(__file__).resolve().parent.parent DEFAULT_DUMP_PATH = os.getenv('MEM_DUMP_PATH', str(_PROJECT_ROOT / 'generated' / 'memory_dump.bin')) + def _derive_key(key_material: str, salt: bytes) -> bytes: - hkdf = HKDF( - algorithm=hashes.SHA256(), - length=32, - salt=salt, - info=b'doorman-mem-dump-v1', - ) + hkdf = HKDF(algorithm=hashes.SHA256(), length=32, salt=salt, info=b'doorman-mem-dump-v1') return hkdf.derive(key_material.encode('utf-8')) + def _encrypt_blob(plaintext: bytes, key_str: str) -> bytes: if not key_str or len(key_str) < 8: raise ValueError('MEM_ENCRYPTION_KEY must be set and at least 8 characters') @@ -37,6 +35,7 @@ def _encrypt_blob(plaintext: bytes, key_str: str) -> bytes: return b'DMP1' + salt + nonce + ct + def _decrypt_blob(blob: bytes, key_str: str) -> bytes: if not key_str or len(key_str) < 8: raise ValueError('MEM_ENCRYPTION_KEY must be set and at least 8 characters') @@ -51,7 +50,8 @@ def _decrypt_blob(blob: bytes, key_str: str) -> bytes: aesgcm = AESGCM(key) return aesgcm.decrypt(nonce, ct, None) -def _split_dir_and_stem(path_hint: Optional[str]) -> tuple[str, str]: + +def _split_dir_and_stem(path_hint: str | None) -> tuple[str, str]: """Return (directory, stem) for naming timestamped dump files. - If hint is a directory (or endswith '/'), use it and default stem 'memory_dump'. @@ -73,8 +73,10 @@ def _split_dir_and_stem(path_hint: Optional[str]) -> tuple[str, str]: stem = stem or 'memory_dump' return dump_dir, stem + BYTES_KEY_PREFIX = '__byteskey__:' + def _to_jsonable(obj: Any) -> Any: """Recursively convert arbitrary objects to JSON-serializable structures. @@ -90,7 +92,6 @@ def _to_jsonable(obj: Any) -> Any: if isinstance(obj, dict): out = {} for k, v in obj.items(): - if isinstance(k, bytes): sk = BYTES_KEY_PREFIX + base64.b64encode(k).decode('ascii') elif isinstance(k, (str, int, float, bool)) or k is None: @@ -112,15 +113,16 @@ def _to_jsonable(obj: Any) -> Any: except Exception: return None + def _json_default(o: Any) -> Any: if isinstance(o, bytes): return {'__type__': 'bytes', 'data': base64.b64encode(o).decode('ascii')} try: - return str(o) except Exception: return None + def _from_jsonable(obj: Any) -> Any: """Inverse of _to_jsonable for the specific encodings we apply.""" if isinstance(obj, dict): @@ -133,7 +135,7 @@ def _from_jsonable(obj: Any) -> Any: for k, v in obj.items(): rk: Any = k if isinstance(k, str) and k.startswith(BYTES_KEY_PREFIX): - b64 = k[len(BYTES_KEY_PREFIX):] + b64 = k[len(BYTES_KEY_PREFIX) :] try: rk = base64.b64decode(b64) except Exception: @@ -144,23 +146,41 @@ def _from_jsonable(obj: Any) -> Any: return [_from_jsonable(v) for v in obj] return obj + def _sanitize_for_dump(data: Any) -> Any: """ Remove sensitive data before dumping to prevent secret exposure. """ SENSITIVE_KEYS = { - 'password', 'secret', 'token', 'key', 'api_key', - 'access_token', 'refresh_token', 'jwt', 'jwt_secret', - 'csrf_token', 'session', 'cookie', - 'credential', 'auth', 'authorization', - 'ssn', 'credit_card', 'cvv', 'private_key', - 'encryption_key', 'signing_key' + 'password', + 'secret', + 'token', + 'key', + 'api_key', + 'access_token', + 'refresh_token', + 'jwt', + 'jwt_secret', + 'csrf_token', + 'session', + 'cookie', + 'credential', + 'auth', + 'authorization', + 'ssn', + 'credit_card', + 'cvv', + 'private_key', + 'encryption_key', + 'signing_key', } + def should_redact(key: str) -> bool: if not isinstance(key, str): return False key_lower = key.lower() return any(s in key_lower for s in SENSITIVE_KEYS) + def redact_value(obj: Any) -> Any: if isinstance(obj, dict): return { @@ -175,20 +195,22 @@ def _sanitize_for_dump(data: Any) -> Any: if cleaned.isalnum(): return '[REDACTED-TOKEN]' return obj + return redact_value(data) -def dump_memory_to_file(path: Optional[str] = None) -> str: + +def dump_memory_to_file(path: str | None = None) -> str: if not database.memory_only: raise RuntimeError('Memory dump is only available in memory-only mode') dump_dir, stem = _split_dir_and_stem(path) os.makedirs(dump_dir, exist_ok=True) - ts = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%SZ') + ts = datetime.now(UTC).strftime('%Y%m%dT%H%M%SZ') dump_path = os.path.join(dump_dir, f'{stem}-{ts}.bin') raw_data = database.db.dump_data() sanitized_data = _sanitize_for_dump(_to_jsonable(raw_data)) payload = { 'version': 1, - 'created_at': datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z'), + 'created_at': datetime.now(UTC).isoformat().replace('+00:00', 'Z'), 'sanitized': True, 'note': 'Sensitive fields (passwords, tokens, secrets) have been redacted', 'data': sanitized_data, @@ -200,7 +222,8 @@ def dump_memory_to_file(path: Optional[str] = None) -> str: f.write(blob) return dump_path -def restore_memory_from_file(path: Optional[str] = None) -> dict: + +def restore_memory_from_file(path: str | None = None) -> dict: if not database.memory_only: raise RuntimeError('Memory restore is only available in memory-only mode') dump_path = path or DEFAULT_DUMP_PATH @@ -215,34 +238,41 @@ def restore_memory_from_file(path: Optional[str] = None) -> dict: data = _from_jsonable(payload.get('data', {})) database.db.load_data(data) try: - from utils.database import user_collection - from utils import password_util as _pw import os as _os + + from utils import password_util as _pw + from utils.database import user_collection + admin = user_collection.find_one({'username': 'admin'}) if admin is not None and not isinstance(admin.get('password'), (bytes, bytearray)): pwd = _os.getenv('DOORMAN_ADMIN_PASSWORD') if not pwd: raise RuntimeError('DOORMAN_ADMIN_PASSWORD must be set in environment') - user_collection.update_one({'username': 'admin'}, {'$set': {'password': _pw.hash_password(pwd)}}) + user_collection.update_one( + {'username': 'admin'}, {'$set': {'password': _pw.hash_password(pwd)}} + ) except Exception: pass return {'version': payload.get('version', 1), 'created_at': payload.get('created_at')} -def find_latest_dump_path(path_hint: Optional[str] = None) -> Optional[str]: + +def find_latest_dump_path(path_hint: str | None = None) -> str | None: """Return the most recent dump file path based on a hint. - If `path_hint` is a file and exists, return it. - If `path_hint` is a directory, search for .bin files and pick the newest. - If no hint or not found, try DEFAULT_DUMP_PATH, or its directory for .bin files. """ - def newest_bin_in_dir(d: str, stem: Optional[str] = None) -> Optional[str]: + + def newest_bin_in_dir(d: str, stem: str | None = None) -> str | None: try: if not os.path.isdir(d): return None candidates = [ os.path.join(d, f) for f in os.listdir(d) - if f.lower().endswith('.bin') and os.path.isfile(os.path.join(d, f)) + if f.lower().endswith('.bin') + and os.path.isfile(os.path.join(d, f)) and (stem is None or f.startswith(stem + '-')) ] if not candidates: diff --git a/backend-services/utils/metrics_util.py b/backend-services/utils/metrics_util.py index ad109bb..5af0b96 100644 --- a/backend-services/utils/metrics_util.py +++ b/backend-services/utils/metrics_util.py @@ -4,12 +4,13 @@ Records count, status code distribution, and response time stats, with per-minut """ from __future__ import annotations + +import json +import os import time from collections import defaultdict, deque from dataclasses import dataclass, field -from typing import Deque, Dict, List, Optional -import json -import os + @dataclass class MinuteBucket: @@ -22,13 +23,21 @@ class MinuteBucket: upstream_timeouts: int = 0 retries: int = 0 - status_counts: Dict[int, int] = field(default_factory=dict) - api_counts: Dict[str, int] = field(default_factory=dict) - api_error_counts: Dict[str, int] = field(default_factory=dict) - user_counts: Dict[str, int] = field(default_factory=dict) - latencies: Deque[float] = field(default_factory=deque) + status_counts: dict[int, int] = field(default_factory=dict) + api_counts: dict[str, int] = field(default_factory=dict) + api_error_counts: dict[str, int] = field(default_factory=dict) + user_counts: dict[str, int] = field(default_factory=dict) + latencies: deque[float] = field(default_factory=deque) - def add(self, ms: float, status: int, username: Optional[str], api_key: Optional[str], bytes_in: int = 0, bytes_out: int = 0) -> None: + def add( + self, + ms: float, + status: int, + username: str | None, + api_key: str | None, + bytes_in: int = 0, + bytes_out: int = 0, + ) -> None: self.count += 1 if status >= 400: self.error_count += 1 @@ -68,7 +77,7 @@ class MinuteBucket: except Exception: pass - def to_dict(self) -> Dict: + def to_dict(self) -> dict: return { 'start_ts': self.start_ts, 'count': self.count, @@ -85,7 +94,7 @@ class MinuteBucket: } @staticmethod - def from_dict(d: Dict) -> 'MinuteBucket': + def from_dict(d: dict) -> MinuteBucket: mb = MinuteBucket( start_ts=int(d.get('start_ts', 0)), count=int(d.get('count', 0)), @@ -105,6 +114,7 @@ class MinuteBucket: pass return mb + class MetricsStore: def __init__(self, max_minutes: int = 60 * 24 * 30): self.total_requests: int = 0 @@ -113,10 +123,10 @@ class MetricsStore: self.total_bytes_out: int = 0 self.total_upstream_timeouts: int = 0 self.total_retries: int = 0 - self.status_counts: Dict[int, int] = defaultdict(int) - self.username_counts: Dict[str, int] = defaultdict(int) - self.api_counts: Dict[str, int] = defaultdict(int) - self._buckets: Deque[MinuteBucket] = deque() + self.status_counts: dict[int, int] = defaultdict(int) + self.username_counts: dict[str, int] = defaultdict(int) + self.api_counts: dict[str, int] = defaultdict(int) + self._buckets: deque[MinuteBucket] = deque() self._max_minutes = max_minutes @staticmethod @@ -134,7 +144,15 @@ class MetricsStore: self._buckets.popleft() return mb - def record(self, status: int, duration_ms: float, username: Optional[str] = None, api_key: Optional[str] = None, bytes_in: int = 0, bytes_out: int = 0) -> None: + def record( + self, + status: int, + duration_ms: float, + username: str | None = None, + api_key: str | None = None, + bytes_in: int = 0, + bytes_out: int = 0, + ) -> None: now = time.time() minute_start = self._minute_floor(now) bucket = self._ensure_bucket(minute_start) @@ -152,7 +170,7 @@ class MetricsStore: if api_key: self.api_counts[api_key] += 1 - def record_retry(self, api_key: Optional[str] = None) -> None: + def record_retry(self, api_key: str | None = None) -> None: now = time.time() minute_start = self._minute_floor(now) bucket = self._ensure_bucket(minute_start) @@ -162,7 +180,7 @@ class MetricsStore: except Exception: pass - def record_upstream_timeout(self, api_key: Optional[str] = None) -> None: + def record_upstream_timeout(self, api_key: str | None = None) -> None: now = time.time() minute_start = self._minute_floor(now) bucket = self._ensure_bucket(minute_start) @@ -172,27 +190,24 @@ class MetricsStore: except Exception: pass - def snapshot(self, range_key: str, group: str = 'minute', sort: str = 'asc') -> Dict: - - range_to_minutes = { - '1h': 60, - '24h': 60 * 24, - '7d': 60 * 24 * 7, - '30d': 60 * 24 * 30, - } + def snapshot(self, range_key: str, group: str = 'minute', sort: str = 'asc') -> dict: + range_to_minutes = {'1h': 60, '24h': 60 * 24, '7d': 60 * 24 * 7, '30d': 60 * 24 * 30} minutes = range_to_minutes.get(range_key, 60 * 24) - buckets: List[MinuteBucket] = list(self._buckets)[-minutes:] + buckets: list[MinuteBucket] = list(self._buckets)[-minutes:] series = [] if group == 'day': from collections import defaultdict - day_map: Dict[int, Dict[str, float]] = defaultdict(lambda: { - 'count': 0, - 'error_count': 0, - 'total_ms': 0.0, - 'bytes_in': 0, - 'bytes_out': 0, - }) + + day_map: dict[int, dict[str, float]] = defaultdict( + lambda: { + 'count': 0, + 'error_count': 0, + 'total_ms': 0.0, + 'bytes_in': 0, + 'bytes_out': 0, + } + ) for b in buckets: day_ts = int((b.start_ts // 86400) * 86400) d = day_map[day_ts] @@ -203,15 +218,19 @@ class MetricsStore: d['bytes_out'] += b.bytes_out for day_ts, d in day_map.items(): avg_ms = (d['total_ms'] / d['count']) if d['count'] else 0.0 - series.append({ - 'timestamp': day_ts, - 'count': int(d['count']), - 'error_count': int(d['error_count']), - 'avg_ms': avg_ms, - 'bytes_in': int(d['bytes_in']), - 'bytes_out': int(d['bytes_out']), - 'error_rate': (int(d['error_count']) / int(d['count'])) if d['count'] else 0.0, - }) + series.append( + { + 'timestamp': day_ts, + 'count': int(d['count']), + 'error_count': int(d['error_count']), + 'avg_ms': avg_ms, + 'bytes_in': int(d['bytes_in']), + 'bytes_out': int(d['bytes_out']), + 'error_rate': (int(d['error_count']) / int(d['count'])) + if d['count'] + else 0.0, + } + ) else: for b in buckets: avg_ms = (b.total_ms / b.count) if b.count else 0.0 @@ -224,20 +243,22 @@ class MetricsStore: p95 = float(arr[k]) except Exception: p95 = 0.0 - series.append({ - 'timestamp': b.start_ts, - 'count': b.count, - 'error_count': b.error_count, - 'avg_ms': avg_ms, - 'p95_ms': p95, - 'bytes_in': b.bytes_in, - 'bytes_out': b.bytes_out, - 'error_rate': (b.error_count / b.count) if b.count else 0.0, - 'upstream_timeouts': b.upstream_timeouts, - 'retries': b.retries, - }) + series.append( + { + 'timestamp': b.start_ts, + 'count': b.count, + 'error_count': b.error_count, + 'avg_ms': avg_ms, + 'p95_ms': p95, + 'bytes_in': b.bytes_in, + 'bytes_out': b.bytes_out, + 'error_rate': (b.error_count / b.count) if b.count else 0.0, + 'upstream_timeouts': b.upstream_timeouts, + 'retries': b.retries, + } + ) - reverse = (str(sort).lower() == 'desc') + reverse = str(sort).lower() == 'desc' try: series.sort(key=lambda x: x.get('timestamp', 0), reverse=reverse) except Exception: @@ -255,11 +276,13 @@ class MetricsStore: 'total_retries': self.total_retries, 'status_counts': status, 'series': series, - 'top_users': sorted(self.username_counts.items(), key=lambda kv: kv[1], reverse=True)[:10], + 'top_users': sorted(self.username_counts.items(), key=lambda kv: kv[1], reverse=True)[ + :10 + ], 'top_apis': sorted(self.api_counts.items(), key=lambda kv: kv[1], reverse=True)[:10], } - def to_dict(self) -> Dict: + def to_dict(self) -> dict: return { 'total_requests': int(self.total_requests), 'total_ms': float(self.total_ms), @@ -271,7 +294,7 @@ class MetricsStore: 'buckets': [b.to_dict() for b in list(self._buckets)], } - def load_dict(self, data: Dict) -> None: + def load_dict(self, data: dict) -> None: try: self.total_requests = int(data.get('total_requests', 0)) self.total_ms = float(data.get('total_ms', 0.0)) @@ -308,11 +331,12 @@ class MetricsStore: try: if not os.path.exists(path): return - with open(path, 'r', encoding='utf-8') as f: + with open(path, encoding='utf-8') as f: data = json.load(f) if isinstance(data, dict): self.load_dict(data) except Exception: pass + metrics_store = MetricsStore() diff --git a/backend-services/utils/paging_util.py b/backend-services/utils/paging_util.py index 7d9ac3f..7f9dd6d 100644 --- a/backend-services/utils/paging_util.py +++ b/backend-services/utils/paging_util.py @@ -2,6 +2,7 @@ import os from utils.constants import Defaults + def max_page_size() -> int: try: env = os.getenv(Defaults.MAX_PAGE_SIZE_ENV) @@ -11,6 +12,7 @@ def max_page_size() -> int: except Exception: return Defaults.MAX_PAGE_SIZE_DEFAULT + def validate_page_params(page: int, page_size: int) -> tuple[int, int]: p = int(page) ps = int(page_size) @@ -22,4 +24,3 @@ def validate_page_params(page: int, page_size: int) -> tuple[int, int]: if ps > m: raise ValueError(f'page_size must be <= {m}') return p, ps - diff --git a/backend-services/utils/password_util.py b/backend-services/utils/password_util.py index d51f021..fdc4680 100644 --- a/backend-services/utils/password_util.py +++ b/backend-services/utils/password_util.py @@ -6,14 +6,17 @@ See https://github.com/pypeople-dev/doorman for more information import bcrypt + def hash_password(password: str): hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()) return hashed_password + def verify_password(password: str, hashed_password: str): password = password.encode('utf-8') return bcrypt.checkpw(password, hashed_password) + def is_secure_password(password: str): if len(password) < 16: return False @@ -25,4 +28,4 @@ def is_secure_password(password: str): return False if not any(c in '!@#$%^&*()-_=+[]{};:,.<>?/' for c in password): return False - return True \ No newline at end of file + return True diff --git a/backend-services/utils/quota_tracker.py b/backend-services/utils/quota_tracker.py index ccc108f..a30697f 100644 --- a/backend-services/utils/quota_tracker.py +++ b/backend-services/utils/quota_tracker.py @@ -5,13 +5,12 @@ Tracks usage quotas (monthly, daily) for users and APIs. Supports quota limits, rollover, and exhaustion detection. """ -import time import logging -from datetime import datetime, timedelta -from typing import Optional, Dict, List from dataclasses import dataclass -from models.rate_limit_models import QuotaUsage, QuotaType, generate_quota_key -from utils.redis_client import get_redis_client, RedisClient +from datetime import datetime, timedelta + +from models.rate_limit_models import QuotaType, QuotaUsage, generate_quota_key +from utils.redis_client import RedisClient, get_redis_client logger = logging.getLogger(__name__) @@ -19,6 +18,7 @@ logger = logging.getLogger(__name__) @dataclass class QuotaCheckResult: """Result of quota check""" + allowed: bool current_usage: int limit: int @@ -32,7 +32,7 @@ class QuotaCheckResult: class QuotaTracker: """ Quota tracker for managing usage quotas - + Features: - Monthly and daily quota tracking - Automatic reset at period boundaries @@ -40,49 +40,45 @@ class QuotaTracker: - Quota exhaustion detection - Historical usage tracking """ - - def __init__(self, redis_client: Optional[RedisClient] = None): + + def __init__(self, redis_client: RedisClient | None = None): """ Initialize quota tracker - + Args: redis_client: Redis client instance """ self.redis = redis_client or get_redis_client() - + def check_quota( - self, - user_id: str, - quota_type: QuotaType, - limit: int, - period: str = 'month' + self, user_id: str, quota_type: QuotaType, limit: int, period: str = 'month' ) -> QuotaCheckResult: """ Check if user has quota available - + Args: user_id: User identifier quota_type: Type of quota (requests, bandwidth, etc.) limit: Quota limit period: 'month' or 'day' - + Returns: QuotaCheckResult with current status """ # Get current period key period_key = self._get_period_key(period) quota_key = generate_quota_key(user_id, quota_type, period_key) - + try: # Read usage and reset keys stored separately - usage = self.redis.get(f"{quota_key}:usage") - reset_at_raw = self.redis.get(f"{quota_key}:reset_at") - + usage = self.redis.get(f'{quota_key}:usage') + reset_at_raw = self.redis.get(f'{quota_key}:reset_at') + if usage is None: current_usage = 0 else: current_usage = int(usage) - + if reset_at_raw: try: reset_at = datetime.fromisoformat(reset_at_raw) @@ -92,26 +88,26 @@ class QuotaTracker: reset_at = self._get_next_reset(period) # Initialize reset_at for future checks try: - self.redis.set(f"{quota_key}:reset_at", reset_at.isoformat()) + self.redis.set(f'{quota_key}:reset_at', reset_at.isoformat()) except Exception: pass - + # Reset if period elapsed if datetime.now() >= reset_at: current_usage = 0 reset_at = self._get_next_reset(period) try: - self.redis.set(f"{quota_key}:usage", 0) - self.redis.set(f"{quota_key}:reset_at", reset_at.isoformat()) + self.redis.set(f'{quota_key}:usage', 0) + self.redis.set(f'{quota_key}:reset_at', reset_at.isoformat()) except Exception: pass - + remaining = max(0, limit - current_usage) percentage_used = (current_usage / limit * 100) if limit > 0 else 0 is_warning = percentage_used >= 80 is_critical = percentage_used >= 95 allowed = current_usage < limit - + # Attach derived exhaustion flag for test expectations result = QuotaCheckResult( allowed=allowed, @@ -121,17 +117,17 @@ class QuotaTracker: reset_at=reset_at, percentage_used=percentage_used, is_warning=is_warning, - is_critical=is_critical + is_critical=is_critical, ) # Inject attribute expected by tests try: - setattr(result, 'is_exhausted', not allowed) + result.is_exhausted = not allowed except Exception: pass return result - + except Exception as e: - logger.error(f"Quota check error for {user_id}: {e}") + logger.error(f'Quota check error for {user_id}: {e}') # Graceful degradation: allow on error res = QuotaCheckResult( allowed=True, @@ -139,190 +135,173 @@ class QuotaTracker: limit=limit, remaining=limit, reset_at=self._get_next_reset(period), - percentage_used=0.0 + percentage_used=0.0, ) try: - setattr(res, 'is_exhausted', False) + res.is_exhausted = False except Exception: pass return res - + def increment_quota( - self, - user_id: str, - quota_type: QuotaType, - amount: int = 1, - period: str = 'month' + self, user_id: str, quota_type: QuotaType, amount: int = 1, period: str = 'month' ) -> int: """ Increment quota usage - + Args: user_id: User identifier quota_type: Type of quota amount: Amount to increment period: 'month' or 'day' - + Returns: New usage value """ period_key = self._get_period_key(period) quota_key = generate_quota_key(user_id, quota_type, period_key) - + try: # Increment usage (string amount tolerated by Redis mock in tests) - new_usage = self.redis.incr(f"{quota_key}:usage", amount) - + new_usage = self.redis.incr(f'{quota_key}:usage', amount) + # Ensure reset_at is set - if not self.redis.exists(f"{quota_key}:reset_at"): + if not self.redis.exists(f'{quota_key}:reset_at'): reset_at = self._get_next_reset(period) try: - self.redis.set(f"{quota_key}:reset_at", reset_at.isoformat()) + self.redis.set(f'{quota_key}:reset_at', reset_at.isoformat()) except Exception: pass - + return new_usage - + except Exception as e: - logger.error(f"Error incrementing quota for {user_id}: {e}") + logger.error(f'Error incrementing quota for {user_id}: {e}') return 0 - + def get_quota_usage( - self, - user_id: str, - quota_type: QuotaType, - limit: int, - period: str = 'month' + self, user_id: str, quota_type: QuotaType, limit: int, period: str = 'month' ) -> QuotaUsage: """ Get current quota usage - + Args: user_id: User identifier quota_type: Type of quota limit: Quota limit period: 'month' or 'day' - + Returns: QuotaUsage object """ period_key = self._get_period_key(period) quota_key = generate_quota_key(user_id, quota_type, period_key) - + try: usage_data = self.redis.hmget(quota_key, ['usage', 'reset_at']) - + if usage_data[0] is None: current_usage = 0 reset_at = self._get_next_reset(period) else: current_usage = int(usage_data[0]) reset_at = datetime.fromisoformat(usage_data[1]) - + # Check if needs reset if datetime.now() >= reset_at: current_usage = 0 reset_at = self._get_next_reset(period) - + return QuotaUsage( key=quota_key, quota_type=quota_type, current_usage=current_usage, limit=limit, - reset_at=reset_at + reset_at=reset_at, ) - + except Exception as e: - logger.error(f"Error getting quota usage for {user_id}: {e}") + logger.error(f'Error getting quota usage for {user_id}: {e}') return QuotaUsage( key=quota_key, quota_type=quota_type, current_usage=0, limit=limit, - reset_at=self._get_next_reset(period) + reset_at=self._get_next_reset(period), ) - - def reset_quota( - self, - user_id: str, - quota_type: QuotaType, - period: str = 'month' - ) -> bool: + + def reset_quota(self, user_id: str, quota_type: QuotaType, period: str = 'month') -> bool: """ Reset quota for user (admin function) - + Args: user_id: User identifier quota_type: Type of quota period: 'month' or 'day' - + Returns: True if successful """ try: period_key = self._get_period_key(period) quota_key = generate_quota_key(user_id, quota_type, period_key) - + # Delete quota key self.redis.delete(quota_key) - logger.info(f"Reset quota for user {user_id}") + logger.info(f'Reset quota for user {user_id}') return True - + except Exception as e: - logger.error(f"Error resetting quota: {e}") + logger.error(f'Error resetting quota: {e}') return False - - def get_all_quotas( - self, - user_id: str, - limits: Dict[QuotaType, int] - ) -> List[QuotaUsage]: + + def get_all_quotas(self, user_id: str, limits: dict[QuotaType, int]) -> list[QuotaUsage]: """ Get all quota usages for a user - + Args: user_id: User identifier limits: Dictionary of quota type to limit - + Returns: List of QuotaUsage objects """ quotas = [] - + for quota_type, limit in limits.items(): usage = self.get_quota_usage(user_id, quota_type, limit) quotas.append(usage) - + return quotas - + def check_and_increment( self, user_id: str, quota_type: QuotaType, limit: int, amount: int = 1, - period: str = 'month' + period: str = 'month', ) -> QuotaCheckResult: """ Check quota and increment if allowed (atomic operation) - + Args: user_id: User identifier quota_type: Type of quota limit: Quota limit amount: Amount to increment period: 'month' or 'day' - + Returns: QuotaCheckResult """ # First check if quota is available check_result = self.check_quota(user_id, quota_type, limit, period) - + if check_result.allowed: # Increment quota new_usage = self.increment_quota(user_id, quota_type, amount, period) - + # Update result with new values check_result.current_usage = new_usage check_result.remaining = max(0, limit - new_usage) @@ -330,40 +309,40 @@ class QuotaTracker: check_result.is_warning = check_result.percentage_used >= 80 check_result.is_critical = check_result.percentage_used >= 95 check_result.allowed = new_usage <= limit - + return check_result - + def _get_period_key(self, period: str) -> str: """ Get period key for current time - + Args: period: 'month' or 'day' - + Returns: Period key (e.g., '2025-12' for month, '2025-12-02' for day) """ now = datetime.now() - + if period == 'month': return now.strftime('%Y-%m') elif period == 'day': return now.strftime('%Y-%m-%d') else: - raise ValueError(f"Invalid period: {period}") - + raise ValueError(f'Invalid period: {period}') + def _get_next_reset(self, period: str) -> datetime: """ Get next reset time for period - + Args: period: 'month' or 'day' - + Returns: Next reset datetime """ now = datetime.now() - + if period == 'month': # Next month, first day, midnight if now.month == 12: @@ -375,61 +354,53 @@ class QuotaTracker: tomorrow = now + timedelta(days=1) return datetime(tomorrow.year, tomorrow.month, tomorrow.day, 0, 0, 0) else: - raise ValueError(f"Invalid period: {period}") - + raise ValueError(f'Invalid period: {period}') + def _initialize_quota(self, quota_key: str, reset_at: datetime): """ Initialize quota in Redis - + Args: quota_key: Redis key for quota reset_at: Reset datetime """ try: - self.redis.hmset(quota_key, { - 'usage': 0, - 'reset_at': reset_at.isoformat() - }) - + self.redis.hmset(quota_key, {'usage': 0, 'reset_at': reset_at.isoformat()}) + # Set TTL to slightly after reset time ttl_seconds = int((reset_at - datetime.now()).total_seconds()) + 3600 self.redis.expire(quota_key, ttl_seconds) - + except Exception as e: - logger.error(f"Error initializing quota: {e}") - - def get_quota_history( - self, - user_id: str, - quota_type: QuotaType, - months: int = 6 - ) -> List[Dict]: + logger.error(f'Error initializing quota: {e}') + + def get_quota_history(self, user_id: str, quota_type: QuotaType, months: int = 6) -> list[dict]: """ Get historical quota usage (placeholder for future implementation) - + Args: user_id: User identifier quota_type: Type of quota months: Number of months to retrieve - + Returns: List of historical usage data """ # TODO: Implement historical tracking # This would query MongoDB or time-series database - logger.warning("Quota history not yet implemented") + logger.warning('Quota history not yet implemented') return [] # Global quota tracker instance -_quota_tracker: Optional[QuotaTracker] = None +_quota_tracker: QuotaTracker | None = None def get_quota_tracker() -> QuotaTracker: """Get or create global quota tracker instance""" global _quota_tracker - + if _quota_tracker is None: _quota_tracker = QuotaTracker() - + return _quota_tracker diff --git a/backend-services/utils/rate_limit_simulator.py b/backend-services/utils/rate_limit_simulator.py index f86f32c..b85ddd0 100644 --- a/backend-services/utils/rate_limit_simulator.py +++ b/backend-services/utils/rate_limit_simulator.py @@ -6,13 +6,12 @@ and preview the impact of rule changes. """ import logging -from typing import List, Dict, Optional +import random from dataclasses import dataclass from datetime import datetime, timedelta -import random from models.rate_limit_models import RateLimitRule, RuleType, TimeWindow -from utils.rate_limiter import RateLimiter, RateLimitResult +from utils.rate_limiter import RateLimiter logger = logging.getLogger(__name__) @@ -20,6 +19,7 @@ logger = logging.getLogger(__name__) @dataclass class SimulationRequest: """Simulated request""" + timestamp: datetime user_id: str endpoint: str @@ -29,6 +29,7 @@ class SimulationRequest: @dataclass class SimulationResult: """Result of simulation""" + total_requests: int allowed_requests: int blocked_requests: int @@ -36,118 +37,123 @@ class SimulationResult: success_rate: float average_remaining: float peak_usage: int - requests_by_second: Dict[int, int] + requests_by_second: dict[int, int] class RateLimitSimulator: """ Simulate rate limiting scenarios without real traffic """ - + def __init__(self): """Initialize simulator""" self.rate_limiter = RateLimiter() - + def generate_requests( - self, - num_requests: int, - duration_seconds: int, - pattern: str = "uniform" - ) -> List[SimulationRequest]: + self, num_requests: int, duration_seconds: int, pattern: str = 'uniform' + ) -> list[SimulationRequest]: """ Generate simulated requests - + Args: num_requests: Number of requests to generate duration_seconds: Duration over which to spread requests pattern: Distribution pattern (uniform, burst, spike, gradual) - + Returns: List of simulated requests """ requests = [] start_time = datetime.now() - - if pattern == "uniform": + + if pattern == 'uniform': # Evenly distributed interval = duration_seconds / num_requests for i in range(num_requests): timestamp = start_time + timedelta(seconds=i * interval) - requests.append(SimulationRequest( - timestamp=timestamp, - user_id="sim_user", - endpoint="/api/test", - ip="192.168.1.1" - )) - - elif pattern == "burst": + requests.append( + SimulationRequest( + timestamp=timestamp, + user_id='sim_user', + endpoint='/api/test', + ip='192.168.1.1', + ) + ) + + elif pattern == 'burst': # All requests in first 10% of duration burst_duration = duration_seconds * 0.1 interval = burst_duration / num_requests for i in range(num_requests): timestamp = start_time + timedelta(seconds=i * interval) - requests.append(SimulationRequest( - timestamp=timestamp, - user_id="sim_user", - endpoint="/api/test", - ip="192.168.1.1" - )) - - elif pattern == "spike": + requests.append( + SimulationRequest( + timestamp=timestamp, + user_id='sim_user', + endpoint='/api/test', + ip='192.168.1.1', + ) + ) + + elif pattern == 'spike': # Spike in the middle spike_start = duration_seconds * 0.4 spike_duration = duration_seconds * 0.2 interval = spike_duration / num_requests for i in range(num_requests): timestamp = start_time + timedelta(seconds=spike_start + i * interval) - requests.append(SimulationRequest( - timestamp=timestamp, - user_id="sim_user", - endpoint="/api/test", - ip="192.168.1.1" - )) - - elif pattern == "gradual": + requests.append( + SimulationRequest( + timestamp=timestamp, + user_id='sim_user', + endpoint='/api/test', + ip='192.168.1.1', + ) + ) + + elif pattern == 'gradual': # Gradually increasing rate for i in range(num_requests): # Quadratic distribution (more requests toward end) progress = (i / num_requests) ** 2 timestamp = start_time + timedelta(seconds=progress * duration_seconds) - requests.append(SimulationRequest( - timestamp=timestamp, - user_id="sim_user", - endpoint="/api/test", - ip="192.168.1.1" - )) - - elif pattern == "random": + requests.append( + SimulationRequest( + timestamp=timestamp, + user_id='sim_user', + endpoint='/api/test', + ip='192.168.1.1', + ) + ) + + elif pattern == 'random': # Random distribution for i in range(num_requests): random_offset = random.uniform(0, duration_seconds) timestamp = start_time + timedelta(seconds=random_offset) - requests.append(SimulationRequest( - timestamp=timestamp, - user_id="sim_user", - endpoint="/api/test", - ip="192.168.1.1" - )) + requests.append( + SimulationRequest( + timestamp=timestamp, + user_id='sim_user', + endpoint='/api/test', + ip='192.168.1.1', + ) + ) # Sort by timestamp requests.sort(key=lambda r: r.timestamp) - + return requests - + def simulate_rule( - self, - rule: RateLimitRule, - requests: List[SimulationRequest] + self, rule: RateLimitRule, requests: list[SimulationRequest] ) -> SimulationResult: """ Simulate rate limit rule against requests - + Args: rule: Rate limit rule to test requests: List of simulated requests - + Returns: Simulation result with statistics """ @@ -156,25 +162,26 @@ class RateLimitSimulator: burst_used = 0 remaining_values = [] requests_by_second = {} - + # Track usage by second for request in requests: second = int(request.timestamp.timestamp()) requests_by_second[second] = requests_by_second.get(second, 0) + 1 - + # Simulate each request for request in requests: # In real scenario, would check actual Redis counters # For simulation, we'll use simplified logic - + # Calculate current window usage - window_start = request.timestamp - timedelta(seconds=self._get_window_seconds(rule.time_window)) + window_start = request.timestamp - timedelta( + seconds=self._get_window_seconds(rule.time_window) + ) window_requests = [ - r for r in requests - if window_start <= r.timestamp <= request.timestamp + r for r in requests if window_start <= r.timestamp <= request.timestamp ] current_usage = len(window_requests) - + # Check if within limit if current_usage <= rule.limit: allowed += 1 @@ -189,13 +196,13 @@ class RateLimitSimulator: else: blocked += 1 remaining_values.append(0) - + # Calculate statistics total = len(requests) success_rate = (allowed / total * 100) if total > 0 else 0 avg_remaining = sum(remaining_values) / len(remaining_values) if remaining_values else 0 peak_usage = max(requests_by_second.values()) if requests_by_second else 0 - + return SimulationResult( total_requests=total, allowed_requests=allowed, @@ -204,9 +211,9 @@ class RateLimitSimulator: success_rate=success_rate, average_remaining=avg_remaining, peak_usage=peak_usage, - requests_by_second=requests_by_second + requests_by_second=requests_by_second, ) - + def _get_window_seconds(self, window: TimeWindow) -> int: """Get window duration in seconds""" window_map = { @@ -214,120 +221,108 @@ class RateLimitSimulator: TimeWindow.MINUTE: 60, TimeWindow.HOUR: 3600, TimeWindow.DAY: 86400, - TimeWindow.MONTH: 2592000 # 30 days + TimeWindow.MONTH: 2592000, # 30 days } return window_map.get(window, 60) - + def compare_rules( - self, - rules: List[RateLimitRule], - requests: List[SimulationRequest] - ) -> Dict[str, SimulationResult]: + self, rules: list[RateLimitRule], requests: list[SimulationRequest] + ) -> dict[str, SimulationResult]: """ Compare multiple rules against same request pattern - + Args: rules: List of rules to compare requests: Simulated requests - + Returns: Dictionary mapping rule_id to simulation result """ results = {} - + for rule in rules: result = self.simulate_rule(rule, requests) results[rule.rule_id] = result - + return results - + def preview_rule_change( self, current_rule: RateLimitRule, new_rule: RateLimitRule, - historical_pattern: str = "uniform", - duration_minutes: int = 60 - ) -> Dict[str, SimulationResult]: + historical_pattern: str = 'uniform', + duration_minutes: int = 60, + ) -> dict[str, SimulationResult]: """ Preview impact of changing a rule - + Args: current_rule: Current rule configuration new_rule: Proposed new rule configuration historical_pattern: Traffic pattern to simulate duration_minutes: Duration to simulate - + Returns: Comparison of current vs new rule performance """ # Estimate request volume based on current limit estimated_requests = int(current_rule.limit * 1.5) # 150% of limit - + # Generate requests requests = self.generate_requests( num_requests=estimated_requests, duration_seconds=duration_minutes * 60, - pattern=historical_pattern + pattern=historical_pattern, ) - + # Compare rules return self.compare_rules([current_rule, new_rule], requests) - + def test_burst_effectiveness( - self, - base_limit: int, - burst_allowances: List[int], - spike_intensity: float = 2.0 - ) -> Dict[int, SimulationResult]: + self, base_limit: int, burst_allowances: list[int], spike_intensity: float = 2.0 + ) -> dict[int, SimulationResult]: """ Test effectiveness of different burst allowances - + Args: base_limit: Base rate limit burst_allowances: List of burst allowances to test spike_intensity: Multiplier for spike (2.0 = 2x normal rate) - + Returns: Results for each burst allowance """ # Generate spike pattern - normal_requests = base_limit spike_requests = int(base_limit * spike_intensity) - + requests = self.generate_requests( - num_requests=spike_requests, - duration_seconds=60, - pattern="spike" + num_requests=spike_requests, duration_seconds=60, pattern='spike' ) - + results = {} - + for burst in burst_allowances: rule = RateLimitRule( - rule_id=f"burst_{burst}", + rule_id=f'burst_{burst}', rule_type=RuleType.PER_USER, time_window=TimeWindow.MINUTE, limit=base_limit, - burst_allowance=burst + burst_allowance=burst, ) - + result = self.simulate_rule(rule, requests) results[burst] = result - + return results - - def generate_report( - self, - rule: RateLimitRule, - result: SimulationResult - ) -> str: + + def generate_report(self, rule: RateLimitRule, result: SimulationResult) -> str: """ Generate human-readable simulation report - + Args: rule: Rule that was simulated result: Simulation result - + Returns: Formatted report string """ @@ -347,7 +342,7 @@ Simulation Results: Allowed: {result.allowed_requests} ({result.success_rate:.1f}%) Blocked: {result.blocked_requests} Burst Used: {result.burst_used_count} - + Performance Metrics: Success Rate: {result.success_rate:.1f}% Average Remaining: {result.average_remaining:.1f} @@ -355,58 +350,56 @@ Performance Metrics: Recommendation: """ - + # Add recommendations if result.success_rate < 90: - report += " ⚠️ Consider increasing limit or burst allowance\n" + report += ' ⚠️ Consider increasing limit or burst allowance\n' elif result.success_rate > 99 and result.average_remaining > rule.limit * 0.5: - report += " ℹ️ Limit may be too high, consider reducing\n" + report += ' ℹ️ Limit may be too high, consider reducing\n' else: - report += " ✅ Rule configuration appears appropriate\n" - + report += ' ✅ Rule configuration appears appropriate\n' + if result.burst_used_count > 0: - burst_percentage = (result.burst_used_count / result.allowed_requests * 100) - report += f" ℹ️ {burst_percentage:.1f}% of requests used burst tokens\n" - + burst_percentage = result.burst_used_count / result.allowed_requests * 100 + report += f' ℹ️ {burst_percentage:.1f}% of requests used burst tokens\n' + return report - + def run_scenario( self, scenario_name: str, rule: RateLimitRule, - pattern: str = "uniform", - duration_minutes: int = 5 - ) -> Dict: + pattern: str = 'uniform', + duration_minutes: int = 5, + ) -> dict: """ Run a named simulation scenario - + Args: scenario_name: Name of the scenario rule: Rule to test pattern: Traffic pattern duration_minutes: Duration to simulate - + Returns: Scenario results with report """ # Generate requests based on rule limit num_requests = rule.limit * 2 # 2x the limit - + requests = self.generate_requests( - num_requests=num_requests, - duration_seconds=duration_minutes * 60, - pattern=pattern + num_requests=num_requests, duration_seconds=duration_minutes * 60, pattern=pattern ) - + result = self.simulate_rule(rule, requests) report = self.generate_report(rule, result) - + return { 'scenario_name': scenario_name, 'rule': rule, 'pattern': pattern, 'result': result, - 'report': report + 'report': report, } @@ -414,40 +407,36 @@ Recommendation: # HELPER FUNCTIONS # ============================================================================ + def quick_simulate( - limit: int, - time_window: str = "minute", - burst: int = 0, - pattern: str = "uniform" + limit: int, time_window: str = 'minute', burst: int = 0, pattern: str = 'uniform' ) -> str: """ Quick simulation helper - + Args: limit: Rate limit time_window: Time window (second, minute, hour, day) burst: Burst allowance pattern: Traffic pattern - + Returns: Simulation report """ simulator = RateLimitSimulator() - + rule = RateLimitRule( - rule_id="quick_sim", + rule_id='quick_sim', rule_type=RuleType.PER_USER, time_window=TimeWindow(time_window), limit=limit, - burst_allowance=burst + burst_allowance=burst, ) - + requests = simulator.generate_requests( - num_requests=limit * 2, - duration_seconds=60, - pattern=pattern + num_requests=limit * 2, duration_seconds=60, pattern=pattern ) - + result = simulator.simulate_rule(rule, requests) return simulator.generate_report(rule, result) @@ -456,45 +445,43 @@ def quick_simulate( # EXAMPLE USAGE # ============================================================================ -if __name__ == "__main__": +if __name__ == '__main__': # Example: Test different burst allowances simulator = RateLimitSimulator() - - print("Testing burst effectiveness...") + + print('Testing burst effectiveness...') results = simulator.test_burst_effectiveness( - base_limit=100, - burst_allowances=[0, 20, 50, 100], - spike_intensity=2.0 + base_limit=100, burst_allowances=[0, 20, 50, 100], spike_intensity=2.0 ) - + for burst, result in results.items(): - print(f"\nBurst Allowance: {burst}") - print(f" Success Rate: {result.success_rate:.1f}%") - print(f" Burst Used: {result.burst_used_count}") - + print(f'\nBurst Allowance: {burst}') + print(f' Success Rate: {result.success_rate:.1f}%') + print(f' Burst Used: {result.burst_used_count}') + # Example: Preview rule change - print("\n" + "="*50) - print("Previewing rule change...") - + print('\n' + '=' * 50) + print('Previewing rule change...') + current = RateLimitRule( - rule_id="current", + rule_id='current', rule_type=RuleType.PER_USER, time_window=TimeWindow.MINUTE, limit=100, - burst_allowance=20 + burst_allowance=20, ) - + proposed = RateLimitRule( - rule_id="proposed", + rule_id='proposed', rule_type=RuleType.PER_USER, time_window=TimeWindow.MINUTE, limit=150, - burst_allowance=30 + burst_allowance=30, ) - + comparison = simulator.preview_rule_change(current, proposed) - + for rule_id, result in comparison.items(): - print(f"\n{rule_id.upper()}:") - print(f" Success Rate: {result.success_rate:.1f}%") - print(f" Blocked: {result.blocked_requests}") + print(f'\n{rule_id.upper()}:') + print(f' Success Rate: {result.success_rate:.1f}%') + print(f' Blocked: {result.blocked_requests}') diff --git a/backend-services/utils/rate_limiter.py b/backend-services/utils/rate_limiter.py index 287ba8b..cdb4892 100644 --- a/backend-services/utils/rate_limiter.py +++ b/backend-services/utils/rate_limiter.py @@ -5,19 +5,18 @@ Implements token bucket and sliding window algorithms for rate limiting. Supports distributed rate limiting across multiple server instances using Redis. """ -import time import logging -from typing import Optional, Tuple +import time from dataclasses import dataclass + from models.rate_limit_models import ( - RateLimitRule, RateLimitCounter, RateLimitInfo, - TimeWindow, + RateLimitRule, + generate_redis_key, get_time_window_seconds, - generate_redis_key ) -from utils.redis_client import get_redis_client, RedisClient +from utils.redis_client import RedisClient, get_redis_client logger = logging.getLogger(__name__) @@ -25,13 +24,14 @@ logger = logging.getLogger(__name__) @dataclass class RateLimitResult: """Result of rate limit check""" + allowed: bool limit: int remaining: int reset_at: int - retry_after: Optional[int] = None + retry_after: int | None = None burst_remaining: int = 0 - + def to_info(self) -> RateLimitInfo: """Convert to RateLimitInfo""" return RateLimitInfo( @@ -39,43 +39,39 @@ class RateLimitResult: remaining=self.remaining, reset_at=self.reset_at, retry_after=self.retry_after, - burst_remaining=self.burst_remaining + burst_remaining=self.burst_remaining, ) class RateLimiter: """ Rate limiter with token bucket and sliding window algorithms - + Features: - Token bucket for burst handling - Sliding window for accurate rate limiting - Distributed locking for multi-instance support - Graceful degradation if Redis is unavailable """ - - def __init__(self, redis_client: Optional[RedisClient] = None): + + def __init__(self, redis_client: RedisClient | None = None): """ Initialize rate limiter - + Args: redis_client: Redis client instance (creates default if None) """ self.redis = redis_client or get_redis_client() self._fallback_mode = False - - def check_rate_limit( - self, - rule: RateLimitRule, - identifier: str - ) -> RateLimitResult: + + def check_rate_limit(self, rule: RateLimitRule, identifier: str) -> RateLimitResult: """ Check if request is allowed under rate limit rule - + Args: rule: Rate limit rule to apply identifier: Unique identifier (user ID, API name, IP, etc.) - + Returns: RateLimitResult with allow/deny decision """ @@ -85,46 +81,44 @@ class RateLimiter: allowed=True, limit=rule.limit, remaining=rule.limit, - reset_at=int(time.time()) + get_time_window_seconds(rule.time_window) + reset_at=int(time.time()) + get_time_window_seconds(rule.time_window), ) - + # Use sliding window algorithm return self._check_sliding_window(rule, identifier) - - def _check_sliding_window( - self, - rule: RateLimitRule, - identifier: str - ) -> RateLimitResult: + + def _check_sliding_window(self, rule: RateLimitRule, identifier: str) -> RateLimitResult: """ Check rate limit using sliding window counter algorithm - + This is more accurate than fixed window and prevents boundary issues. - + Algorithm: 1. Get current and previous window counts 2. Calculate weighted count based on time elapsed in current window 3. Check if weighted count exceeds limit 4. If allowed, increment current window counter - + Args: rule: Rate limit rule identifier: Unique identifier - + Returns: RateLimitResult """ now = time.time() window_size = get_time_window_seconds(rule.time_window) - + # Current window timestamp and key current_window = int(now / window_size) * window_size - current_key = generate_redis_key(rule.rule_type, identifier, rule.time_window, current_window) - + current_key = generate_redis_key( + rule.rule_type, identifier, rule.time_window, current_window + ) + try: # Use only current window counter for deterministic behavior in unit tests current_count = int(self.redis.get(current_key) or 0) - + # Limit exceeded? if current_count >= rule.limit: reset_at = current_window + window_size @@ -134,21 +128,21 @@ class RateLimiter: limit=rule.limit, remaining=0, reset_at=int(reset_at), - retry_after=retry_after + retry_after=retry_after, ) - + # Burst allowance tracking (does not affect allow when under limit) burst_remaining = rule.burst_allowance if rule.burst_allowance > 0: - burst_key = f"{current_key}:burst" + burst_key = f'{current_key}:burst' burst_count = int(self.redis.get(burst_key) or 0) burst_remaining = max(0, rule.burst_allowance - burst_count) - + # Increment counter and set TTL, but report remaining based on pre-increment value new_count = self.redis.incr(current_key) if new_count == 1: self.redis.expire(current_key, window_size * 2) - + remaining = max(0, rule.limit - current_count) reset_at = current_window + window_size return RateLimitResult( @@ -156,53 +150,49 @@ class RateLimiter: limit=rule.limit, remaining=remaining, reset_at=int(reset_at), - burst_remaining=burst_remaining + burst_remaining=burst_remaining, ) - + except Exception as e: - logger.error(f"Rate limit check error: {e}") + logger.error(f'Rate limit check error: {e}') # Graceful degradation: allow request on error return RateLimitResult( allowed=True, limit=rule.limit, remaining=rule.limit, - reset_at=int(now) + window_size + reset_at=int(now) + window_size, ) - - def check_token_bucket( - self, - rule: RateLimitRule, - identifier: str - ) -> RateLimitResult: + + def check_token_bucket(self, rule: RateLimitRule, identifier: str) -> RateLimitResult: """ Check rate limit using token bucket algorithm - + Token bucket allows bursts while maintaining average rate. - + Algorithm: 1. Calculate tokens to add based on time elapsed 2. Add tokens to bucket (up to limit) 3. Check if enough tokens available 4. If yes, consume token and allow request - + Args: rule: Rate limit rule identifier: Unique identifier - + Returns: RateLimitResult """ now = time.time() window_size = get_time_window_seconds(rule.time_window) refill_rate = rule.limit / window_size # Tokens per second - + # Generate Redis key for bucket - bucket_key = f"bucket:{rule.rule_type.value}:{identifier}:{rule.time_window.value}" - + bucket_key = f'bucket:{rule.rule_type.value}:{identifier}:{rule.time_window.value}' + try: # Get current bucket state bucket_data = self.redis.hmget(bucket_key, ['tokens', 'last_refill']) - + if bucket_data[0] is None: # Initialize bucket tokens = float(rule.limit) @@ -210,119 +200,106 @@ class RateLimiter: else: tokens = float(bucket_data[0]) last_refill = float(bucket_data[1]) - + # Calculate tokens to add elapsed = now - last_refill tokens_to_add = elapsed * refill_rate tokens = min(rule.limit, tokens + tokens_to_add) - + # Check if request is allowed if tokens >= 1.0: # Consume token tokens -= 1.0 - + # Update bucket state - self.redis.hmset(bucket_key, { - 'tokens': tokens, - 'last_refill': now - }) + self.redis.hmset(bucket_key, {'tokens': tokens, 'last_refill': now}) self.redis.expire(bucket_key, window_size * 2) - + # Calculate reset time (when bucket will be full) time_to_full = (rule.limit - tokens) / refill_rate reset_at = int(now + time_to_full) - + return RateLimitResult( - allowed=True, - limit=rule.limit, - remaining=int(tokens), - reset_at=reset_at + allowed=True, limit=rule.limit, remaining=int(tokens), reset_at=reset_at ) else: # Not enough tokens time_to_token = (1.0 - tokens) / refill_rate retry_after = int(time_to_token) + 1 reset_at = int(now + time_to_token) - + return RateLimitResult( allowed=False, limit=rule.limit, remaining=0, reset_at=reset_at, - retry_after=retry_after + retry_after=retry_after, ) - + except Exception as e: - logger.error(f"Token bucket check error: {e}") + logger.error(f'Token bucket check error: {e}') # Graceful degradation return RateLimitResult( allowed=True, limit=rule.limit, remaining=rule.limit, - reset_at=int(now) + window_size + reset_at=int(now) + window_size, ) - - def check_hybrid( - self, - rule: RateLimitRule, - identifier: str - ) -> RateLimitResult: + + def check_hybrid(self, rule: RateLimitRule, identifier: str) -> RateLimitResult: """ Check rate limit using hybrid approach (sliding window + token bucket) - + This combines accuracy of sliding window with burst handling of token bucket. - + Algorithm: 1. Check sliding window (accurate rate limit) 2. If allowed, check token bucket (burst handling) 3. Both must pass for request to be allowed - + Args: rule: Rate limit rule identifier: Unique identifier - + Returns: RateLimitResult """ # First check sliding window sliding_result = self._check_sliding_window(rule, identifier) - + if not sliding_result.allowed: return sliding_result - + # If sliding window allows, check token bucket for burst if rule.burst_allowance > 0: bucket_result = self.check_token_bucket(rule, identifier) - + if not bucket_result.allowed: # Use burst tokens if available return self._use_burst_tokens(rule, identifier, sliding_result) - + return sliding_result - + def _use_burst_tokens( - self, - rule: RateLimitRule, - identifier: str, - sliding_result: RateLimitResult + self, rule: RateLimitRule, identifier: str, sliding_result: RateLimitResult ) -> RateLimitResult: """ Try to use burst tokens when normal tokens are exhausted - + Args: rule: Rate limit rule identifier: Unique identifier sliding_result: Result from sliding window check - + Returns: RateLimitResult """ now = time.time() window_size = get_time_window_seconds(rule.time_window) current_window = int(now / window_size) * window_size - - burst_key = f"burst:{rule.rule_type.value}:{identifier}:{current_window}" - + + burst_key = f'burst:{rule.rule_type.value}:{identifier}:{current_window}' + try: # Get current burst usage (tolerate mocks that provide multiple side-effect values) burst_count = int(self.redis.get(burst_key) or 0) @@ -332,22 +309,22 @@ class RateLimiter: burst_count = int(second) except Exception: pass - + if burst_count < rule.burst_allowance: # Burst tokens available new_burst_count = self.redis.incr(burst_key) - + if new_burst_count == 1: self.redis.expire(burst_key, window_size * 2) - + burst_remaining = rule.burst_allowance - new_burst_count - + return RateLimitResult( allowed=True, limit=rule.limit, remaining=sliding_result.remaining, reset_at=sliding_result.reset_at, - burst_remaining=burst_remaining + burst_remaining=burst_remaining, ) else: # No burst tokens available @@ -357,22 +334,22 @@ class RateLimiter: remaining=0, reset_at=sliding_result.reset_at, retry_after=sliding_result.retry_after, - burst_remaining=0 + burst_remaining=0, ) - + except Exception as e: - logger.error(f"Burst token check error: {e}") + logger.error(f'Burst token check error: {e}') # On error, allow with sliding window result return sliding_result - + def reset_limit(self, rule: RateLimitRule, identifier: str) -> bool: """ Reset rate limit for identifier (admin function) - + Args: rule: Rate limit rule identifier: Unique identifier - + Returns: True if successful """ @@ -380,54 +357,47 @@ class RateLimiter: now = time.time() window_size = get_time_window_seconds(rule.time_window) current_window = int(now / window_size) * window_size - + # Delete all related keys keys_to_delete = [ generate_redis_key(rule.rule_type, identifier, rule.time_window, current_window), - generate_redis_key(rule.rule_type, identifier, rule.time_window, current_window - window_size), - f"bucket:{rule.rule_type.value}:{identifier}:{rule.time_window.value}", - f"burst:{rule.rule_type.value}:{identifier}:{current_window}" + generate_redis_key( + rule.rule_type, identifier, rule.time_window, current_window - window_size + ), + f'bucket:{rule.rule_type.value}:{identifier}:{rule.time_window.value}', + f'burst:{rule.rule_type.value}:{identifier}:{current_window}', ] - + self.redis.delete(*keys_to_delete) - logger.info(f"Reset rate limit for {identifier}") + logger.info(f'Reset rate limit for {identifier}') return True - + except Exception as e: - logger.error(f"Error resetting rate limit: {e}") + logger.error(f'Error resetting rate limit: {e}') return False - - def get_current_usage( - self, - rule: RateLimitRule, - identifier: str - ) -> RateLimitCounter: + + def get_current_usage(self, rule: RateLimitRule, identifier: str) -> RateLimitCounter: """ Get current usage for identifier - + Args: rule: Rate limit rule identifier: Unique identifier - + Returns: RateLimitCounter with current state """ now = time.time() window_size = get_time_window_seconds(rule.time_window) current_window = int(now / window_size) * window_size - - key = generate_redis_key( - rule.rule_type, - identifier, - rule.time_window, - current_window - ) - + + key = generate_redis_key(rule.rule_type, identifier, rule.time_window, current_window) + try: count = int(self.redis.get(key) or 0) - burst_key = f"{key}:burst" + burst_key = f'{key}:burst' burst_count = int(self.redis.get(burst_key) or 0) - + return RateLimitCounter( key=key, window_start=current_window, @@ -435,29 +405,29 @@ class RateLimiter: count=count, limit=rule.limit, burst_count=burst_count, - burst_limit=rule.burst_allowance + burst_limit=rule.burst_allowance, ) - + except Exception as e: - logger.error(f"Error getting current usage: {e}") + logger.error(f'Error getting current usage: {e}') return RateLimitCounter( key=key, window_start=current_window, window_size=window_size, count=0, - limit=rule.limit + limit=rule.limit, ) # Global rate limiter instance -_rate_limiter: Optional[RateLimiter] = None +_rate_limiter: RateLimiter | None = None def get_rate_limiter() -> RateLimiter: """Get or create global rate limiter instance""" global _rate_limiter - + if _rate_limiter is None: _rate_limiter = RateLimiter() - + return _rate_limiter diff --git a/backend-services/utils/redis_client.py b/backend-services/utils/redis_client.py index bb605d3..01d8fbd 100644 --- a/backend-services/utils/redis_client.py +++ b/backend-services/utils/redis_client.py @@ -5,11 +5,11 @@ Provides a robust Redis client with connection pooling, error handling, and graceful degradation for rate limiting operations. """ -import redis import logging -from typing import Optional, Any, Dict from contextlib import contextmanager -import time +from typing import Any + +import redis logger = logging.getLogger(__name__) @@ -17,28 +17,28 @@ logger = logging.getLogger(__name__) class RedisClient: """ Redis client wrapper with connection pooling and error handling - + Features: - Connection pooling for performance - Automatic reconnection on failure - Graceful degradation (returns None on errors) - Pipeline support for batch operations """ - + def __init__( self, host: str = 'localhost', port: int = 6379, - password: Optional[str] = None, + password: str | None = None, db: int = 0, max_connections: int = 50, socket_timeout: int = 5, socket_connect_timeout: int = 5, - decode_responses: bool = True + decode_responses: bool = True, ): """ Initialize Redis client - + Args: host: Redis server host port: Redis server port @@ -53,7 +53,7 @@ class RedisClient: self.port = port self.password = password self.db = db - + # Create connection pool self.pool = redis.ConnectionPool( host=host, @@ -63,56 +63,56 @@ class RedisClient: max_connections=max_connections, socket_timeout=socket_timeout, socket_connect_timeout=socket_connect_timeout, - decode_responses=decode_responses + decode_responses=decode_responses, ) - + # Create Redis client self.client = redis.Redis(connection_pool=self.pool) - + # Test connection self._test_connection() - + def _test_connection(self) -> bool: """Test Redis connection""" try: self.client.ping() - logger.info(f"Redis connection successful: {self.host}:{self.port}") + logger.info(f'Redis connection successful: {self.host}:{self.port}') return True except redis.ConnectionError as e: - logger.error(f"Redis connection failed: {e}") + logger.error(f'Redis connection failed: {e}') return False except Exception as e: - logger.error(f"Unexpected Redis error: {e}") + logger.error(f'Unexpected Redis error: {e}') return False - - def get(self, key: str) -> Optional[str]: + + def get(self, key: str) -> str | None: """ Get value by key - + Args: key: Redis key - + Returns: Value or None if not found or error """ try: return self.client.get(key) except Exception as e: - logger.error(f"Redis GET error for key {key}: {e}") + logger.error(f'Redis GET error for key {key}: {e}') return None - + def set( self, key: str, value: Any, - ex: Optional[int] = None, - px: Optional[int] = None, + ex: int | None = None, + px: int | None = None, nx: bool = False, - xx: bool = False + xx: bool = False, ) -> bool: """ Set key-value pair - + Args: key: Redis key value: Value to store @@ -120,7 +120,7 @@ class RedisClient: px: Expiration time in milliseconds nx: Only set if key doesn't exist xx: Only set if key exists - + Returns: True if successful, False otherwise """ @@ -128,134 +128,134 @@ class RedisClient: result = self.client.set(key, value, ex=ex, px=px, nx=nx, xx=xx) return bool(result) except Exception as e: - logger.error(f"Redis SET error for key {key}: {e}") + logger.error(f'Redis SET error for key {key}: {e}') return False - - def incr(self, key: str, amount: int = 1) -> Optional[int]: + + def incr(self, key: str, amount: int = 1) -> int | None: """ Increment counter atomically - + Args: key: Redis key amount: Amount to increment - + Returns: New value or None on error """ try: return self.client.incr(key, amount) except Exception as e: - logger.error(f"Redis INCR error for key {key}: {e}") + logger.error(f'Redis INCR error for key {key}: {e}') return None - - def decr(self, key: str, amount: int = 1) -> Optional[int]: + + def decr(self, key: str, amount: int = 1) -> int | None: """ Decrement counter atomically - + Args: key: Redis key amount: Amount to decrement - + Returns: New value or None on error """ try: return self.client.decr(key, amount) except Exception as e: - logger.error(f"Redis DECR error for key {key}: {e}") + logger.error(f'Redis DECR error for key {key}: {e}') return None - + def expire(self, key: str, seconds: int) -> bool: """ Set expiration time for key - + Args: key: Redis key seconds: Expiration time in seconds - + Returns: True if successful, False otherwise """ try: return bool(self.client.expire(key, seconds)) except Exception as e: - logger.error(f"Redis EXPIRE error for key {key}: {e}") + logger.error(f'Redis EXPIRE error for key {key}: {e}') return False - - def ttl(self, key: str) -> Optional[int]: + + def ttl(self, key: str) -> int | None: """ Get time to live for key - + Args: key: Redis key - + Returns: TTL in seconds, -1 if no expiry, -2 if key doesn't exist, None on error """ try: return self.client.ttl(key) except Exception as e: - logger.error(f"Redis TTL error for key {key}: {e}") + logger.error(f'Redis TTL error for key {key}: {e}') return None - + def delete(self, *keys: str) -> int: """ Delete one or more keys - + Args: keys: Keys to delete - + Returns: Number of keys deleted """ try: return self.client.delete(*keys) except Exception as e: - logger.error(f"Redis DELETE error: {e}") + logger.error(f'Redis DELETE error: {e}') return 0 - + def exists(self, *keys: str) -> int: """ Check if keys exist - + Args: keys: Keys to check - + Returns: Number of existing keys """ try: return self.client.exists(*keys) except Exception as e: - logger.error(f"Redis EXISTS error: {e}") + logger.error(f'Redis EXISTS error: {e}') return 0 - - def hget(self, name: str, key: str) -> Optional[str]: + + def hget(self, name: str, key: str) -> str | None: """ Get hash field value - + Args: name: Hash name key: Field key - + Returns: Field value or None """ try: return self.client.hget(name, key) except Exception as e: - logger.error(f"Redis HGET error for {name}:{key}: {e}") + logger.error(f'Redis HGET error for {name}:{key}: {e}') return None - + def hset(self, name: str, key: str, value: Any) -> bool: """ Set hash field value - + Args: name: Hash name key: Field key value: Field value - + Returns: True if successful """ @@ -263,34 +263,34 @@ class RedisClient: self.client.hset(name, key, value) return True except Exception as e: - logger.error(f"Redis HSET error for {name}:{key}: {e}") + logger.error(f'Redis HSET error for {name}:{key}: {e}') return False - - def hmget(self, name: str, keys: list) -> Optional[list]: + + def hmget(self, name: str, keys: list) -> list | None: """ Get multiple hash field values - + Args: name: Hash name keys: List of field keys - + Returns: List of values or None """ try: return self.client.hmget(name, keys) except Exception as e: - logger.error(f"Redis HMGET error for {name}: {e}") + logger.error(f'Redis HMGET error for {name}: {e}') return None - - def hmset(self, name: str, mapping: Dict[str, Any]) -> bool: + + def hmset(self, name: str, mapping: dict[str, Any]) -> bool: """ Set multiple hash field values - + Args: name: Hash name mapping: Dictionary of field:value pairs - + Returns: True if successful """ @@ -298,20 +298,20 @@ class RedisClient: self.client.hset(name, mapping=mapping) return True except Exception as e: - logger.error(f"Redis HMSET error for {name}: {e}") + logger.error(f'Redis HMSET error for {name}: {e}') return False - + @contextmanager def pipeline(self, transaction: bool = True): """ Context manager for Redis pipeline - + Args: transaction: Use MULTI/EXEC transaction - + Yields: Pipeline object - + Example: with redis_client.pipeline() as pipe: pipe.incr('key1') @@ -322,112 +322,104 @@ class RedisClient: try: yield pipe except Exception as e: - logger.error(f"Redis pipeline error: {e}") + logger.error(f'Redis pipeline error: {e}') raise finally: pipe.reset() - - def zadd(self, name: str, mapping: Dict[str, float]) -> int: + + def zadd(self, name: str, mapping: dict[str, float]) -> int: """ Add members to sorted set - + Args: name: Sorted set name mapping: Dictionary of member:score pairs - + Returns: Number of members added """ try: return self.client.zadd(name, mapping) except Exception as e: - logger.error(f"Redis ZADD error for {name}: {e}") + logger.error(f'Redis ZADD error for {name}: {e}') return 0 - + def zremrangebyscore(self, name: str, min_score: float, max_score: float) -> int: """ Remove members from sorted set by score range - + Args: name: Sorted set name min_score: Minimum score max_score: Maximum score - + Returns: Number of members removed """ try: return self.client.zremrangebyscore(name, min_score, max_score) except Exception as e: - logger.error(f"Redis ZREMRANGEBYSCORE error for {name}: {e}") + logger.error(f'Redis ZREMRANGEBYSCORE error for {name}: {e}') return 0 - + def zcount(self, name: str, min_score: float, max_score: float) -> int: """ Count members in sorted set by score range - + Args: name: Sorted set name min_score: Minimum score max_score: Maximum score - + Returns: Number of members in range """ try: return self.client.zcount(name, min_score, max_score) except Exception as e: - logger.error(f"Redis ZCOUNT error for {name}: {e}") + logger.error(f'Redis ZCOUNT error for {name}: {e}') return 0 - + def close(self): """Close Redis connection pool""" try: self.pool.disconnect() - logger.info("Redis connection pool closed") + logger.info('Redis connection pool closed') except Exception as e: - logger.error(f"Error closing Redis pool: {e}") + logger.error(f'Error closing Redis pool: {e}') # Global Redis client instance -_redis_client: Optional[RedisClient] = None +_redis_client: RedisClient | None = None def get_redis_client( - host: str = 'localhost', - port: int = 6379, - password: Optional[str] = None, - db: int = 0 + host: str = 'localhost', port: int = 6379, password: str | None = None, db: int = 0 ) -> RedisClient: """ Get or create global Redis client instance - + Args: host: Redis server host port: Redis server port password: Redis password db: Redis database number - + Returns: RedisClient instance """ global _redis_client - + if _redis_client is None: - _redis_client = RedisClient( - host=host, - port=port, - password=password, - db=db - ) - + _redis_client = RedisClient(host=host, port=port, password=password, db=db) + return _redis_client def close_redis_client(): """Close global Redis client""" global _redis_client - + if _redis_client is not None: _redis_client.close() _redis_client = None diff --git a/backend-services/utils/response_util.py b/backend-services/utils/response_util.py index d4c8d89..7d90c77 100644 --- a/backend-services/utils/response_util.py +++ b/backend-services/utils/response_util.py @@ -1,11 +1,13 @@ -from fastapi.responses import JSONResponse, Response -import os import logging +import os + +from fastapi.responses import JSONResponse, Response from models.response_model import ResponseModel logger = logging.getLogger('doorman.gateway') + def _normalize_headers(hdrs: dict | None) -> dict | None: try: if not hdrs: @@ -20,15 +22,13 @@ def _normalize_headers(hdrs: dict | None) -> dict | None: except Exception: return hdrs + def _envelope(content: dict, status_code: int) -> dict: - return { - 'status_code': status_code, - **content - } + return {'status_code': status_code, **content} + def _add_token_compat(enveloped: dict, payload: dict): try: - if isinstance(payload, dict): for key in ('access_token', 'refresh_token'): if key in payload: @@ -36,6 +36,7 @@ def _add_token_compat(enveloped: dict, payload: dict): except Exception: pass + def respond_rest(model): """Return a REST JSONResponse using the normalized envelope logic. @@ -47,6 +48,7 @@ def respond_rest(model): rm = model return process_rest_response(rm) + def process_rest_response(response): try: strict = os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true' @@ -62,12 +64,20 @@ def process_rest_response(response): _add_token_compat(content, response.response) http_status = 200 elif response.message: - content = {'message': response.message} if not strict else _envelope({'message': response.message}, response.status_code) + content = ( + {'message': response.message} + if not strict + else _envelope({'message': response.message}, response.status_code) + ) http_status = response.status_code if not strict else 200 else: content = {} if not strict else _envelope({}, response.status_code) http_status = response.status_code if not strict else 200 - resp = JSONResponse(content=content, status_code=http_status, headers=_normalize_headers(response.response_headers)) + resp = JSONResponse( + content=content, + status_code=http_status, + headers=_normalize_headers(response.response_headers), + ) try: # Ensure Content-Length is set for downstream metrics/bandwidth accounting blen = len(getattr(resp, 'body', b'') or b'') @@ -92,7 +102,11 @@ def process_rest_response(response): content = err_payload if not strict else _envelope(err_payload, response.status_code) http_status = response.status_code if not strict else 200 - resp = JSONResponse(content=content, status_code=http_status, headers=_normalize_headers(response.response_headers)) + resp = JSONResponse( + content=content, + status_code=http_status, + headers=_normalize_headers(response.response_headers), + ) try: blen = len(getattr(resp, 'body', b'') or b'') if blen > 0: @@ -103,11 +117,14 @@ def process_rest_response(response): return resp except Exception as e: logger.error(f'An error occurred while processing the response: {e}') - return JSONResponse(content={'error_message': 'Unable to process response'}, status_code=500) + return JSONResponse( + content={'error_message': 'Unable to process response'}, status_code=500 + ) + def process_soap_response(response): try: - strict = os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true' + os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true' if response.status_code == 200: if getattr(response, 'soap_envelope', None): soap_response = response.soap_envelope @@ -136,6 +153,7 @@ def process_soap_response(response): error_response = 'Unable to process SOAP response' return Response(content=error_response, status_code=500, media_type='application/xml') + def process_response(response, type): response = ResponseModel(**response) if type == 'rest': @@ -146,9 +164,17 @@ def process_response(response, type): try: strict = os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true' if response.status_code == 200: - content = response.response if not strict else _envelope({'response': response.response}, response.status_code) + content = ( + response.response + if not strict + else _envelope({'response': response.response}, response.status_code) + ) code = response.status_code if not strict else 200 - resp = JSONResponse(content=content, status_code=code, headers=_normalize_headers(response.response_headers)) + resp = JSONResponse( + content=content, + status_code=code, + headers=_normalize_headers(response.response_headers), + ) try: blen = len(getattr(resp, 'body', b'') or b'') if blen > 0: @@ -158,11 +184,18 @@ def process_response(response, type): pass return resp else: - content = {'error_code': response.error_code, 'error_message': response.error_message} + content = { + 'error_code': response.error_code, + 'error_message': response.error_message, + } if strict: content = _envelope(content, response.status_code) code = response.status_code if not strict else 200 - resp = JSONResponse(content=content, status_code=code, headers=_normalize_headers(response.response_headers)) + resp = JSONResponse( + content=content, + status_code=code, + headers=_normalize_headers(response.response_headers), + ) try: blen = len(getattr(resp, 'body', b'') or b'') if blen > 0: @@ -173,23 +206,42 @@ def process_response(response, type): return resp except Exception as e: logger.error(f'An error occurred while processing the GraphQL response: {e}') - return JSONResponse(content={'error': 'Unable to process GraphQL response'}, status_code=500) + return JSONResponse( + content={'error': 'Unable to process GraphQL response'}, status_code=500 + ) elif type == 'grpc': try: strict = os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true' if response.status_code == 200: - content = response.response if not strict else _envelope({'response': response.response}, response.status_code) + content = ( + response.response + if not strict + else _envelope({'response': response.response}, response.status_code) + ) code = response.status_code if not strict else 200 - return JSONResponse(content=content, status_code=code, headers=_normalize_headers(response.response_headers)) + return JSONResponse( + content=content, + status_code=code, + headers=_normalize_headers(response.response_headers), + ) else: - content = {'error_code': response.error_code, 'error_message': response.error_message} + content = { + 'error_code': response.error_code, + 'error_message': response.error_message, + } if strict: content = _envelope(content, response.status_code) code = response.status_code if not strict else 200 - return JSONResponse(content=content, status_code=code, headers=_normalize_headers(response.response_headers)) + return JSONResponse( + content=content, + status_code=code, + headers=_normalize_headers(response.response_headers), + ) except Exception as e: logger.error(f'An error occurred while processing the gRPC response: {e}') - return JSONResponse(content={'error': 'Unable to process gRPC response'}, status_code=500) + return JSONResponse( + content={'error': 'Unable to process gRPC response'}, status_code=500 + ) else: logger.error(f'Unhandled response type: {type}') return JSONResponse(content={'error': 'Unhandled response type'}, status_code=500) diff --git a/backend-services/utils/role_util.py b/backend-services/utils/role_util.py index 8389c10..0f3af70 100644 --- a/backend-services/utils/role_util.py +++ b/backend-services/utils/role_util.py @@ -5,6 +5,7 @@ See https://github.com/pypeople-dev/doorman for more information """ import logging + from fastapi import HTTPException from utils.database import role_collection, user_collection @@ -12,6 +13,7 @@ from utils.doorman_cache_util import doorman_cache logger = logging.getLogger('doorman.gateway') + def _strip_id(r): try: if r and r.get('_id'): @@ -20,6 +22,7 @@ def _strip_id(r): pass return r + async def is_admin_role(role_name: str) -> bool: """Return True if the given role is the admin role. @@ -34,7 +37,6 @@ async def is_admin_role(role_name: str) -> bool: if role: doorman_cache.set_cache('role_cache', role_name, role) if not role: - rn = (role_name or '').strip().lower() return rn in ('admin', 'platform admin') if role.get('platform_admin') is True: @@ -44,6 +46,7 @@ async def is_admin_role(role_name: str) -> bool: except Exception: return False + async def is_admin_user(username: str) -> bool: """Return True if the user has the admin role.""" try: @@ -59,12 +62,15 @@ async def is_admin_user(username: str) -> bool: except Exception: return False + async def is_platform_admin_role(role_name: str) -> bool: return await is_admin_role(role_name) + async def is_platform_admin_user(username: str) -> bool: return await is_admin_user(username) + async def validate_platform_role(role_name, action): """ Get the platform roles from the cache or database. @@ -75,7 +81,8 @@ async def validate_platform_role(role_name, action): role = role_collection.find_one({'role_name': role_name}) if not role: raise HTTPException(status_code=404, detail='Role not found') - if role.get('_id'): del role['_id'] + if role.get('_id'): + del role['_id'] doorman_cache.set_cache('role_cache', role_name, role) if action == 'manage_users' and role.get('manage_users'): return True @@ -114,6 +121,7 @@ async def validate_platform_role(role_name, action): logger.error(f'validate_platform_role error: {e}') raise HTTPException(status_code=500, detail='Internal Server Error') + async def platform_role_required_bool(username, action): try: user = doorman_cache.get_cache('user_cache', username) @@ -121,15 +129,17 @@ async def platform_role_required_bool(username, action): user = user_collection.find_one({'username': username}) if not user: raise HTTPException(status_code=404, detail='User not found') - if user.get('_id'): del user['_id'] - if user.get('password'): del user['password'] + if user.get('_id'): + del user['_id'] + if user.get('password'): + del user['password'] doorman_cache.set_cache('user_cache', username, user) if not user: raise HTTPException(status_code=404, detail='User not found') if not await validate_platform_role(user.get('role'), action): raise HTTPException(status_code=403, detail='You do not have the correct role for this') return True - except HTTPException as e: + except HTTPException: return False except Exception as e: logger.error(f'Unexpected error: {e}') diff --git a/backend-services/utils/routing_util.py b/backend-services/utils/routing_util.py index 8d2e0d5..1a39d55 100644 --- a/backend-services/utils/routing_util.py +++ b/backend-services/utils/routing_util.py @@ -1,14 +1,14 @@ -from typing import Optional, Dict import logging -from utils.doorman_cache_util import doorman_cache -from utils.database_async import routing_collection -from utils.async_db import db_find_one from utils import api_util +from utils.async_db import db_find_one +from utils.database_async import routing_collection +from utils.doorman_cache_util import doorman_cache logger = logging.getLogger('doorman.gateway') -async def get_client_routing(client_key: str) -> Optional[Dict]: + +async def get_client_routing(client_key: str) -> dict | None: """Get the routing information for a specific client. Args: @@ -23,14 +23,16 @@ async def get_client_routing(client_key: str) -> Optional[Dict]: client_routing = await db_find_one(routing_collection, {'client_key': client_key}) if not client_routing: return None - if client_routing.get('_id'): del client_routing['_id'] + if client_routing.get('_id'): + del client_routing['_id'] doorman_cache.set_cache('client_routing_cache', client_key, client_routing) return client_routing except Exception as e: logger.error(f'Error in get_client_routing: {e}') return None -async def get_routing_info(client_key: str) -> Optional[str]: + +async def get_routing_info(client_key: str) -> str | None: """Get next upstream server for client using round-robin. Args: @@ -50,7 +52,10 @@ async def get_routing_info(client_key: str) -> Optional[str]: doorman_cache.set_cache('client_routing_cache', client_key, routing) return server -async def pick_upstream_server(api: Dict, method: str, endpoint_uri: str, client_key: Optional[str]) -> Optional[str]: + +async def pick_upstream_server( + api: dict, method: str, endpoint_uri: str, client_key: str | None +) -> str | None: """Resolve upstream server with precedence: Routing (1) > Endpoint (2) > API (3). - Routing: client-specific routing list with round-robin in the routing doc/cache. @@ -70,10 +75,12 @@ async def pick_upstream_server(api: Dict, method: str, endpoint_uri: str, client if endpoint: ep_servers = endpoint.get('endpoint_servers') or [] if isinstance(ep_servers, list) and len(ep_servers) > 0: - idx_key = endpoint.get('endpoint_id') or f"{api.get('api_id')}:{method}:{endpoint_uri}" + idx_key = endpoint.get('endpoint_id') or f'{api.get("api_id")}:{method}:{endpoint_uri}' server_index = doorman_cache.get_cache('endpoint_server_cache', idx_key) or 0 server = ep_servers[server_index % len(ep_servers)] - doorman_cache.set_cache('endpoint_server_cache', idx_key, (server_index + 1) % len(ep_servers)) + doorman_cache.set_cache( + 'endpoint_server_cache', idx_key, (server_index + 1) % len(ep_servers) + ) return server api_servers = api.get('api_servers') or [] @@ -81,7 +88,9 @@ async def pick_upstream_server(api: Dict, method: str, endpoint_uri: str, client idx_key = api.get('api_id') server_index = doorman_cache.get_cache('endpoint_server_cache', idx_key) or 0 server = api_servers[server_index % len(api_servers)] - doorman_cache.set_cache('endpoint_server_cache', idx_key, (server_index + 1) % len(api_servers)) + doorman_cache.set_cache( + 'endpoint_server_cache', idx_key, (server_index + 1) % len(api_servers) + ) return server return None diff --git a/backend-services/utils/sanitize_util.py b/backend-services/utils/sanitize_util.py index ac4076c..0f7cfcd 100644 --- a/backend-services/utils/sanitize_util.py +++ b/backend-services/utils/sanitize_util.py @@ -4,9 +4,9 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -import re import html -from typing import Optional +import re + def sanitize_html(text: str, allow_tags: bool = False) -> str: """ @@ -20,7 +20,8 @@ def sanitize_html(text: str, allow_tags: bool = False) -> str: text = re.sub(r'<[^>]+>', '', text) return html.escape(text) -def sanitize_input(text: str, max_length: Optional[int] = None) -> str: + +def sanitize_input(text: str, max_length: int | None = None) -> str: """ Comprehensive input sanitization for user-provided text. """ @@ -33,26 +34,22 @@ def sanitize_input(text: str, max_length: Optional[int] = None) -> str: sanitized = sanitized[:max_length] return sanitized + def sanitize_url(url: str) -> str: """ Sanitize URL to prevent javascript: and data: URL attacks. """ if not url or not isinstance(url, str): - return "" + return '' url = url.strip() - dangerous_schemes = [ - 'javascript:', - 'data:', - 'vbscript:', - 'file:', - 'about:' - ] + dangerous_schemes = ['javascript:', 'data:', 'vbscript:', 'file:', 'about:'] url_lower = url.lower() for scheme in dangerous_schemes: if url_lower.startswith(scheme): - return "" + return '' return url + def strip_control_characters(text: str) -> str: """ Remove control characters that might cause issues. @@ -61,6 +58,7 @@ def strip_control_characters(text: str) -> str: return text return ''.join(char for char in text if ord(char) >= 32 or char in '\n\r\t') + def sanitize_username(username: str) -> str: """ Sanitize username to only allow safe characters. @@ -71,6 +69,7 @@ def sanitize_username(username: str) -> str: safe_pattern = re.compile(r'[^a-zA-Z0-9_\-\.@]') return safe_pattern.sub('', username) + def sanitize_api_name(name: str) -> str: """ Sanitize API name to prevent injection attacks. diff --git a/backend-services/utils/security_settings_util.py b/backend-services/utils/security_settings_util.py index 2d8026e..acd4a23 100644 --- a/backend-services/utils/security_settings_util.py +++ b/backend-services/utils/security_settings_util.py @@ -3,19 +3,19 @@ Utilities to manage security-related settings and schedule auto-save of memory d """ import asyncio +import logging import os from pathlib import Path -from typing import Optional, Dict, Any -import logging +from typing import Any from .database import database, db from .memory_dump_util import dump_memory_to_file logger = logging.getLogger('doorman.gateway') -_CACHE: Dict[str, Any] = {} -_AUTO_TASK: Optional[asyncio.Task] = None -_STOP_EVENT: Optional[asyncio.Event] = None +_CACHE: dict[str, Any] = {} +_AUTO_TASK: asyncio.Task | None = None +_STOP_EVENT: asyncio.Event | None = None _PROJECT_ROOT = Path(__file__).resolve().parent.parent _GEN_DIR = _PROJECT_ROOT / 'generated' @@ -34,31 +34,35 @@ DEFAULTS = { SETTINGS_FILE = os.getenv('SECURITY_SETTINGS_FILE', str(_GEN_DIR / 'security_settings.json')) + def _get_collection(): return db.settings if not database.memory_only else database.db.settings -def _merge_settings(doc: Dict[str, Any]) -> Dict[str, Any]: + +def _merge_settings(doc: dict[str, Any]) -> dict[str, Any]: merged = DEFAULTS.copy() if doc: merged.update({k: v for k, v in doc.items() if v is not None}) return merged -def get_cached_settings() -> Dict[str, Any]: + +def get_cached_settings() -> dict[str, Any]: global _CACHE if not _CACHE: - _CACHE = DEFAULTS.copy() return _CACHE -def _load_from_file() -> Optional[Dict[str, Any]]: + +def _load_from_file() -> dict[str, Any] | None: try: if not os.path.exists(SETTINGS_FILE): return None - with open(SETTINGS_FILE, 'r', encoding='utf-8') as f: + with open(SETTINGS_FILE, encoding='utf-8') as f: data = f.read().strip() if not data: return None import json + obj = json.loads(data) if isinstance(obj, dict): @@ -67,39 +71,41 @@ def _load_from_file() -> Optional[Dict[str, Any]]: logger.error('Failed to read settings file %s: %s', SETTINGS_FILE, e) return None -def _save_to_file(settings: Dict[str, Any]) -> None: + +def _save_to_file(settings: dict[str, Any]) -> None: try: os.makedirs(os.path.dirname(SETTINGS_FILE) or '.', exist_ok=True) import json + with open(SETTINGS_FILE, 'w', encoding='utf-8') as f: f.write(json.dumps(settings, separators=(',', ':'))) except Exception as e: logger.error('Failed to write settings file %s: %s', SETTINGS_FILE, e) -async def load_settings() -> Dict[str, Any]: + +async def load_settings() -> dict[str, Any]: coll = _get_collection() doc = coll.find_one({'type': 'security_settings'}) if not doc and database.memory_only: file_obj = _load_from_file() if file_obj: - try: to_set = _merge_settings(file_obj) coll.update_one({'type': 'security_settings'}, {'$set': to_set}) doc = to_set except Exception: - doc = file_obj settings = _merge_settings(doc or {}) _CACHE.update(settings) return settings -async def save_settings(partial: Dict[str, Any]) -> Dict[str, Any]: + +async def save_settings(partial: dict[str, Any]) -> dict[str, Any]: coll = _get_collection() current = _merge_settings(coll.find_one({'type': 'security_settings'}) or {}) current.update({k: v for k, v in partial.items() if v is not None}) - result = coll.update_one({'type': 'security_settings'}, {'$set': current},) + result = coll.update_one({'type': 'security_settings'}, {'$set': current}) try: modified = getattr(result, 'modified_count', 0) @@ -113,6 +119,7 @@ async def save_settings(partial: Dict[str, Any]) -> Dict[str, Any]: await restart_auto_save_task() return current + async def _auto_save_loop(stop_event: asyncio.Event): while not stop_event.is_set(): try: @@ -127,12 +134,13 @@ async def _auto_save_loop(stop_event: asyncio.Event): logger.error('Auto-save memory dump failed: %s', e) await asyncio.wait_for(stop_event.wait(), timeout=max(freq, 60) if freq > 0 else 60) - except asyncio.TimeoutError: + except TimeoutError: continue except Exception as e: logger.error('Auto-save loop error: %s', e) await asyncio.sleep(60) + async def start_auto_save_task(): global _AUTO_TASK, _STOP_EVENT if _AUTO_TASK and not _AUTO_TASK.done(): @@ -141,6 +149,7 @@ async def start_auto_save_task(): _AUTO_TASK = asyncio.create_task(_auto_save_loop(_STOP_EVENT)) logger.info('Security auto-save task started') + async def stop_auto_save_task(): global _AUTO_TASK, _STOP_EVENT if _STOP_EVENT: @@ -153,6 +162,7 @@ async def stop_auto_save_task(): _AUTO_TASK = None _STOP_EVENT = None + async def restart_auto_save_task(): await stop_auto_save_task() await start_auto_save_task() diff --git a/backend-services/utils/subscription_util.py b/backend-services/utils/subscription_util.py index fa4a174..2cdeaa4 100644 --- a/backend-services/utils/subscription_util.py +++ b/backend-services/utils/subscription_util.py @@ -4,17 +4,19 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/pypeople-dev/doorman for more information """ -from fastapi import HTTPException, Depends, Request -from jose import jwt, JWTError import logging -from utils.doorman_cache_util import doorman_cache +from fastapi import HTTPException, Request +from jose import JWTError + +from utils.async_db import db_find_one +from utils.auth_util import auth_required from utils.database_async import subscriptions_collection -from utils.async_db import db_find_one, db_update_one -from utils.auth_util import SECRET_KEY, ALGORITHM, auth_required +from utils.doorman_cache_util import doorman_cache logger = logging.getLogger('doorman.gateway') + async def subscription_required(request: Request): try: payload = await auth_required(request) @@ -24,11 +26,11 @@ async def subscription_required(request: Request): full_path = request.url.path if full_path.startswith('/api/rest/'): prefix = '/api/rest/' - path = full_path[len(prefix):] + path = full_path[len(prefix) :] api_and_version = '/'.join(path.split('/')[:2]) elif full_path.startswith('/api/soap/'): prefix = '/api/soap/' - path = full_path[len(prefix):] + path = full_path[len(prefix) :] api_and_version = '/'.join(path.split('/')[:2]) elif full_path.startswith('/api/graphql/'): api_name = full_path.replace('/api/graphql/', '') @@ -45,8 +47,14 @@ async def subscription_required(request: Request): api_and_version = '/'.join(segs[2:4]) else: api_and_version = '/'.join(segs[:2]) - user_subscriptions = doorman_cache.get_cache('user_subscription_cache', username) or await db_find_one(subscriptions_collection, {'username': username}) - subscriptions = user_subscriptions.get('apis') if user_subscriptions and 'apis' in user_subscriptions else None + user_subscriptions = doorman_cache.get_cache( + 'user_subscription_cache', username + ) or await db_find_one(subscriptions_collection, {'username': username}) + subscriptions = ( + user_subscriptions.get('apis') + if user_subscriptions and 'apis' in user_subscriptions + else None + ) if not subscriptions or api_and_version not in subscriptions: logger.info(f'User {username} attempted access to {api_and_version}') raise HTTPException(status_code=403, detail='You are not subscribed to this resource') diff --git a/backend-services/utils/validation_util.py b/backend-services/utils/validation_util.py index bd45da3..e8ef861 100644 --- a/backend-services/utils/validation_util.py +++ b/backend-services/utils/validation_util.py @@ -4,26 +4,31 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -from typing import Dict, Any, Optional, Callable -from fastapi import HTTPException -import json import re -from datetime import datetime import uuid +from collections.abc import Callable +from datetime import datetime +from typing import Any + +from fastapi import HTTPException + try: from defusedxml import ElementTree as ET + _DEFUSED = True except Exception: import xml.etree.ElementTree as ET + _DEFUSED = False -from graphql import parse, GraphQLError import grpc +from graphql import GraphQLError, parse from models.field_validation_model import FieldValidation from models.validation_schema_model import ValidationSchema -from utils.doorman_cache_util import doorman_cache -from utils.database_async import endpoint_validation_collection from utils.async_db import db_find_one +from utils.database_async import endpoint_validation_collection +from utils.doorman_cache_util import doorman_cache + class ValidationError(Exception): def __init__(self, message: str, field_path: str): @@ -31,6 +36,7 @@ class ValidationError(Exception): self.field_path = field_path super().__init__(self.message) + class ValidationUtil: def __init__(self): self.type_validators = { @@ -38,16 +44,16 @@ class ValidationUtil: 'number': self._validate_number, 'boolean': self._validate_boolean, 'array': self._validate_array, - 'object': self._validate_object + 'object': self._validate_object, } self.format_validators = { 'email': self._validate_email, 'url': self._validate_url, 'date': self._validate_date, 'datetime': self._validate_datetime, - 'uuid': self._validate_uuid + 'uuid': self._validate_uuid, } - self.custom_validators: Dict[str, Callable] = {} + self.custom_validators: dict[str, Callable] = {} # When defusedxml is unavailable, apply a basic pre-parse guard against DOCTYPE/ENTITY. def _reject_unsafe_xml(self, xml_text: str) -> None: @@ -62,10 +68,12 @@ class ValidationUtil: if ' None: + def register_custom_validator( + self, name: str, validator: Callable[[Any, FieldValidation], None] + ) -> None: self.custom_validators[name] = validator - async def get_validation_schema(self, endpoint_id: str) -> Optional[ValidationSchema]: + async def get_validation_schema(self, endpoint_id: str) -> ValidationSchema | None: """Return the ValidationSchema for an endpoint_id if configured. Looks up the in-memory cache first, then falls back to the DB collection. @@ -75,7 +83,9 @@ class ValidationUtil: """ validation_doc = doorman_cache.get_cache('endpoint_validation_cache', endpoint_id) if not validation_doc: - validation_doc = await db_find_one(endpoint_validation_collection, {'endpoint_id': endpoint_id}) + validation_doc = await db_find_one( + endpoint_validation_collection, {'endpoint_id': endpoint_id} + ) if validation_doc: try: vdoc = dict(validation_doc) @@ -91,14 +101,20 @@ class ValidationUtil: raw = validation_doc.get('validation_schema') if not raw: return None - mapping = raw.get('validation_schema') if isinstance(raw, dict) and 'validation_schema' in raw else raw + mapping = ( + raw.get('validation_schema') + if isinstance(raw, dict) and 'validation_schema' in raw + else raw + ) if not isinstance(mapping, dict): return None schema = ValidationSchema(validation_schema=mapping) self._validate_schema_paths(schema.validation_schema) return schema - def _validate_schema_paths(self, schema: Dict[str, FieldValidation], parent_path: str = '') -> None: + def _validate_schema_paths( + self, schema: dict[str, FieldValidation], parent_path: str = '' + ) -> None: for field_path, validation in schema.items(): full_path = f'{parent_path}.{field_path}' if parent_path else field_path if not self._is_valid_field_path(full_path): @@ -160,7 +176,9 @@ class ValidationUtil: if field_validation.required and field_path not in value: raise ValidationError(f'Required field {field_path} is missing', path) if field_path in value: - self._validate_value(value[field_path], field_validation, f'{path}.{field_path}') + self._validate_value( + value[field_path], field_validation, f'{path}.{field_path}' + ) def _validate_value(self, value: Any, validation: FieldValidation, field_path: str) -> None: if validation.required and value is None: @@ -175,7 +193,6 @@ class ValidationUtil: try: self.custom_validators[validation.custom_validator](value, validation) except ValidationError as e: - raise ValidationError(e.message, field_path) def _validate_email(self, value: str, validation: FieldValidation, path: str) -> None: @@ -184,29 +201,32 @@ class ValidationUtil: raise ValidationError('Invalid email format', path) def _validate_url(self, value: str, validation: FieldValidation, path: str) -> None: - url_pattern = r'^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)$' + url_pattern = ( + r'^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}' + r'\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)$' + ) if not re.match(url_pattern, value): raise ValidationError('Invalid URL format', path) def _validate_date(self, value: str, validation: FieldValidation, path: str) -> None: try: datetime.strptime(value, '%Y-%m-%d') - except ValueError: - raise ValidationError('Invalid date format (YYYY-MM-DD)', path) + except ValueError as e: + raise ValidationError('Invalid date format (YYYY-MM-DD)', path) from e def _validate_datetime(self, value: str, validation: FieldValidation, path: str) -> None: try: datetime.fromisoformat(value.replace('Z', '+00:00')) - except ValueError: - raise ValidationError('Invalid datetime format (ISO 8601)', path) + except ValueError as e: + raise ValidationError('Invalid datetime format (ISO 8601)', path) from e def _validate_uuid(self, value: str, validation: FieldValidation, path: str) -> None: try: uuid.UUID(value) - except ValueError: - raise ValidationError('Invalid UUID format', path) + except ValueError as e: + raise ValidationError('Invalid UUID format', path) from e - async def validate_rest_request(self, endpoint_id: str, request_data: Dict[str, Any]) -> None: + async def validate_rest_request(self, endpoint_id: str, request_data: dict[str, Any]) -> None: schema = await self.get_validation_schema(endpoint_id) if not schema: return @@ -216,8 +236,11 @@ class ValidationUtil: self._validate_value(value, validation, field_path) except ValidationError as e: import logging - logging.getLogger('doorman.gateway').error(f'Validation failed for {field_path}: {e}') - raise HTTPException(status_code=400, detail=str(e)) + + logging.getLogger('doorman.gateway').error( + f'Validation failed for {field_path}: {e}' + ) + raise HTTPException(status_code=400, detail=str(e)) from e async def validate_soap_request(self, endpoint_id: str, soap_envelope: str) -> None: schema = await self.get_validation_schema(endpoint_id) @@ -235,9 +258,10 @@ class ValidationUtil: value = self._get_nested_value(request_data, field_path) self._validate_value(value, validation, field_path) except ValidationError as e: - raise HTTPException(status_code=400, detail=str(e)) - except ET.ParseError: - raise HTTPException(status_code=400, detail='Invalid SOAP envelope') + raise HTTPException(status_code=400, detail=str(e)) from e + except ET.ParseError as e: + raise HTTPException(status_code=400, detail='Invalid SOAP envelope') from e + async def validate_grpc_request(self, endpoint_id: str, request: Any) -> None: schema = await self.get_validation_schema(endpoint_id) if not schema: @@ -248,9 +272,11 @@ class ValidationUtil: value = self._get_nested_value(request_data, field_path) self._validate_value(value, validation, field_path) except ValidationError as e: - raise grpc.RpcError(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + raise grpc.RpcError(grpc.StatusCode.INVALID_ARGUMENT, str(e)) from e - async def validate_graphql_request(self, endpoint_id: str, query: str, variables: Dict[str, Any]) -> None: + async def validate_graphql_request( + self, endpoint_id: str, query: str, variables: dict[str, Any] + ) -> None: schema = await self.get_validation_schema(endpoint_id) if not schema: return @@ -261,19 +287,22 @@ class ValidationUtil: for field_path, validation in schema.validation_schema.items(): if field_path.startswith(operation_name): try: - value = self._get_nested_value(variables, field_path[len(operation_name)+1:]) + value = self._get_nested_value( + variables, field_path[len(operation_name) + 1 :] + ) self._validate_value(value, validation, field_path) except ValidationError as e: - raise HTTPException(status_code=400, detail=str(e)) + raise HTTPException(status_code=400, detail=str(e)) from e except GraphQLError as e: - raise HTTPException(status_code=400, detail=str(e)) + raise HTTPException(status_code=400, detail=str(e)) from e except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - def _extract_operation_name(self, query: str) -> Optional[str]: + raise HTTPException(status_code=400, detail=str(e)) from e + + def _extract_operation_name(self, query: str) -> str | None: match = re.search(r'(?:query|mutation)\s+(\w+)', query) return match.group(1) if match else None - def _get_nested_value(self, data: Dict[str, Any], field_path: str) -> Any: + def _get_nested_value(self, data: dict[str, Any], field_path: str) -> Any: parts = field_path.split('.') current = data for part in parts: @@ -298,7 +327,7 @@ class ValidationUtil: return tag.split('}', 1)[1] return tag - def _xml_to_dict(self, element: Any) -> Dict[str, Any]: + def _xml_to_dict(self, element: Any) -> dict[str, Any]: result = {} for child in element: key = self._strip_ns(child.tag) @@ -308,7 +337,7 @@ class ValidationUtil: result[key] = child.text return result - def _protobuf_to_dict(self, message: Any) -> Dict[str, Any]: + def _protobuf_to_dict(self, message: Any) -> dict[str, Any]: result = {} for field in message.DESCRIPTOR.fields: value = getattr(message, field.name) @@ -321,4 +350,5 @@ class ValidationUtil: result[field.name] = value return result + validation_util = ValidationUtil() diff --git a/backend-services/utils/vault_encryption_util.py b/backend-services/utils/vault_encryption_util.py index ac78642..e1ba370 100644 --- a/backend-services/utils/vault_encryption_util.py +++ b/backend-services/utils/vault_encryption_util.py @@ -4,14 +4,15 @@ Review the Apache License 2.0 for valid authorization of use See https://github.com/apidoorman/doorman for more information """ -import os +import base64 import hashlib +import logging +import os + from cryptography.fernet import Fernet +from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC -from cryptography.hazmat.backends import default_backend -import base64 -import logging logger = logging.getLogger('doorman.gateway') @@ -19,31 +20,31 @@ logger = logging.getLogger('doorman.gateway') def _derive_key_from_components(email: str, username: str, vault_key: str) -> bytes: """ Derive a Fernet-compatible encryption key from email, username, and vault key. - + Args: email: User's email address username: User's username vault_key: VAULT_KEY from environment - + Returns: bytes: 32-byte key suitable for Fernet encryption """ # Combine all components to create a unique salt - combined = f"{email}:{username}:{vault_key}" + combined = f'{email}:{username}:{vault_key}' salt = hashlib.sha256(combined.encode()).digest() - + # Use PBKDF2 to derive a key kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=100000, - backend=default_backend() + backend=default_backend(), ) - + # Derive key from the vault_key key = kdf.derive(vault_key.encode()) - + # Encode to base64 for Fernet compatibility return base64.urlsafe_b64encode(key) @@ -51,15 +52,15 @@ def _derive_key_from_components(email: str, username: str, vault_key: str) -> by def encrypt_vault_value(value: str, email: str, username: str) -> str: """ Encrypt a vault value using email, username, and VAULT_KEY from environment. - + Args: value: The plaintext value to encrypt email: User's email address username: User's username - + Returns: str: Base64-encoded encrypted value - + Raises: RuntimeError: If VAULT_KEY is not configured ValueError: If encryption fails @@ -67,36 +68,36 @@ def encrypt_vault_value(value: str, email: str, username: str) -> str: vault_key = os.getenv('VAULT_KEY') if not vault_key: raise RuntimeError('VAULT_KEY is not configured in environment variables') - + try: # Derive encryption key encryption_key = _derive_key_from_components(email, username, vault_key) - + # Create Fernet cipher cipher = Fernet(encryption_key) - + # Encrypt the value encrypted_bytes = cipher.encrypt(value.encode('utf-8')) - + # Return as base64 string return encrypted_bytes.decode('utf-8') except Exception as e: logger.error(f'Encryption failed: {str(e)}') - raise ValueError(f'Failed to encrypt vault value: {str(e)}') + raise ValueError(f'Failed to encrypt vault value: {str(e)}') from e def decrypt_vault_value(encrypted_value: str, email: str, username: str) -> str: """ Decrypt a vault value using email, username, and VAULT_KEY from environment. - + Args: encrypted_value: The base64-encoded encrypted value email: User's email address username: User's username - + Returns: str: Decrypted plaintext value - + Raises: RuntimeError: If VAULT_KEY is not configured ValueError: If decryption fails @@ -104,28 +105,28 @@ def decrypt_vault_value(encrypted_value: str, email: str, username: str) -> str: vault_key = os.getenv('VAULT_KEY') if not vault_key: raise RuntimeError('VAULT_KEY is not configured in environment variables') - + try: # Derive encryption key encryption_key = _derive_key_from_components(email, username, vault_key) - + # Create Fernet cipher cipher = Fernet(encryption_key) - + # Decrypt the value decrypted_bytes = cipher.decrypt(encrypted_value.encode('utf-8')) - + # Return as string return decrypted_bytes.decode('utf-8') except Exception as e: logger.error(f'Decryption failed: {str(e)}') - raise ValueError(f'Failed to decrypt vault value: {str(e)}') + raise ValueError(f'Failed to decrypt vault value: {str(e)}') from e def is_vault_configured() -> bool: """ Check if VAULT_KEY is configured in environment variables. - + Returns: bool: True if VAULT_KEY is set, False otherwise """ diff --git a/pyproject.toml b/pyproject.toml index 00db760..17ab294 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ select = [ "E", "F", "I", "UP", "B", "W" ] ignore = [ + "B904" # Exception chaining is a style preference, not a functional requirement ] [tool.ruff.lint.isort]