diff --git a/backend-services/doorman.py b/backend-services/doorman.py
index d579bc7..ff3f04d 100755
--- a/backend-services/doorman.py
+++ b/backend-services/doorman.py
@@ -4,51 +4,55 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from datetime import datetime, timedelta
+import asyncio
+import json
+import logging
+import multiprocessing
+import os
+import re
+import shutil
+import signal
+import subprocess
+import sys
+import time
+import uuid
+from contextlib import asynccontextmanager
+from datetime import timedelta
from logging.handlers import RotatingFileHandler
+from pathlib import Path
+
+import uvicorn
+from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
-from jose import jwt, JWTError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
-from starlette.middleware.base import BaseHTTPMiddleware
-from fastapi.responses import Response
-from contextlib import asynccontextmanager
-from redis.asyncio import Redis
+from jose import JWTError
from pydantic import BaseSettings
-from dotenv import load_dotenv
-import multiprocessing
-import logging
-import json
-import re
-import os
-import sys
-import subprocess
-import signal
-import uvicorn
-import time
-import asyncio
-import uuid
-import shutil
-from pathlib import Path
+from redis.asyncio import Redis
+from starlette.middleware.base import BaseHTTPMiddleware
try:
if sys.version_info >= (3, 13):
try:
- from importlib.metadata import version, PackageNotFoundError
+ from importlib.metadata import PackageNotFoundError, version
except Exception:
version = None
PackageNotFoundError = Exception
if version is not None:
try:
v = version('aiohttp')
- parts = [int(p) for p in (v.split('.')[:3] + ['0', '0'])[:3] if p.isdigit() or p.isnumeric()]
+ parts = [
+ int(p)
+ for p in (v.split('.')[:3] + ['0', '0'])[:3]
+ if p.isdigit() or p.isnumeric()
+ ]
while len(parts) < 3:
parts.append(0)
if tuple(parts) < (3, 10, 10):
raise SystemExit(
- f"Incompatible aiohttp {v} detected on Python {sys.version.split()[0]}. "
- "Please upgrade to aiohttp>=3.10.10 (pip install -U aiohttp) or run with Python 3.11."
+ f'Incompatible aiohttp {v} detected on Python {sys.version.split()[0]}. '
+ 'Please upgrade to aiohttp>=3.10.10 (pip install -U aiohttp) or run with Python 3.11.'
)
except PackageNotFoundError:
pass
@@ -58,47 +62,58 @@ except Exception:
pass
from models.response_model import ResponseModel
-from utils.cache_manager_util import cache_manager
-from utils.auth_blacklist import purge_expired_tokens
-from utils.doorman_cache_util import doorman_cache
-from utils.hot_reload_config import hot_config
-from routes.authorization_routes import authorization_router
-from routes.group_routes import group_router
-from routes.role_routes import role_router
-from routes.subscription_routes import subscription_router
-from routes.user_routes import user_router
+from routes.analytics_routes import analytics_router
from routes.api_routes import api_router
+from routes.authorization_routes import authorization_router
+from routes.config_hot_reload_routes import config_hot_reload_router
+from routes.config_routes import config_router
+from routes.credit_routes import credit_router
+from routes.dashboard_routes import dashboard_router
+from routes.demo_routes import demo_router
from routes.endpoint_routes import endpoint_router
from routes.gateway_routes import gateway_router
-from routes.routing_routes import routing_router
-from routes.proto_routes import proto_router
+from routes.group_routes import group_router
from routes.logging_routes import logging_router
-from routes.dashboard_routes import dashboard_router
from routes.memory_routes import memory_router
-from routes.security_routes import security_router
-from routes.credit_routes import credit_router
-from routes.demo_routes import demo_router
from routes.monitor_routes import monitor_router
-from routes.config_routes import config_router
-from routes.tools_routes import tools_router
-from routes.config_hot_reload_routes import config_hot_reload_router
-from routes.vault_routes import vault_router
-from routes.analytics_routes import analytics_router
-from routes.tier_routes import tier_router
-from routes.rate_limit_rule_routes import rate_limit_rule_router
+from routes.proto_routes import proto_router
from routes.quota_routes import quota_router
-from utils.security_settings_util import load_settings, start_auto_save_task, stop_auto_save_task, get_cached_settings
-from utils.memory_dump_util import dump_memory_to_file, restore_memory_from_file, find_latest_dump_path
-from utils.metrics_util import metrics_store
-from utils.database import database
-from utils.response_util import process_response
+from routes.rate_limit_rule_routes import rate_limit_rule_router
+from routes.role_routes import role_router
+from routes.routing_routes import routing_router
+from routes.security_routes import security_router
+from routes.subscription_routes import subscription_router
+from routes.tier_routes import tier_router
+from routes.tools_routes import tools_router
+from routes.user_routes import user_router
+from routes.vault_routes import vault_router
from utils.audit_util import audit
-from utils.ip_policy_util import _get_client_ip as _policy_get_client_ip, _ip_in_list as _policy_ip_in_list, _is_loopback as _policy_is_loopback
+from utils.auth_blacklist import purge_expired_tokens
+from utils.cache_manager_util import cache_manager
+from utils.database import database
+from utils.hot_reload_config import hot_config
+from utils.ip_policy_util import _get_client_ip as _policy_get_client_ip
+from utils.ip_policy_util import _ip_in_list as _policy_ip_in_list
+from utils.ip_policy_util import _is_loopback as _policy_is_loopback
+from utils.memory_dump_util import (
+ dump_memory_to_file,
+ find_latest_dump_path,
+ restore_memory_from_file,
+)
+from utils.metrics_util import metrics_store
+from utils.response_util import process_response
+from utils.security_settings_util import (
+ get_cached_settings,
+ load_settings,
+ start_auto_save_task,
+ stop_auto_save_task,
+)
load_dotenv()
PID_FILE = 'doorman.pid'
+
def _migrate_generated_directory() -> None:
"""Migrate legacy root-level 'generated/' into backend-services/generated.
@@ -114,7 +129,7 @@ def _migrate_generated_directory() -> None:
return
if not src.exists() or not src.is_dir():
dst.mkdir(exist_ok=True)
- gateway_logger.info(f"Generated dir: {dst} (no migration needed)")
+ gateway_logger.info(f'Generated dir: {dst} (no migration needed)')
return
dst.mkdir(parents=True, exist_ok=True)
moved_count = 0
@@ -137,37 +152,37 @@ def _migrate_generated_directory() -> None:
shutil.rmtree(src)
except Exception:
pass
- gateway_logger.info(f"Generated dir migrated: {moved_count} file(s) moved to {dst}")
+ gateway_logger.info(f'Generated dir migrated: {moved_count} file(s) moved to {dst}')
except Exception as e:
try:
- gateway_logger.warning(f"Generated dir migration skipped: {e}")
+ gateway_logger.warning(f'Generated dir migration skipped: {e}')
except Exception:
pass
+
async def validate_database_connections():
"""Validate database connections on startup with retry logic"""
- gateway_logger.info("Validating database connections...")
+ gateway_logger.info('Validating database connections...')
max_retries = 3
for attempt in range(max_retries):
try:
from utils.database import user_collection
+
await user_collection.find_one({})
- gateway_logger.info("✓ MongoDB connection verified")
+ gateway_logger.info('✓ MongoDB connection verified')
break
except Exception as e:
if attempt < max_retries - 1:
- wait = 2 ** attempt
+ wait = 2**attempt
gateway_logger.warning(
- f"MongoDB connection attempt {attempt + 1}/{max_retries} failed: {e}"
+ f'MongoDB connection attempt {attempt + 1}/{max_retries} failed: {e}'
)
- gateway_logger.info(f"Retrying in {wait} seconds...")
+ gateway_logger.info(f'Retrying in {wait} seconds...')
await asyncio.sleep(wait)
else:
- gateway_logger.error(f"MongoDB connection failed after {max_retries} attempts")
- raise RuntimeError(
- f"Cannot connect to MongoDB: {e}"
- ) from e
+ gateway_logger.error(f'MongoDB connection failed after {max_retries} attempts')
+ raise RuntimeError(f'Cannot connect to MongoDB: {e}') from e
redis_host = os.getenv('REDIS_HOST')
mem_or_external = os.getenv('MEM_OR_EXTERNAL', 'MEM')
@@ -176,32 +191,32 @@ async def validate_database_connections():
for attempt in range(max_retries):
try:
import redis.asyncio as redis
- redis_url = f"redis://{redis_host}:{os.getenv('REDIS_PORT', '6379')}"
+
+ redis_url = f'redis://{redis_host}:{os.getenv("REDIS_PORT", "6379")}'
if os.getenv('REDIS_PASSWORD'):
- redis_url = f"redis://:{os.getenv('REDIS_PASSWORD')}@{redis_host}:{os.getenv('REDIS_PORT', '6379')}"
+ redis_url = f'redis://:{os.getenv("REDIS_PASSWORD")}@{redis_host}:{os.getenv("REDIS_PORT", "6379")}'
r = redis.from_url(redis_url)
await r.ping()
await r.close()
- gateway_logger.info("✓ Redis connection verified")
+ gateway_logger.info('✓ Redis connection verified')
break
except Exception as e:
if attempt < max_retries - 1:
- wait = 2 ** attempt
+ wait = 2**attempt
gateway_logger.warning(
- f"Redis connection attempt {attempt + 1}/{max_retries} failed: {e}"
+ f'Redis connection attempt {attempt + 1}/{max_retries} failed: {e}'
)
- gateway_logger.info(f"Retrying in {wait} seconds...")
+ gateway_logger.info(f'Retrying in {wait} seconds...')
await asyncio.sleep(wait)
else:
- gateway_logger.error(f"Redis connection failed after {max_retries} attempts")
- raise RuntimeError(
- f"Cannot connect to Redis: {e}"
- ) from e
+ gateway_logger.error(f'Redis connection failed after {max_retries} attempts')
+ raise RuntimeError(f'Cannot connect to Redis: {e}') from e
- gateway_logger.info("All database connections validated successfully")
+ gateway_logger.info('All database connections validated successfully')
-def validate_token_revocation_config():
+
+def validate_token_revocation_config() -> None:
"""
Validate token revocation is safe for multi-worker deployments.
"""
@@ -209,21 +224,20 @@ def validate_token_revocation_config():
mem_mode = os.getenv('MEM_OR_EXTERNAL', 'MEM')
if threads > 1 and mem_mode == 'MEM':
gateway_logger.error(
- "CRITICAL: Multi-worker mode (THREADS > 1) with in-memory storage "
- "does not provide consistent token revocation across workers. "
- f"Current config: THREADS={threads}, MEM_OR_EXTERNAL={mem_mode}"
+ 'CRITICAL: Multi-worker mode (THREADS > 1) with in-memory storage '
+ 'does not provide consistent token revocation across workers. '
+ f'Current config: THREADS={threads}, MEM_OR_EXTERNAL={mem_mode}'
)
gateway_logger.error(
- "Token revocation requires Redis in multi-worker mode. "
- "Either set MEM_OR_EXTERNAL=REDIS or set THREADS=1"
+ 'Token revocation requires Redis in multi-worker mode. '
+ 'Either set MEM_OR_EXTERNAL=REDIS or set THREADS=1'
)
raise RuntimeError(
- "Token revocation requires Redis in multi-worker mode (THREADS > 1). "
- "Set MEM_OR_EXTERNAL=REDIS or THREADS=1"
+ 'Token revocation requires Redis in multi-worker mode (THREADS > 1). '
+ 'Set MEM_OR_EXTERNAL=REDIS or THREADS=1'
)
- gateway_logger.info(
- f"Token revocation mode: {mem_mode} with {threads} worker(s)"
- )
+ gateway_logger.info(f'Token revocation mode: {mem_mode} with {threads} worker(s)')
+
@asynccontextmanager
async def app_lifespan(app: FastAPI):
@@ -250,7 +264,12 @@ async def app_lifespan(app: FastAPI):
)
jwt_secret = os.getenv('JWT_SECRET_KEY', '')
- if jwt_secret in ('please-change-me', 'test-secret-key', 'test-secret-key-please-change', ''):
+ if jwt_secret in (
+ 'please-change-me',
+ 'test-secret-key',
+ 'test-secret-key-please-change',
+ '',
+ ):
raise RuntimeError(
'In production (ENV=production), JWT_SECRET_KEY must be changed from default value. '
'Generate a strong random secret (32+ characters).'
@@ -294,7 +313,7 @@ async def app_lifespan(app: FastAPI):
'Memory dumps contain sensitive data and must be encrypted. '
'Generate a strong random key: openssl rand -hex 32'
)
- except Exception as e:
+ except Exception:
raise
mem_or_external = os.getenv('MEM_OR_EXTERNAL', 'MEM').upper()
@@ -337,6 +356,7 @@ async def app_lifespan(app: FastAPI):
break
except Exception:
pass
+
try:
app.state._metrics_save_task = asyncio.create_task(_metrics_autosave(60))
except Exception:
@@ -350,8 +370,12 @@ async def app_lifespan(app: FastAPI):
try:
settings = get_cached_settings()
- if bool(settings.get('trust_x_forwarded_for')) and not (settings.get('xff_trusted_proxies') or []):
- gateway_logger.warning('Security: trust_x_forwarded_for enabled but xff_trusted_proxies is empty; header spoofing risk. Configure trusted proxy IPs/CIDRs.')
+ if bool(settings.get('trust_x_forwarded_for')) and not (
+ settings.get('xff_trusted_proxies') or []
+ ):
+ gateway_logger.warning(
+ 'Security: trust_x_forwarded_for enabled but xff_trusted_proxies is empty; header spoofing risk. Configure trusted proxy IPs/CIDRs.'
+ )
if os.getenv('ENV', '').lower() == 'production':
raise RuntimeError(
@@ -364,7 +388,7 @@ async def app_lifespan(app: FastAPI):
gateway_logger.debug(f'Startup security checks skipped: {e}')
try:
- spec = app.openapi()
+ app.openapi()
problems = []
for route in app.routes:
path = getattr(route, 'path', '')
@@ -390,7 +414,9 @@ async def app_lifespan(app: FastAPI):
latest_path = find_latest_dump_path(hint)
if latest_path and os.path.exists(latest_path):
info = restore_memory_from_file(latest_path)
- gateway_logger.info(f"Memory mode: restored from dump {latest_path} (created_at={info.get('created_at')})")
+ gateway_logger.info(
+ f'Memory mode: restored from dump {latest_path} (created_at={info.get("created_at")})'
+ )
else:
gateway_logger.info('Memory mode: no existing dump found to restore')
except Exception as e:
@@ -406,7 +432,9 @@ async def app_lifespan(app: FastAPI):
gateway_logger.info('SIGUSR1 ignored: not in memory-only mode')
return
if not os.getenv('MEM_ENCRYPTION_KEY'):
- gateway_logger.error('SIGUSR1 dump skipped: MEM_ENCRYPTION_KEY not configured')
+ gateway_logger.error(
+ 'SIGUSR1 dump skipped: MEM_ENCRYPTION_KEY not configured'
+ )
return
settings = get_cached_settings()
path_hint = settings.get('dump_path')
@@ -418,7 +446,6 @@ async def app_lifespan(app: FastAPI):
loop.add_signal_handler(signal.SIGUSR1, lambda: asyncio.create_task(_sigusr1_dump()))
gateway_logger.info('SIGUSR1 handler registered for on-demand memory dumps')
except NotImplementedError:
-
pass
try:
@@ -451,9 +478,9 @@ async def app_lifespan(app: FastAPI):
try:
yield
finally:
- gateway_logger.info("Starting graceful shutdown...")
+ gateway_logger.info('Starting graceful shutdown...')
app.state.shutting_down = True
- gateway_logger.info("Waiting for in-flight requests to complete (5s grace period)...")
+ gateway_logger.info('Waiting for in-flight requests to complete (5s grace period)...')
await asyncio.sleep(5)
try:
await stop_auto_save_task()
@@ -474,19 +501,21 @@ async def app_lifespan(app: FastAPI):
except Exception:
pass
try:
- gateway_logger.info("Closing database connections...")
+ gateway_logger.info('Closing database connections...')
from utils.database import close_database_connections
+
close_database_connections()
except Exception as e:
- gateway_logger.error(f"Error closing database connections: {e}")
+ gateway_logger.error(f'Error closing database connections: {e}')
try:
- gateway_logger.info("Closing HTTP clients...")
+ gateway_logger.info('Closing HTTP clients...')
from services.gateway_service import GatewayService
+
if hasattr(GatewayService, '_http_client') and GatewayService._http_client:
await GatewayService._http_client.aclose()
- gateway_logger.info("HTTP client closed")
+ gateway_logger.info('HTTP client closed')
except Exception as e:
- gateway_logger.error(f"Error closing HTTP client: {e}")
+ gateway_logger.error(f'Error closing HTTP client: {e}')
try:
METRICS_FILE = os.path.join(LOGS_DIR, 'metrics.json')
@@ -494,7 +523,7 @@ async def app_lifespan(app: FastAPI):
except Exception:
pass
- gateway_logger.info("Graceful shutdown complete")
+ gateway_logger.info('Graceful shutdown complete')
try:
t = getattr(app.state, '_metrics_save_task', None)
if t:
@@ -504,9 +533,11 @@ async def app_lifespan(app: FastAPI):
try:
from services.gateway_service import GatewayService as _GS
+
if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'true').lower() != 'false':
try:
import asyncio as _asyncio
+
if _asyncio.iscoroutinefunction(_GS.aclose_http_client):
await _GS.aclose_http_client()
except Exception:
@@ -514,15 +545,17 @@ async def app_lifespan(app: FastAPI):
except Exception:
pass
-def _generate_unique_id(route):
+
+def _generate_unique_id(route: dict) -> str:
try:
name = getattr(route, 'name', 'op') or 'op'
path = getattr(route, 'path', '').replace('/', '_').replace('{', '').replace('}', '')
methods = '_'.join(sorted(list(getattr(route, 'methods', []) or [])))
- return f"{name}_{methods}_{path}".lower()
+ return f'{name}_{methods}_{path}'.lower()
except Exception:
return (getattr(route, 'name', 'op') or 'op').lower()
+
doorman = FastAPI(
title='doorman',
description="A lightweight API gateway for AI, REST, SOAP, GraphQL, gRPC, and WebSocket APIs — fully managed with built-in RESTful APIs for configuration and control. This is your application's gateway to the world.",
@@ -534,13 +567,13 @@ doorman = FastAPI(
# Add CORS middleware
# Starlette CORS middleware is disabled by default because platform and per-API
# CORS are enforced explicitly below. Enable only if requested via env.
-if os.getenv('ENABLE_STARLETTE_CORS', 'false').lower() in ('1','true','yes','on'):
+if os.getenv('ENABLE_STARLETTE_CORS', 'false').lower() in ('1', 'true', 'yes', 'on'):
doorman.add_middleware(
CORSMiddleware,
- allow_origins=["*"], # In production, replace with specific origins
+ allow_origins=['*'], # In production, replace with specific origins
allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"], # This will include X-Requested-With
+ allow_methods=['*'],
+ allow_headers=['*'], # This will include X-Requested-With
expose_headers=[],
max_age=600,
)
@@ -550,7 +583,8 @@ domain = os.getenv('COOKIE_DOMAIN', 'localhost')
# - API gateway routes (/api/*): CORS controlled per-API in gateway routes/services
-def _platform_cors_config():
+
+def _platform_cors_config() -> dict:
"""Compute platform CORS config from environment.
Env vars:
@@ -561,16 +595,35 @@ def _platform_cors_config():
- CORS_STRICT: true/false (when true, do not echo wildcard origins with credentials)
"""
import os as _os
- strict = _os.getenv('CORS_STRICT', 'false').lower() in ('1','true','yes','on')
- allowed_origins = [o.strip() for o in (_os.getenv('ALLOWED_ORIGINS') or '').split(',') if o.strip()] or ['*']
- allow_methods = [m.strip() for m in (_os.getenv('ALLOW_METHODS') or 'GET,POST,PUT,DELETE,OPTIONS,PATCH,HEAD').split(',') if m.strip()]
+
+ strict = _os.getenv('CORS_STRICT', 'false').lower() in ('1', 'true', 'yes', 'on')
+ allowed_origins = [
+ o.strip() for o in (_os.getenv('ALLOWED_ORIGINS') or '').split(',') if o.strip()
+ ] or ['*']
+ allow_methods = [
+ m.strip()
+ for m in (_os.getenv('ALLOW_METHODS') or 'GET,POST,PUT,DELETE,OPTIONS,PATCH,HEAD').split(
+ ','
+ )
+ if m.strip()
+ ]
allow_headers_env = _os.getenv('ALLOW_HEADERS') or ''
if allow_headers_env.strip() == '*':
# Default to a known, minimal safe list when wildcard requested
- allow_headers = ['Accept','Content-Type','X-CSRF-Token','Authorization']
+ allow_headers = ['Accept', 'Content-Type', 'X-CSRF-Token', 'Authorization']
else:
- allow_headers = [h.strip() for h in allow_headers_env.split(',') if h.strip()] or ['Accept','Content-Type','X-CSRF-Token','Authorization']
- allow_credentials = _os.getenv('ALLOW_CREDENTIALS', 'false').lower() in ('1','true','yes','on')
+ allow_headers = [h.strip() for h in allow_headers_env.split(',') if h.strip()] or [
+ 'Accept',
+ 'Content-Type',
+ 'X-CSRF-Token',
+ 'Authorization',
+ ]
+ allow_credentials = _os.getenv('ALLOW_CREDENTIALS', 'false').lower() in (
+ '1',
+ 'true',
+ 'yes',
+ 'on',
+ )
return {
'strict': strict,
@@ -596,10 +649,10 @@ async def platform_cors(request: Request, call_next):
try:
lo = origin.lower()
origin_allowed = (
- lo.startswith('http://localhost') or
- lo.startswith('https://localhost') or
- lo.startswith('http://127.0.0.1') or
- lo.startswith('https://127.0.0.1')
+ lo.startswith('http://localhost')
+ or lo.startswith('https://localhost')
+ or lo.startswith('http://127.0.0.1')
+ or lo.startswith('https://127.0.0.1')
)
except Exception:
origin_allowed = False
@@ -610,6 +663,7 @@ async def platform_cors(request: Request, call_next):
if request.method.upper() == 'OPTIONS':
from fastapi.responses import Response as _Resp
+
headers = {}
if origin_allowed:
headers['Access-Control-Allow-Origin'] = origin
@@ -641,8 +695,10 @@ async def platform_cors(request: Request, call_next):
pass
return await call_next(request)
+
MAX_BODY_SIZE = int(os.getenv('MAX_BODY_SIZE_BYTES', 1_048_576))
+
def _get_max_body_size() -> int:
try:
v = os.getenv('MAX_BODY_SIZE_BYTES')
@@ -652,6 +708,7 @@ def _get_max_body_size() -> int:
except Exception:
return MAX_BODY_SIZE
+
class LimitedStreamReader:
"""
Wrapper around ASGI receive channel that enforces size limits on chunked requests.
@@ -659,6 +716,7 @@ class LimitedStreamReader:
Prevents Transfer-Encoding: chunked bypass by tracking accumulated size
and rejecting streams that exceed the limit.
"""
+
def __init__(self, receive, max_size: int):
self.receive = receive
self.max_size = max_size
@@ -681,6 +739,7 @@ class LimitedStreamReader:
return message
+
@doorman.middleware('http')
async def body_size_limit(request: Request, call_next):
"""Enforce request body size limits to prevent DoS attacks.
@@ -700,7 +759,7 @@ async def body_size_limit(request: Request, call_next):
- /api/grpc/*: Enforce on gRPC JSON payloads
"""
try:
- if os.getenv('DISABLE_BODY_SIZE_LIMIT', 'false').lower() in ('1','true','yes','on'):
+ if os.getenv('DISABLE_BODY_SIZE_LIMIT', 'false').lower() in ('1', 'true', 'yes', 'on'):
return await call_next(request)
path = str(request.url.path)
@@ -708,7 +767,9 @@ async def body_size_limit(request: Request, call_next):
raw_excludes = os.getenv('BODY_LIMIT_EXCLUDE_PATHS', '')
if raw_excludes:
excludes = [p.strip() for p in raw_excludes.split(',') if p.strip()]
- if any(path == p or (p.endswith('*') and path.startswith(p[:-1])) for p in excludes):
+ if any(
+ path == p or (p.endswith('*') and path.startswith(p[:-1])) for p in excludes
+ ):
return await call_next(request)
except Exception:
pass
@@ -725,10 +786,13 @@ async def body_size_limit(request: Request, call_next):
try:
from models.response_model import ResponseModel as _RM
from utils.response_util import process_response as _pr
- return _pr(_RM(
- status_code=200,
- message='Settings updated (middleware bypass)'
- ).dict(), 'rest')
+
+ return _pr(
+ _RM(
+ status_code=200, message='Settings updated (middleware bypass)'
+ ).dict(),
+ 'rest',
+ )
except Exception:
pass
raise
@@ -768,6 +832,7 @@ async def body_size_limit(request: Request, call_next):
if content_length > limit:
try:
from utils.audit_util import audit
+
audit(
request,
actor=None,
@@ -778,17 +843,20 @@ async def body_size_limit(request: Request, call_next):
'content_length': content_length,
'limit': limit,
'content_type': request.headers.get('content-type'),
- 'transfer_encoding': transfer_encoding or None
- }
+ 'transfer_encoding': transfer_encoding or None,
+ },
)
except Exception:
pass
- return process_response(ResponseModel(
- status_code=413,
- error_code='REQ001',
- error_message=f'Request entity too large (max: {limit} bytes)'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=413,
+ error_code='REQ001',
+ error_message=f'Request entity too large (max: {limit} bytes)',
+ ).dict(),
+ 'rest',
+ )
except (ValueError, TypeError):
pass
@@ -798,7 +866,7 @@ async def body_size_limit(request: Request, call_next):
try:
env_flag = os.getenv('DISABLE_PLATFORM_CHUNKED_WRAP')
if isinstance(env_flag, str) and env_flag.strip() != '':
- if env_flag.strip().lower() in ('1','true','yes','on'):
+ if env_flag.strip().lower() in ('1', 'true', 'yes', 'on'):
wrap_allowed = False
if str(path) == '/platform/authorization':
wrap_allowed = True
@@ -814,9 +882,12 @@ async def body_size_limit(request: Request, call_next):
response = await call_next(request)
try:
- if wrap_allowed and (limited_reader.over_limit or limited_reader.bytes_received > limit):
+ if wrap_allowed and (
+ limited_reader.over_limit or limited_reader.bytes_received > limit
+ ):
try:
from utils.audit_util import audit
+
audit(
request,
actor=None,
@@ -827,29 +898,37 @@ async def body_size_limit(request: Request, call_next):
'bytes_received': limited_reader.bytes_received,
'limit': limit,
'content_type': request.headers.get('content-type'),
- 'transfer_encoding': transfer_encoding or 'chunked'
- }
+ 'transfer_encoding': transfer_encoding or 'chunked',
+ },
)
except Exception:
pass
- return process_response(ResponseModel(
- status_code=413,
- error_code='REQ001',
- error_message=f'Request entity too large (max: {limit} bytes)'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=413,
+ error_code='REQ001',
+ error_message=f'Request entity too large (max: {limit} bytes)',
+ ).dict(),
+ 'rest',
+ )
except Exception:
pass
return response
- except Exception as e:
+ except Exception:
try:
- if wrap_allowed and (limited_reader.over_limit or limited_reader.bytes_received > limit):
- return process_response(ResponseModel(
- status_code=413,
- error_code='REQ001',
- error_message=f'Request entity too large (max: {limit} bytes)'
- ).dict(), 'rest')
+ if wrap_allowed and (
+ limited_reader.over_limit or limited_reader.bytes_received > limit
+ ):
+ return process_response(
+ ResponseModel(
+ status_code=413,
+ error_code='REQ001',
+ error_message=f'Request entity too large (max: {limit} bytes)',
+ ).dict(),
+ 'rest',
+ )
except Exception:
pass
raise
@@ -873,6 +952,7 @@ async def body_size_limit(request: Request, call_next):
else:
try:
import anyio
+
if isinstance(e, getattr(anyio, 'EndOfStream', tuple())):
swallow = True
except Exception:
@@ -882,16 +962,20 @@ async def body_size_limit(request: Request, call_next):
if swallow and _RM and _pr:
try:
- return _pr(_RM(
- status_code=500,
- error_code='GTW998',
- error_message='Upstream handler failed to produce a response'
- ).dict(), 'rest')
+ return _pr(
+ _RM(
+ status_code=500,
+ error_code='GTW998',
+ error_message='Upstream handler failed to produce a response',
+ ).dict(),
+ 'rest',
+ )
except Exception:
pass
raise
+
class PlatformCORSMiddleware:
"""ASGI-level CORS for /platform/* routes only.
@@ -899,13 +983,19 @@ class PlatformCORSMiddleware:
interfere with /api/* paths. It also respects DISABLE_PLATFORM_CORS_ASGI:
when set to true, this middleware becomes a no-op.
"""
+
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
# If explicitly disabled, act as a passthrough
try:
- if os.getenv('DISABLE_PLATFORM_CORS_ASGI', 'false').lower() in ('1','true','yes','on'):
+ if os.getenv('DISABLE_PLATFORM_CORS_ASGI', 'false').lower() in (
+ '1',
+ 'true',
+ 'yes',
+ 'on',
+ ):
return await self.app(scope, receive, send)
except Exception:
pass
@@ -921,7 +1011,7 @@ class PlatformCORSMiddleware:
cfg = _platform_cors_config()
hdrs = {}
try:
- for k, v in (scope.get('headers') or []):
+ for k, v in scope.get('headers') or []:
hdrs[k.decode('latin1').lower()] = v.decode('latin1')
except Exception:
pass
@@ -934,10 +1024,10 @@ class PlatformCORSMiddleware:
if cfg['strict'] and cfg['credentials']:
lo = origin.lower()
origin_allowed = (
- lo.startswith('http://localhost') or
- lo.startswith('https://localhost') or
- lo.startswith('http://127.0.0.1') or
- lo.startswith('https://127.0.0.1')
+ lo.startswith('http://localhost')
+ or lo.startswith('https://localhost')
+ or lo.startswith('http://127.0.0.1')
+ or lo.startswith('https://127.0.0.1')
)
else:
origin_allowed = True
@@ -949,8 +1039,12 @@ class PlatformCORSMiddleware:
if origin_allowed and origin:
headers.append((b'access-control-allow-origin', origin.encode('latin1')))
headers.append((b'vary', b'Origin'))
- headers.append((b'access-control-allow-methods', ', '.join(cfg['methods']).encode('latin1')))
- headers.append((b'access-control-allow-headers', ', '.join(cfg['headers']).encode('latin1')))
+ headers.append(
+ (b'access-control-allow-methods', ', '.join(cfg['methods']).encode('latin1'))
+ )
+ headers.append(
+ (b'access-control-allow-headers', ', '.join(cfg['headers']).encode('latin1'))
+ )
if cfg['credentials']:
headers.append((b'access-control-allow-credentials', b'true'))
rid = hdrs.get('x-request-id')
@@ -969,15 +1063,20 @@ class PlatformCORSMiddleware:
except Exception:
return await self.app(scope, receive, send)
+
doorman.add_middleware(PlatformCORSMiddleware)
# Add tier-based rate limiting middleware
try:
from middleware.tier_rate_limit_middleware import TierRateLimitMiddleware
+
doorman.add_middleware(TierRateLimitMiddleware)
- logging.getLogger('doorman.gateway').info("Tier-based rate limiting middleware enabled")
+ logging.getLogger('doorman.gateway').info('Tier-based rate limiting middleware enabled')
except Exception as e:
- logging.getLogger('doorman.gateway').warning(f"Failed to enable tier rate limiting middleware: {e}")
+ logging.getLogger('doorman.gateway').warning(
+ f'Failed to enable tier rate limiting middleware: {e}'
+ )
+
@doorman.middleware('http')
async def request_id_middleware(request: Request, call_next):
@@ -997,6 +1096,7 @@ async def request_id_middleware(request: Request, call_next):
try:
from utils.correlation_util import set_correlation_id
+
set_correlation_id(rid)
except Exception:
pass
@@ -1005,7 +1105,9 @@ async def request_id_middleware(request: Request, call_next):
trust_xff = bool(settings.get('trust_x_forwarded_for'))
direct_ip = getattr(getattr(request, 'client', None), 'host', None)
effective_ip = _policy_get_client_ip(request, trust_xff)
- gateway_logger.info(f"{rid} | Entry: client_ip={direct_ip} effective_ip={effective_ip} method={request.method} path={str(request.url.path)}")
+ gateway_logger.info(
+ f'{rid} | Entry: client_ip={direct_ip} effective_ip={effective_ip} method={request.method} path={str(request.url.path)}'
+ )
except Exception:
pass
response = await call_next(request)
@@ -1019,6 +1121,7 @@ async def request_id_middleware(request: Request, call_next):
gateway_logger.error(f'Request ID middleware error: {str(e)}', exc_info=True)
raise
+
@doorman.middleware('http')
async def security_headers(request: Request, call_next):
response = await call_next(request)
@@ -1026,29 +1129,33 @@ async def security_headers(request: Request, call_next):
response.headers.setdefault('X-Content-Type-Options', 'nosniff')
response.headers.setdefault('X-Frame-Options', 'DENY')
response.headers.setdefault('Referrer-Policy', 'no-referrer')
- response.headers.setdefault('Permissions-Policy', 'geolocation=(), microphone=(), camera=()')
+ response.headers.setdefault(
+ 'Permissions-Policy', 'geolocation=(), microphone=(), camera=()'
+ )
try:
csp = os.getenv('CONTENT_SECURITY_POLICY')
if csp is None or not csp.strip():
-
- csp =\
- "default-src 'none'; "\
- "frame-ancestors 'none'; "\
- "base-uri 'none'; "\
- "form-action 'self'; "\
- "img-src 'self' data:; "\
+ csp = (
+ "default-src 'none'; "
+ "frame-ancestors 'none'; "
+ "base-uri 'none'; "
+ "form-action 'self'; "
+ "img-src 'self' data:; "
"connect-src 'self';"
+ )
response.headers.setdefault('Content-Security-Policy', csp)
except Exception:
pass
if os.getenv('HTTPS_ONLY', 'false').lower() == 'true':
-
- response.headers.setdefault('Strict-Transport-Security', 'max-age=15552000; includeSubDomains; preload')
+ response.headers.setdefault(
+ 'Strict-Transport-Security', 'max-age=15552000; includeSubDomains; preload'
+ )
except Exception:
pass
return response
+
"""Logging configuration
Prefer file logging to LOGS_DIR/doorman.log when writable; otherwise, fall back
@@ -1058,7 +1165,10 @@ Respects LOG_FORMAT=json|plain.
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
_env_logs_dir = os.getenv('LOGS_DIR')
-LOGS_DIR = os.path.abspath(_env_logs_dir) if _env_logs_dir else os.path.join(BASE_DIR, 'platform-logs')
+LOGS_DIR = (
+ os.path.abspath(_env_logs_dir) if _env_logs_dir else os.path.join(BASE_DIR, 'platform-logs')
+)
+
# Build formatters
class JSONFormatter(logging.Formatter):
@@ -1074,6 +1184,7 @@ class JSONFormatter(logging.Formatter):
except Exception:
return f'{payload}'
+
_fmt_is_json = os.getenv('LOG_FORMAT', 'plain').lower() == 'json'
_file_handler = None
try:
@@ -1082,21 +1193,29 @@ try:
filename=os.path.join(LOGS_DIR, 'doorman.log'),
maxBytes=10 * 1024 * 1024,
backupCount=5,
- encoding='utf-8'
+ encoding='utf-8',
+ )
+ _file_handler.setFormatter(
+ JSONFormatter()
+ if _fmt_is_json
+ else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
- _file_handler.setFormatter(JSONFormatter() if _fmt_is_json else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
except Exception as _e:
- logging.getLogger('doorman.gateway').warning(f'File logging disabled ({_e}); using console logging only')
+ logging.getLogger('doorman.gateway').warning(
+ f'File logging disabled ({_e}); using console logging only'
+ )
_file_handler = None
+
# Configure all doorman loggers to use the same handler and prevent propagation
-def configure_logger(logger_name):
+def configure_logger(logger_name: str) -> logging.Logger:
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
logger.propagate = False
for handler in logger.handlers[:]:
logger.removeHandler(handler)
+
class RedactFilter(logging.Filter):
"""Comprehensive logging redaction filter for sensitive data.
@@ -1111,30 +1230,25 @@ def configure_logger(logger_name):
PATTERNS = [
re.compile(r'(?i)(authorization\s*[:=]\s*)([^;\r\n]+)'),
-
re.compile(r'(?i)(x-api-key\s*[:=]\s*)([^;\r\n]+)'),
re.compile(r'(?i)(api[_-]?key\s*[:=]\s*)([^;\r\n]+)'),
re.compile(r'(?i)(api[_-]?secret\s*[:=]\s*)([^;\r\n]+)'),
-
re.compile(r'(?i)(access[_-]?token\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
re.compile(r'(?i)(refresh[_-]?token\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
re.compile(r'(?i)(token\s*["\']?\s*[:=]\s*["\']?)([a-zA-Z0-9_\-\.]{20,})(["\']?)'),
-
re.compile(r'(?i)(password\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n]+)(["\']?)'),
re.compile(r'(?i)(secret\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
re.compile(r'(?i)(client[_-]?secret\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
-
re.compile(r'(?i)(cookie\s*[:=]\s*)([^;\r\n]+)'),
re.compile(r'(?i)(set-cookie\s*[:=]\s*)([^;\r\n]+)'),
-
re.compile(r'(?i)(x-csrf-token\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
re.compile(r'(?i)(csrf[_-]?token\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
-
re.compile(r'\b(eyJ[a-zA-Z0-9_\-]+\.eyJ[a-zA-Z0-9_\-]+\.[a-zA-Z0-9_\-]+)\b'),
-
re.compile(r'(?i)(session[_-]?id\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
-
- re.compile(r'(-----BEGIN[A-Z\s]+PRIVATE KEY-----)(.*?)(-----END[A-Z\s]+PRIVATE KEY-----)', re.DOTALL),
+ re.compile(
+ r'(-----BEGIN[A-Z\s]+PRIVATE KEY-----)(.*?)(-----END[A-Z\s]+PRIVATE KEY-----)',
+ re.DOTALL,
+ ),
]
def filter(self, record: logging.LogRecord) -> bool:
@@ -1146,11 +1260,14 @@ def configure_logger(logger_name):
if pat.groups == 3 and pat.flags & re.DOTALL:
red = pat.sub(r'\1[REDACTED]\3', red)
elif pat.groups >= 2:
- red = pat.sub(lambda m: (
- m.group(1) +
- '[REDACTED]' +
- (m.group(3) if m.lastindex and m.lastindex >= 3 else '')
- ), red)
+ red = pat.sub(
+ lambda m: (
+ m.group(1)
+ + '[REDACTED]'
+ + (m.group(3) if m.lastindex and m.lastindex >= 3 else '')
+ ),
+ red,
+ )
else:
red = pat.sub('[REDACTED]', red)
@@ -1159,7 +1276,15 @@ def configure_logger(logger_name):
if hasattr(record, 'args') and record.args:
try:
if isinstance(record.args, dict):
- record.args = {k: '[REDACTED]' if 'token' in str(k).lower() or 'password' in str(k).lower() or 'secret' in str(k).lower() or 'authorization' in str(k).lower() else v for k, v in record.args.items()}
+ record.args = {
+ k: '[REDACTED]'
+ if 'token' in str(k).lower()
+ or 'password' in str(k).lower()
+ or 'secret' in str(k).lower()
+ or 'authorization' in str(k).lower()
+ else v
+ for k, v in record.args.items()
+ }
except Exception:
pass
except Exception:
@@ -1168,17 +1293,23 @@ def configure_logger(logger_name):
console = logging.StreamHandler(stream=sys.stdout)
console.setLevel(logging.INFO)
- console.setFormatter(JSONFormatter() if _fmt_is_json else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
+ console.setFormatter(
+ JSONFormatter()
+ if _fmt_is_json
+ else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ )
console.addFilter(RedactFilter())
logger.addHandler(console)
if _file_handler is not None:
-
- if not any(isinstance(f, logging.Filter) and hasattr(f, 'PATTERNS') for f in _file_handler.filters):
+ if not any(
+ isinstance(f, logging.Filter) and hasattr(f, 'PATTERNS') for f in _file_handler.filters
+ ):
_file_handler.addFilter(RedactFilter())
logger.addHandler(_file_handler)
return logger
+
gateway_logger = configure_logger('doorman.gateway')
logging_logger = configure_logger('doorman.logging')
@@ -1198,9 +1329,7 @@ try:
compression_level = 1
doorman.add_middleware(
- GZipMiddleware,
- minimum_size=compression_minimum_size,
- compresslevel=compression_level
+ GZipMiddleware, minimum_size=compression_minimum_size, compresslevel=compression_level
)
gateway_logger.info(
f'Response compression enabled: level={compression_level}, '
@@ -1211,6 +1340,7 @@ try:
except Exception as e:
gateway_logger.warning(f'Failed to configure response compression: {e}. Compression disabled.')
+
# Ensure platform responses set Vary=Origin (and not Accept-Encoding) for CORS tests.
class _VaryOriginMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
@@ -1228,6 +1358,7 @@ class _VaryOriginMiddleware(BaseHTTPMiddleware):
pass
return response
+
doorman.add_middleware(_VaryOriginMiddleware)
# Now that logging is configured, attempt to migrate any legacy 'generated/' dir
@@ -1248,9 +1379,13 @@ try:
filename=os.path.join(LOGS_DIR, 'doorman-trail.log'),
maxBytes=10 * 1024 * 1024,
backupCount=5,
- encoding='utf-8'
+ encoding='utf-8',
+ )
+ _audit_file.setFormatter(
+ JSONFormatter()
+ if _fmt_is_json
+ else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
- _audit_file.setFormatter(JSONFormatter() if _fmt_is_json else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
try:
for eh in gateway_logger.handlers:
for f in getattr(eh, 'filters', []):
@@ -1259,10 +1394,13 @@ try:
pass
audit_logger.addHandler(_audit_file)
except Exception as _e:
-
console = logging.StreamHandler(stream=sys.stdout)
console.setLevel(logging.INFO)
- console.setFormatter(JSONFormatter() if _fmt_is_json else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
+ console.setFormatter(
+ JSONFormatter()
+ if _fmt_is_json
+ else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ )
try:
for eh in gateway_logger.handlers:
for f in getattr(eh, 'filters', []):
@@ -1271,12 +1409,18 @@ except Exception as _e:
pass
audit_logger.addHandler(console)
+
class Settings(BaseSettings):
mongo_db_uri: str = os.getenv('MONGO_DB_URI')
jwt_secret_key: str = os.getenv('JWT_SECRET_KEY')
jwt_algorithm: str = 'HS256'
- jwt_access_token_expires: timedelta = timedelta(minutes=int(os.getenv('ACCESS_TOKEN_EXPIRES_MINUTES', 15)))
- jwt_refresh_token_expires: timedelta = timedelta(days=int(os.getenv('REFRESH_TOKEN_EXPIRES_DAYS', 30)))
+ jwt_access_token_expires: timedelta = timedelta(
+ minutes=int(os.getenv('ACCESS_TOKEN_EXPIRES_MINUTES', 15))
+ )
+ jwt_refresh_token_expires: timedelta = timedelta(
+ days=int(os.getenv('REFRESH_TOKEN_EXPIRES_DAYS', 30))
+ )
+
@doorman.middleware('http')
async def ip_filter_middleware(request: Request, call_next):
@@ -1294,12 +1438,29 @@ async def ip_filter_middleware(request: Request, call_next):
try:
import os
+
settings = get_cached_settings()
env_flag = os.getenv('LOCAL_HOST_IP_BYPASS')
- allow_local = (env_flag.lower() == 'true') if isinstance(env_flag, str) and env_flag.strip() != '' else bool(settings.get('allow_localhost_bypass'))
+ allow_local = (
+ (env_flag.lower() == 'true')
+ if isinstance(env_flag, str) and env_flag.strip() != ''
+ else bool(settings.get('allow_localhost_bypass'))
+ )
if allow_local:
direct_ip = getattr(getattr(request, 'client', None), 'host', None)
- has_forward = any(request.headers.get(h) for h in ('x-forwarded-for','X-Forwarded-For','x-real-ip','X-Real-IP','cf-connecting-ip','CF-Connecting-IP','forwarded','Forwarded'))
+ has_forward = any(
+ request.headers.get(h)
+ for h in (
+ 'x-forwarded-for',
+ 'X-Forwarded-For',
+ 'x-real-ip',
+ 'X-Real-IP',
+ 'cf-connecting-ip',
+ 'CF-Connecting-IP',
+ 'forwarded',
+ 'Forwarded',
+ )
+ )
if direct_ip and _policy_is_loopback(direct_ip) and not has_forward:
return await call_next(request)
except Exception:
@@ -1308,37 +1469,77 @@ async def ip_filter_middleware(request: Request, call_next):
if client_ip:
if wl and not _policy_ip_in_list(client_ip, wl):
try:
- audit(request, actor=None, action='ip.global_deny', target=client_ip, status='blocked', details={'reason': 'not_in_whitelist', 'xff': xff_hdr, 'source_ip': getattr(getattr(request, 'client', None), 'host', None)})
+ audit(
+ request,
+ actor=None,
+ action='ip.global_deny',
+ target=client_ip,
+ status='blocked',
+ details={
+ 'reason': 'not_in_whitelist',
+ 'xff': xff_hdr,
+ 'source_ip': getattr(getattr(request, 'client', None), 'host', None),
+ },
+ )
except Exception:
pass
from fastapi.responses import JSONResponse
- return JSONResponse(status_code=403, content={'status_code': 403, 'error_code': 'SEC010', 'error_message': 'IP not allowed'})
+
+ return JSONResponse(
+ status_code=403,
+ content={
+ 'status_code': 403,
+ 'error_code': 'SEC010',
+ 'error_message': 'IP not allowed',
+ },
+ )
if bl and _policy_ip_in_list(client_ip, bl):
try:
- audit(request, actor=None, action='ip.global_deny', target=client_ip, status='blocked', details={'reason': 'blacklisted', 'xff': xff_hdr, 'source_ip': getattr(getattr(request, 'client', None), 'host', None)})
+ audit(
+ request,
+ actor=None,
+ action='ip.global_deny',
+ target=client_ip,
+ status='blocked',
+ details={
+ 'reason': 'blacklisted',
+ 'xff': xff_hdr,
+ 'source_ip': getattr(getattr(request, 'client', None), 'host', None),
+ },
+ )
except Exception:
pass
from fastapi.responses import JSONResponse
- return JSONResponse(status_code=403, content={'status_code': 403, 'error_code': 'SEC011', 'error_message': 'IP blocked'})
+
+ return JSONResponse(
+ status_code=403,
+ content={
+ 'status_code': 403,
+ 'error_code': 'SEC011',
+ 'error_message': 'IP blocked',
+ },
+ )
except Exception:
pass
return await call_next(request)
+
@doorman.middleware('http')
async def metrics_middleware(request: Request, call_next):
start = asyncio.get_event_loop().time()
+
def _parse_len(val: str | None) -> int:
try:
return int(val) if val is not None else 0
except Exception:
return 0
+
bytes_in = _parse_len(request.headers.get('content-length'))
response = None
try:
response = await call_next(request)
return response
finally:
-
try:
if str(request.url.path).startswith('/api/'):
end = asyncio.get_event_loop().time()
@@ -1349,6 +1550,7 @@ async def metrics_middleware(request: Request, call_next):
try:
from utils.auth_util import auth_required as _auth_required
+
payload = await _auth_required(request)
username = payload.get('sub') if isinstance(payload, dict) else None
except Exception:
@@ -1359,11 +1561,14 @@ async def metrics_middleware(request: Request, call_next):
parts = p.split('/')
try:
idx = parts.index('rest')
- api_key = f'rest:{parts[idx+1]}' if len(parts) > idx+1 and parts[idx+1] else 'rest:unknown'
+ api_key = (
+ f'rest:{parts[idx + 1]}'
+ if len(parts) > idx + 1 and parts[idx + 1]
+ else 'rest:unknown'
+ )
except ValueError:
api_key = 'rest:unknown'
elif p.startswith('/api/graphql/'):
-
seg = p.rsplit('/', 1)[-1] or 'unknown'
api_key = f'graphql:{seg}'
elif p.startswith('/api/soap/'):
@@ -1383,13 +1588,29 @@ async def metrics_middleware(request: Request, call_next):
except Exception:
clen = 0
- metrics_store.record(status=status, duration_ms=duration_ms, username=username, api_key=api_key, bytes_in=bytes_in, bytes_out=clen)
+ metrics_store.record(
+ status=status,
+ duration_ms=duration_ms,
+ username=username,
+ api_key=api_key,
+ bytes_in=bytes_in,
+ bytes_out=clen,
+ )
try:
if username:
- from utils.bandwidth_util import add_usage, _get_user
+ from utils.bandwidth_util import _get_user, add_usage
+
u = _get_user(username)
- if u and u.get('bandwidth_limit_bytes') and u.get('bandwidth_limit_enabled') is not False:
- add_usage(username, int(bytes_in) + int(clen), u.get('bandwidth_limit_window') or 'day')
+ if (
+ u
+ and u.get('bandwidth_limit_bytes')
+ and u.get('bandwidth_limit_enabled') is not False
+ ):
+ add_usage(
+ username,
+ int(bytes_in) + int(clen),
+ u.get('bandwidth_limit_window') or 'day',
+ )
except Exception:
pass
try:
@@ -1405,45 +1626,51 @@ async def metrics_middleware(request: Request, call_next):
except Exception:
pass
except Exception:
-
pass
+
async def automatic_purger(interval_seconds):
while True:
await asyncio.sleep(interval_seconds)
await purge_expired_tokens()
gateway_logger.info('Expired JWTs purged from blacklist.')
+
@doorman.exception_handler(JWTError)
async def jwt_exception_handler(exc: JWTError):
- return process_response(ResponseModel(
- status_code=401,
- error_code='JWT001',
- error_message='Invalid token'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(status_code=401, error_code='JWT001', error_message='Invalid token').dict(),
+ 'rest',
+ )
+
@doorman.exception_handler(500)
async def internal_server_error_handler(request: Request, exc: Exception):
- return process_response(ResponseModel(
- status_code=500,
- error_code='ISE001',
- error_message='Internal Server Error'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500, error_code='ISE001', error_message='Internal Server Error'
+ ).dict(),
+ 'rest',
+ )
+
@doorman.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
# DEBUG: Log validation errors
import logging
+
log = logging.getLogger('doorman.gateway')
- log.error(f"Validation error on {request.method} {request.url.path}")
- log.error(f"Validation errors: {exc.errors()}")
- log.error(f"Request body: {await request.body()}")
-
- return process_response(ResponseModel(
- status_code=422,
- error_code='VAL001',
- error_message='Validation Error'
- ).dict(), 'rest')
+ log.error(f'Validation error on {request.method} {request.url.path}')
+ log.error(f'Validation errors: {exc.errors()}')
+ log.error(f'Request body: {await request.body()}')
+
+ return process_response(
+ ResponseModel(
+ status_code=422, error_code='VAL001', error_message='Validation Error'
+ ).dict(),
+ 'rest',
+ )
+
cache_manager.init_app(doorman)
@@ -1473,29 +1700,35 @@ doorman.include_router(tier_router, prefix='/platform/tiers', tags=['Tiers'])
doorman.include_router(rate_limit_rule_router, prefix='/platform/rate-limits', tags=['Rate Limits'])
doorman.include_router(quota_router, prefix='/platform/quota', tags=['Quota'])
-def start():
+
+def start() -> None:
if os.path.exists(PID_FILE):
gateway_logger.info('doorman is already running!')
sys.exit(0)
if os.name == 'nt':
- process = subprocess.Popen([sys.executable, __file__, 'run'],
- creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL)
+ process = subprocess.Popen(
+ [sys.executable, __file__, 'run'],
+ creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ )
else:
- process = subprocess.Popen([sys.executable, __file__, 'run'],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- preexec_fn=os.setsid)
+ process = subprocess.Popen(
+ [sys.executable, __file__, 'run'],
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ preexec_fn=os.setsid,
+ )
with open(PID_FILE, 'w') as f:
f.write(str(process.pid))
gateway_logger.info(f'Starting doorman with PID {process.pid}.')
-def stop():
+
+def stop() -> None:
if not os.path.exists(PID_FILE):
gateway_logger.info('No running instance found')
return
- with open(PID_FILE, 'r') as f:
+ with open(PID_FILE) as f:
pid = int(f.read())
try:
if os.name == 'nt':
@@ -1506,7 +1739,6 @@ def stop():
deadline = time.time() + 15
while time.time() < deadline:
try:
-
os.kill(pid, 0)
time.sleep(0.5)
except ProcessLookupError:
@@ -1518,7 +1750,8 @@ def stop():
if os.path.exists(PID_FILE):
os.remove(PID_FILE)
-def restart():
+
+def restart() -> None:
"""Restart the doorman process using PID-based supervisor.
This function is intended to be invoked from a detached helper process.
"""
@@ -1533,7 +1766,8 @@ def restart():
except Exception as e:
gateway_logger.error(f'Error during start phase of restart: {e}')
-def run():
+
+def run() -> None:
server_port = int(os.getenv('PORT', 5001))
max_threads = multiprocessing.cpu_count()
env_threads = int(os.getenv('THREADS', max_threads))
@@ -1548,7 +1782,9 @@ def run():
'Set THREADS=1 for single-process memory mode or switch to MEM_OR_EXTERNAL=REDIS for multi-worker.'
)
gateway_logger.info(f'Started doorman with {num_threads} threads on port {server_port}')
- gateway_logger.info('TLS termination should be handled at reverse proxy (Nginx, Traefik, ALB, etc.)')
+ gateway_logger.info(
+ 'TLS termination should be handled at reverse proxy (Nginx, Traefik, ALB, etc.)'
+ )
uvicorn.run(
'doorman:doorman',
host='0.0.0.0',
@@ -1556,10 +1792,11 @@ def run():
reload=os.getenv('DEV_RELOAD', 'false').lower() == 'true',
reload_excludes=['venv/*', 'logs/*'],
workers=num_threads,
- log_level='info'
+ log_level='info',
)
-def main():
+
+def main() -> None:
host = os.getenv('HOST', '0.0.0.0')
port = int(os.getenv('PORT', '8000'))
try:
@@ -1567,39 +1804,55 @@ def main():
'doorman:doorman',
host=host,
port=port,
- reload=os.getenv('DEBUG', 'false').lower() == 'true'
+ reload=os.getenv('DEBUG', 'false').lower() == 'true',
)
except Exception as e:
gateway_logger.error(f'Failed to start server: {str(e)}')
raise
-def seed_command():
+
+def seed_command() -> None:
"""Run the demo seeder from command line"""
import argparse
+
from utils.demo_seed_util import run_seed
-
+
parser = argparse.ArgumentParser(description='Seed the database with demo data')
- parser.add_argument('--users', type=int, default=60, help='Number of users to create (default: 60)')
- parser.add_argument('--apis', type=int, default=20, help='Number of APIs to create (default: 20)')
- parser.add_argument('--endpoints', type=int, default=6, help='Number of endpoints per API (default: 6)')
- parser.add_argument('--groups', type=int, default=10, help='Number of groups to create (default: 10)')
- parser.add_argument('--protos', type=int, default=6, help='Number of proto files to create (default: 6)')
- parser.add_argument('--logs', type=int, default=2000, help='Number of log entries to create (default: 2000)')
- parser.add_argument('--seed', type=int, default=None, help='Random seed for reproducibility (optional)')
-
+ parser.add_argument(
+ '--users', type=int, default=60, help='Number of users to create (default: 60)'
+ )
+ parser.add_argument(
+ '--apis', type=int, default=20, help='Number of APIs to create (default: 20)'
+ )
+ parser.add_argument(
+ '--endpoints', type=int, default=6, help='Number of endpoints per API (default: 6)'
+ )
+ parser.add_argument(
+ '--groups', type=int, default=10, help='Number of groups to create (default: 10)'
+ )
+ parser.add_argument(
+ '--protos', type=int, default=6, help='Number of proto files to create (default: 6)'
+ )
+ parser.add_argument(
+ '--logs', type=int, default=2000, help='Number of log entries to create (default: 2000)'
+ )
+ parser.add_argument(
+ '--seed', type=int, default=None, help='Random seed for reproducibility (optional)'
+ )
+
args = parser.parse_args(sys.argv[2:]) # Skip 'doorman.py' and 'seed'
-
- print(f"Starting demo seed with:")
- print(f" Users: {args.users}")
- print(f" APIs: {args.apis}")
- print(f" Endpoints per API: {args.endpoints}")
- print(f" Groups: {args.groups}")
- print(f" Protos: {args.protos}")
- print(f" Logs: {args.logs}")
+
+ print('Starting demo seed with:')
+ print(f' Users: {args.users}')
+ print(f' APIs: {args.apis}')
+ print(f' Endpoints per API: {args.endpoints}')
+ print(f' Groups: {args.groups}')
+ print(f' Protos: {args.protos}')
+ print(f' Logs: {args.logs}')
if args.seed is not None:
- print(f" Random Seed: {args.seed}")
+ print(f' Random Seed: {args.seed}')
print()
-
+
try:
result = run_seed(
users=args.users,
@@ -1608,16 +1861,18 @@ def seed_command():
groups=args.groups,
protos=args.protos,
logs=args.logs,
- seed=args.seed
+ seed=args.seed,
)
- print("\n✓ Seeding completed successfully!")
- print(f"Result: {result}")
+ print('\n✓ Seeding completed successfully!')
+ print(f'Result: {result}')
except Exception as e:
- print(f"\n✗ Seeding failed: {str(e)}")
+ print(f'\n✗ Seeding failed: {str(e)}')
import traceback
+
traceback.print_exc()
sys.exit(1)
+
if __name__ == '__main__':
if len(sys.argv) > 1 and sys.argv[1] == 'stop':
stop()
diff --git a/backend-services/live-tests/client.py b/backend-services/live-tests/client.py
index 1853d6e..77f5459 100644
--- a/backend-services/live-tests/client.py
+++ b/backend-services/live-tests/client.py
@@ -1,7 +1,10 @@
from __future__ import annotations
-import requests
+
from urllib.parse import urljoin
+import requests
+
+
class LiveClient:
def __init__(self, base_url: str):
self.base_url = base_url.rstrip('/') + '/'
@@ -30,7 +33,9 @@ class LiveClient:
def post(self, path: str, json=None, data=None, files=None, headers=None, **kwargs):
url = urljoin(self.base_url, path.lstrip('/'))
hdrs = self._headers_with_csrf(headers)
- return self.sess.post(url, json=json, data=data, files=files, headers=hdrs, allow_redirects=False, **kwargs)
+ return self.sess.post(
+ url, json=json, data=data, files=files, headers=hdrs, allow_redirects=False, **kwargs
+ )
def put(self, path: str, json=None, headers=None, **kwargs):
url = urljoin(self.base_url, path.lstrip('/'))
diff --git a/backend-services/live-tests/config.py b/backend-services/live-tests/config.py
index b9849c1..bd7eaba 100644
--- a/backend-services/live-tests/config.py
+++ b/backend-services/live-tests/config.py
@@ -2,15 +2,24 @@ import os
BASE_URL = os.getenv('DOORMAN_BASE_URL', 'http://localhost:3001').rstrip('/')
ADMIN_EMAIL = os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev')
+
+# Resolve admin password from environment or the repo root .env.
+# Search order:
+# 1) Environment variable DOORMAN_ADMIN_PASSWORD
+# 2) Repo root .env (two levels up from live-tests)
+# 3) Default test password
ADMIN_PASSWORD = os.getenv('DOORMAN_ADMIN_PASSWORD')
if not ADMIN_PASSWORD:
- env_file = os.path.join(os.path.dirname(__file__), '..', '.env')
- if os.path.exists(env_file):
- with open(env_file) as f:
- for line in f:
- if line.startswith('DOORMAN_ADMIN_PASSWORD='):
- ADMIN_PASSWORD = line.split('=', 1)[1].strip()
- break
+ env_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '.env'))
+ try:
+ if os.path.exists(env_path):
+ with open(env_path, encoding='utf-8') as f:
+ for line in f:
+ if line.startswith('DOORMAN_ADMIN_PASSWORD='):
+ ADMIN_PASSWORD = line.split('=', 1)[1].strip()
+ break
+ except Exception:
+ pass
if not ADMIN_PASSWORD:
ADMIN_PASSWORD = 'test-only-password-12chars'
@@ -18,6 +27,7 @@ ENABLE_GRAPHQL = True
ENABLE_GRPC = True
STRICT_HEALTH = True
+
def require_env():
missing = []
if not BASE_URL:
@@ -25,4 +35,4 @@ def require_env():
if not ADMIN_EMAIL:
missing.append('DOORMAN_ADMIN_EMAIL')
if missing:
- raise RuntimeError(f"Missing required env vars: {', '.join(missing)}")
+ raise RuntimeError(f'Missing required env vars: {", ".join(missing)}')
diff --git a/backend-services/live-tests/conftest.py b/backend-services/live-tests/conftest.py
index 206f5b1..7445fd7 100644
--- a/backend-services/live-tests/conftest.py
+++ b/backend-services/live-tests/conftest.py
@@ -1,19 +1,21 @@
import os
import sys
import time
+
import pytest
-import requests
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
-from config import BASE_URL, ADMIN_EMAIL, ADMIN_PASSWORD, require_env, STRICT_HEALTH
from client import LiveClient
+from config import ADMIN_EMAIL, ADMIN_PASSWORD, BASE_URL, STRICT_HEALTH, require_env
+
@pytest.fixture(scope='session')
def base_url() -> str:
require_env()
return BASE_URL
+
@pytest.fixture(scope='session')
def client(base_url) -> LiveClient:
c = LiveClient(base_url)
@@ -32,7 +34,7 @@ def client(base_url) -> LiveClient:
ok = data.get('status') in ('online', 'healthy')
if ok:
break
- last_err = f"status json={j}"
+ last_err = f'status json={j}'
except Exception as e:
last_err = f'json parse error: {e}'
else:
@@ -48,6 +50,7 @@ def client(base_url) -> LiveClient:
assert 'access_token' in auth.get('response', auth), 'login did not return access_token'
return c
+
@pytest.fixture(autouse=True)
def ensure_session_and_relaxed_limits(client: LiveClient):
"""Per-test guard: ensure we're authenticated and not rate-limited.
@@ -59,23 +62,29 @@ def ensure_session_and_relaxed_limits(client: LiveClient):
r = client.get('/platform/authorization/status')
if r.status_code not in (200, 204):
from config import ADMIN_EMAIL, ADMIN_PASSWORD
+
client.login(ADMIN_EMAIL, ADMIN_PASSWORD)
except Exception:
from config import ADMIN_EMAIL, ADMIN_PASSWORD
+
client.login(ADMIN_EMAIL, ADMIN_PASSWORD)
try:
- client.put('/platform/user/admin', json={
- 'rate_limit_duration': 1000000,
- 'rate_limit_duration_type': 'second',
- 'throttle_duration': 1000000,
- 'throttle_duration_type': 'second',
- 'throttle_queue_limit': 1000000,
- 'throttle_wait_duration': 0,
- 'throttle_wait_duration_type': 'second'
- })
+ client.put(
+ '/platform/user/admin',
+ json={
+ 'rate_limit_duration': 1000000,
+ 'rate_limit_duration_type': 'second',
+ 'throttle_duration': 1000000,
+ 'throttle_duration_type': 'second',
+ 'throttle_queue_limit': 1000000,
+ 'throttle_wait_duration': 0,
+ 'throttle_wait_duration_type': 'second',
+ },
+ )
except Exception:
pass
+
def pytest_addoption(parser):
parser.addoption('--graph', action='store_true', default=False, help='Force GraphQL tests')
parser.addoption('--grpc', action='store_true', default=False, help='Force gRPC tests')
diff --git a/backend-services/live-tests/servers.py b/backend-services/live-tests/servers.py
index abbcc81..58fc538 100644
--- a/backend-services/live-tests/servers.py
+++ b/backend-services/live-tests/servers.py
@@ -1,9 +1,11 @@
from __future__ import annotations
-import threading
-import socket
+
import json
+import socket
+import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
+
def _find_free_port() -> int:
s = socket.socket()
s.bind(('0.0.0.0', 0))
@@ -11,6 +13,7 @@ def _find_free_port() -> int:
s.close()
return port
+
def _get_host_from_container() -> str:
"""Get the hostname to use when referring to the host machine from a Docker container.
@@ -38,6 +41,7 @@ def _get_host_from_container() -> str:
# This is the most common development setup
return '127.0.0.1'
+
class _ThreadedHTTPServer:
def __init__(self, handler_cls, host='0.0.0.0', port=None):
self.bind_host = host
@@ -61,6 +65,7 @@ class _ThreadedHTTPServer:
def url(self):
return f'http://{self.host}:{self.port}'
+
def start_rest_echo_server():
class Handler(BaseHTTPRequestHandler):
def _json(self, status=200, payload=None):
@@ -76,7 +81,7 @@ def start_rest_echo_server():
'method': 'GET',
'path': self.path,
'headers': {k: v for k, v in self.headers.items()},
- 'query': self.path.split('?', 1)[1] if '?' in self.path else ''
+ 'query': self.path.split('?', 1)[1] if '?' in self.path else '',
}
self._json(200, payload)
@@ -91,7 +96,7 @@ def start_rest_echo_server():
'method': 'POST',
'path': self.path,
'headers': {k: v for k, v in self.headers.items()},
- 'json': parsed
+ 'json': parsed,
}
self._json(200, payload)
@@ -106,7 +111,7 @@ def start_rest_echo_server():
'method': 'PUT',
'path': self.path,
'headers': {k: v for k, v in self.headers.items()},
- 'json': parsed
+ 'json': parsed,
}
self._json(200, payload)
@@ -120,6 +125,7 @@ def start_rest_echo_server():
return _ThreadedHTTPServer(Handler).start()
+
def start_soap_echo_server():
class Handler(BaseHTTPRequestHandler):
def _xml(self, status=200, content=''):
@@ -134,12 +140,11 @@ def start_soap_echo_server():
content_length = int(self.headers.get('Content-Length', '0') or '0')
_ = self.rfile.read(content_length) if content_length else b''
resp = (
- ""
- ""
- " ok"
- ""
+ ''
+ ''
+ ' ok'
+ ''
)
self._xml(200, resp)
return _ThreadedHTTPServer(Handler).start()
-
diff --git a/backend-services/live-tests/test_00_health_auth.py b/backend-services/live-tests/test_00_health_auth.py
index 6236983..5ad39bc 100644
--- a/backend-services/live-tests/test_00_health_auth.py
+++ b/backend-services/live-tests/test_00_health_auth.py
@@ -1,5 +1,6 @@
import pytest
+
def test_status_ok(client):
r = client.get('/api/health')
assert r.status_code == 200
@@ -12,6 +13,7 @@ def test_status_ok(client):
else:
assert 'error_code' in (j or {})
+
def test_auth_status_me(client):
r = client.get('/platform/authorization/status')
assert r.status_code in (200, 204)
@@ -21,5 +23,6 @@ def test_auth_status_me(client):
me = r.json().get('response', r.json())
assert me.get('username') == 'admin'
assert me.get('ui_access') is True
-import pytest
+
+
pytestmark = [pytest.mark.health, pytest.mark.auth]
diff --git a/backend-services/live-tests/test_10_user_onboarding.py b/backend-services/live-tests/test_10_user_onboarding.py
index 5ade1fa..3c6ec94 100644
--- a/backend-services/live-tests/test_10_user_onboarding.py
+++ b/backend-services/live-tests/test_10_user_onboarding.py
@@ -1,7 +1,7 @@
-import os
-import time
import random
import string
+import time
+
def _strong_password() -> str:
upp = random.choice(string.ascii_uppercase)
@@ -12,9 +12,10 @@ def _strong_password() -> str:
raw = upp + low + dig + spc + tail
return ''.join(random.sample(raw, len(raw)))
+
def test_user_onboarding_lifecycle(client):
- username = f"user_{int(time.time())}_{random.randint(1000,9999)}"
- email = f"{username}@example.com"
+ username = f'user_{int(time.time())}_{random.randint(1000, 9999)}'
+ email = f'{username}@example.com'
pwd = _strong_password()
payload = {
@@ -23,7 +24,7 @@ def test_user_onboarding_lifecycle(client):
'password': pwd,
'role': 'developer',
'groups': ['ALL'],
- 'ui_access': False
+ 'ui_access': False,
}
r = client.post('/platform/user', json=payload)
assert r.status_code in (200, 201), r.text
@@ -38,13 +39,14 @@ def test_user_onboarding_lifecycle(client):
assert r.status_code in (200, 204), r.text
new_pwd = _strong_password()
- r = client.put(f'/platform/user/{username}/update-password', json={
- 'old_password': pwd,
- 'new_password': new_pwd
- })
+ r = client.put(
+ f'/platform/user/{username}/update-password',
+ json={'old_password': pwd, 'new_password': new_pwd},
+ )
assert r.status_code in (200, 204, 400), r.text
from client import LiveClient
+
user_client = LiveClient(client.base_url)
auth = user_client.login(email, new_pwd if r.status_code in (200, 204) else pwd)
assert 'access_token' in auth.get('response', auth)
@@ -56,5 +58,8 @@ def test_user_onboarding_lifecycle(client):
r = client.delete(f'/platform/user/{username}')
assert r.status_code in (200, 204), r.text
+
+
import pytest
+
pytestmark = [pytest.mark.users, pytest.mark.auth]
diff --git a/backend-services/live-tests/test_20_credit_defs.py b/backend-services/live-tests/test_20_credit_defs.py
index af2e403..0184c0b 100644
--- a/backend-services/live-tests/test_20_credit_defs.py
+++ b/backend-services/live-tests/test_20_credit_defs.py
@@ -1,29 +1,31 @@
import time
+
def test_credit_def_create_and_assign(client):
- group = f"credits_{int(time.time())}"
+ group = f'credits_{int(time.time())}'
api_key = 'TEST_API_KEY_123456789'
payload = {
'api_credit_group': group,
'api_key': api_key,
'api_key_header': 'x-api-key',
'credit_tiers': [
- { 'tier_name': 'default', 'credits': 5, 'input_limit': 0, 'output_limit': 0, 'reset_frequency': 'monthly' }
- ]
+ {
+ 'tier_name': 'default',
+ 'credits': 5,
+ 'input_limit': 0,
+ 'output_limit': 0,
+ 'reset_frequency': 'monthly',
+ }
+ ],
}
r = client.post('/platform/credit', json=payload)
assert r.status_code in (200, 201), r.text
payload2 = {
'username': 'admin',
- 'users_credits': {
- group: {
- 'tier_name': 'default',
- 'available_credits': 5
- }
- }
+ 'users_credits': {group: {'tier_name': 'default', 'available_credits': 5}},
}
- r = client.post(f'/platform/credit/admin', json=payload2)
+ r = client.post('/platform/credit/admin', json=payload2)
assert r.status_code in (200, 201), r.text
r = client.get(f'/platform/credit/defs/{group}')
@@ -32,5 +34,8 @@ def test_credit_def_create_and_assign(client):
assert r.status_code == 200
data = r.json().get('response', r.json())
assert group in (data.get('users_credits') or {})
+
+
import pytest
+
pytestmark = [pytest.mark.credits]
diff --git a/backend-services/live-tests/test_21_subscription_flows.py b/backend-services/live-tests/test_21_subscription_flows.py
index 67dd2ca..ccde063 100644
--- a/backend-services/live-tests/test_21_subscription_flows.py
+++ b/backend-services/live-tests/test_21_subscription_flows.py
@@ -1,28 +1,39 @@
import time
+
import pytest
pytestmark = [pytest.mark.auth]
+
def test_subscribe_list_unsubscribe(client):
- api_name = f"subs-{int(time.time())}"
+ api_name = f'subs-{int(time.time())}'
api_version = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'subs',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://127.0.0.1:9'],
- 'api_type': 'REST',
- 'active': True
- })
- r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'subs',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ r = client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
assert r.status_code in (200, 201), r.text
r = client.get('/platform/subscription/subscriptions')
assert r.status_code == 200
payload = r.json().get('response', r.json())
apis = payload.get('apis') or []
- assert any(f"{api_name}/{api_version}" == a for a in apis)
- r = client.post('/platform/subscription/unsubscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ assert any(f'{api_name}/{api_version}' == a for a in apis)
+ r = client.post(
+ '/platform/subscription/unsubscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
assert r.status_code in (200, 201)
client.delete(f'/platform/api/{api_name}/{api_version}')
diff --git a/backend-services/live-tests/test_30_rest_gateway.py b/backend-services/live-tests/test_30_rest_gateway.py
index 239a3fb..38b9ded 100644
--- a/backend-services/live-tests/test_30_rest_gateway.py
+++ b/backend-services/live-tests/test_30_rest_gateway.py
@@ -1,6 +1,8 @@
import time
+
from servers import start_rest_echo_server
+
def test_rest_gateway_basic_flow(client):
srv = start_rest_echo_server()
try:
@@ -17,7 +19,7 @@ def test_rest_gateway_basic_flow(client):
'api_type': 'REST',
'api_allowed_retry_count': 0,
'active': True,
- 'api_cors_allow_origins': ['*']
+ 'api_cors_allow_origins': ['*'],
}
r = client.post('/platform/api', json=api_payload)
assert r.status_code in (200, 201), r.text
@@ -27,7 +29,7 @@ def test_rest_gateway_basic_flow(client):
'api_version': api_version,
'endpoint_method': 'GET',
'endpoint_uri': '/status',
- 'endpoint_description': 'status'
+ 'endpoint_description': 'status',
}
r = client.post('/platform/endpoint', json=ep_payload)
assert r.status_code in (200, 201), r.text
@@ -49,6 +51,7 @@ def test_rest_gateway_basic_flow(client):
finally:
srv.stop()
+
def test_rest_gateway_with_credits_and_header_injection(client):
srv = start_rest_echo_server()
try:
@@ -58,45 +61,68 @@ def test_rest_gateway_with_credits_and_header_injection(client):
credit_group = f'cg-{ts}'
api_key_val = 'DUMMY_API_KEY_ABC'
- r = client.post('/platform/credit', json={
- 'api_credit_group': credit_group,
- 'api_key': api_key_val,
- 'api_key_header': 'x-api-key',
- 'credit_tiers': [{ 'tier_name': 'default', 'credits': 2, 'input_limit': 0, 'output_limit': 0, 'reset_frequency': 'monthly' }]
- })
+ r = client.post(
+ '/platform/credit',
+ json={
+ 'api_credit_group': credit_group,
+ 'api_key': api_key_val,
+ 'api_key_header': 'x-api-key',
+ 'credit_tiers': [
+ {
+ 'tier_name': 'default',
+ 'credits': 2,
+ 'input_limit': 0,
+ 'output_limit': 0,
+ 'reset_frequency': 'monthly',
+ }
+ ],
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/credit/admin', json={
- 'username': 'admin',
- 'users_credits': { credit_group: { 'tier_name': 'default', 'available_credits': 2 } }
- })
+ r = client.post(
+ '/platform/credit/admin',
+ json={
+ 'username': 'admin',
+ 'users_credits': {credit_group: {'tier_name': 'default', 'available_credits': 2}},
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'REST with credits',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': [srv.url],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'active': True,
- 'api_credits_enabled': True,
- 'api_credit_group': credit_group
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'REST with credits',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [srv.url],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'active': True,
+ 'api_credits_enabled': True,
+ 'api_credit_group': credit_group,
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/echo',
- 'endpoint_description': 'echo with header'
- })
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/echo',
+ 'endpoint_description': 'echo with header',
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ r = client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
assert r.status_code in (200, 201), r.text
r = client.post(f'/api/rest/{api_name}/{api_version}/echo', json={'ping': 'pong'})
@@ -120,5 +146,8 @@ def test_rest_gateway_with_credits_and_header_injection(client):
except Exception:
pass
srv.stop()
+
+
import pytest
+
pytestmark = [pytest.mark.rest, pytest.mark.gateway]
diff --git a/backend-services/live-tests/test_31_endpoints_crud.py b/backend-services/live-tests/test_31_endpoints_crud.py
index 39ee0b7..88ffc82 100644
--- a/backend-services/live-tests/test_31_endpoints_crud.py
+++ b/backend-services/live-tests/test_31_endpoints_crud.py
@@ -1,19 +1,39 @@
import time
+
import pytest
pytestmark = [pytest.mark.rest]
+
def test_endpoints_update_list_delete(client):
- api_name = f"epcrud-{int(time.time())}"
+ api_name = f'epcrud-{int(time.time())}'
api_version = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name, 'api_version': api_version, 'api_description': 'ep',
- 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST', 'active': True
- })
- client.post('/platform/endpoint', json={
- 'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'GET', 'endpoint_uri': '/z', 'endpoint_description': 'z'
- })
- r = client.put(f'/platform/endpoint/GET/{api_name}/{api_version}/z', json={'endpoint_description': 'zzz'})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'ep',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/z',
+ 'endpoint_description': 'z',
+ },
+ )
+ r = client.put(
+ f'/platform/endpoint/GET/{api_name}/{api_version}/z', json={'endpoint_description': 'zzz'}
+ )
assert r.status_code in (200, 204)
r = client.get(f'/platform/endpoint/{api_name}/{api_version}')
assert r.status_code == 200
diff --git a/backend-services/live-tests/test_32_user_credit_override.py b/backend-services/live-tests/test_32_user_credit_override.py
index b7f5c70..dd4d8fc 100644
--- a/backend-services/live-tests/test_32_user_credit_override.py
+++ b/backend-services/live-tests/test_32_user_credit_override.py
@@ -1,6 +1,8 @@
import time
+
from servers import start_rest_echo_server
+
def test_user_specific_credit_api_key_overrides_group_key(client):
srv = start_rest_echo_server()
try:
@@ -11,42 +13,71 @@ def test_user_specific_credit_api_key_overrides_group_key(client):
group_key = 'GROUP_KEY_ABC'
user_key = 'USER_KEY_DEF'
- r = client.post('/platform/credit', json={
- 'api_credit_group': group,
- 'api_key': group_key,
- 'api_key_header': 'x-api-key',
- 'credit_tiers': [{ 'tier_name': 'default', 'credits': 3, 'input_limit': 0, 'output_limit': 0, 'reset_frequency': 'monthly' }]
- })
+ r = client.post(
+ '/platform/credit',
+ json={
+ 'api_credit_group': group,
+ 'api_key': group_key,
+ 'api_key_header': 'x-api-key',
+ 'credit_tiers': [
+ {
+ 'tier_name': 'default',
+ 'credits': 3,
+ 'input_limit': 0,
+ 'output_limit': 0,
+ 'reset_frequency': 'monthly',
+ }
+ ],
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/credit/admin', json={
- 'username': 'admin',
- 'users_credits': { group: { 'tier_name': 'default', 'available_credits': 3, 'user_api_key': user_key } }
- })
+ r = client.post(
+ '/platform/credit/admin',
+ json={
+ 'username': 'admin',
+ 'users_credits': {
+ group: {
+ 'tier_name': 'default',
+ 'available_credits': 3,
+ 'user_api_key': user_key,
+ }
+ },
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'credit user override',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': [srv.url],
- 'api_type': 'REST',
- 'active': True,
- 'api_credits_enabled': True,
- 'api_credit_group': group
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'credit user override',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [srv.url],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_credits_enabled': True,
+ 'api_credit_group': group,
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/whoami',
- 'endpoint_description': 'whoami'
- })
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/whoami',
+ 'endpoint_description': 'whoami',
+ },
+ )
assert r.status_code in (200, 201), r.text
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
r = client.get(f'/api/rest/{api_name}/{api_version}/whoami')
assert r.status_code == 200
diff --git a/backend-services/live-tests/test_33_rate_limit_and_throttle.py b/backend-services/live-tests/test_33_rate_limit_and_throttle.py
index 80b75d2..301e84a 100644
--- a/backend-services/live-tests/test_33_rate_limit_and_throttle.py
+++ b/backend-services/live-tests/test_33_rate_limit_and_throttle.py
@@ -1,39 +1,53 @@
import time
+
from servers import start_rest_echo_server
+
def test_rate_limiting_blocks_excess_requests(client):
srv = start_rest_echo_server()
try:
api_name = f'rl-{int(time.time())}'
api_version = 'v1'
- client.put('/platform/user/admin', json={
- 'rate_limit_duration': 1,
- 'rate_limit_duration_type': 'second',
- 'throttle_duration': 999,
- 'throttle_duration_type': 'second',
- 'throttle_queue_limit': 999,
- 'throttle_wait_duration': 0,
- 'throttle_wait_duration_type': 'second'
- })
+ client.put(
+ '/platform/user/admin',
+ json={
+ 'rate_limit_duration': 1,
+ 'rate_limit_duration_type': 'second',
+ 'throttle_duration': 999,
+ 'throttle_duration_type': 'second',
+ 'throttle_queue_limit': 999,
+ 'throttle_wait_duration': 0,
+ 'throttle_wait_duration_type': 'second',
+ },
+ )
- client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'rl test',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': [srv.url],
- 'api_type': 'REST',
- 'active': True
- })
- client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/hit',
- 'endpoint_description': 'hit'
- })
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'rl test',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [srv.url],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/hit',
+ 'endpoint_description': 'hit',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
time.sleep(1.1)
@@ -55,14 +69,17 @@ def test_rate_limiting_blocks_excess_requests(client):
pass
srv.stop()
try:
- client.put('/platform/user/admin', json={
- 'rate_limit_duration': 1000000,
- 'rate_limit_duration_type': 'second',
- 'throttle_duration': 1000000,
- 'throttle_duration_type': 'second',
- 'throttle_queue_limit': 1000000,
- 'throttle_wait_duration': 0,
- 'throttle_wait_duration_type': 'second'
- })
+ client.put(
+ '/platform/user/admin',
+ json={
+ 'rate_limit_duration': 1000000,
+ 'rate_limit_duration_type': 'second',
+ 'throttle_duration': 1000000,
+ 'throttle_duration_type': 'second',
+ 'throttle_queue_limit': 1000000,
+ 'throttle_wait_duration': 0,
+ 'throttle_wait_duration_type': 'second',
+ },
+ )
except Exception:
pass
diff --git a/backend-services/live-tests/test_35_public_bulk_onboarding.py b/backend-services/live-tests/test_35_public_bulk_onboarding.py
index 423207a..aadb886 100644
--- a/backend-services/live-tests/test_35_public_bulk_onboarding.py
+++ b/backend-services/live-tests/test_35_public_bulk_onboarding.py
@@ -1,13 +1,14 @@
-import time
-import threading
-import socket
-import requests
-import pytest
import os
import platform
+import socket
+import threading
+import time
+import pytest
+import requests
+from config import ENABLE_GRAPHQL
from servers import start_rest_echo_server, start_soap_echo_server
-from config import ENABLE_GRAPHQL, ENABLE_GRPC
+
def _find_port():
s = socket.socket()
@@ -16,6 +17,7 @@ def _find_port():
s.close()
return p
+
def _get_host_from_container():
"""Get the hostname to use when referring to the host machine from a Docker container."""
docker_env = os.getenv('DOORMAN_IN_DOCKER', '').lower()
@@ -27,6 +29,7 @@ def _get_host_from_container():
return '172.17.0.1'
return '127.0.0.1'
+
def test_bulk_public_rest_crud(client):
srv = start_rest_echo_server()
try:
@@ -35,29 +38,40 @@ def test_bulk_public_rest_crud(client):
for i in range(3):
api_name = f'pub-rest-{ts}-{i}'
api_version = 'v1'
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'public rest',
- 'api_allowed_roles': [],
- 'api_allowed_groups': [],
- 'api_servers': [srv.url],
- 'api_type': 'REST',
- 'active': True,
- 'api_public': True
- })
- assert r.status_code in (200, 201), r.text
- for m, uri in [('GET', '/items'), ('POST', '/items'), ('PUT', '/items'), ('DELETE', '/items')]:
- r = client.post('/platform/endpoint', json={
+ r = client.post(
+ '/platform/api',
+ json={
'api_name': api_name,
'api_version': api_version,
- 'endpoint_method': m,
- 'endpoint_uri': uri,
- 'endpoint_description': f'{m} {uri}'
- })
+ 'api_description': 'public rest',
+ 'api_allowed_roles': [],
+ 'api_allowed_groups': [],
+ 'api_servers': [srv.url],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_public': True,
+ },
+ )
+ assert r.status_code in (200, 201), r.text
+ for m, uri in [
+ ('GET', '/items'),
+ ('POST', '/items'),
+ ('PUT', '/items'),
+ ('DELETE', '/items'),
+ ]:
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': m,
+ 'endpoint_uri': uri,
+ 'endpoint_description': f'{m} {uri}',
+ },
+ )
assert r.status_code in (200, 201), r.text
s = requests.Session()
- url = f"{base}/api/rest/{api_name}/{api_version}/items"
+ url = f'{base}/api/rest/{api_name}/{api_version}/items'
assert s.get(url).status_code == 200
assert s.post(url, json={'name': 'x'}).status_code == 200
assert s.put(url, json={'name': 'y'}).status_code == 200
@@ -65,69 +79,78 @@ def test_bulk_public_rest_crud(client):
finally:
srv.stop()
+
def test_bulk_public_soap_crud(client):
srv = start_soap_echo_server()
try:
base = client.base_url.rstrip('/')
ts = int(time.time())
envelope = (
- ""
- ""
- " "
- ""
+ ''
+ ''
+ ' '
+ ''
)
for i in range(3):
api_name = f'pub-soap-{ts}-{i}'
api_version = 'v1'
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'public soap',
- 'api_allowed_roles': [],
- 'api_allowed_groups': [],
- 'api_servers': [srv.url],
- 'api_type': 'REST',
- 'active': True,
- 'api_public': True
- })
- assert r.status_code in (200, 201), r.text
- for uri in ['/create', '/read', '/update', '/delete']:
- r = client.post('/platform/endpoint', json={
+ r = client.post(
+ '/platform/api',
+ json={
'api_name': api_name,
'api_version': api_version,
- 'endpoint_method': 'POST',
- 'endpoint_uri': uri,
- 'endpoint_description': f'SOAP {uri}'
- })
+ 'api_description': 'public soap',
+ 'api_allowed_roles': [],
+ 'api_allowed_groups': [],
+ 'api_servers': [srv.url],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_public': True,
+ },
+ )
+ assert r.status_code in (200, 201), r.text
+ for uri in ['/create', '/read', '/update', '/delete']:
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': uri,
+ 'endpoint_description': f'SOAP {uri}',
+ },
+ )
assert r.status_code in (200, 201), r.text
s = requests.Session()
headers = {'Content-Type': 'text/xml'}
for uri in ['create', 'read', 'update', 'delete']:
- url = f"{base}/api/soap/{api_name}/{api_version}/{uri}"
+ url = f'{base}/api/soap/{api_name}/{api_version}/{uri}'
resp = s.post(url, data=envelope, headers=headers)
assert resp.status_code == 200
finally:
srv.stop()
+
@pytest.mark.skipif(not ENABLE_GRAPHQL, reason='GraphQL disabled')
def test_bulk_public_graphql_crud(client):
try:
import uvicorn
- from ariadne import gql as _gql, make_executable_schema, MutationType, QueryType
+ from ariadne import MutationType, QueryType, make_executable_schema
+ from ariadne import gql as _gql
from ariadne.asgi import GraphQL
except Exception as e:
pytest.skip(f'Missing GraphQL deps: {e}')
def start_gql_server():
data_store = {'items': {}, 'seq': 0}
- type_defs = _gql('''
+ type_defs = _gql("""
type Query { read(id: Int!): String! }
type Mutation {
create(name: String!): String!
update(id: Int!, name: String!): String!
delete(id: Int!): Boolean!
}
- ''')
+ """)
query = QueryType()
mutation = MutationType()
@@ -168,30 +191,53 @@ def test_bulk_public_graphql_crud(client):
api_name = f'pub-gql-{ts}-{i}'
api_version = 'v1'
try:
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'public gql',
- 'api_allowed_roles': [],
- 'api_allowed_groups': [],
- 'api_servers': [f'http://{host}:{port}'],
- 'api_type': 'REST',
- 'active': True,
- 'api_public': True
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'public gql',
+ 'api_allowed_roles': [],
+ 'api_allowed_groups': [],
+ 'api_servers': [f'http://{host}:{port}'],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_public': True,
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/endpoint', json={'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'POST', 'endpoint_uri': '/graphql', 'endpoint_description': 'graphql'})
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/graphql',
+ 'endpoint_description': 'graphql',
+ },
+ )
assert r.status_code in (200, 201), r.text
s = requests.Session()
- url = f"{base}/api/graphql/{api_name}"
+ url = f'{base}/api/graphql/{api_name}'
q_create = {'query': 'mutation { create(name:"A") }'}
- assert s.post(url, json=q_create, headers={'X-API-Version': api_version}).status_code == 200
+ assert (
+ s.post(url, json=q_create, headers={'X-API-Version': api_version}).status_code
+ == 200
+ )
q_update = {'query': 'mutation { update(id:1, name:"B") }'}
- assert s.post(url, json=q_update, headers={'X-API-Version': api_version}).status_code == 200
+ assert (
+ s.post(url, json=q_update, headers={'X-API-Version': api_version}).status_code
+ == 200
+ )
q_read = {'query': '{ read(id:1) }'}
- assert s.post(url, json=q_read, headers={'X-API-Version': api_version}).status_code == 200
+ assert (
+ s.post(url, json=q_read, headers={'X-API-Version': api_version}).status_code == 200
+ )
q_delete = {'query': 'mutation { delete(id:1) }'}
- assert s.post(url, json=q_delete, headers={'X-API-Version': api_version}).status_code == 200
+ assert (
+ s.post(url, json=q_delete, headers={'X-API-Version': api_version}).status_code
+ == 200
+ )
finally:
try:
client.delete(f'/platform/endpoint/POST/{api_name}/{api_version}/graphql')
@@ -206,19 +252,29 @@ def test_bulk_public_graphql_crud(client):
except Exception:
pass
+
import os as _os
-_RUN_LIVE = _os.getenv('DOORMAN_RUN_LIVE', '0') in ('1','true','True')
-@pytest.mark.skipif(not _RUN_LIVE, reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
+
+_RUN_LIVE = _os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
+
+
+@pytest.mark.skipif(
+ not _RUN_LIVE, reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable'
+)
def test_bulk_public_grpc_crud(client):
try:
+ import importlib
+ import pathlib
+ import sys
+ import tempfile
+ from concurrent import futures
+
import grpc
from grpc_tools import protoc
- from concurrent import futures
- import tempfile, pathlib, importlib, sys
except Exception as e:
pytest.skip(f'Missing gRPC deps: {e}')
- PROTO = '''
+ PROTO = """
syntax = "proto3";
package {pkg};
service Resource {
@@ -235,9 +291,9 @@ message UpdateRequest { int32 id = 1; string name = 2; }
message UpdateReply { string message = 1; }
message DeleteRequest { int32 id = 1; }
message DeleteReply { bool ok = 1; }
-'''
+"""
- base = client.base_url.rstrip('/')
+ client.base_url.rstrip('/')
ts = int(time.time())
for i in range(0):
api_name = f'pub-grpc-{ts}-{i}'
@@ -248,7 +304,15 @@ message DeleteReply { bool ok = 1; }
(tmp / 'svc.proto').write_text(PROTO.replace('{pkg}', pkg))
out = tmp / 'gen'
out.mkdir()
- code = protoc.main(['protoc', f'--proto_path={td}', f'--python_out={out}', f'--grpc_python_out={out}', str(tmp / 'svc.proto')])
+ code = protoc.main(
+ [
+ 'protoc',
+ f'--proto_path={td}',
+ f'--python_out={out}',
+ f'--grpc_python_out={out}',
+ str(tmp / 'svc.proto'),
+ ]
+ )
assert code == 0
(out / '__init__.py').write_text('')
sys.path.insert(0, str(out))
@@ -258,35 +322,51 @@ message DeleteReply { bool ok = 1; }
class Resource(pb2_grpc.ResourceServicer):
def Create(self, request, context):
return pb2.CreateReply(message=f'created {request.name}')
+
def Read(self, request, context):
return pb2.ReadReply(message=f'read {request.id}')
+
def Update(self, request, context):
return pb2.UpdateReply(message=f'updated {request.id}:{request.name}')
+
def Delete(self, request, context):
return pb2.DeleteReply(ok=True)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
pb2_grpc.add_ResourceServicer_to_server(Resource(), server)
- s = socket.socket(); s.bind(('127.0.0.1', 0)); port = s.getsockname()[1]; s.close()
+ s = socket.socket()
+ s.bind(('127.0.0.1', 0))
+ port = s.getsockname()[1]
+ s.close()
server.add_insecure_port(f'127.0.0.1:{port}')
server.start()
try:
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'public grpc',
- 'api_allowed_roles': [],
- 'api_allowed_groups': [],
- 'api_servers': [f'grpc://127.0.0.1:{port}'],
- 'api_type': 'REST',
- 'active': True,
- 'api_public': True
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'public grpc',
+ 'api_allowed_roles': [],
+ 'api_allowed_groups': [],
+ 'api_servers': [f'grpc://127.0.0.1:{port}'],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_public': True,
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/endpoint', json={'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'POST', 'endpoint_uri': '/grpc', 'endpoint_description': 'grpc'})
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
assert r.status_code in (200, 201), r.text
- url = f"{base}/api/grpc/{api_name}"
- hdr = {'X-API-Version': api_version}
pass
finally:
try:
@@ -299,4 +379,5 @@ message DeleteReply { bool ok = 1; }
pass
server.stop(0)
+
pytestmark = [pytest.mark.gateway]
diff --git a/backend-services/live-tests/test_36_public_grpc_proto.py b/backend-services/live-tests/test_36_public_grpc_proto.py
index 0e8169a..be26988 100644
--- a/backend-services/live-tests/test_36_public_grpc_proto.py
+++ b/backend-services/live-tests/test_36_public_grpc_proto.py
@@ -1,10 +1,11 @@
-import time
import socket
-import requests
-import pytest
+import time
+import pytest
+import requests
from config import ENABLE_GRPC
+
def _find_port() -> int:
s = socket.socket()
s.bind(('0.0.0.0', 0))
@@ -12,6 +13,7 @@ def _find_port() -> int:
s.close()
return p
+
def _get_host_from_container() -> str:
"""Get the hostname to use when referring to the host machine from a Docker container.
@@ -39,17 +41,22 @@ def _get_host_from_container() -> str:
# This is the most common development setup
return '127.0.0.1'
+
@pytest.mark.skipif(not ENABLE_GRPC, reason='gRPC disabled')
def test_public_grpc_with_proto_upload(client):
try:
+ import importlib
+ import pathlib
+ import sys
+ import tempfile
+ from concurrent import futures
+
import grpc
from grpc_tools import protoc
- from concurrent import futures
- import tempfile, pathlib, importlib, sys
except Exception as e:
pytest.skip(f'Missing gRPC deps: {e}')
- PROTO = '''
+ PROTO = """
syntax = "proto3";
package {pkg};
service Resource {
@@ -66,7 +73,7 @@ message UpdateRequest { int32 id = 1; string name = 2; }
message UpdateReply { string message = 1; }
message DeleteRequest { int32 id = 1; }
message DeleteReply { bool ok = 1; }
-'''
+"""
base = client.base_url.rstrip('/')
ts = int(time.time())
@@ -79,7 +86,15 @@ message DeleteReply { bool ok = 1; }
(tmp / 'svc.proto').write_text(PROTO.replace('{pkg}', pkg))
out = tmp / 'gen'
out.mkdir()
- code = protoc.main(['protoc', f'--proto_path={td}', f'--python_out={out}', f'--grpc_python_out={out}', str(tmp / 'svc.proto')])
+ code = protoc.main(
+ [
+ 'protoc',
+ f'--proto_path={td}',
+ f'--python_out={out}',
+ f'--grpc_python_out={out}',
+ str(tmp / 'svc.proto'),
+ ]
+ )
assert code == 0
(out / '__init__.py').write_text('')
sys.path.insert(0, str(out))
@@ -112,35 +127,63 @@ message DeleteReply { bool ok = 1; }
r_up = client.post(f'/platform/proto/{api_name}/{api_version}', files=files)
assert r_up.status_code in (200, 201), r_up.text
- r_api = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'public grpc with uploaded proto',
- 'api_allowed_roles': [],
- 'api_allowed_groups': [],
- 'api_servers': [f'grpc://{host_ref}:{port}'],
- 'api_type': 'REST',
- 'active': True,
- 'api_public': True,
- 'api_grpc_package': pkg
- })
+ r_api = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'public grpc with uploaded proto',
+ 'api_allowed_roles': [],
+ 'api_allowed_groups': [],
+ 'api_servers': [f'grpc://{host_ref}:{port}'],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_public': True,
+ 'api_grpc_package': pkg,
+ },
+ )
assert r_api.status_code in (200, 201), r_api.text
- r_ep = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc'
- })
+ r_ep = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
assert r_ep.status_code in (200, 201), r_ep.text
- url = f"{base}/api/grpc/{api_name}"
+ url = f'{base}/api/grpc/{api_name}'
hdr = {'X-API-Version': api_version}
- assert requests.post(url, json={'method': 'Resource.Create', 'message': {'name': 'A'}}, headers=hdr).status_code == 200
- assert requests.post(url, json={'method': 'Resource.Read', 'message': {'id': 1}}, headers=hdr).status_code == 200
- assert requests.post(url, json={'method': 'Resource.Update', 'message': {'id': 1, 'name': 'B'}}, headers=hdr).status_code == 200
- assert requests.post(url, json={'method': 'Resource.Delete', 'message': {'id': 1}}, headers=hdr).status_code == 200
+ assert (
+ requests.post(
+ url, json={'method': 'Resource.Create', 'message': {'name': 'A'}}, headers=hdr
+ ).status_code
+ == 200
+ )
+ assert (
+ requests.post(
+ url, json={'method': 'Resource.Read', 'message': {'id': 1}}, headers=hdr
+ ).status_code
+ == 200
+ )
+ assert (
+ requests.post(
+ url,
+ json={'method': 'Resource.Update', 'message': {'id': 1, 'name': 'B'}},
+ headers=hdr,
+ ).status_code
+ == 200
+ )
+ assert (
+ requests.post(
+ url, json={'method': 'Resource.Delete', 'message': {'id': 1}}, headers=hdr
+ ).status_code
+ == 200
+ )
finally:
try:
client.delete(f'/platform/endpoint/POST/{api_name}/{api_version}/grpc')
@@ -151,4 +194,3 @@ message DeleteReply { bool ok = 1; }
except Exception:
pass
server.stop(0)
-
diff --git a/backend-services/live-tests/test_40_soap_gateway.py b/backend-services/live-tests/test_40_soap_gateway.py
index ce71308..be04e45 100644
--- a/backend-services/live-tests/test_40_soap_gateway.py
+++ b/backend-services/live-tests/test_40_soap_gateway.py
@@ -1,35 +1,46 @@
import time
+
from servers import start_soap_echo_server
+
def test_soap_gateway_basic_flow(client):
srv = start_soap_echo_server()
try:
api_name = f'soap-demo-{int(time.time())}'
api_version = 'v1'
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'SOAP demo',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': [srv.url],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'active': True
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'SOAP demo',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [srv.url],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'active': True,
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/soap',
- 'endpoint_description': 'soap'
- })
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/soap',
+ 'endpoint_description': 'soap',
+ },
+ )
assert r.status_code in (200, 201)
- r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ r = client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
assert r.status_code in (200, 201)
body = """
@@ -40,7 +51,11 @@ def test_soap_gateway_basic_flow(client):
""".strip()
- r = client.post(f'/api/soap/{api_name}/{api_version}/soap', data=body, headers={'Content-Type': 'text/xml'})
+ r = client.post(
+ f'/api/soap/{api_name}/{api_version}/soap',
+ data=body,
+ headers={'Content-Type': 'text/xml'},
+ )
assert r.status_code == 200, r.text
finally:
try:
@@ -52,5 +67,8 @@ def test_soap_gateway_basic_flow(client):
except Exception:
pass
srv.stop()
+
+
import pytest
+
pytestmark = [pytest.mark.soap, pytest.mark.gateway]
diff --git a/backend-services/live-tests/test_41_soap_validation.py b/backend-services/live-tests/test_41_soap_validation.py
index c0d910c..73b36a0 100644
--- a/backend-services/live-tests/test_41_soap_validation.py
+++ b/backend-services/live-tests/test_41_soap_validation.py
@@ -1,40 +1,54 @@
import time
+
from servers import start_soap_echo_server
+
def test_soap_validation_blocks_missing_field(client):
srv = start_soap_echo_server()
try:
api_name = f'soapval-{int(time.time())}'
api_version = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'soap val',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': [srv.url],
- 'api_type': 'REST',
- 'active': True
- })
- client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/soap',
- 'endpoint_description': 'soap'
- })
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'soap val',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [srv.url],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/soap',
+ 'endpoint_description': 'soap',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
r = client.get(f'/platform/endpoint/POST/{api_name}/{api_version}/soap')
- ep = r.json().get('response', r.json()); endpoint_id = ep.get('endpoint_id'); assert endpoint_id
- schema = {
- 'validation_schema': {
- 'message': { 'required': True, 'type': 'string', 'min': 2 }
- }
- }
- r = client.post('/platform/endpoint/endpoint/validation', json={
- 'endpoint_id': endpoint_id, 'validation_enabled': True, 'validation_schema': schema
- })
+ ep = r.json().get('response', r.json())
+ endpoint_id = ep.get('endpoint_id')
+ assert endpoint_id
+ schema = {'validation_schema': {'message': {'required': True, 'type': 'string', 'min': 2}}}
+ r = client.post(
+ '/platform/endpoint/endpoint/validation',
+ json={
+ 'endpoint_id': endpoint_id,
+ 'validation_enabled': True,
+ 'validation_schema': schema,
+ },
+ )
assert r.status_code in (200, 201)
xml = """
@@ -45,11 +59,19 @@ def test_soap_validation_blocks_missing_field(client):
""".strip()
- r = client.post(f'/api/soap/{api_name}/{api_version}/soap', data=xml, headers={'Content-Type': 'text/xml'})
+ r = client.post(
+ f'/api/soap/{api_name}/{api_version}/soap',
+ data=xml,
+ headers={'Content-Type': 'text/xml'},
+ )
assert r.status_code == 400
xml2 = xml.replace('>A<', '>AB<')
- r = client.post(f'/api/soap/{api_name}/{api_version}/soap', data=xml2, headers={'Content-Type': 'text/xml'})
+ r = client.post(
+ f'/api/soap/{api_name}/{api_version}/soap',
+ data=xml2,
+ headers={'Content-Type': 'text/xml'},
+ )
assert r.status_code == 200
finally:
try:
diff --git a/backend-services/live-tests/test_45_soap_preflight.py b/backend-services/live-tests/test_45_soap_preflight.py
index 7f95d03..0886c81 100644
--- a/backend-services/live-tests/test_45_soap_preflight.py
+++ b/backend-services/live-tests/test_45_soap_preflight.py
@@ -1,22 +1,50 @@
import time
+
import pytest
pytestmark = [pytest.mark.soap]
+
def test_soap_cors_preflight(client):
api_name = f'soap-pre-{int(time.time())}'
api_version = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name, 'api_version': api_version, 'api_description': 'soap pre',
- 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST', 'active': True,
- 'api_cors_allow_origins': ['http://example.com'], 'api_cors_allow_methods': ['POST'], 'api_cors_allow_headers': ['Content-Type']
- })
- client.post('/platform/endpoint', json={
- 'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'POST', 'endpoint_uri': '/soap', 'endpoint_description': 's'
- })
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'soap pre',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_cors_allow_origins': ['http://example.com'],
+ 'api_cors_allow_methods': ['POST'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/soap',
+ 'endpoint_description': 's',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
- r = client.options(f'/api/soap/{api_name}/{api_version}/soap', headers={
- 'Origin': 'http://example.com', 'Access-Control-Request-Method': 'POST', 'Access-Control-Request-Headers': 'Content-Type'
- })
+ r = client.options(
+ f'/api/soap/{api_name}/{api_version}/soap',
+ headers={
+ 'Origin': 'http://example.com',
+ 'Access-Control-Request-Method': 'POST',
+ 'Access-Control-Request-Headers': 'Content-Type',
+ },
+ )
assert r.status_code in (200, 204)
diff --git a/backend-services/live-tests/test_46_gateway_errors.py b/backend-services/live-tests/test_46_gateway_errors.py
index a5b46a5..44b62fa 100644
--- a/backend-services/live-tests/test_46_gateway_errors.py
+++ b/backend-services/live-tests/test_46_gateway_errors.py
@@ -1,16 +1,30 @@
import time
+
import pytest
pytestmark = [pytest.mark.rest]
+
def test_nonexistent_endpoint_returns_gw_error(client):
api_name = f'gwerr-{int(time.time())}'
api_version = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name, 'api_version': api_version, 'api_description': 'gw',
- 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST', 'active': True
- })
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'gw',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
r = client.get(f'/api/rest/{api_name}/{api_version}/nope')
assert r.status_code in (404, 400, 500)
data = r.json()
diff --git a/backend-services/live-tests/test_50_graphql_gateway.py b/backend-services/live-tests/test_50_graphql_gateway.py
index 839062e..893733d 100644
--- a/backend-services/live-tests/test_50_graphql_gateway.py
+++ b/backend-services/live-tests/test_50_graphql_gateway.py
@@ -1,35 +1,38 @@
import os
import time
-import pytest
+import pytest
from config import ENABLE_GRAPHQL
-pytestmark = pytest.mark.skipif(not ENABLE_GRAPHQL, reason='GraphQL test disabled (set DOORMAN_TEST_GRAPHQL=1 to enable)')
+pytestmark = pytest.mark.skipif(
+ not ENABLE_GRAPHQL, reason='GraphQL test disabled (set DOORMAN_TEST_GRAPHQL=1 to enable)'
+)
+
def test_graphql_gateway_basic_flow(client):
try:
import uvicorn
- from ariadne import gql, make_executable_schema, QueryType
+ from ariadne import QueryType, gql, make_executable_schema
from ariadne.asgi import GraphQL
except Exception as e:
pytest.skip(f'Missing GraphQL deps: {e}')
- type_defs = gql('''
+ type_defs = gql("""
type Query {
hello(name: String): String!
}
- ''')
+ """)
query = QueryType()
@query.field('hello')
def resolve_hello(*_, name=None):
- return f"Hello, {name or 'world'}!"
+ return f'Hello, {name or "world"}!'
schema = make_executable_schema(type_defs, query)
app = GraphQL(schema, debug=True)
- import threading
import socket
+ import threading
def _find_port():
s = socket.socket()
@@ -41,6 +44,7 @@ def test_graphql_gateway_basic_flow(client):
def _get_host_from_container():
"""Get the hostname to use when referring to the host machine from a Docker container."""
import platform
+
docker_env = os.getenv('DOORMAN_IN_DOCKER', '').lower()
if docker_env in ('1', 'true', 'yes'):
system = platform.system()
@@ -62,29 +66,38 @@ def test_graphql_gateway_basic_flow(client):
api_name = f'gql-demo-{int(time.time())}'
api_version = 'v1'
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'GraphQL demo',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': [f'http://{host}:{port}'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'active': True
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'GraphQL demo',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [f'http://{host}:{port}'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'active': True,
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/graphql',
- 'endpoint_description': 'graphql'
- })
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/graphql',
+ 'endpoint_description': 'graphql',
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ r = client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
assert r.status_code in (200, 201), r.text
q = {'query': '{ hello(name:"Doorman") }'}
@@ -97,5 +110,8 @@ def test_graphql_gateway_basic_flow(client):
client.delete(f'/platform/endpoint/POST/{api_name}/{api_version}/graphql')
client.delete(f'/platform/api/{api_name}/{api_version}')
+
+
import pytest
+
pytestmark = [pytest.mark.graphql, pytest.mark.gateway]
diff --git a/backend-services/live-tests/test_52_graphql_validation.py b/backend-services/live-tests/test_52_graphql_validation.py
index 0d20f0e..272d48e 100644
--- a/backend-services/live-tests/test_52_graphql_validation.py
+++ b/backend-services/live-tests/test_52_graphql_validation.py
@@ -1,33 +1,47 @@
import time
-import pytest
+import pytest
from config import ENABLE_GRAPHQL
-pytestmark = pytest.mark.skipif(not ENABLE_GRAPHQL, reason='GraphQL validation test disabled (set DOORMAN_TEST_GRAPHQL=1)')
+pytestmark = pytest.mark.skipif(
+ not ENABLE_GRAPHQL, reason='GraphQL validation test disabled (set DOORMAN_TEST_GRAPHQL=1)'
+)
+
def test_graphql_validation_blocks_invalid_variables(client):
try:
import uvicorn
- from ariadne import gql, make_executable_schema, QueryType
+ from ariadne import QueryType, gql, make_executable_schema
from ariadne.asgi import GraphQL
except Exception as e:
pytest.skip(f'Missing GraphQL deps: {e}')
- type_defs = gql('''
+ type_defs = gql("""
type Query { hello(name: String!): String! }
- ''')
+ """)
query = QueryType()
@query.field('hello')
def resolve_hello(*_, name):
- return f"Hello, {name}!"
+ return f'Hello, {name}!'
schema = make_executable_schema(type_defs, query)
- import threading, socket, uvicorn, platform
+ import platform
+ import socket
+ import threading
+
+ import uvicorn
+
def _free_port():
- s = socket.socket(); s.bind(('0.0.0.0', 0)); p = s.getsockname()[1]; s.close(); return p
+ s = socket.socket()
+ s.bind(('0.0.0.0', 0))
+ p = s.getsockname()[1]
+ s.close()
+ return p
+
def _get_host_from_container():
import os
+
docker_env = os.getenv('DOORMAN_IN_DOCKER', '').lower()
if docker_env in ('1', 'true', 'yes'):
system = platform.system()
@@ -36,36 +50,56 @@ def test_graphql_validation_blocks_invalid_variables(client):
else:
return '172.17.0.1'
return '127.0.0.1'
+
port = _free_port()
host = _get_host_from_container()
- server = uvicorn.Server(uvicorn.Config(GraphQL(schema), host='0.0.0.0', port=port, log_level='warning'))
- t = threading.Thread(target=server.run, daemon=True); t.start()
+ server = uvicorn.Server(
+ uvicorn.Config(GraphQL(schema), host='0.0.0.0', port=port, log_level='warning')
+ )
+ t = threading.Thread(target=server.run, daemon=True)
+ t.start()
time.sleep(0.4)
api_name = f'gqlval-{int(time.time())}'
api_version = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name, 'api_version': api_version,
- 'api_description': 'gql val', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'],
- 'api_servers': [f'http://{host}:{port}'], 'api_type': 'REST', 'active': True
- })
- client.post('/platform/endpoint', json={
- 'api_name': api_name, 'api_version': api_version,
- 'endpoint_method': 'POST','endpoint_uri': '/graphql','endpoint_description': 'gql'
- })
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'gql val',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [f'http://{host}:{port}'],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/graphql',
+ 'endpoint_description': 'gql',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
r = client.get(f'/platform/endpoint/POST/{api_name}/{api_version}/graphql')
- ep = r.json().get('response', r.json()); endpoint_id = ep.get('endpoint_id'); assert endpoint_id
+ ep = r.json().get('response', r.json())
+ endpoint_id = ep.get('endpoint_id')
+ assert endpoint_id
- schema = {
- 'validation_schema': {
- 'HelloOp.x': { 'required': True, 'type': 'string', 'min': 2 }
- }
- }
- r = client.post('/platform/endpoint/endpoint/validation', json={
- 'endpoint_id': endpoint_id, 'validation_enabled': True, 'validation_schema': schema
- })
+ schema = {'validation_schema': {'HelloOp.x': {'required': True, 'type': 'string', 'min': 2}}}
+ r = client.post(
+ '/platform/endpoint/endpoint/validation',
+ json={'endpoint_id': endpoint_id, 'validation_enabled': True, 'validation_schema': schema},
+ )
assert r.status_code in (200, 201)
q = {'query': 'query HelloOp($x: String!) { hello(name: $x) }', 'variables': {'x': 'A'}}
diff --git a/backend-services/live-tests/test_59_grpc_invalid_method.py b/backend-services/live-tests/test_59_grpc_invalid_method.py
index 16d1684..a5cb89b 100644
--- a/backend-services/live-tests/test_59_grpc_invalid_method.py
+++ b/backend-services/live-tests/test_59_grpc_invalid_method.py
@@ -1,14 +1,16 @@
import time
+
import pytest
from config import ENABLE_GRPC
pytestmark = [pytest.mark.grpc]
+
def test_grpc_invalid_method_returns_error(client):
if not ENABLE_GRPC:
pytest.skip('gRPC disabled')
try:
- import grpc_tools
+ pass
except Exception as e:
pytest.skip(f'Missing gRPC deps: {e}')
@@ -19,15 +21,41 @@ syntax = "proto3";
package {pkg};
service Greeter {}
""".replace('{pkg}', f'{api_name}_{api_version}')
- r = client.post(f'/platform/proto/{api_name}/{api_version}', files={'file': ('s.proto', proto.encode('utf-8'))})
+ r = client.post(
+ f'/platform/proto/{api_name}/{api_version}',
+ files={'file': ('s.proto', proto.encode('utf-8'))},
+ )
assert r.status_code == 200
- client.post('/platform/api', json={
- 'api_name': api_name, 'api_version': api_version, 'api_description': 'g', 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'], 'api_servers': ['grpc://127.0.0.1:9'], 'api_type': 'REST', 'active': True
- })
- client.post('/platform/endpoint', json={
- 'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'POST', 'endpoint_uri': '/grpc', 'endpoint_description': 'g'
- })
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
- r = client.post(f'/api/grpc/{api_name}', json={'method': 'Nope.Do', 'message': {}}, headers={'X-API-Version': api_version})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'g',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['grpc://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'g',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
+ r = client.post(
+ f'/api/grpc/{api_name}',
+ json={'method': 'Nope.Do', 'message': {}},
+ headers={'X-API-Version': api_version},
+ )
assert r.status_code in (404, 500)
diff --git a/backend-services/live-tests/test_60_grpc_gateway.py b/backend-services/live-tests/test_60_grpc_gateway.py
index 736e92a..969d6da 100644
--- a/backend-services/live-tests/test_60_grpc_gateway.py
+++ b/backend-services/live-tests/test_60_grpc_gateway.py
@@ -1,11 +1,11 @@
-import io
-import os
import time
-import pytest
+import pytest
from config import ENABLE_GRPC
-pytestmark = pytest.mark.skipif(not ENABLE_GRPC, reason='gRPC test disabled (set DOORMAN_TEST_GRPC=1 to enable)')
+pytestmark = pytest.mark.skipif(
+ not ENABLE_GRPC, reason='gRPC test disabled (set DOORMAN_TEST_GRPC=1 to enable)'
+)
PROTO_TEMPLATE = """
syntax = "proto3";
@@ -25,22 +25,23 @@ message HelloReply {
}
"""
+
def _start_grpc_server(port: int):
- import grpc
from concurrent import futures
+ import grpc
+
class GreeterServicer:
def Hello(self, request, context):
- from google.protobuf import struct_pb2
pass
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
return server
+
def test_grpc_gateway_basic_flow(client):
try:
import grpc
- import grpc_tools
except Exception as e:
pytest.skip(f'Missing gRPC deps: {e}')
@@ -54,26 +55,38 @@ def test_grpc_gateway_basic_flow(client):
assert r.status_code == 200, r.text
try:
+ import importlib
+ import pathlib
+ import tempfile
+
from grpc_tools import protoc
- import tempfile, pathlib, importlib
+
with tempfile.TemporaryDirectory() as td:
tmp = pathlib.Path(td)
(tmp / 'svc.proto').write_text(proto)
out = tmp / 'gen'
out.mkdir()
- code = protoc.main([
- 'protoc', f'--proto_path={td}', f'--python_out={out}', f'--grpc_python_out={out}', str(tmp / 'svc.proto')
- ])
+ code = protoc.main(
+ [
+ 'protoc',
+ f'--proto_path={td}',
+ f'--python_out={out}',
+ f'--grpc_python_out={out}',
+ str(tmp / 'svc.proto'),
+ ]
+ )
assert code == 0
(out / '__init__.py').write_text('')
import sys
+
sys.path.insert(0, str(out))
pb2 = importlib.import_module('svc_pb2')
pb2_grpc = importlib.import_module('svc_pb2_grpc')
- import grpc
from concurrent import futures
+ import grpc
+
class Greeter(pb2_grpc.GreeterServicer):
def Hello(self, request, context):
return pb2.HelloReply(message=f'Hello, {request.name}!')
@@ -81,41 +94,53 @@ def test_grpc_gateway_basic_flow(client):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
pb2_grpc.add_GreeterServicer_to_server(Greeter(), server)
import socket
- s = socket.socket(); s.bind(('127.0.0.1', 0)); port = s.getsockname()[1]; s.close()
+
+ s = socket.socket()
+ s.bind(('127.0.0.1', 0))
+ port = s.getsockname()[1]
+ s.close()
server.add_insecure_port(f'127.0.0.1:{port}')
server.start()
try:
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'gRPC demo',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': [f'grpc://127.0.0.1:{port}'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'active': True
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'gRPC demo',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [f'grpc://127.0.0.1:{port}'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'active': True,
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc'
- })
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
assert r.status_code in (200, 201)
- r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ r = client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
assert r.status_code in (200, 201)
- body = {
- 'method': 'Greeter.Hello',
- 'message': {'name': 'Doorman'}
- }
- r = client.post(f'/api/grpc/{api_name}', json=body, headers={'X-API-Version': api_version})
+ body = {'method': 'Greeter.Hello', 'message': {'name': 'Doorman'}}
+ r = client.post(
+ f'/api/grpc/{api_name}', json=body, headers={'X-API-Version': api_version}
+ )
assert r.status_code == 200, r.text
data = r.json().get('response', r.json())
assert data.get('message') == 'Hello, Doorman!'
@@ -131,5 +156,8 @@ def test_grpc_gateway_basic_flow(client):
server.stop(0)
except Exception as e:
pytest.skip(f'Skipping gRPC due to setup failure: {e}')
+
+
import pytest
+
pytestmark = [pytest.mark.grpc, pytest.mark.gateway]
diff --git a/backend-services/live-tests/test_61_ip_policy.py b/backend-services/live-tests/test_61_ip_policy.py
index 4c1fe25..fb2a4b1 100644
--- a/backend-services/live-tests/test_61_ip_policy.py
+++ b/backend-services/live-tests/test_61_ip_policy.py
@@ -1,31 +1,36 @@
-import pytest
-
from types import SimpleNamespace
-from utils.ip_policy_util import _ip_in_list, _get_client_ip, enforce_api_ip_policy
+import pytest
+
+from utils.ip_policy_util import _get_client_ip, _ip_in_list, enforce_api_ip_policy
+
@pytest.fixture(autouse=True, scope='session')
def ensure_session_and_relaxed_limits():
yield
+
def make_request(host: str | None = None, headers: dict | None = None):
client = SimpleNamespace(host=host, port=None)
return SimpleNamespace(client=client, headers=headers or {}, url=SimpleNamespace(path='/'))
+
def test_ip_in_list_ipv4_exact_and_cidr():
assert _ip_in_list('192.168.1.10', ['192.168.1.10'])
assert _ip_in_list('10.1.2.3', ['10.0.0.0/8'])
assert not _ip_in_list('11.1.2.3', ['10.0.0.0/8'])
+
def test_ip_in_list_ipv6_exact_and_cidr():
assert _ip_in_list('2001:db8::1', ['2001:db8::1'])
assert _ip_in_list('2001:db8::abcd', ['2001:db8::/32'])
assert not _ip_in_list('2001:db9::1', ['2001:db8::/32'])
+
def test_get_client_ip_trusted_proxy(monkeypatch):
- monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: {
- 'xff_trusted_proxies': ['10.0.0.0/8']
- })
+ monkeypatch.setattr(
+ 'utils.ip_policy_util.get_cached_settings', lambda: {'xff_trusted_proxies': ['10.0.0.0/8']}
+ )
req1 = make_request('10.1.2.3', {'X-Forwarded-For': '1.2.3.4, 10.1.2.3'})
assert _get_client_ip(req1, True) == '1.2.3.4'
@@ -33,17 +38,21 @@ def test_get_client_ip_trusted_proxy(monkeypatch):
req2 = make_request('8.8.8.8', {'X-Forwarded-For': '1.2.3.4'})
assert _get_client_ip(req2, True) == '8.8.8.8'
+
def test_enforce_api_policy_never_blocks_localhost(monkeypatch):
- monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: {
- 'trust_x_forwarded_for': False,
- 'xff_trusted_proxies': [],
- 'allow_localhost_bypass': True,
- })
+ monkeypatch.setattr(
+ 'utils.ip_policy_util.get_cached_settings',
+ lambda: {
+ 'trust_x_forwarded_for': False,
+ 'xff_trusted_proxies': [],
+ 'allow_localhost_bypass': True,
+ },
+ )
api = {
'api_ip_mode': 'whitelist',
'api_ip_whitelist': ['203.0.113.0/24'],
- 'api_ip_blacklist': ['0.0.0.0/0']
+ 'api_ip_blacklist': ['0.0.0.0/0'],
}
req_local_v4 = make_request('127.0.0.1', {})
@@ -52,38 +61,46 @@ def test_enforce_api_policy_never_blocks_localhost(monkeypatch):
req_local_v6 = make_request('::1', {})
enforce_api_ip_policy(req_local_v6, api)
+
def test_get_client_ip_secure_default_no_trust_when_empty_list(monkeypatch):
- monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: {
- 'trust_x_forwarded_for': True,
- 'xff_trusted_proxies': []
- })
+ monkeypatch.setattr(
+ 'utils.ip_policy_util.get_cached_settings',
+ lambda: {'trust_x_forwarded_for': True, 'xff_trusted_proxies': []},
+ )
req = make_request('10.0.0.5', {'X-Forwarded-For': '203.0.113.9'})
assert _get_client_ip(req, True) == '10.0.0.5'
+
def test_get_client_ip_x_real_ip_and_cf_connecting(monkeypatch):
- monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: {
- 'trust_x_forwarded_for': True,
- 'xff_trusted_proxies': ['10.0.0.0/8']
- })
+ monkeypatch.setattr(
+ 'utils.ip_policy_util.get_cached_settings',
+ lambda: {'trust_x_forwarded_for': True, 'xff_trusted_proxies': ['10.0.0.0/8']},
+ )
req1 = make_request('10.2.3.4', {'X-Real-IP': '198.51.100.7'})
assert _get_client_ip(req1, True) == '198.51.100.7'
req2 = make_request('10.2.3.4', {'CF-Connecting-IP': '2001:db8::2'})
assert _get_client_ip(req2, True) == '2001:db8::2'
+
def test_get_client_ip_ignores_headers_when_trust_disabled(monkeypatch):
- monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: {
- 'trust_x_forwarded_for': False,
- 'xff_trusted_proxies': ['10.0.0.0/8']
- })
+ monkeypatch.setattr(
+ 'utils.ip_policy_util.get_cached_settings',
+ lambda: {'trust_x_forwarded_for': False, 'xff_trusted_proxies': ['10.0.0.0/8']},
+ )
req = make_request('10.2.3.4', {'X-Forwarded-For': '198.51.100.7'})
assert _get_client_ip(req, False) == '10.2.3.4'
+
def test_enforce_api_policy_whitelist_and_blacklist(monkeypatch):
- monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: {
- 'trust_x_forwarded_for': False,
- 'xff_trusted_proxies': []
- })
- api = {'api_ip_mode': 'whitelist', 'api_ip_whitelist': ['203.0.113.0/24'], 'api_ip_blacklist': []}
+ monkeypatch.setattr(
+ 'utils.ip_policy_util.get_cached_settings',
+ lambda: {'trust_x_forwarded_for': False, 'xff_trusted_proxies': []},
+ )
+ api = {
+ 'api_ip_mode': 'whitelist',
+ 'api_ip_whitelist': ['203.0.113.0/24'],
+ 'api_ip_blacklist': [],
+ }
req = make_request('198.51.100.10', {})
raised = False
try:
@@ -92,7 +109,11 @@ def test_enforce_api_policy_whitelist_and_blacklist(monkeypatch):
raised = True
assert raised
- api2 = {'api_ip_mode': 'allow_all', 'api_ip_whitelist': [], 'api_ip_blacklist': ['198.51.100.0/24']}
+ api2 = {
+ 'api_ip_mode': 'allow_all',
+ 'api_ip_whitelist': [],
+ 'api_ip_blacklist': ['198.51.100.0/24'],
+ }
req2 = make_request('198.51.100.10', {})
raised2 = False
try:
@@ -101,12 +122,16 @@ def test_enforce_api_policy_whitelist_and_blacklist(monkeypatch):
raised2 = True
assert raised2
+
def test_localhost_bypass_requires_no_forwarding_headers(monkeypatch):
- monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: {
- 'allow_localhost_bypass': True,
- 'trust_x_forwarded_for': False,
- 'xff_trusted_proxies': []
- })
+ monkeypatch.setattr(
+ 'utils.ip_policy_util.get_cached_settings',
+ lambda: {
+ 'allow_localhost_bypass': True,
+ 'trust_x_forwarded_for': False,
+ 'xff_trusted_proxies': [],
+ },
+ )
api = {'api_ip_mode': 'whitelist', 'api_ip_whitelist': ['203.0.113.0/24']}
req = make_request('::1', {'X-Forwarded-For': '1.2.3.4'})
raised = False
@@ -116,13 +141,17 @@ def test_localhost_bypass_requires_no_forwarding_headers(monkeypatch):
raised = True
assert raised, 'Expected enforcement when forwarding header present'
+
def test_env_overrides_localhost_bypass(monkeypatch):
monkeypatch.setenv('LOCAL_HOST_IP_BYPASS', 'true')
- monkeypatch.setattr('utils.ip_policy_util.get_cached_settings', lambda: {
- 'allow_localhost_bypass': False,
- 'trust_x_forwarded_for': False,
- 'xff_trusted_proxies': []
- })
+ monkeypatch.setattr(
+ 'utils.ip_policy_util.get_cached_settings',
+ lambda: {
+ 'allow_localhost_bypass': False,
+ 'trust_x_forwarded_for': False,
+ 'xff_trusted_proxies': [],
+ },
+ )
api = {'api_ip_mode': 'whitelist', 'api_ip_whitelist': ['203.0.113.0/24']}
req = make_request('127.0.0.1', {})
enforce_api_ip_policy(req, api)
diff --git a/backend-services/live-tests/test_66_graphql_missing_version_header.py b/backend-services/live-tests/test_66_graphql_missing_version_header.py
index fb6c956..5b27b2d 100644
--- a/backend-services/live-tests/test_66_graphql_missing_version_header.py
+++ b/backend-services/live-tests/test_66_graphql_missing_version_header.py
@@ -1,41 +1,74 @@
import time
+
import pytest
from config import ENABLE_GRAPHQL
pytestmark = [pytest.mark.graphql]
+
def test_graphql_missing_version_header_returns_400(client):
if not ENABLE_GRAPHQL:
pytest.skip('GraphQL disabled')
try:
- from ariadne import gql, make_executable_schema, QueryType
- from ariadne.asgi import GraphQL
import uvicorn
+ from ariadne import QueryType, gql, make_executable_schema
+ from ariadne.asgi import GraphQL
except Exception as e:
pytest.skip(f'Missing deps: {e}')
type_defs = gql('type Query { ok: String! }')
query = QueryType()
+
@query.field('ok')
def resolve_ok(*_):
return 'ok'
+
schema = make_executable_schema(type_defs, query)
- import threading, socket
- s = socket.socket(); s.bind(('127.0.0.1', 0)); port = s.getsockname()[1]; s.close()
- server = uvicorn.Server(uvicorn.Config(GraphQL(schema), host='127.0.0.1', port=port, log_level='warning'))
- t = threading.Thread(target=server.run, daemon=True); t.start()
- import time as _t; _t.sleep(0.4)
+ import socket
+ import threading
+
+ s = socket.socket()
+ s.bind(('127.0.0.1', 0))
+ port = s.getsockname()[1]
+ s.close()
+ server = uvicorn.Server(
+ uvicorn.Config(GraphQL(schema), host='127.0.0.1', port=port, log_level='warning')
+ )
+ t = threading.Thread(target=server.run, daemon=True)
+ t.start()
+ import time as _t
+
+ _t.sleep(0.4)
api_name = f'gql-novh-{int(time.time())}'
api_version = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name, 'api_version': api_version, 'api_description': 'gql',
- 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': [f'http://127.0.0.1:{port}'], 'api_type': 'REST', 'active': True
- })
- client.post('/platform/endpoint', json={
- 'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'POST', 'endpoint_uri': '/graphql', 'endpoint_description': 'gql'
- })
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'gql',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [f'http://127.0.0.1:{port}'],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/graphql',
+ 'endpoint_description': 'gql',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
r = client.post(f'/api/graphql/{api_name}', json={'query': '{ ok }'})
assert r.status_code == 400
diff --git a/backend-services/live-tests/test_70_combinations.py b/backend-services/live-tests/test_70_combinations.py
index dfbd104..75e8090 100644
--- a/backend-services/live-tests/test_70_combinations.py
+++ b/backend-services/live-tests/test_70_combinations.py
@@ -1,6 +1,8 @@
import time
+
from servers import start_rest_echo_server
+
def test_endpoint_level_servers_override(client):
srv_api = start_rest_echo_server()
srv_ep = start_rest_echo_server()
@@ -8,29 +10,38 @@ def test_endpoint_level_servers_override(client):
api_name = f'combo-{int(time.time())}'
api_version = 'v1'
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'combo demo',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': [srv_api.url],
- 'api_type': 'REST',
- 'active': True
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'combo demo',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [srv_api.url],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/who',
- 'endpoint_description': 'who am i',
- 'endpoint_servers': [srv_ep.url]
- })
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/who',
+ 'endpoint_description': 'who am i',
+ 'endpoint_servers': [srv_ep.url],
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ r = client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
assert r.status_code in (200, 201)
r = client.get(f'/api/rest/{api_name}/{api_version}/who')
@@ -51,5 +62,8 @@ def test_endpoint_level_servers_override(client):
pass
srv_api.stop()
srv_ep.stop()
+
+
import pytest
+
pytestmark = [pytest.mark.gateway, pytest.mark.routing]
diff --git a/backend-services/live-tests/test_80_roles_groups_permissions.py b/backend-services/live-tests/test_80_roles_groups_permissions.py
index 4f9dbbe..c959fc9 100644
--- a/backend-services/live-tests/test_80_roles_groups_permissions.py
+++ b/backend-services/live-tests/test_80_roles_groups_permissions.py
@@ -1,11 +1,12 @@
-import time
import random
import string
+import time
+
def _rand_user() -> tuple[str, str, str]:
ts = int(time.time())
- uname = f"usr_{ts}_{random.randint(1000,9999)}"
- email = f"{uname}@example.com"
+ uname = f'usr_{ts}_{random.randint(1000, 9999)}'
+ email = f'{uname}@example.com'
upp = random.choice(string.ascii_uppercase)
low = ''.join(random.choices(string.ascii_lowercase, k=8))
dig = ''.join(random.choices(string.digits, k=4))
@@ -14,40 +15,47 @@ def _rand_user() -> tuple[str, str, str]:
pwd = ''.join(random.sample(upp + low + dig + spc + tail, len(upp + low + dig + spc + tail)))
return uname, email, pwd
+
def test_role_permission_blocks_api_management(client):
- role_name = f"viewer_{int(time.time())}"
- r = client.post('/platform/role', json={
- 'role_name': role_name,
- 'role_description': 'temporary viewer',
- 'view_logs': True
- })
+ role_name = f'viewer_{int(time.time())}'
+ r = client.post(
+ '/platform/role',
+ json={'role_name': role_name, 'role_description': 'temporary viewer', 'view_logs': True},
+ )
assert r.status_code in (200, 201), r.text
uname, email, pwd = _rand_user()
- r = client.post('/platform/user', json={
- 'username': uname,
- 'email': email,
- 'password': pwd,
- 'role': role_name,
- 'groups': ['ALL'],
- 'ui_access': True
- })
+ r = client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': email,
+ 'password': pwd,
+ 'role': role_name,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
assert r.status_code in (200, 201), r.text
from client import LiveClient
+
user_client = LiveClient(client.base_url)
user_client.login(email, pwd)
- api_name = f"nope-{int(time.time())}"
- r = user_client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': 'v1',
- 'api_description': 'should be blocked',
- 'api_allowed_roles': [role_name],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://127.0.0.1:1'],
- 'api_type': 'REST'
- })
+ api_name = f'nope-{int(time.time())}'
+ r = user_client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': 'v1',
+ 'api_description': 'should be blocked',
+ 'api_allowed_roles': [role_name],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://127.0.0.1:1'],
+ 'api_type': 'REST',
+ },
+ )
assert r.status_code == 403
body = r.json()
data = body.get('response', body)
@@ -55,5 +63,8 @@ def test_role_permission_blocks_api_management(client):
client.delete(f'/platform/user/{uname}')
client.delete(f'/platform/role/{role_name}')
+
+
import pytest
+
pytestmark = [pytest.mark.security, pytest.mark.roles]
diff --git a/backend-services/live-tests/test_81_negative_permissions.py b/backend-services/live-tests/test_81_negative_permissions.py
index 9a1dcef..2a70ddb 100644
--- a/backend-services/live-tests/test_81_negative_permissions.py
+++ b/backend-services/live-tests/test_81_negative_permissions.py
@@ -1,14 +1,15 @@
-import time
import random
-import string
+import time
+
import pytest
pytestmark = [pytest.mark.security]
+
def _mk_user_payload(role_name: str) -> tuple[str, str, str, dict]:
ts = int(time.time())
- uname = f"min_{ts}_{random.randint(1000,9999)}"
- email = f"{uname}@example.com"
+ uname = f'min_{ts}_{random.randint(1000, 9999)}'
+ email = f'{uname}@example.com'
pwd = 'Strong!Passw0rd1234'
payload = {
'username': uname,
@@ -16,16 +17,14 @@ def _mk_user_payload(role_name: str) -> tuple[str, str, str, dict]:
'password': pwd,
'role': role_name,
'groups': ['ALL'],
- 'ui_access': True
+ 'ui_access': True,
}
return uname, email, pwd, payload
+
def test_negative_permissions_for_logs_and_config(client):
- role_name = f"minrole_{int(time.time())}"
- r = client.post('/platform/role', json={
- 'role_name': role_name,
- 'role_description': 'minimal'
- })
+ role_name = f'minrole_{int(time.time())}'
+ r = client.post('/platform/role', json={'role_name': role_name, 'role_description': 'minimal'})
assert r.status_code in (200, 201)
uname, email, pwd, payload = _mk_user_payload(role_name)
@@ -33,6 +32,7 @@ def test_negative_permissions_for_logs_and_config(client):
assert r.status_code in (200, 201)
from client import LiveClient
+
u = LiveClient(client.base_url)
u.login(email, pwd)
diff --git a/backend-services/live-tests/test_82_role_matrix_negative.py b/backend-services/live-tests/test_82_role_matrix_negative.py
index 28d820d..ce2898a 100644
--- a/backend-services/live-tests/test_82_role_matrix_negative.py
+++ b/backend-services/live-tests/test_82_role_matrix_negative.py
@@ -1,69 +1,110 @@
import time
+
import pytest
pytestmark = [pytest.mark.security, pytest.mark.roles]
+
def _mk_user(client, role_name: str):
ts = int(time.time())
- uname = f"perm_{ts}"
- email = f"{uname}@example.com"
+ uname = f'perm_{ts}'
+ email = f'{uname}@example.com'
pwd = 'Strong!Passw0rd1234'
- r = client.post('/platform/user', json={
- 'username': uname,
- 'email': email,
- 'password': pwd,
- 'role': role_name,
- 'groups': ['ALL'],
- 'ui_access': True,
- 'rate_limit_duration': 1000000,
- 'rate_limit_duration_type': 'second',
- 'throttle_duration': 1000000,
- 'throttle_duration_type': 'second',
- 'throttle_queue_limit': 1000000,
- 'throttle_wait_duration': 0,
- 'throttle_wait_duration_type': 'second'
- })
+ r = client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': email,
+ 'password': pwd,
+ 'role': role_name,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ 'rate_limit_duration': 1000000,
+ 'rate_limit_duration_type': 'second',
+ 'throttle_duration': 1000000,
+ 'throttle_duration_type': 'second',
+ 'throttle_queue_limit': 1000000,
+ 'throttle_wait_duration': 0,
+ 'throttle_wait_duration_type': 'second',
+ },
+ )
assert r.status_code in (200, 201), r.text
return uname, email, pwd
+
def _login(base_client, email, pwd):
from client import LiveClient
+
c = LiveClient(base_client.base_url)
c.login(email, pwd)
return c
+
def test_permission_matrix_block_then_allow(client):
api_name = f'permapi-{int(time.time())}'
api_version = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name, 'api_version': api_version, 'api_description': 'perm',
- 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST', 'active': True
- })
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'perm',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
def try_manage_apis(c):
- return c.post('/platform/api', json={
- 'api_name': f'pa-{int(time.time())}', 'api_version': 'v1', 'api_description': 'x',
- 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST'
- })
+ return c.post(
+ '/platform/api',
+ json={
+ 'api_name': f'pa-{int(time.time())}',
+ 'api_version': 'v1',
+ 'api_description': 'x',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://127.0.0.1:9'],
+ 'api_type': 'REST',
+ },
+ )
def try_manage_endpoints(c):
- return c.post('/platform/endpoint', json={
- 'api_name': api_name, 'api_version': api_version,
- 'endpoint_method': 'GET', 'endpoint_uri': f'/p{int(time.time())}', 'endpoint_description': 'x'
- })
+ return c.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': f'/p{int(time.time())}',
+ 'endpoint_description': 'x',
+ },
+ )
def try_manage_users(c):
- return c.post('/platform/user', json={
- 'username': f'u{int(time.time())}', 'email': f'u{int(time.time())}@ex.com',
- 'password': 'Strong!Passw0rd1234', 'role': 'viewer', 'groups': ['ALL'], 'ui_access': False
- })
+ return c.post(
+ '/platform/user',
+ json={
+ 'username': f'u{int(time.time())}',
+ 'email': f'u{int(time.time())}@ex.com',
+ 'password': 'Strong!Passw0rd1234',
+ 'role': 'viewer',
+ 'groups': ['ALL'],
+ 'ui_access': False,
+ },
+ )
def try_manage_groups(c):
- return c.post('/platform/group', json={'group_name': f'g{int(time.time())}', 'group_description': 'x'})
+ return c.post(
+ '/platform/group', json={'group_name': f'g{int(time.time())}', 'group_description': 'x'}
+ )
def try_manage_roles(c):
- return c.post('/platform/role', json={'role_name': f'r{int(time.time())}', 'role_description': 'x'})
+ return c.post(
+ '/platform/role', json={'role_name': f'r{int(time.time())}', 'role_description': 'x'}
+ )
matrix = [
('manage_apis', try_manage_apis, 'API007'),
@@ -74,19 +115,23 @@ def test_permission_matrix_block_then_allow(client):
]
for perm_field, attempt, expected_code in matrix:
- role_name = f"role_{perm_field}_{int(time.time())}"
- r = client.post('/platform/role', json={'role_name': role_name, 'role_description': 'matrix', perm_field: False})
+ role_name = f'role_{perm_field}_{int(time.time())}'
+ r = client.post(
+ '/platform/role',
+ json={'role_name': role_name, 'role_description': 'matrix', perm_field: False},
+ )
assert r.status_code in (200, 201), r.text
uname, email, pwd = _mk_user(client, role_name)
uc = _login(client, email, pwd)
resp = attempt(uc)
- assert resp.status_code == 403, f"{perm_field} should be blocked: {resp.text}"
- data = resp.json(); code = data.get('error_code') or (data.get('response') or {}).get('error_code')
+ assert resp.status_code == 403, f'{perm_field} should be blocked: {resp.text}'
+ data = resp.json()
+ code = data.get('error_code') or (data.get('response') or {}).get('error_code')
assert code == expected_code
client.put(f'/platform/role/{role_name}', json={perm_field: True})
resp2 = attempt(uc)
- assert resp2.status_code != 403, f"{perm_field} still blocked after enable: {resp2.text}"
+ assert resp2.status_code != 403, f'{perm_field} still blocked after enable: {resp2.text}'
client.delete(f'/platform/user/{uname}')
client.delete(f'/platform/role/{role_name}')
diff --git a/backend-services/live-tests/test_83_roles_groups_crud.py b/backend-services/live-tests/test_83_roles_groups_crud.py
index 8f0b7cd..09b3b38 100644
--- a/backend-services/live-tests/test_83_roles_groups_crud.py
+++ b/backend-services/live-tests/test_83_roles_groups_crud.py
@@ -1,12 +1,16 @@
import time
+
import pytest
pytestmark = [pytest.mark.security, pytest.mark.roles]
+
def test_roles_groups_crud_and_list(client):
- role = f"rolex-{int(time.time())}"
- group = f"groupx-{int(time.time())}"
- r = client.post('/platform/role', json={'role_name': role, 'role_description': 'x', 'manage_users': True})
+ role = f'rolex-{int(time.time())}'
+ group = f'groupx-{int(time.time())}'
+ r = client.post(
+ '/platform/role', json={'role_name': role, 'role_description': 'x', 'manage_users': True}
+ )
assert r.status_code in (200, 201)
r = client.post('/platform/group', json={'group_name': group, 'group_description': 'x'})
assert r.status_code in (200, 201)
diff --git a/backend-services/live-tests/test_85_endpoint_validation.py b/backend-services/live-tests/test_85_endpoint_validation.py
index c2429e8..04d0197 100644
--- a/backend-services/live-tests/test_85_endpoint_validation.py
+++ b/backend-services/live-tests/test_85_endpoint_validation.py
@@ -1,30 +1,38 @@
import time
+
from servers import start_rest_echo_server
+
def test_rest_endpoint_validation_blocks_invalid_payload(client):
srv = start_rest_echo_server()
try:
api_name = f'val-{int(time.time())}'
api_version = 'v1'
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'validation test',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': [srv.url],
- 'api_type': 'REST',
- 'active': True
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'validation test',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [srv.url],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
assert r.status_code in (200, 201)
- r = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/create',
- 'endpoint_description': 'create'
- })
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/create',
+ 'endpoint_description': 'create',
+ },
+ )
assert r.status_code in (200, 201)
r = client.get(f'/platform/endpoint/POST/{api_name}/{api_version}/create')
@@ -36,17 +44,23 @@ def test_rest_endpoint_validation_blocks_invalid_payload(client):
schema = {
'validation_schema': {
'user.name': {'required': True, 'type': 'string', 'min': 2},
- 'user.age': {'required': True, 'type': 'number', 'min': 1}
+ 'user.age': {'required': True, 'type': 'number', 'min': 1},
}
}
- r = client.post('/platform/endpoint/endpoint/validation', json={
- 'endpoint_id': endpoint_id,
- 'validation_enabled': True,
- 'validation_schema': schema
- })
+ r = client.post(
+ '/platform/endpoint/endpoint/validation',
+ json={
+ 'endpoint_id': endpoint_id,
+ 'validation_enabled': True,
+ 'validation_schema': schema,
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ r = client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
assert r.status_code in (200, 201)
r = client.post(f'/api/rest/{api_name}/{api_version}/create', json={'user': {'name': 'A'}})
@@ -55,7 +69,9 @@ def test_rest_endpoint_validation_blocks_invalid_payload(client):
err = body.get('error_code') or body.get('response', {}).get('error_code')
assert err == 'GTW011' or body.get('error_message')
- r = client.post(f'/api/rest/{api_name}/{api_version}/create', json={'user': {'name': 'Alan', 'age': 33}})
+ r = client.post(
+ f'/api/rest/{api_name}/{api_version}/create', json={'user': {'name': 'Alan', 'age': 33}}
+ )
assert r.status_code == 200
finally:
try:
@@ -67,5 +83,8 @@ def test_rest_endpoint_validation_blocks_invalid_payload(client):
except Exception:
pass
srv.stop()
+
+
import pytest
+
pytestmark = [pytest.mark.validation, pytest.mark.rest]
diff --git a/backend-services/live-tests/test_86_validation_edge_cases.py b/backend-services/live-tests/test_86_validation_edge_cases.py
index 21f7f09..1fb22ec 100644
--- a/backend-services/live-tests/test_86_validation_edge_cases.py
+++ b/backend-services/live-tests/test_86_validation_edge_cases.py
@@ -1,61 +1,88 @@
import time
-from servers import start_rest_echo_server
+
import pytest
+from servers import start_rest_echo_server
pytestmark = [pytest.mark.validation]
+
def test_nested_array_and_format_validations(client):
srv = start_rest_echo_server()
try:
api_name = f'valedge-{int(time.time())}'
api_version = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name, 'api_version': api_version,
- 'api_description': 'edge validations', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'],
- 'api_servers': [srv.url], 'api_type': 'REST', 'active': True
- })
- client.post('/platform/endpoint', json={
- 'api_name': api_name, 'api_version': api_version,
- 'endpoint_method': 'POST', 'endpoint_uri': '/submit', 'endpoint_description': 'submit'
- })
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'edge validations',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [srv.url],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/submit',
+ 'endpoint_description': 'submit',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
r = client.get(f'/platform/endpoint/POST/{api_name}/{api_version}/submit')
- ep = r.json().get('response', r.json()); endpoint_id = ep.get('endpoint_id'); assert endpoint_id
+ ep = r.json().get('response', r.json())
+ endpoint_id = ep.get('endpoint_id')
+ assert endpoint_id
schema = {
'validation_schema': {
'user.email': {'required': True, 'type': 'string', 'format': 'email'},
'items': {
- 'required': True, 'type': 'array', 'min': 1,
+ 'required': True,
+ 'type': 'array',
+ 'min': 1,
'array_items': {
'type': 'object',
'nested_schema': {
'id': {'required': True, 'type': 'string', 'format': 'uuid'},
- 'quantity': {'required': True, 'type': 'number', 'min': 1}
- }
- }
- }
+ 'quantity': {'required': True, 'type': 'number', 'min': 1},
+ },
+ },
+ },
}
}
- r = client.post('/platform/endpoint/endpoint/validation', json={
- 'endpoint_id': endpoint_id, 'validation_enabled': True, 'validation_schema': schema
- })
+ r = client.post(
+ '/platform/endpoint/endpoint/validation',
+ json={
+ 'endpoint_id': endpoint_id,
+ 'validation_enabled': True,
+ 'validation_schema': schema,
+ },
+ )
if r.status_code == 422:
import pytest
+
pytest.skip('Validation schema shape not accepted by server (422)')
assert r.status_code in (200, 201)
- bad = {
- 'user': {'email': 'not-an-email'},
- 'items': [{'id': '123', 'quantity': 0}]
- }
+ bad = {'user': {'email': 'not-an-email'}, 'items': [{'id': '123', 'quantity': 0}]}
r = client.post(f'/api/rest/{api_name}/{api_version}/submit', json=bad)
assert r.status_code == 400
import uuid
+
ok = {
'user': {'email': 'u@example.com'},
- 'items': [{'id': str(uuid.uuid4()), 'quantity': 2}]
+ 'items': [{'id': str(uuid.uuid4()), 'quantity': 2}],
}
r = client.post(f'/api/rest/{api_name}/{api_version}/submit', json=ok)
assert r.status_code == 200
diff --git a/backend-services/live-tests/test_90_security_tools_logging.py b/backend-services/live-tests/test_90_security_tools_logging.py
index dd19477..9487d0d 100644
--- a/backend-services/live-tests/test_90_security_tools_logging.py
+++ b/backend-services/live-tests/test_90_security_tools_logging.py
@@ -10,21 +10,29 @@ def test_security_settings_get_put(client):
updated = r.json().get('response', r.json())
assert bool(updated.get('enable_auto_save') or False) == desired
+
def test_tools_cors_check(client):
- r = client.post('/platform/tools/cors/check', json={
- 'origin': 'http://localhost:3000',
- 'method': 'GET',
- 'request_headers': ['Content-Type']
- })
+ r = client.post(
+ '/platform/tools/cors/check',
+ json={
+ 'origin': 'http://localhost:3000',
+ 'method': 'GET',
+ 'request_headers': ['Content-Type'],
+ },
+ )
assert r.status_code == 200
payload = r.json().get('response', r.json())
assert 'config' in payload and 'preflight' in payload
+
def test_clear_all_caches(client):
r = client.delete('/api/caches')
assert r.status_code == 200
body = r.json().get('response', r.json())
- assert 'All caches cleared' in (body.get('message') or body.get('error_message') or 'All caches cleared')
+ assert 'All caches cleared' in (
+ body.get('message') or body.get('error_message') or 'All caches cleared'
+ )
+
def test_logging_endpoints(client):
r = client.get('/platform/logging/logs?limit=10')
@@ -36,5 +44,8 @@ def test_logging_endpoints(client):
assert r.status_code == 200
files = r.json().get('response', r.json())
assert 'count' in files
+
+
import pytest
+
pytestmark = [pytest.mark.security, pytest.mark.tools, pytest.mark.logging]
diff --git a/backend-services/live-tests/test_91_memory_dump_restore.py b/backend-services/live-tests/test_91_memory_dump_restore.py
index 84a052e..9fe3b6a 100644
--- a/backend-services/live-tests/test_91_memory_dump_restore.py
+++ b/backend-services/live-tests/test_91_memory_dump_restore.py
@@ -1,10 +1,14 @@
import os
+
import pytest
pytestmark = [pytest.mark.security]
+
def test_memory_dump_restore_conditionally(client):
- mem_mode = os.environ.get('MEM_OR_EXTERNAL', os.environ.get('MEM_OR_REDIS', 'MEM')).upper() == 'MEM'
+ mem_mode = (
+ os.environ.get('MEM_OR_EXTERNAL', os.environ.get('MEM_OR_REDIS', 'MEM')).upper() == 'MEM'
+ )
key = os.environ.get('MEM_ENCRYPTION_KEY')
if not mem_mode or not key:
pytest.skip('Memory dump/restore only in memory mode with MEM_ENCRYPTION_KEY set')
diff --git a/backend-services/live-tests/test_92_authorization_flows.py b/backend-services/live-tests/test_92_authorization_flows.py
index 187d2ab..19e571c 100644
--- a/backend-services/live-tests/test_92_authorization_flows.py
+++ b/backend-services/live-tests/test_92_authorization_flows.py
@@ -10,24 +10,32 @@ def test_token_refresh_and_invalidate(client):
assert r.status_code == 401
from config import ADMIN_EMAIL, ADMIN_PASSWORD
+
client.login(ADMIN_EMAIL, ADMIN_PASSWORD)
+
def test_admin_revoke_tokens_for_user(client):
- import time, random, string
+ import random
+ import time
+
ts = int(time.time())
- uname = f'revoke_{ts}_{random.randint(1000,9999)}'
+ uname = f'revoke_{ts}_{random.randint(1000, 9999)}'
email = f'{uname}@example.com'
pwd = 'Strong!Passw0rd1234'
- r = client.post('/platform/user', json={
- 'username': uname,
- 'email': email,
- 'password': pwd,
- 'role': 'admin',
- 'groups': ['ALL'],
- 'ui_access': True
- })
+ r = client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': email,
+ 'password': pwd,
+ 'role': 'admin',
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
assert r.status_code in (200, 201), r.text
from client import LiveClient
+
user_client = LiveClient(client.base_url)
user_client.login(email, pwd)
r = client.post(f'/platform/authorization/admin/revoke/{uname}', json={})
@@ -35,5 +43,8 @@ def test_admin_revoke_tokens_for_user(client):
r = user_client.get('/platform/user/me')
assert r.status_code == 401
client.delete(f'/platform/user/{uname}')
+
+
import pytest
+
pytestmark = [pytest.mark.auth]
diff --git a/backend-services/live-tests/test_93_public_and_auth_optional.py b/backend-services/live-tests/test_93_public_and_auth_optional.py
index 9327a41..6a296c6 100644
--- a/backend-services/live-tests/test_93_public_and_auth_optional.py
+++ b/backend-services/live-tests/test_93_public_and_auth_optional.py
@@ -1,31 +1,39 @@
import time
-from servers import start_rest_echo_server
+
import requests
+from servers import start_rest_echo_server
+
def test_public_api_no_auth_required(client):
srv = start_rest_echo_server()
try:
api_name = f'public-{int(time.time())}'
api_version = 'v1'
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'public',
- 'api_allowed_roles': [],
- 'api_allowed_groups': [],
- 'api_servers': [srv.url],
- 'api_type': 'REST',
- 'active': True,
- 'api_public': True
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'public',
+ 'api_allowed_roles': [],
+ 'api_allowed_groups': [],
+ 'api_servers': [srv.url],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_public': True,
+ },
+ )
assert r.status_code in (200, 201)
- r = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/status',
- 'endpoint_description': 'status'
- })
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/status',
+ 'endpoint_description': 'status',
+ },
+ )
assert r.status_code in (200, 201)
s = requests.Session()
url = client.base_url.rstrip('/') + f'/api/rest/{api_name}/{api_version}/status'
@@ -42,33 +50,41 @@ def test_public_api_no_auth_required(client):
pass
srv.stop()
+
def test_auth_not_required_but_not_public_allows_unauthenticated(client):
srv = start_rest_echo_server()
try:
api_name = f'authopt-{int(time.time())}'
api_version = 'v1'
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'auth optional',
- 'api_allowed_roles': [],
- 'api_allowed_groups': [],
- 'api_servers': [srv.url],
- 'api_type': 'REST',
- 'active': True,
- 'api_public': False,
- 'api_auth_required': False
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'auth optional',
+ 'api_allowed_roles': [],
+ 'api_allowed_groups': [],
+ 'api_servers': [srv.url],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_public': False,
+ 'api_auth_required': False,
+ },
+ )
assert r.status_code in (200, 201)
- r = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/ping',
- 'endpoint_description': 'ping'
- })
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/ping',
+ 'endpoint_description': 'ping',
+ },
+ )
assert r.status_code in (200, 201)
import requests
+
s = requests.Session()
url = client.base_url.rstrip('/') + f'/api/rest/{api_name}/{api_version}/ping'
r = s.get(url)
@@ -83,5 +99,8 @@ def test_auth_not_required_but_not_public_allows_unauthenticated(client):
except Exception:
pass
srv.stop()
+
+
import pytest
+
pytestmark = [pytest.mark.rest, pytest.mark.auth]
diff --git a/backend-services/live-tests/test_94_routing_and_header_swap.py b/backend-services/live-tests/test_94_routing_and_header_swap.py
index c17b4dc..f950dff 100644
--- a/backend-services/live-tests/test_94_routing_and_header_swap.py
+++ b/backend-services/live-tests/test_94_routing_and_header_swap.py
@@ -1,6 +1,8 @@
import time
+
from servers import start_rest_echo_server
+
def test_client_routing_overrides_api_servers(client):
srv_a = start_rest_echo_server()
srv_b = start_rest_echo_server()
@@ -9,38 +11,52 @@ def test_client_routing_overrides_api_servers(client):
api_version = 'v1'
client_key = f'ck-{int(time.time())}'
- r = client.post('/platform/routing', json={
- 'routing_name': 'test-routing',
- 'routing_servers': [srv_b.url],
- 'routing_description': 'test',
- 'client_key': client_key,
- 'server_index': 0
- })
+ r = client.post(
+ '/platform/routing',
+ json={
+ 'routing_name': 'test-routing',
+ 'routing_servers': [srv_b.url],
+ 'routing_description': 'test',
+ 'client_key': client_key,
+ 'server_index': 0,
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'routing demo',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': [srv_a.url],
- 'api_type': 'REST',
- 'active': True
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'routing demo',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [srv_a.url],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
assert r.status_code in (200, 201)
- r = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/where',
- 'endpoint_description': 'where'
- })
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/where',
+ 'endpoint_description': 'where',
+ },
+ )
assert r.status_code in (200, 201)
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
- r = client.get(f'/api/rest/{api_name}/{api_version}/where', headers={'client-key': client_key})
+ r = client.get(
+ f'/api/rest/{api_name}/{api_version}/where', headers={'client-key': client_key}
+ )
assert r.status_code == 200
data = r.json().get('response', r.json())
hdrs = {k.lower(): v for k, v in (data.get('headers') or {}).items()}
@@ -55,10 +71,12 @@ def test_client_routing_overrides_api_servers(client):
except Exception:
pass
try:
- client.delete(f"/platform/routing/{client_key}")
+ client.delete(f'/platform/routing/{client_key}')
except Exception:
pass
- srv_a.stop(); srv_b.stop()
+ srv_a.stop()
+ srv_b.stop()
+
def test_authorization_field_swap_sets_auth_header(client):
srv = start_rest_echo_server()
@@ -67,30 +85,41 @@ def test_authorization_field_swap_sets_auth_header(client):
api_version = 'v1'
swap_header = 'x-up-auth'
token_value = 'Bearer SHHH_TOKEN'
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'auth swap',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': [srv.url],
- 'api_type': 'REST',
- 'active': True,
- 'api_allowed_headers': [swap_header],
- 'api_authorization_field_swap': swap_header
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'auth swap',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': [srv.url],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_allowed_headers': [swap_header],
+ 'api_authorization_field_swap': swap_header,
+ },
+ )
assert r.status_code in (200, 201)
- r = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/secure',
- 'endpoint_description': 'secure'
- })
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/secure',
+ 'endpoint_description': 'secure',
+ },
+ )
assert r.status_code in (200, 201)
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
- r = client.get(f'/api/rest/{api_name}/{api_version}/secure', headers={swap_header: token_value})
+ r = client.get(
+ f'/api/rest/{api_name}/{api_version}/secure', headers={swap_header: token_value}
+ )
assert r.status_code == 200
data = r.json().get('response', r.json())
hdrs = {k.lower(): v for k, v in (data.get('headers') or {}).items()}
@@ -105,5 +134,8 @@ def test_authorization_field_swap_sets_auth_header(client):
except Exception:
pass
srv.stop()
+
+
import pytest
+
pytestmark = [pytest.mark.routing, pytest.mark.gateway]
diff --git a/backend-services/live-tests/test_95_api_export_import_roundtrip.py b/backend-services/live-tests/test_95_api_export_import_roundtrip.py
index 141c226..362f9a5 100644
--- a/backend-services/live-tests/test_95_api_export_import_roundtrip.py
+++ b/backend-services/live-tests/test_95_api_export_import_roundtrip.py
@@ -1,30 +1,38 @@
import time
+
def test_single_api_export_import_roundtrip(client):
api_name = f'cfg-{int(time.time())}'
api_version = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'cfg demo',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://127.0.0.1:9'],
- 'api_type': 'REST',
- 'active': True
- })
- client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/x',
- 'endpoint_description': 'x'
- })
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'cfg demo',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/x',
+ 'endpoint_description': 'x',
+ },
+ )
r = client.get(f'/platform/config/export/apis?api_name={api_name}&api_version={api_version}')
assert r.status_code == 200
payload = r.json().get('response', r.json())
- exported_api = payload.get('api'); exported_eps = payload.get('endpoints')
+ exported_api = payload.get('api')
+ exported_eps = payload.get('endpoints')
assert exported_api and exported_api.get('api_name') == api_name
assert any(ep.get('endpoint_uri') == '/x' for ep in (exported_eps or []))
@@ -39,5 +47,8 @@ def test_single_api_export_import_roundtrip(client):
assert r.status_code == 200
r = client.get(f'/platform/endpoint/GET/{api_name}/{api_version}/x')
assert r.status_code == 200
+
+
import pytest
+
pytestmark = [pytest.mark.config]
diff --git a/backend-services/live-tests/test_95_chaos_backends.py b/backend-services/live-tests/test_95_chaos_backends.py
index 4d7e9a2..0ad7e88 100644
--- a/backend-services/live-tests/test_95_chaos_backends.py
+++ b/backend-services/live-tests/test_95_chaos_backends.py
@@ -1,12 +1,17 @@
import time
+
import pytest
+
@pytest.mark.order(-10)
def test_redis_outage_during_requests(client):
r = client.get('/platform/authorization/status')
assert r.status_code in (200, 204)
- r = client.post('/platform/tools/chaos/toggle', json={'backend': 'redis', 'enabled': True, 'duration_ms': 1500})
+ r = client.post(
+ '/platform/tools/chaos/toggle',
+ json={'backend': 'redis', 'enabled': True, 'duration_ms': 1500},
+ )
assert r.status_code == 200
t0 = time.time()
@@ -25,13 +30,17 @@ def test_redis_outage_during_requests(client):
data = js.get('response', js)
assert isinstance(data.get('error_budget_burn'), int)
+
@pytest.mark.order(-9)
def test_mongo_outage_during_requests(client):
- t0 = time.time()
+ time.time()
r0 = client.get('/platform/user/me')
assert r0.status_code in (200, 204)
- r = client.post('/platform/tools/chaos/toggle', json={'backend': 'mongo', 'enabled': True, 'duration_ms': 1500})
+ r = client.post(
+ '/platform/tools/chaos/toggle',
+ json={'backend': 'mongo', 'enabled': True, 'duration_ms': 1500},
+ )
assert r.status_code == 200
t1 = time.time()
@@ -49,4 +58,3 @@ def test_mongo_outage_during_requests(client):
js = s.json()
data = js.get('response', js)
assert isinstance(data.get('error_budget_burn'), int)
-
diff --git a/backend-services/live-tests/test_96_config_import_export.py b/backend-services/live-tests/test_96_config_import_export.py
index 7cf43cd..57da1f0 100644
--- a/backend-services/live-tests/test_96_config_import_export.py
+++ b/backend-services/live-tests/test_96_config_import_export.py
@@ -7,5 +7,8 @@ def test_config_export_import_roundtrip(client):
assert r.status_code == 200
data = r.json().get('response', r.json())
assert 'imported' in data
+
+
import pytest
+
pytestmark = [pytest.mark.config]
diff --git a/backend-services/live-tests/test_97_monitor_and_readiness.py b/backend-services/live-tests/test_97_monitor_and_readiness.py
index 6d5b980..d4a3e95 100644
--- a/backend-services/live-tests/test_97_monitor_and_readiness.py
+++ b/backend-services/live-tests/test_97_monitor_and_readiness.py
@@ -12,5 +12,8 @@ def test_monitor_endpoints(client):
assert r.status_code == 200
metrics = r.json().get('response', r.json())
assert isinstance(metrics, dict)
+
+
import pytest
+
pytestmark = [pytest.mark.monitor]
diff --git a/backend-services/live-tests/test_98_cors_preflight.py b/backend-services/live-tests/test_98_cors_preflight.py
index 8ae4ebf..f7a2a8e 100644
--- a/backend-services/live-tests/test_98_cors_preflight.py
+++ b/backend-services/live-tests/test_98_cors_preflight.py
@@ -1,38 +1,51 @@
def test_api_cors_preflight_and_response_headers(client):
import time
+
api_name = f'cors-{int(time.time())}'
api_version = 'v1'
- r = client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'api_description': 'cors test',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://127.0.0.1:9'],
- 'api_type': 'REST',
- 'active': True,
- 'api_cors_allow_origins': ['http://example.com'],
- 'api_cors_allow_methods': ['GET','POST'],
- 'api_cors_allow_headers': ['Content-Type','X-CSRF-Token'],
- 'api_cors_allow_credentials': True
- })
+ r = client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'cors test',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_cors_allow_origins': ['http://example.com'],
+ 'api_cors_allow_methods': ['GET', 'POST'],
+ 'api_cors_allow_headers': ['Content-Type', 'X-CSRF-Token'],
+ 'api_cors_allow_credentials': True,
+ },
+ )
assert r.status_code in (200, 201), r.text
- r = client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/ok',
- 'endpoint_description': 'ok'
- })
+ r = client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/ok',
+ 'endpoint_description': 'ok',
+ },
+ )
assert r.status_code in (200, 201)
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
path = f'/api/rest/{api_name}/{api_version}/ok'
- r = client.options(path, headers={
- 'Origin': 'http://example.com',
- 'Access-Control-Request-Method': 'GET',
- 'Access-Control-Request-Headers': 'Content-Type'
- })
+ r = client.options(
+ path,
+ headers={
+ 'Origin': 'http://example.com',
+ 'Access-Control-Request-Method': 'GET',
+ 'Access-Control-Request-Headers': 'Content-Type',
+ },
+ )
assert r.status_code in (200, 204)
acao = r.headers.get('Access-Control-Allow-Origin')
assert acao in (None, 'http://example.com') or True
@@ -43,5 +56,8 @@ def test_api_cors_preflight_and_response_headers(client):
client.delete(f'/platform/endpoint/GET/{api_name}/{api_version}/ok')
client.delete(f'/platform/api/{api_name}/{api_version}')
+
+
import pytest
+
pytestmark = [pytest.mark.cors]
diff --git a/backend-services/live-tests/test_99_cors_matrices.py b/backend-services/live-tests/test_99_cors_matrices.py
index 1b68afe..3aa69f8 100644
--- a/backend-services/live-tests/test_99_cors_matrices.py
+++ b/backend-services/live-tests/test_99_cors_matrices.py
@@ -1,58 +1,116 @@
import time
+
import pytest
pytestmark = [pytest.mark.cors]
+
def test_cors_wildcard_with_credentials_true_sets_origin(client):
api_name = f'corsw-{int(time.time())}'
api_version = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name, 'api_version': api_version, 'api_description': 'cors wild',
- 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST', 'active': True,
- 'api_cors_allow_origins': ['*'], 'api_cors_allow_methods': ['GET','OPTIONS'], 'api_cors_allow_headers': ['Content-Type'],
- 'api_cors_allow_credentials': True
- })
- client.post('/platform/endpoint', json={
- 'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'GET', 'endpoint_uri': '/c', 'endpoint_description': 'c'
- })
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'cors wild',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_cors_allow_origins': ['*'],
+ 'api_cors_allow_methods': ['GET', 'OPTIONS'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ 'api_cors_allow_credentials': True,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/c',
+ 'endpoint_description': 'c',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
path = f'/api/rest/{api_name}/{api_version}/c'
- r = client.options(path, headers={
- 'Origin': 'http://foo.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'Content-Type'
- })
+ r = client.options(
+ path,
+ headers={
+ 'Origin': 'http://foo.example',
+ 'Access-Control-Request-Method': 'GET',
+ 'Access-Control-Request-Headers': 'Content-Type',
+ },
+ )
assert r.status_code in (200, 204)
assert r.headers.get('Access-Control-Allow-Origin') in (None, 'http://foo.example') or True
client.delete(f'/platform/endpoint/GET/{api_name}/{api_version}/c')
client.delete(f'/platform/api/{api_name}/{api_version}')
+
def test_cors_specific_origin_and_headers(client):
api_name = f'corss-{int(time.time())}'
api_version = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name, 'api_version': api_version, 'api_description': 'cors spec',
- 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://127.0.0.1:9'], 'api_type': 'REST', 'active': True,
- 'api_cors_allow_origins': ['http://ok.example'], 'api_cors_allow_methods': ['GET','POST','OPTIONS'],
- 'api_cors_allow_headers': ['Content-Type','X-CSRF-Token'], 'api_cors_allow_credentials': False
- })
- client.post('/platform/endpoint', json={
- 'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'GET', 'endpoint_uri': '/d', 'endpoint_description': 'd'
- })
- client.post('/platform/subscription/subscribe', json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'api_description': 'cors spec',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'active': True,
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['GET', 'POST', 'OPTIONS'],
+ 'api_cors_allow_headers': ['Content-Type', 'X-CSRF-Token'],
+ 'api_cors_allow_credentials': False,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/d',
+ 'endpoint_description': 'd',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': api_name, 'api_version': api_version, 'username': 'admin'},
+ )
path = f'/api/rest/{api_name}/{api_version}/d'
- r = client.options(path, headers={
- 'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-CSRF-Token'
- })
+ r = client.options(
+ path,
+ headers={
+ 'Origin': 'http://ok.example',
+ 'Access-Control-Request-Method': 'GET',
+ 'Access-Control-Request-Headers': 'X-CSRF-Token',
+ },
+ )
assert r.status_code in (200, 204)
assert r.headers.get('Access-Control-Allow-Origin') in (None, 'http://ok.example') or True
- r = client.options(path, headers={
- 'Origin': 'http://bad.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-CSRF-Token'
- })
+ r = client.options(
+ path,
+ headers={
+ 'Origin': 'http://bad.example',
+ 'Access-Control-Request-Method': 'GET',
+ 'Access-Control-Request-Headers': 'X-CSRF-Token',
+ },
+ )
assert r.status_code in (200, 204)
client.delete(f'/platform/endpoint/GET/{api_name}/{api_version}/d')
diff --git a/backend-services/live-tests/test_api_cors_headers_matrix_live.py b/backend-services/live-tests/test_api_cors_headers_matrix_live.py
index afe2c37..5032069 100644
--- a/backend-services/live-tests/test_api_cors_headers_matrix_live.py
+++ b/backend-services/live-tests/test_api_cors_headers_matrix_live.py
@@ -1,32 +1,59 @@
import os
+
import pytest
_RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
if not _RUN_LIVE:
- pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
+ pytestmark = pytest.mark.skip(
+ reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable'
+ )
+
def test_api_cors_allow_origins_allow_methods_headers_credentials_expose_live(client):
import time
+
api_name = f'corslive-{int(time.time())}'
ver = 'v1'
- client.post('/platform/api', json={
- 'api_name': api_name,
- 'api_version': ver,
- 'api_description': 'cors live',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://upstream.example'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'api_cors_allow_origins': ['http://ok.example'],
- 'api_cors_allow_methods': ['GET','POST'],
- 'api_cors_allow_headers': ['Content-Type','X-CSRF-Token'],
- 'api_cors_allow_credentials': True,
- 'api_cors_expose_headers': ['X-Resp-Id'],
- })
- client.post('/platform/endpoint', json={'api_name': api_name, 'api_version': ver, 'endpoint_method': 'GET', 'endpoint_uri': '/q', 'endpoint_description': 'q'})
- client.post('/platform/subscription/subscribe', json={'username': 'admin', 'api_name': api_name, 'api_version': ver})
- r = client.options(f'/api/rest/{api_name}/{ver}/q', headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-CSRF-Token'})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': api_name,
+ 'api_version': ver,
+ 'api_description': 'cors live',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://upstream.example'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['GET', 'POST'],
+ 'api_cors_allow_headers': ['Content-Type', 'X-CSRF-Token'],
+ 'api_cors_allow_credentials': True,
+ 'api_cors_expose_headers': ['X-Resp-Id'],
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/q',
+ 'endpoint_description': 'q',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'username': 'admin', 'api_name': api_name, 'api_version': ver},
+ )
+ r = client.options(
+ f'/api/rest/{api_name}/{ver}/q',
+ headers={
+ 'Origin': 'http://ok.example',
+ 'Access-Control-Request-Method': 'GET',
+ 'Access-Control-Request-Headers': 'X-CSRF-Token',
+ },
+ )
assert r.status_code == 204
assert r.headers.get('Access-Control-Allow-Origin') == 'http://ok.example'
assert 'GET' in (r.headers.get('Access-Control-Allow-Methods') or '')
diff --git a/backend-services/live-tests/test_bandwidth_limit_live.py b/backend-services/live-tests/test_bandwidth_limit_live.py
index 4c31883..2d91c44 100644
--- a/backend-services/live-tests/test_bandwidth_limit_live.py
+++ b/backend-services/live-tests/test_bandwidth_limit_live.py
@@ -1,30 +1,57 @@
import os
+
import pytest
_RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
if not _RUN_LIVE:
- pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
+ pytestmark = pytest.mark.skip(
+ reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable'
+ )
+
def test_bandwidth_limit_enforced_and_window_resets_live(client):
name, ver = 'bwlive', 'v1'
- client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': 'bw live',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://up.example'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
- client.post('/platform/endpoint', json={'api_name': name, 'api_version': ver, 'endpoint_method': 'GET', 'endpoint_uri': '/p', 'endpoint_description': 'p'})
- client.post('/platform/subscription/subscribe', json={'username': 'admin', 'api_name': name, 'api_version': ver})
- client.put('/platform/user/admin', json={'bandwidth_limit_bytes': 1, 'bandwidth_limit_window': 'second', 'bandwidth_limit_enabled': True})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'bw live',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up.example'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/p',
+ 'endpoint_description': 'p',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'username': 'admin', 'api_name': name, 'api_version': ver},
+ )
+ client.put(
+ '/platform/user/admin',
+ json={
+ 'bandwidth_limit_bytes': 1,
+ 'bandwidth_limit_window': 'second',
+ 'bandwidth_limit_enabled': True,
+ },
+ )
client.delete('/api/caches')
r1 = client.get(f'/api/rest/{name}/{ver}/p')
r2 = client.get(f'/api/rest/{name}/{ver}/p')
assert r1.status_code == 200 and r2.status_code == 429
import time
+
time.sleep(1.1)
r3 = client.get(f'/api/rest/{name}/{ver}/p')
assert r3.status_code == 200
diff --git a/backend-services/live-tests/test_graphql_fallback_live.py b/backend-services/live-tests/test_graphql_fallback_live.py
index 8beac52..24df5bb 100644
--- a/backend-services/live-tests/test_graphql_fallback_live.py
+++ b/backend-services/live-tests/test_graphql_fallback_live.py
@@ -1,81 +1,123 @@
import os
+
import pytest
_RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
if not _RUN_LIVE:
- pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
+ pytestmark = pytest.mark.skip(
+ reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable'
+ )
+
async def _setup(client, name='gllive', ver='v1'):
- await client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://gql.up'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'api_public': True,
- })
- await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/graphql',
- 'endpoint_description': 'gql'
- })
+ await client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://gql.up'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'api_public': True,
+ },
+ )
+ await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/graphql',
+ 'endpoint_description': 'gql',
+ },
+ )
return name, ver
+
@pytest.mark.asyncio
async def test_graphql_client_fallback_to_httpx_live(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = await _setup(authed_client, name='gll1')
+
class Dummy:
pass
+
class FakeHTTPResp:
def __init__(self, payload):
self._p = payload
+
def json(self):
return self._p
+
class H:
def __init__(self, *args, **kwargs):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, json=None, headers=None):
return FakeHTTPResp({'ok': True})
+
monkeypatch.setattr(gs, 'Client', Dummy)
monkeypatch.setattr(gs.httpx, 'AsyncClient', H)
- r = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ ping }', 'variables': {}})
+ r = await authed_client.post(
+ f'/api/graphql/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'query': '{ ping }', 'variables': {}},
+ )
assert r.status_code == 200 and r.json().get('ok') is True
+
@pytest.mark.asyncio
async def test_graphql_errors_live_strict_and_loose(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = await _setup(authed_client, name='gll2')
+
class Dummy:
pass
+
class FakeHTTPResp:
def __init__(self, payload):
self._p = payload
+
def json(self):
return self._p
+
class H:
def __init__(self, *args, **kwargs):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, json=None, headers=None):
return FakeHTTPResp({'errors': [{'message': 'boom'}]})
+
monkeypatch.setattr(gs, 'Client', Dummy)
monkeypatch.setattr(gs.httpx, 'AsyncClient', H)
monkeypatch.delenv('STRICT_RESPONSE_ENVELOPE', raising=False)
- r1 = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ err }', 'variables': {}})
+ r1 = await authed_client.post(
+ f'/api/graphql/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'query': '{ err }', 'variables': {}},
+ )
assert r1.status_code == 200 and isinstance(r1.json().get('errors'), list)
monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true')
- r2 = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ err }', 'variables': {}})
+ r2 = await authed_client.post(
+ f'/api/graphql/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'query': '{ err }', 'variables': {}},
+ )
assert r2.status_code == 200 and r2.json().get('status_code') == 200
diff --git a/backend-services/live-tests/test_grpc_pkg_override_live.py b/backend-services/live-tests/test_grpc_pkg_override_live.py
index c7813ef..278dbc7 100644
--- a/backend-services/live-tests/test_grpc_pkg_override_live.py
+++ b/backend-services/live-tests/test_grpc_pkg_override_live.py
@@ -1,24 +1,33 @@
import os
+
import pytest
_RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
if not _RUN_LIVE:
- pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
+ pytestmark = pytest.mark.skip(
+ reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable'
+ )
+
def _fake_pb2_module(method_name='M'):
class Req:
pass
+
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
+
def __init__(self, ok=True):
self.ok = ok
+
@staticmethod
def FromString(b):
return Reply(True)
- setattr(Req, '__name__', f'{method_name}Request')
- setattr(Reply, '__name__', f'{method_name}Reply')
+
+ Req.__name__ = f'{method_name}Request'
+ Reply.__name__ = f'{method_name}Reply'
return Req, Reply
+
def _make_import_module_recorder(record, pb2_map):
def _imp(name):
record.append(name)
@@ -27,31 +36,47 @@ def _make_import_module_recorder(record, pb2_map):
mapping = pb2_map.get(name)
if mapping is None:
req_cls, rep_cls = _fake_pb2_module('M')
- setattr(mod, 'MRequest', req_cls)
- setattr(mod, 'MReply', rep_cls)
+ mod.MRequest = req_cls
+ mod.MReply = rep_cls
else:
req_cls, rep_cls = mapping
if req_cls:
- setattr(mod, 'MRequest', req_cls)
+ mod.MRequest = req_cls
if rep_cls:
- setattr(mod, 'MReply', rep_cls)
+ mod.MReply = rep_cls
return mod
if name.endswith('_pb2_grpc'):
+
class Stub:
def __init__(self, ch):
self._ch = ch
+
async def M(self, req):
- return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})()
+ return type(
+ 'R',
+ (),
+ {
+ 'DESCRIPTOR': type(
+ 'D', (), {'fields': [type('F', (), {'name': 'ok'})()]}
+ )(),
+ 'ok': True,
+ },
+ )()
+
mod = type('SVC', (), {'SvcStub': Stub})
return mod
raise ImportError(name)
+
return _imp
+
def _make_fake_grpc_unary(sequence_codes, grpc_mod):
counter = {'i': 0}
+
class AioChan:
async def channel_ready(self):
return True
+
class Chan(AioChan):
def unary_unary(self, method, request_serializer=None, response_deserializer=None):
async def _call(req):
@@ -59,143 +84,214 @@ def _make_fake_grpc_unary(sequence_codes, grpc_mod):
code = sequence_codes[idx]
counter['i'] += 1
if code is None:
- return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})()
+ return type(
+ 'R',
+ (),
+ {
+ 'DESCRIPTOR': type(
+ 'D', (), {'fields': [type('F', (), {'name': 'ok'})()]}
+ )(),
+ 'ok': True,
+ },
+ )()
+
class E(Exception):
def code(self):
return code
+
def details(self):
return 'err'
+
raise E()
+
return _call
+
class aio:
@staticmethod
def insecure_channel(url):
return Chan()
+
fake = type('G', (), {'aio': aio, 'StatusCode': grpc_mod.StatusCode, 'RpcError': Exception})
return fake
+
@pytest.mark.asyncio
async def test_grpc_with_api_grpc_package_config(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gplive1', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': 'g',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['grpc://127.0.0.1:9'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'api_grpc_package': 'api.pkg'
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc'
- })
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'g',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['grpc://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'api_grpc_package': 'api.pkg',
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
record = []
req_cls, rep_cls = _fake_pb2_module('M')
pb2_map = {'api.pkg_pb2': (req_cls, rep_cls)}
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
+ monkeypatch.setattr(
+ gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)
+ )
monkeypatch.setattr(gs.os.path, 'exists', lambda p: True)
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
- r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'})
+ r = await authed_client.post(
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'},
+ )
assert r.status_code == 200
assert any(n == 'api.pkg_pb2' for n in record)
+
@pytest.mark.asyncio
async def test_grpc_with_request_package_override(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gplive2', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': 'g',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['grpc://127.0.0.1:9'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc'
- })
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'g',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['grpc://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
record = []
req_cls, rep_cls = _fake_pb2_module('M')
pb2_map = {'req.pkg_pb2': (req_cls, rep_cls)}
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
+ monkeypatch.setattr(
+ gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)
+ )
monkeypatch.setattr(gs.os.path, 'exists', lambda p: True)
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
- r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'})
+ r = await authed_client.post(
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'},
+ )
assert r.status_code == 200
assert any(n == 'req.pkg_pb2' for n in record)
+
@pytest.mark.asyncio
async def test_grpc_without_package_server_uses_fallback_path(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gplive3', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': 'g',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['grpc://127.0.0.1:9'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc'
- })
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'g',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['grpc://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
record = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
pb2_map = {default_pkg: (req_cls, rep_cls)}
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
+ monkeypatch.setattr(
+ gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)
+ )
monkeypatch.setattr(gs.os.path, 'exists', lambda p: True)
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
- r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}})
+ r = await authed_client.post(
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
+ )
assert r.status_code == 200
assert any(n.endswith(default_pkg) for n in record)
+
@pytest.mark.asyncio
async def test_grpc_unavailable_then_success_with_retry_live(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gplive4', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': 'g',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['grpc://127.0.0.1:9'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 1,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc'
- })
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'g',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['grpc://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 1,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
record = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
pb2_map = {default_pkg: (req_cls, rep_cls)}
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
+ monkeypatch.setattr(
+ gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)
+ )
fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNAVAILABLE, None], gs.grpc)
monkeypatch.setattr(gs, 'grpc', fake_grpc)
- r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}})
+ r = await authed_client.post(
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
+ )
assert r.status_code == 200
diff --git a/backend-services/live-tests/test_memory_dump_sigusr1_live.py b/backend-services/live-tests/test_memory_dump_sigusr1_live.py
index f992f3c..716a8bb 100644
--- a/backend-services/live-tests/test_memory_dump_sigusr1_live.py
+++ b/backend-services/live-tests/test_memory_dump_sigusr1_live.py
@@ -1,16 +1,22 @@
-import pytest
import os
import platform
+import pytest
+
_RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
if not _RUN_LIVE:
- pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
+ pytestmark = pytest.mark.skip(
+ reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable'
+ )
+
@pytest.mark.skipif(platform.system() == 'Windows', reason='SIGUSR1 not available on Windows')
def test_sigusr1_dump_in_memory_mode_live(client, monkeypatch, tmp_path):
monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'live-secret-xyz')
monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'live' / 'memory_dump.bin'))
- import signal, time
+ import signal
+ import time
+
os.kill(os.getpid(), signal.SIGUSR1)
time.sleep(0.5)
assert True
diff --git a/backend-services/live-tests/test_platform_cors_env_edges_live.py b/backend-services/live-tests/test_platform_cors_env_edges_live.py
index 454b6e0..78497df 100644
--- a/backend-services/live-tests/test_platform_cors_env_edges_live.py
+++ b/backend-services/live-tests/test_platform_cors_env_edges_live.py
@@ -1,22 +1,41 @@
import os
+
import pytest
_RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
if not _RUN_LIVE:
- pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
+ pytestmark = pytest.mark.skip(
+ reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable'
+ )
+
def test_platform_cors_strict_wildcard_credentials_edges_live(client, monkeypatch):
monkeypatch.setenv('ALLOWED_ORIGINS', '*')
monkeypatch.setenv('ALLOW_CREDENTIALS', 'true')
monkeypatch.setenv('CORS_STRICT', 'true')
- r = client.options('/platform/api', headers={'Origin': 'http://evil.example', 'Access-Control-Request-Method': 'GET'})
+ r = client.options(
+ '/platform/api',
+ headers={'Origin': 'http://evil.example', 'Access-Control-Request-Method': 'GET'},
+ )
assert r.status_code == 204
assert r.headers.get('Access-Control-Allow-Origin') in (None, '')
+
def test_platform_cors_methods_headers_defaults_live(client, monkeypatch):
monkeypatch.setenv('ALLOW_METHODS', '')
monkeypatch.setenv('ALLOW_HEADERS', '*')
- r = client.options('/platform/api', headers={'Origin': 'http://localhost:3000', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-Rand'})
+ r = client.options(
+ '/platform/api',
+ headers={
+ 'Origin': 'http://localhost:3000',
+ 'Access-Control-Request-Method': 'GET',
+ 'Access-Control-Request-Headers': 'X-Rand',
+ },
+ )
assert r.status_code == 204
- methods = [m.strip() for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',') if m.strip()]
- assert set(methods) == {'GET','POST','PUT','DELETE','OPTIONS','PATCH','HEAD'}
+ methods = [
+ m.strip()
+ for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',')
+ if m.strip()
+ ]
+ assert set(methods) == {'GET', 'POST', 'PUT', 'DELETE', 'OPTIONS', 'PATCH', 'HEAD'}
diff --git a/backend-services/live-tests/test_rest_header_forwarding_live.py b/backend-services/live-tests/test_rest_header_forwarding_live.py
index 81e55b9..bf52a07 100644
--- a/backend-services/live-tests/test_rest_header_forwarding_live.py
+++ b/backend-services/live-tests/test_rest_header_forwarding_live.py
@@ -1,14 +1,20 @@
import os
+
import pytest
_RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
if not _RUN_LIVE:
- pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
+ pytestmark = pytest.mark.skip(
+ reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable'
+ )
+
@pytest.mark.asyncio
async def test_forward_allowed_headers_only(monkeypatch, authed_client):
- from conftest import create_api, create_endpoint, subscribe_self
+ from conftest import create_endpoint, subscribe_self
+
import services.gateway_service as gs
+
name, ver = 'hforw', 'v1'
payload = {
'api_name': name,
@@ -19,7 +25,7 @@ async def test_forward_allowed_headers_only(monkeypatch, authed_client):
'api_servers': ['http://up'],
'api_type': 'REST',
'api_allowed_retry_count': 0,
- 'api_allowed_headers': ['x-allowed', 'content-type']
+ 'api_allowed_headers': ['x-allowed', 'content-type'],
}
await authed_client.post('/platform/api', json=payload)
await create_endpoint(authed_client, name, ver, 'GET', '/p')
@@ -31,26 +37,37 @@ async def test_forward_allowed_headers_only(monkeypatch, authed_client):
self._p = {'ok': True}
self.headers = {'Content-Type': 'application/json'}
self.text = ''
+
def json(self):
return self._p
+
captured = {}
+
class CapClient:
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def get(self, url, params=None, headers=None):
captured['headers'] = headers or {}
return Resp()
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', CapClient)
- await authed_client.get(f'/api/rest/{name}/{ver}/p', headers={'X-Allowed': 'yes', 'X-Blocked': 'no'})
+ await authed_client.get(
+ f'/api/rest/{name}/{ver}/p', headers={'X-Allowed': 'yes', 'X-Blocked': 'no'}
+ )
ch = {k.lower(): v for k, v in (captured.get('headers') or {}).items()}
assert 'x-allowed' in ch and 'x-blocked' not in ch
+
@pytest.mark.asyncio
async def test_response_headers_filtered_by_allowlist(monkeypatch, authed_client):
- from conftest import create_api, create_endpoint, subscribe_self
+ from conftest import create_endpoint, subscribe_self
+
import services.gateway_service as gs
+
name, ver = 'hresp', 'v1'
payload = {
'api_name': name,
@@ -61,7 +78,7 @@ async def test_response_headers_filtered_by_allowlist(monkeypatch, authed_client
'api_servers': ['http://up'],
'api_type': 'REST',
'api_allowed_retry_count': 0,
- 'api_allowed_headers': ['x-upstream']
+ 'api_allowed_headers': ['x-upstream'],
}
await authed_client.post('/platform/api', json=payload)
await create_endpoint(authed_client, name, ver, 'GET', '/p')
@@ -71,17 +88,26 @@ async def test_response_headers_filtered_by_allowlist(monkeypatch, authed_client
def __init__(self):
self.status_code = 200
self._p = {'ok': True}
- self.headers = {'Content-Type': 'application/json', 'X-Upstream': 'yes', 'X-Secret': 'no'}
+ self.headers = {
+ 'Content-Type': 'application/json',
+ 'X-Upstream': 'yes',
+ 'X-Secret': 'no',
+ }
self.text = ''
+
def json(self):
return self._p
+
class HC:
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def get(self, url, params=None, headers=None):
return Resp()
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', HC)
r = await authed_client.get(f'/api/rest/{name}/{ver}/p')
assert r.status_code == 200
diff --git a/backend-services/live-tests/test_rest_retries_live.py b/backend-services/live-tests/test_rest_retries_live.py
index d390a0b..18f135a 100644
--- a/backend-services/live-tests/test_rest_retries_live.py
+++ b/backend-services/live-tests/test_rest_retries_live.py
@@ -1,23 +1,30 @@
import os
+
import pytest
_RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
if not _RUN_LIVE:
- pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
+ pytestmark = pytest.mark.skip(
+ reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable'
+ )
-from tests.test_gateway_routing_limits import _FakeAsyncClient
@pytest.mark.asyncio
async def test_rest_retries_on_500_then_success(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
import services.gateway_service as gs
+
name, ver = 'rlive500', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/r')
await subscribe_self(authed_client, name, ver)
from utils.database import api_collection
- api_collection.update_one({'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}})
+
+ api_collection.update_one(
+ {'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}}
+ )
await authed_client.delete('/api/caches')
class Resp:
@@ -26,30 +33,42 @@ async def test_rest_retries_on_500_then_success(monkeypatch, authed_client):
self._json = body or {}
self.text = ''
self.headers = headers or {'Content-Type': 'application/json'}
+
def json(self):
return self._json
+
seq = [Resp(500), Resp(200, {'ok': True})]
+
class SeqClient:
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def get(self, url, params=None, headers=None):
return seq.pop(0)
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', SeqClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/r')
assert r.status_code == 200 and r.json().get('ok') is True
+
@pytest.mark.asyncio
async def test_rest_retries_on_503_then_success(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
import services.gateway_service as gs
+
name, ver = 'rlive503', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/r')
await subscribe_self(authed_client, name, ver)
from utils.database import api_collection
- api_collection.update_one({'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}})
+
+ api_collection.update_one(
+ {'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}}
+ )
await authed_client.delete('/api/caches')
class Resp:
@@ -57,43 +76,58 @@ async def test_rest_retries_on_503_then_success(monkeypatch, authed_client):
self.status_code = status
self.headers = {'Content-Type': 'application/json'}
self.text = ''
+
def json(self):
return {}
+
seq = [Resp(503), Resp(200)]
+
class SeqClient:
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def get(self, url, params=None, headers=None):
return seq.pop(0)
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', SeqClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/r')
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_rest_no_retry_when_retry_count_zero(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
import services.gateway_service as gs
+
name, ver = 'rlivez0', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/r')
await subscribe_self(authed_client, name, ver)
await authed_client.delete('/api/caches')
+
class Resp:
def __init__(self, status):
self.status_code = status
self.headers = {'Content-Type': 'application/json'}
self.text = ''
+
def json(self):
return {}
+
class OneClient:
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def get(self, url, params=None, headers=None):
return Resp(500)
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', OneClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/r')
assert r.status_code == 500
diff --git a/backend-services/live-tests/test_soap_content_type_and_retries_live.py b/backend-services/live-tests/test_soap_content_type_and_retries_live.py
index ba97b98..2c387b3 100644
--- a/backend-services/live-tests/test_soap_content_type_and_retries_live.py
+++ b/backend-services/live-tests/test_soap_content_type_and_retries_live.py
@@ -1,14 +1,20 @@
import os
+
import pytest
_RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
if not _RUN_LIVE:
- pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
+ pytestmark = pytest.mark.skip(
+ reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable'
+ )
+
@pytest.mark.asyncio
async def test_soap_content_types_matrix(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
import services.gateway_service as gs
+
name, ver = 'soapct', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/s')
@@ -19,30 +25,43 @@ async def test_soap_content_types_matrix(monkeypatch, authed_client):
self.status_code = 200
self.headers = {'Content-Type': 'application/xml'}
self.text = ''
+
def json(self):
return {'ok': True}
+
class HC:
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, json=None, params=None, headers=None, content=None):
return Resp()
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', HC)
for ct in ['application/xml', 'text/xml']:
- r = await authed_client.post(f'/api/soap/{name}/{ver}/s', headers={'Content-Type': ct}, content='')
+ r = await authed_client.post(
+ f'/api/soap/{name}/{ver}/s', headers={'Content-Type': ct}, content=''
+ )
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_soap_retries_then_success(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
import services.gateway_service as gs
+
name, ver = 'soaprt', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/s')
await subscribe_self(authed_client, name, ver)
from utils.database import api_collection
- api_collection.update_one({'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}})
+
+ api_collection.update_one(
+ {'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}}
+ )
await authed_client.delete('/api/caches')
class Resp:
@@ -50,16 +69,24 @@ async def test_soap_retries_then_success(monkeypatch, authed_client):
self.status_code = status
self.headers = {'Content-Type': 'application/xml'}
self.text = ''
+
def json(self):
return {'ok': True}
+
seq = [Resp(503), Resp(200)]
+
class HC:
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, json=None, params=None, headers=None, content=None):
return seq.pop(0)
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', HC)
- r = await authed_client.post(f'/api/soap/{name}/{ver}/s', headers={'Content-Type': 'application/xml'}, content='')
+ r = await authed_client.post(
+ f'/api/soap/{name}/{ver}/s', headers={'Content-Type': 'application/xml'}, content=''
+ )
assert r.status_code == 200
diff --git a/backend-services/live-tests/test_throttle_queue_and_wait_live.py b/backend-services/live-tests/test_throttle_queue_and_wait_live.py
index 3d4454c..64bd7e3 100644
--- a/backend-services/live-tests/test_throttle_queue_and_wait_live.py
+++ b/backend-services/live-tests/test_throttle_queue_and_wait_live.py
@@ -1,68 +1,94 @@
import os
+
import pytest
_RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
if not _RUN_LIVE:
- pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
+ pytestmark = pytest.mark.skip(
+ reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable'
+ )
+
def test_throttle_queue_limit_exceeded_429_live(client):
- from config import ADMIN_EMAIL
name, ver = 'throtq', 'v1'
- client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': 'live throttle',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://up.example'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
- client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/t',
- 'endpoint_description': 't'
- })
- client.post('/platform/subscription/subscribe', json={'username': 'admin', 'api_name': name, 'api_version': ver})
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'live throttle',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up.example'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/t',
+ 'endpoint_description': 't',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'username': 'admin', 'api_name': name, 'api_version': ver},
+ )
client.put('/platform/user/admin', json={'throttle_queue_limit': 1})
client.delete('/api/caches')
- r1 = client.get(f'/api/rest/{name}/{ver}/t')
+ client.get(f'/api/rest/{name}/{ver}/t')
r2 = client.get(f'/api/rest/{name}/{ver}/t')
assert r2.status_code == 429
+
def test_throttle_dynamic_wait_live(client):
name, ver = 'throtw', 'v1'
- client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': 'live throttle wait',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://up.example'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
- client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/w',
- 'endpoint_description': 'w'
- })
- client.post('/platform/subscription/subscribe', json={'username': 'admin', 'api_name': name, 'api_version': ver})
- client.put('/platform/user/admin', json={
- 'throttle_duration': 1,
- 'throttle_duration_type': 'second',
- 'throttle_queue_limit': 10,
- 'throttle_wait_duration': 0.1,
- 'throttle_wait_duration_type': 'second',
- 'rate_limit_duration': 1000,
- 'rate_limit_duration_type': 'second',
- })
+ client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'live throttle wait',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up.example'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/w',
+ 'endpoint_description': 'w',
+ },
+ )
+ client.post(
+ '/platform/subscription/subscribe',
+ json={'username': 'admin', 'api_name': name, 'api_version': ver},
+ )
+ client.put(
+ '/platform/user/admin',
+ json={
+ 'throttle_duration': 1,
+ 'throttle_duration_type': 'second',
+ 'throttle_queue_limit': 10,
+ 'throttle_wait_duration': 0.1,
+ 'throttle_wait_duration_type': 'second',
+ 'rate_limit_duration': 1000,
+ 'rate_limit_duration_type': 'second',
+ },
+ )
client.delete('/api/caches')
import time
+
t0 = time.perf_counter()
r1 = client.get(f'/api/rest/{name}/{ver}/w')
t1 = time.perf_counter()
diff --git a/backend-services/middleware/analytics_middleware.py b/backend-services/middleware/analytics_middleware.py
index 734e884..e95798d 100644
--- a/backend-services/middleware/analytics_middleware.py
+++ b/backend-services/middleware/analytics_middleware.py
@@ -5,12 +5,13 @@ Automatically records detailed metrics for every request passing through
the gateway, including per-endpoint tracking and full performance data.
"""
-import time
import logging
+import time
+from collections.abc import Callable
+
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
-from typing import Callable
from utils.enhanced_metrics_util import enhanced_metrics_store
@@ -20,7 +21,7 @@ logger = logging.getLogger('doorman.analytics')
class AnalyticsMiddleware(BaseHTTPMiddleware):
"""
Middleware to capture comprehensive request/response metrics.
-
+
Records:
- Response time
- Status code
@@ -29,21 +30,21 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
- Endpoint URI and method
- Request/response sizes
"""
-
+
def __init__(self, app: ASGIApp):
super().__init__(app)
-
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""
Process request and record metrics.
"""
# Start timing
start_time = time.time()
-
+
# Extract request metadata
method = request.method
path = str(request.url.path)
-
+
# Estimate request size (headers + body)
request_size = 0
try:
@@ -54,16 +55,16 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
request_size += int(request.headers['content-length'])
except Exception:
pass
-
+
# Process request
response = await call_next(request)
-
+
# Calculate duration
duration_ms = (time.time() - start_time) * 1000
-
+
# Extract response metadata
status_code = response.status_code
-
+
# Estimate response size
response_size = 0
try:
@@ -74,18 +75,20 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
response_size += int(response.headers['content-length'])
except Exception:
pass
-
+
# Extract user from request state (set by auth middleware)
username = None
try:
if hasattr(request.state, 'user'):
- username = request.state.user.get('sub') if isinstance(request.state.user, dict) else None
+ username = (
+ request.state.user.get('sub') if isinstance(request.state.user, dict) else None
+ )
except Exception:
pass
-
+
# Parse API and endpoint from path
api_key, endpoint_uri = self._parse_api_endpoint(path)
-
+
# Record metrics
try:
enhanced_metrics_store.record(
@@ -96,17 +99,17 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
endpoint_uri=endpoint_uri,
method=method,
bytes_in=request_size,
- bytes_out=response_size
+ bytes_out=response_size,
)
except Exception as e:
- logger.error(f"Failed to record analytics: {str(e)}")
-
+ logger.error(f'Failed to record analytics: {str(e)}')
+
return response
-
+
def _parse_api_endpoint(self, path: str) -> tuple[str | None, str | None]:
"""
Parse API key and endpoint URI from request path.
-
+
Examples:
- /api/rest/customer/v1/users -> ("rest:customer", "/customer/v1/users")
- /platform/analytics/overview -> (None, "/platform/analytics/overview")
@@ -118,34 +121,34 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
parts = path.split('/')
if len(parts) >= 5:
api_name = parts[3]
- api_version = parts[4]
+ parts[4]
endpoint_uri = '/' + '/'.join(parts[3:])
- return f"rest:{api_name}", endpoint_uri
-
+ return f'rest:{api_name}', endpoint_uri
+
elif path.startswith('/api/graphql/'):
# GraphQL API
parts = path.split('/')
if len(parts) >= 4:
api_name = parts[3]
- return f"graphql:{api_name}", path
-
+ return f'graphql:{api_name}', path
+
elif path.startswith('/api/soap/'):
# SOAP API
parts = path.split('/')
if len(parts) >= 4:
api_name = parts[3]
- return f"soap:{api_name}", path
-
+ return f'soap:{api_name}', path
+
elif path.startswith('/api/grpc/'):
# gRPC API
parts = path.split('/')
if len(parts) >= 4:
api_name = parts[3]
- return f"grpc:{api_name}", path
-
+ return f'grpc:{api_name}', path
+
# Platform endpoints (not API requests)
return None, path
-
+
except Exception:
return None, path
@@ -153,8 +156,8 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
def setup_analytics_middleware(app):
"""
Add analytics middleware to FastAPI app.
-
+
Should be called during app initialization.
"""
app.add_middleware(AnalyticsMiddleware)
- logger.info("Analytics middleware initialized")
+ logger.info('Analytics middleware initialized')
diff --git a/backend-services/middleware/rate_limit_middleware.py b/backend-services/middleware/rate_limit_middleware.py
index 2d9668d..6202a6a 100644
--- a/backend-services/middleware/rate_limit_middleware.py
+++ b/backend-services/middleware/rate_limit_middleware.py
@@ -6,16 +6,16 @@ Checks rate limits and quotas, adds headers, and returns 429 when exceeded.
"""
import logging
-import time
-from typing import Optional, List, Callable
+from collections.abc import Callable
+
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
-from models.rate_limit_models import RateLimitRule, RuleType, TimeWindow, TierLimits
-from utils.rate_limiter import get_rate_limiter, RateLimiter
-from utils.quota_tracker import get_quota_tracker, QuotaTracker, QuotaType
+from models.rate_limit_models import RateLimitRule, RuleType, TierLimits, TimeWindow
+from utils.quota_tracker import QuotaTracker, QuotaType, get_quota_tracker
+from utils.rate_limiter import RateLimiter, get_rate_limiter
logger = logging.getLogger(__name__)
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware for rate limiting requests
-
+
Features:
- Applies rate limit rules based on user, API, endpoint, IP
- Checks quotas (monthly, daily)
@@ -31,18 +31,18 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
- Returns 429 Too Many Requests when limits exceeded
- Supports tier-based limits
"""
-
+
def __init__(
self,
app: ASGIApp,
- rate_limiter: Optional[RateLimiter] = None,
- quota_tracker: Optional[QuotaTracker] = None,
- get_rules_func: Optional[Callable] = None,
- get_user_tier_func: Optional[Callable] = None
+ rate_limiter: RateLimiter | None = None,
+ quota_tracker: QuotaTracker | None = None,
+ get_rules_func: Callable | None = None,
+ get_user_tier_func: Callable | None = None,
):
"""
Initialize rate limit middleware
-
+
Args:
app: FastAPI application
rate_limiter: Rate limiter instance
@@ -55,134 +55,134 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
self.quota_tracker = quota_tracker or get_quota_tracker()
self.get_rules_func = get_rules_func or self._default_get_rules
self.get_user_tier_func = get_user_tier_func or self._default_get_user_tier
-
+
async def dispatch(self, request: Request, call_next):
"""
Process request through rate limiting
-
+
Args:
request: Incoming request
call_next: Next middleware/handler
-
+
Returns:
Response (possibly 429 if rate limited)
"""
# Skip rate limiting for certain paths
if self._should_skip(request):
return await call_next(request)
-
+
# Extract identifiers
user_id = self._get_user_id(request)
api_name = self._get_api_name(request)
endpoint_uri = str(request.url.path)
ip_address = self._get_client_ip(request)
-
+
# Get applicable rules
rules = await self.get_rules_func(request, user_id, api_name, endpoint_uri, ip_address)
-
+
# Check rate limits
for rule in rules:
identifier = self._get_identifier(rule, user_id, api_name, endpoint_uri, ip_address)
-
+
if identifier:
result = self.rate_limiter.check_rate_limit(rule, identifier)
-
+
if not result.allowed:
# Rate limit exceeded
return self._create_rate_limit_response(result, rule)
-
+
# Check quotas if user identified
if user_id:
tier_limits = await self.get_user_tier_func(user_id)
-
+
if tier_limits:
quota_result = await self._check_quotas(user_id, tier_limits)
-
+
if not quota_result.allowed:
return self._create_quota_exceeded_response(quota_result)
-
+
# Process request
response = await call_next(request)
-
+
# Add rate limit headers
if rules:
# Use first rule for headers (highest priority)
rule = rules[0]
identifier = self._get_identifier(rule, user_id, api_name, endpoint_uri, ip_address)
-
+
if identifier:
usage = self.rate_limiter.get_current_usage(rule, identifier)
self._add_rate_limit_headers(response, usage.limit, usage.remaining, usage.reset_at)
-
+
# Increment quota (async, don't block response)
if user_id:
try:
self.quota_tracker.increment_quota(user_id, QuotaType.REQUESTS, 1, 'month')
except Exception as e:
- logger.error(f"Error incrementing quota: {e}")
-
+ logger.error(f'Error incrementing quota: {e}')
+
return response
-
+
def _should_skip(self, request: Request) -> bool:
"""
Check if rate limiting should be skipped for this request
-
+
Args:
request: Incoming request
-
+
Returns:
True if should skip
"""
# Skip health checks, metrics, etc.
skip_paths = ['/health', '/metrics', '/docs', '/redoc', '/openapi.json']
-
+
return any(request.url.path.startswith(path) for path in skip_paths)
-
- def _get_user_id(self, request: Request) -> Optional[str]:
+
+ def _get_user_id(self, request: Request) -> str | None:
"""
Extract user ID from request
-
+
Args:
request: Incoming request
-
+
Returns:
User ID or None
"""
# Try to get from request state (set by auth middleware)
if hasattr(request.state, 'user'):
return getattr(request.state.user, 'username', None)
-
+
# Try to get from headers
return request.headers.get('X-User-ID')
-
- def _get_api_name(self, request: Request) -> Optional[str]:
+
+ def _get_api_name(self, request: Request) -> str | None:
"""
Extract API name from request
-
+
Args:
request: Incoming request
-
+
Returns:
API name or None
"""
# Try to get from request state (set by routing)
if hasattr(request.state, 'api_name'):
return request.state.api_name
-
+
# Try to extract from path
path_parts = request.url.path.strip('/').split('/')
if len(path_parts) > 0:
return path_parts[0]
-
+
return None
-
+
def _get_client_ip(self, request: Request) -> str:
"""
Extract client IP address from request
-
+
Args:
request: Incoming request
-
+
Returns:
IP address
"""
@@ -191,36 +191,36 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
if forwarded_for:
# Take first IP (original client)
return forwarded_for.split(',')[0].strip()
-
+
# Check X-Real-IP header
real_ip = request.headers.get('X-Real-IP')
if real_ip:
return real_ip
-
+
# Fall back to direct connection
if request.client:
return request.client.host
-
+
return 'unknown'
-
+
def _get_identifier(
self,
rule: RateLimitRule,
- user_id: Optional[str],
- api_name: Optional[str],
+ user_id: str | None,
+ api_name: str | None,
endpoint_uri: str,
- ip_address: str
- ) -> Optional[str]:
+ ip_address: str,
+ ) -> str | None:
"""
Get identifier for rate limit rule
-
+
Args:
rule: Rate limit rule
user_id: User ID
api_name: API name
endpoint_uri: Endpoint URI
ip_address: IP address
-
+
Returns:
Identifier string or None
"""
@@ -233,188 +233,183 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
elif rule.rule_type == RuleType.PER_IP:
return ip_address
elif rule.rule_type == RuleType.PER_USER_API:
- return f"{user_id}:{api_name}" if user_id and api_name else None
+ return f'{user_id}:{api_name}' if user_id and api_name else None
elif rule.rule_type == RuleType.PER_USER_ENDPOINT:
- return f"{user_id}:{endpoint_uri}" if user_id else None
+ return f'{user_id}:{endpoint_uri}' if user_id else None
elif rule.rule_type == RuleType.GLOBAL:
return 'global'
-
+
return None
-
+
async def _default_get_rules(
self,
request: Request,
- user_id: Optional[str],
- api_name: Optional[str],
+ user_id: str | None,
+ api_name: str | None,
endpoint_uri: str,
- ip_address: str
- ) -> List[RateLimitRule]:
+ ip_address: str,
+ ) -> list[RateLimitRule]:
"""
Default function to get applicable rules
-
+
Priority order:
1. If user has tier assigned → Use tier limits ONLY
2. If user has NO tier → Use per-user rate limit rules
3. Fall back to global rules
-
+
In production, this should query MongoDB for rules.
This is a placeholder that returns default rules.
-
+
Args:
request: Incoming request
user_id: User ID
api_name: API name
endpoint_uri: Endpoint URI
ip_address: IP address
-
+
Returns:
List of applicable rules
"""
# TODO: Query MongoDB for rules
# For now, return default rules
-
+
rules = []
-
+
# Check if user has a tier assigned
user_tier = None
if user_id:
user_tier = await self.get_user_tier_func(user_id)
-
+
if user_tier:
# User has tier → Use tier limits ONLY (priority)
# Convert tier limits to rate limit rules
if user_tier.requests_per_minute:
- rules.append(RateLimitRule(
- rule_id=f'tier_{user_id}',
- rule_type=RuleType.PER_USER,
- time_window=TimeWindow.MINUTE,
- limit=user_tier.requests_per_minute,
- burst_allowance=user_tier.burst_allowance or 0,
- priority=100, # Highest priority
- enabled=True,
- description=f"Tier-based limit for {user_id}"
- ))
+ rules.append(
+ RateLimitRule(
+ rule_id=f'tier_{user_id}',
+ rule_type=RuleType.PER_USER,
+ time_window=TimeWindow.MINUTE,
+ limit=user_tier.requests_per_minute,
+ burst_allowance=user_tier.burst_allowance or 0,
+ priority=100, # Highest priority
+ enabled=True,
+ description=f'Tier-based limit for {user_id}',
+ )
+ )
else:
# User has NO tier → Use per-user rate limit rules
if user_id:
# TODO: Query MongoDB for per-user rules
# For now, use default per-user rule
- rules.append(RateLimitRule(
- rule_id='default_per_user',
- rule_type=RuleType.PER_USER,
- time_window=TimeWindow.MINUTE,
- limit=100,
- burst_allowance=20,
- priority=10,
- enabled=True,
- description="Default per-user limit"
- ))
-
+ rules.append(
+ RateLimitRule(
+ rule_id='default_per_user',
+ rule_type=RuleType.PER_USER,
+ time_window=TimeWindow.MINUTE,
+ limit=100,
+ burst_allowance=20,
+ priority=10,
+ enabled=True,
+ description='Default per-user limit',
+ )
+ )
+
# Always add global rule as fallback
- rules.append(RateLimitRule(
- rule_id='default_global',
- rule_type=RuleType.GLOBAL,
- time_window=TimeWindow.MINUTE,
- limit=1000,
- priority=0,
- enabled=True,
- description="Global rate limit"
- ))
-
+ rules.append(
+ RateLimitRule(
+ rule_id='default_global',
+ rule_type=RuleType.GLOBAL,
+ time_window=TimeWindow.MINUTE,
+ limit=1000,
+ priority=0,
+ enabled=True,
+ description='Global rate limit',
+ )
+ )
+
# Sort by priority (highest first)
rules.sort(key=lambda r: r.priority, reverse=True)
-
+
return rules
-
- async def _default_get_user_tier(self, user_id: str) -> Optional[TierLimits]:
+
+ async def _default_get_user_tier(self, user_id: str) -> TierLimits | None:
"""
Get user's tier limits from TierService
-
+
Args:
user_id: User ID
-
+
Returns:
TierLimits or None
"""
try:
- from services.tier_service import TierService, get_tier_service
+ from services.tier_service import get_tier_service
from utils.database_async import async_database
-
+
tier_service = get_tier_service(async_database.db)
limits = await tier_service.get_user_limits(user_id)
-
+
return limits
except Exception as e:
- logger.error(f"Error fetching user tier limits: {e}")
+ logger.error(f'Error fetching user tier limits: {e}')
return None
-
- async def _check_quotas(
- self,
- user_id: str,
- tier_limits: TierLimits
- ) -> 'QuotaCheckResult':
+
+ async def _check_quotas(self, user_id: str, tier_limits: TierLimits) -> 'QuotaCheckResult':
"""
Check user's quotas
-
+
Args:
user_id: User ID
tier_limits: User's tier limits
-
+
Returns:
QuotaCheckResult
"""
# Check monthly quota
if tier_limits.monthly_request_quota:
result = self.quota_tracker.check_quota(
- user_id,
- QuotaType.REQUESTS,
- tier_limits.monthly_request_quota,
- 'month'
+ user_id, QuotaType.REQUESTS, tier_limits.monthly_request_quota, 'month'
)
-
+
if not result.allowed:
return result
-
+
# Check daily quota
if tier_limits.daily_request_quota:
result = self.quota_tracker.check_quota(
- user_id,
- QuotaType.REQUESTS,
- tier_limits.daily_request_quota,
- 'day'
+ user_id, QuotaType.REQUESTS, tier_limits.daily_request_quota, 'day'
)
-
+
if not result.allowed:
return result
-
+
# All quotas OK
from utils.quota_tracker import QuotaCheckResult
+
return QuotaCheckResult(
allowed=True,
current_usage=0,
limit=tier_limits.monthly_request_quota or 0,
remaining=tier_limits.monthly_request_quota or 0,
reset_at=self.quota_tracker._get_next_reset('month'),
- percentage_used=0.0
+ percentage_used=0.0,
)
-
+
def _create_rate_limit_response(
- self,
- result: 'RateLimitResult',
- rule: RateLimitRule
+ self, result: 'RateLimitResult', rule: RateLimitRule
) -> JSONResponse:
"""
Create 429 response for rate limit exceeded
-
+
Args:
result: Rate limit result
rule: Rule that was exceeded
-
+
Returns:
JSONResponse with 429 status
"""
info = result.to_info()
-
+
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
@@ -423,21 +418,18 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
'limit': result.limit,
'remaining': result.remaining,
'reset_at': result.reset_at,
- 'retry_after': result.retry_after
+ 'retry_after': result.retry_after,
},
- headers=info.to_headers()
+ headers=info.to_headers(),
)
-
- def _create_quota_exceeded_response(
- self,
- result: 'QuotaCheckResult'
- ) -> JSONResponse:
+
+ def _create_quota_exceeded_response(self, result: 'QuotaCheckResult') -> JSONResponse:
"""
Create 429 response for quota exceeded
-
+
Args:
result: Quota check result
-
+
Returns:
JSONResponse with 429 status
"""
@@ -450,26 +442,22 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
'limit': result.limit,
'remaining': result.remaining,
'reset_at': result.reset_at.isoformat(),
- 'percentage_used': result.percentage_used
+ 'percentage_used': result.percentage_used,
},
headers={
'X-RateLimit-Limit': str(result.limit),
'X-RateLimit-Remaining': str(result.remaining),
'X-RateLimit-Reset': str(int(result.reset_at.timestamp())),
- 'Retry-After': str(int((result.reset_at - datetime.now()).total_seconds()))
- }
+ 'Retry-After': str(int((result.reset_at - datetime.now()).total_seconds())),
+ },
)
-
+
def _add_rate_limit_headers(
- self,
- response: Response,
- limit: int,
- remaining: int,
- reset_at: int
+ self, response: Response, limit: int, remaining: int, reset_at: int
):
"""
Add rate limit headers to response
-
+
Args:
response: Response object
limit: Rate limit
diff --git a/backend-services/middleware/tier_rate_limit_middleware.py b/backend-services/middleware/tier_rate_limit_middleware.py
index 7d97f0d..ba6a423 100644
--- a/backend-services/middleware/tier_rate_limit_middleware.py
+++ b/backend-services/middleware/tier_rate_limit_middleware.py
@@ -8,14 +8,14 @@ Works alongside existing per-user rate limiting.
import asyncio
import logging
import time
-from typing import Optional
+
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from models.rate_limit_models import TierLimits
-from services.tier_service import TierService, get_tier_service
+from services.tier_service import get_tier_service
from utils.database_async import async_database
logger = logging.getLogger(__name__)
@@ -24,19 +24,19 @@ logger = logging.getLogger(__name__)
class TierRateLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware for tier-based rate limiting and throttling
-
+
Features:
- Enforces tier-based rate limits (requests per minute/hour/day)
- Supports throttling (queuing requests) vs hard rejection
- Respects user-specific limit overrides
- Adds rate limit headers to responses
"""
-
+
def __init__(self, app: ASGIApp):
super().__init__(app)
self._request_counts = {} # Simple in-memory counter (use Redis in production)
self._request_queue = {} # Queue for throttling
-
+
async def dispatch(self, request: Request, call_next):
"""
Process request through tier-based rate limiting
@@ -44,244 +44,215 @@ class TierRateLimitMiddleware(BaseHTTPMiddleware):
# Skip rate limiting for certain paths
if self._should_skip(request):
return await call_next(request)
-
+
# Extract user ID
user_id = self._get_user_id(request)
-
+
if not user_id:
# No user ID, skip tier-based limiting
return await call_next(request)
-
+
# Get user's tier limits
tier_service = get_tier_service(async_database.db)
limits = await tier_service.get_user_limits(user_id)
-
+
if not limits:
# No tier limits configured, allow request
return await call_next(request)
-
+
# Check rate limits
rate_limit_result = await self._check_rate_limits(user_id, limits)
-
+
if not rate_limit_result['allowed']:
# Check if throttling is enabled
if limits.enable_throttling:
# Try to queue the request
- queued = await self._try_queue_request(
- user_id,
- limits.max_queue_time_ms
- )
-
+ queued = await self._try_queue_request(user_id, limits.max_queue_time_ms)
+
if not queued:
# Queue full or timeout, return 429
- return self._create_rate_limit_response(
- rate_limit_result,
- limits
- )
-
+ return self._create_rate_limit_response(rate_limit_result, limits)
+
# Request was queued and processed, continue
else:
# Throttling disabled, hard reject
- return self._create_rate_limit_response(
- rate_limit_result,
- limits
- )
-
+ return self._create_rate_limit_response(rate_limit_result, limits)
+
# Increment counters
self._increment_counters(user_id, limits)
-
+
# Process request
response = await call_next(request)
-
+
# Add rate limit headers
self._add_rate_limit_headers(response, user_id, limits)
-
+
return response
-
- async def _check_rate_limits(
- self,
- user_id: str,
- limits: TierLimits
- ) -> dict:
+
+ async def _check_rate_limits(self, user_id: str, limits: TierLimits) -> dict:
"""
Check if user has exceeded any rate limits
-
+
Returns:
dict with 'allowed' (bool) and 'limit_type' (str)
"""
now = int(time.time())
-
+
# Check requests per minute
if limits.requests_per_minute and limits.requests_per_minute < 999999:
- key = f"{user_id}:minute:{now // 60}"
+ key = f'{user_id}:minute:{now // 60}'
count = self._request_counts.get(key, 0)
-
+
if count >= limits.requests_per_minute:
return {
'allowed': False,
'limit_type': 'minute',
'limit': limits.requests_per_minute,
'current': count,
- 'reset_at': ((now // 60) + 1) * 60
+ 'reset_at': ((now // 60) + 1) * 60,
}
-
+
# Check requests per hour
if limits.requests_per_hour and limits.requests_per_hour < 999999:
- key = f"{user_id}:hour:{now // 3600}"
+ key = f'{user_id}:hour:{now // 3600}'
count = self._request_counts.get(key, 0)
-
+
if count >= limits.requests_per_hour:
return {
'allowed': False,
'limit_type': 'hour',
'limit': limits.requests_per_hour,
'current': count,
- 'reset_at': ((now // 3600) + 1) * 3600
+ 'reset_at': ((now // 3600) + 1) * 3600,
}
-
+
# Check requests per day
if limits.requests_per_day and limits.requests_per_day < 999999:
- key = f"{user_id}:day:{now // 86400}"
+ key = f'{user_id}:day:{now // 86400}'
count = self._request_counts.get(key, 0)
-
+
if count >= limits.requests_per_day:
return {
'allowed': False,
'limit_type': 'day',
'limit': limits.requests_per_day,
'current': count,
- 'reset_at': ((now // 86400) + 1) * 86400
+ 'reset_at': ((now // 86400) + 1) * 86400,
}
-
+
return {'allowed': True}
-
+
def _increment_counters(self, user_id: str, limits: TierLimits):
"""Increment request counters for all time windows"""
now = int(time.time())
-
+
if limits.requests_per_minute:
- key = f"{user_id}:minute:{now // 60}"
+ key = f'{user_id}:minute:{now // 60}'
self._request_counts[key] = self._request_counts.get(key, 0) + 1
-
+
if limits.requests_per_hour:
- key = f"{user_id}:hour:{now // 3600}"
+ key = f'{user_id}:hour:{now // 3600}'
self._request_counts[key] = self._request_counts.get(key, 0) + 1
-
+
if limits.requests_per_day:
- key = f"{user_id}:day:{now // 86400}"
+ key = f'{user_id}:day:{now // 86400}'
self._request_counts[key] = self._request_counts.get(key, 0) + 1
-
- async def _try_queue_request(
- self,
- user_id: str,
- max_wait_ms: int
- ) -> bool:
+
+ async def _try_queue_request(self, user_id: str, max_wait_ms: int) -> bool:
"""
Try to queue request with throttling
-
+
Returns:
True if request was processed, False if timeout/rejected
"""
- queue_key = f"{user_id}:queue"
+ queue_key = f'{user_id}:queue'
start_time = time.time() * 1000 # milliseconds
-
+
# Initialize queue if needed
if queue_key not in self._request_queue:
self._request_queue[queue_key] = asyncio.Queue(maxsize=100)
-
+
queue = self._request_queue[queue_key]
-
+
try:
# Add to queue with timeout
- await asyncio.wait_for(
- queue.put(1),
- timeout=max_wait_ms / 1000.0
- )
-
+ await asyncio.wait_for(queue.put(1), timeout=max_wait_ms / 1000.0)
+
# Wait for rate limit to reset
while True:
elapsed = (time.time() * 1000) - start_time
-
+
if elapsed >= max_wait_ms:
# Timeout exceeded
await queue.get() # Remove from queue
return False
-
+
# Check if we can proceed
# In a real implementation, check actual rate limit status
await asyncio.sleep(0.1) # Small delay
-
+
# For now, assume we can proceed after a short wait
if elapsed >= 100: # 100ms min throttle delay
await queue.get() # Remove from queue
return True
-
- except asyncio.TimeoutError:
+
+ except TimeoutError:
return False
-
- def _create_rate_limit_response(
- self,
- result: dict,
- limits: TierLimits
- ) -> JSONResponse:
+
+ def _create_rate_limit_response(self, result: dict, limits: TierLimits) -> JSONResponse:
"""Create 429 Too Many Requests response"""
retry_after = result.get('reset_at', 0) - int(time.time())
-
+
return JSONResponse(
status_code=429,
content={
'error': 'Rate limit exceeded',
'error_code': 'RATE_LIMIT_EXCEEDED',
- 'message': f"Rate limit exceeded: {result.get('current', 0)}/{result.get('limit', 0)} requests per {result.get('limit_type', 'period')}",
+ 'message': f'Rate limit exceeded: {result.get("current", 0)}/{result.get("limit", 0)} requests per {result.get("limit_type", "period")}',
'limit_type': result.get('limit_type'),
'limit': result.get('limit'),
'current': result.get('current'),
'reset_at': result.get('reset_at'),
'retry_after': max(0, retry_after),
- 'throttling_enabled': limits.enable_throttling
+ 'throttling_enabled': limits.enable_throttling,
},
headers={
'Retry-After': str(max(0, retry_after)),
'X-RateLimit-Limit': str(result.get('limit', 0)),
'X-RateLimit-Remaining': '0',
- 'X-RateLimit-Reset': str(result.get('reset_at', 0))
- }
+ 'X-RateLimit-Reset': str(result.get('reset_at', 0)),
+ },
)
-
- def _add_rate_limit_headers(
- self,
- response: Response,
- user_id: str,
- limits: TierLimits
- ):
+
+ def _add_rate_limit_headers(self, response: Response, user_id: str, limits: TierLimits):
"""Add rate limit headers to response"""
now = int(time.time())
-
+
# Add headers for minute limit (most relevant)
if limits.requests_per_minute:
- key = f"{user_id}:minute:{now // 60}"
+ key = f'{user_id}:minute:{now // 60}'
current = self._request_counts.get(key, 0)
remaining = max(0, limits.requests_per_minute - current)
reset_at = ((now // 60) + 1) * 60
-
+
response.headers['X-RateLimit-Limit'] = str(limits.requests_per_minute)
response.headers['X-RateLimit-Remaining'] = str(remaining)
response.headers['X-RateLimit-Reset'] = str(reset_at)
-
+
def _should_skip(self, request: Request) -> bool:
"""Check if rate limiting should be skipped"""
skip_paths = [
- '/health',
- '/metrics',
- '/docs',
- '/redoc',
+ '/health',
+ '/metrics',
+ '/docs',
+ '/redoc',
'/openapi.json',
- '/platform/authorization' # Skip auth endpoints
+ '/platform/authorization', # Skip auth endpoints
]
-
+
return any(request.url.path.startswith(path) for path in skip_paths)
-
- def _get_user_id(self, request: Request) -> Optional[str]:
+
+ def _get_user_id(self, request: Request) -> str | None:
"""Extract user ID from request"""
# Try to get from request state (set by auth middleware)
if hasattr(request.state, 'user'):
@@ -290,9 +261,9 @@ class TierRateLimitMiddleware(BaseHTTPMiddleware):
return user.username
elif isinstance(user, dict):
return user.get('username') or user.get('sub')
-
+
# Try to get from JWT payload in state
if hasattr(request.state, 'jwt_payload'):
return request.state.jwt_payload.get('sub')
-
+
return None
diff --git a/backend-services/models/analytics_models.py b/backend-services/models/analytics_models.py
index 4c3abe5..ffb53db 100644
--- a/backend-services/models/analytics_models.py
+++ b/backend-services/models/analytics_models.py
@@ -6,38 +6,41 @@ analytics capabilities while maintaining backward compatibility.
"""
from __future__ import annotations
+
+from collections import deque
from dataclasses import dataclass, field
-from typing import Dict, List, Optional, Deque
-from collections import deque, defaultdict
from enum import Enum
class AggregationLevel(str, Enum):
"""Time-based aggregation levels for metrics."""
- MINUTE = "minute"
- FIVE_MINUTE = "5minute"
- HOUR = "hour"
- DAY = "day"
+
+ MINUTE = 'minute'
+ FIVE_MINUTE = '5minute'
+ HOUR = 'hour'
+ DAY = 'day'
class MetricType(str, Enum):
"""Types of metrics tracked."""
- REQUEST_COUNT = "request_count"
- ERROR_RATE = "error_rate"
- RESPONSE_TIME = "response_time"
- BANDWIDTH = "bandwidth"
- STATUS_CODE = "status_code"
- LATENCY_PERCENTILE = "latency_percentile"
+
+ REQUEST_COUNT = 'request_count'
+ ERROR_RATE = 'error_rate'
+ RESPONSE_TIME = 'response_time'
+ BANDWIDTH = 'bandwidth'
+ STATUS_CODE = 'status_code'
+ LATENCY_PERCENTILE = 'latency_percentile'
@dataclass
class PercentileMetrics:
"""
Latency percentile calculations.
-
+
Stores multiple percentiles for comprehensive performance analysis.
Uses a reservoir sampling approach to maintain a representative sample.
"""
+
p50: float = 0.0 # Median
p75: float = 0.0 # 75th percentile
p90: float = 0.0 # 90th percentile
@@ -45,20 +48,20 @@ class PercentileMetrics:
p99: float = 0.0 # 99th percentile
min: float = 0.0 # Minimum latency
max: float = 0.0 # Maximum latency
-
+
@staticmethod
- def calculate(latencies: List[float]) -> 'PercentileMetrics':
+ def calculate(latencies: list[float]) -> PercentileMetrics:
"""Calculate percentiles from a list of latencies."""
if not latencies:
return PercentileMetrics()
-
+
sorted_latencies = sorted(latencies)
n = len(sorted_latencies)
-
+
def percentile(p: float) -> float:
k = max(0, int(p * n) - 1)
return float(sorted_latencies[k])
-
+
return PercentileMetrics(
p50=percentile(0.50),
p75=percentile(0.75),
@@ -66,10 +69,10 @@ class PercentileMetrics:
p95=percentile(0.95),
p99=percentile(0.99),
min=float(sorted_latencies[0]),
- max=float(sorted_latencies[-1])
+ max=float(sorted_latencies[-1]),
)
-
- def to_dict(self) -> Dict:
+
+ def to_dict(self) -> dict:
return {
'p50': self.p50,
'p75': self.p75,
@@ -77,7 +80,7 @@ class PercentileMetrics:
'p95': self.p95,
'p99': self.p99,
'min': self.min,
- 'max': self.max
+ 'max': self.max,
}
@@ -85,36 +88,37 @@ class PercentileMetrics:
class EndpointMetrics:
"""
Per-endpoint performance metrics.
-
+
Tracks detailed metrics for individual API endpoints to identify
performance bottlenecks at a granular level.
"""
+
endpoint_uri: str
method: str
count: int = 0
error_count: int = 0
total_ms: float = 0.0
- latencies: Deque[float] = field(default_factory=deque)
- status_counts: Dict[int, int] = field(default_factory=dict)
-
+ latencies: deque[float] = field(default_factory=deque)
+ status_counts: dict[int, int] = field(default_factory=dict)
+
def add(self, ms: float, status: int, max_samples: int = 500) -> None:
"""Record a request for this endpoint."""
self.count += 1
if status >= 400:
self.error_count += 1
self.total_ms += ms
-
+
self.status_counts[status] = self.status_counts.get(status, 0) + 1
-
+
self.latencies.append(ms)
while len(self.latencies) > max_samples:
self.latencies.popleft()
-
+
def get_percentiles(self) -> PercentileMetrics:
"""Calculate percentiles for this endpoint."""
return PercentileMetrics.calculate(list(self.latencies))
-
- def to_dict(self) -> Dict:
+
+ def to_dict(self) -> dict:
percentiles = self.get_percentiles()
return {
'endpoint_uri': self.endpoint_uri,
@@ -124,7 +128,7 @@ class EndpointMetrics:
'error_rate': (self.error_count / self.count) if self.count > 0 else 0.0,
'avg_ms': (self.total_ms / self.count) if self.count > 0 else 0.0,
'percentiles': percentiles.to_dict(),
- 'status_counts': dict(self.status_counts)
+ 'status_counts': dict(self.status_counts),
}
@@ -132,13 +136,14 @@ class EndpointMetrics:
class EnhancedMinuteBucket:
"""
Enhanced version of MinuteBucket with additional analytics.
-
+
Extends the existing MinuteBucket from metrics_util.py with:
- Per-endpoint tracking
- Full percentile calculations (p50, p75, p90, p95, p99)
- Unique user tracking
- Request/response size tracking
"""
+
start_ts: int
count: int = 0
error_count: int = 0
@@ -147,35 +152,35 @@ class EnhancedMinuteBucket:
bytes_out: int = 0
upstream_timeouts: int = 0
retries: int = 0
-
+
# Existing tracking (compatible with metrics_util.py)
- status_counts: Dict[int, int] = field(default_factory=dict)
- api_counts: Dict[str, int] = field(default_factory=dict)
- api_error_counts: Dict[str, int] = field(default_factory=dict)
- user_counts: Dict[str, int] = field(default_factory=dict)
- latencies: Deque[float] = field(default_factory=deque)
-
+ status_counts: dict[int, int] = field(default_factory=dict)
+ api_counts: dict[str, int] = field(default_factory=dict)
+ api_error_counts: dict[str, int] = field(default_factory=dict)
+ user_counts: dict[str, int] = field(default_factory=dict)
+ latencies: deque[float] = field(default_factory=deque)
+
# NEW: Enhanced tracking
- endpoint_metrics: Dict[str, EndpointMetrics] = field(default_factory=dict)
+ endpoint_metrics: dict[str, EndpointMetrics] = field(default_factory=dict)
unique_users: set = field(default_factory=set)
- request_sizes: Deque[int] = field(default_factory=deque)
- response_sizes: Deque[int] = field(default_factory=deque)
-
+ request_sizes: deque[int] = field(default_factory=deque)
+ response_sizes: deque[int] = field(default_factory=deque)
+
def add_request(
self,
ms: float,
status: int,
- username: Optional[str],
- api_key: Optional[str],
- endpoint_uri: Optional[str] = None,
- method: Optional[str] = None,
+ username: str | None,
+ api_key: str | None,
+ endpoint_uri: str | None = None,
+ method: str | None = None,
bytes_in: int = 0,
bytes_out: int = 0,
- max_samples: int = 500
+ max_samples: int = 500,
) -> None:
"""
Record a request with enhanced tracking.
-
+
Compatible with existing metrics_util.py while adding new capabilities.
"""
# Existing tracking (backward compatible)
@@ -185,62 +190,61 @@ class EnhancedMinuteBucket:
self.total_ms += ms
self.bytes_in += bytes_in
self.bytes_out += bytes_out
-
+
self.status_counts[status] = self.status_counts.get(status, 0) + 1
-
+
if api_key:
self.api_counts[api_key] = self.api_counts.get(api_key, 0) + 1
if status >= 400:
self.api_error_counts[api_key] = self.api_error_counts.get(api_key, 0) + 1
-
+
if username:
self.user_counts[username] = self.user_counts.get(username, 0) + 1
self.unique_users.add(username)
-
+
self.latencies.append(ms)
while len(self.latencies) > max_samples:
self.latencies.popleft()
-
+
# NEW: Per-endpoint tracking
if endpoint_uri and method:
- endpoint_key = f"{method}:{endpoint_uri}"
+ endpoint_key = f'{method}:{endpoint_uri}'
if endpoint_key not in self.endpoint_metrics:
self.endpoint_metrics[endpoint_key] = EndpointMetrics(
- endpoint_uri=endpoint_uri,
- method=method
+ endpoint_uri=endpoint_uri, method=method
)
self.endpoint_metrics[endpoint_key].add(ms, status, max_samples)
-
+
# NEW: Request/response size tracking
if bytes_in > 0:
self.request_sizes.append(bytes_in)
while len(self.request_sizes) > max_samples:
self.request_sizes.popleft()
-
+
if bytes_out > 0:
self.response_sizes.append(bytes_out)
while len(self.response_sizes) > max_samples:
self.response_sizes.popleft()
-
+
def get_percentiles(self) -> PercentileMetrics:
"""Calculate full percentiles for this bucket."""
return PercentileMetrics.calculate(list(self.latencies))
-
+
def get_unique_user_count(self) -> int:
"""Get count of unique users in this bucket."""
return len(self.unique_users)
-
- def get_top_endpoints(self, limit: int = 10) -> List[Dict]:
+
+ def get_top_endpoints(self, limit: int = 10) -> list[dict]:
"""Get top N slowest/most-used endpoints."""
endpoints = [ep.to_dict() for ep in self.endpoint_metrics.values()]
# Sort by count (most used)
endpoints.sort(key=lambda x: x['count'], reverse=True)
return endpoints[:limit]
-
- def to_dict(self) -> Dict:
+
+ def to_dict(self) -> dict:
"""Serialize to dictionary (backward compatible + enhanced)."""
percentiles = self.get_percentiles()
-
+
return {
# Existing fields (backward compatible)
'start_ts': self.start_ts,
@@ -255,13 +259,16 @@ class EnhancedMinuteBucket:
'api_counts': dict(self.api_counts),
'api_error_counts': dict(self.api_error_counts),
'user_counts': dict(self.user_counts),
-
# NEW: Enhanced fields
'percentiles': percentiles.to_dict(),
'unique_users': self.get_unique_user_count(),
'endpoint_metrics': {k: v.to_dict() for k, v in self.endpoint_metrics.items()},
- 'avg_request_size': sum(self.request_sizes) / len(self.request_sizes) if self.request_sizes else 0,
- 'avg_response_size': sum(self.response_sizes) / len(self.response_sizes) if self.response_sizes else 0,
+ 'avg_request_size': sum(self.request_sizes) / len(self.request_sizes)
+ if self.request_sizes
+ else 0,
+ 'avg_response_size': sum(self.response_sizes) / len(self.response_sizes)
+ if self.response_sizes
+ else 0,
}
@@ -269,10 +276,11 @@ class EnhancedMinuteBucket:
class AggregatedMetrics:
"""
Multi-level aggregated metrics (5-minute, hourly, daily).
-
+
Used for efficient querying of historical data without
scanning all minute-level buckets.
"""
+
start_ts: int
end_ts: int
level: AggregationLevel
@@ -282,12 +290,12 @@ class AggregatedMetrics:
bytes_in: int = 0
bytes_out: int = 0
unique_users: int = 0
-
- status_counts: Dict[int, int] = field(default_factory=dict)
- api_counts: Dict[str, int] = field(default_factory=dict)
- percentiles: Optional[PercentileMetrics] = None
-
- def to_dict(self) -> Dict:
+
+ status_counts: dict[int, int] = field(default_factory=dict)
+ api_counts: dict[str, int] = field(default_factory=dict)
+ percentiles: PercentileMetrics | None = None
+
+ def to_dict(self) -> dict:
return {
'start_ts': self.start_ts,
'end_ts': self.end_ts,
@@ -301,7 +309,7 @@ class AggregatedMetrics:
'unique_users': self.unique_users,
'status_counts': dict(self.status_counts),
'api_counts': dict(self.api_counts),
- 'percentiles': self.percentiles.to_dict() if self.percentiles else None
+ 'percentiles': self.percentiles.to_dict() if self.percentiles else None,
}
@@ -309,9 +317,10 @@ class AggregatedMetrics:
class AnalyticsSnapshot:
"""
Complete analytics snapshot for a time range.
-
+
Used as the response format for analytics API endpoints.
"""
+
start_ts: int
end_ts: int
total_requests: int
@@ -322,19 +331,19 @@ class AnalyticsSnapshot:
total_bytes_in: int
total_bytes_out: int
unique_users: int
-
+
# Time-series data
- series: List[Dict]
-
+ series: list[dict]
+
# Top N lists
- top_apis: List[tuple]
- top_users: List[tuple]
- top_endpoints: List[Dict]
-
+ top_apis: list[tuple]
+ top_users: list[tuple]
+ top_endpoints: list[dict]
+
# Status code distribution
- status_distribution: Dict[str, int]
-
- def to_dict(self) -> Dict:
+ status_distribution: dict[str, int]
+
+ def to_dict(self) -> dict:
return {
'start_ts': self.start_ts,
'end_ts': self.end_ts,
@@ -346,11 +355,11 @@ class AnalyticsSnapshot:
'percentiles': self.percentiles.to_dict(),
'total_bytes_in': self.total_bytes_in,
'total_bytes_out': self.total_bytes_out,
- 'unique_users': self.unique_users
+ 'unique_users': self.unique_users,
},
'series': self.series,
'top_apis': [{'api': api, 'count': count} for api, count in self.top_apis],
'top_users': [{'user': user, 'count': count} for user, count in self.top_users],
'top_endpoints': self.top_endpoints,
- 'status_distribution': self.status_distribution
+ 'status_distribution': self.status_distribution,
}
diff --git a/backend-services/models/api_model_response.py b/backend-services/models/api_model_response.py
index 0b0dc0b..58c69c6 100644
--- a/backend-services/models/api_model_response.py
+++ b/backend-services/models/api_model_response.py
@@ -5,24 +5,61 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import List, Optional
+
class ApiModelResponse(BaseModel):
-
- api_name: Optional[str] = Field(None, min_length=1, max_length=25, description='Name of the API', example='customer')
- api_version: Optional[str] = Field(None, min_length=1, max_length=8, description='Version of the API', example='v1')
- api_description: Optional[str] = Field(None, min_length=1, max_length=127, description='Description of the API', example='New customer onboarding API')
- api_allowed_roles: Optional[List[str]] = Field(None, description='Allowed user roles for the API', example=['admin', 'user'])
- api_allowed_groups: Optional[List[str]] = Field(None, description='Allowed user groups for the API' , example=['admin', 'client-1-group'])
- api_servers: Optional[List[str]] = Field(None, description='List of backend servers for the API', example=['http://localhost:8080', 'http://localhost:8081'])
- api_type: Optional[str] = Field(None, description="Type of the API. Valid values: 'REST'", example='REST')
- api_authorization_field_swap: Optional[str] = Field(None, description='Header to swap for backend authorization header', example='backend-auth-header')
- api_allowed_headers: Optional[List[str]] = Field(None, description='Allowed headers for the API', example=['Content-Type', 'Authorization'])
- api_allowed_retry_count: Optional[int] = Field(None, description='Number of allowed retries for the API', example=0)
- api_credits_enabled: Optional[bool] = Field(False, description='Enable credit-based authentication for the API', example=True)
- api_credit_group: Optional[str] = Field(None, description='API credit group for the API credits', example='ai-group-1')
- api_id: Optional[str] = Field(None, description='Unique identifier for the API, auto-generated', example='c3eda315-545a-4fef-a831-7e45e2f68987')
- api_path: Optional[str] = Field(None, description='Unqiue path for the API, auto-generated', example='/customer/v1')
+ api_name: str | None = Field(
+ None, min_length=1, max_length=25, description='Name of the API', example='customer'
+ )
+ api_version: str | None = Field(
+ None, min_length=1, max_length=8, description='Version of the API', example='v1'
+ )
+ api_description: str | None = Field(
+ None,
+ min_length=1,
+ max_length=127,
+ description='Description of the API',
+ example='New customer onboarding API',
+ )
+ api_allowed_roles: list[str] | None = Field(
+ None, description='Allowed user roles for the API', example=['admin', 'user']
+ )
+ api_allowed_groups: list[str] | None = Field(
+ None, description='Allowed user groups for the API', example=['admin', 'client-1-group']
+ )
+ api_servers: list[str] | None = Field(
+ None,
+ description='List of backend servers for the API',
+ example=['http://localhost:8080', 'http://localhost:8081'],
+ )
+ api_type: str | None = Field(
+ None, description="Type of the API. Valid values: 'REST'", example='REST'
+ )
+ api_authorization_field_swap: str | None = Field(
+ None,
+ description='Header to swap for backend authorization header',
+ example='backend-auth-header',
+ )
+ api_allowed_headers: list[str] | None = Field(
+ None, description='Allowed headers for the API', example=['Content-Type', 'Authorization']
+ )
+ api_allowed_retry_count: int | None = Field(
+ None, description='Number of allowed retries for the API', example=0
+ )
+ api_credits_enabled: bool | None = Field(
+ False, description='Enable credit-based authentication for the API', example=True
+ )
+ api_credit_group: str | None = Field(
+ None, description='API credit group for the API credits', example='ai-group-1'
+ )
+ api_id: str | None = Field(
+ None,
+ description='Unique identifier for the API, auto-generated',
+ example='c3eda315-545a-4fef-a831-7e45e2f68987',
+ )
+ api_path: str | None = Field(
+ None, description='Unqiue path for the API, auto-generated', example='/customer/v1'
+ )
class Config:
arbitrary_types_allowed = True
diff --git a/backend-services/models/create_api_model.py b/backend-services/models/create_api_model.py
index fd158ab..0c66844 100644
--- a/backend-services/models/create_api_model.py
+++ b/backend-services/models/create_api_model.py
@@ -5,46 +5,127 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import List, Optional
+
class CreateApiModel(BaseModel):
-
- api_name: str = Field(..., min_length=1, max_length=64, description='Name of the API', example='customer')
- api_version: str = Field(..., min_length=1, max_length=8, description='Version of the API', example='v1')
- api_description: Optional[str] = Field(None, max_length=127, description='Description of the API', example='New customer onboarding API')
- api_allowed_roles: List[str] = Field(default_factory=list, description='Allowed user roles for the API', example=['admin', 'user'])
- api_allowed_groups: List[str] = Field(default_factory=list, description='Allowed user groups for the API' , example=['admin', 'client-1-group'])
- api_servers: List[str] = Field(default_factory=list, description='List of backend servers for the API', example=['http://localhost:8080', 'http://localhost:8081'])
+ api_name: str = Field(
+ ..., min_length=1, max_length=64, description='Name of the API', example='customer'
+ )
+ api_version: str = Field(
+ ..., min_length=1, max_length=8, description='Version of the API', example='v1'
+ )
+ api_description: str | None = Field(
+ None,
+ max_length=127,
+ description='Description of the API',
+ example='New customer onboarding API',
+ )
+ api_allowed_roles: list[str] = Field(
+ default_factory=list,
+ description='Allowed user roles for the API',
+ example=['admin', 'user'],
+ )
+ api_allowed_groups: list[str] = Field(
+ default_factory=list,
+ description='Allowed user groups for the API',
+ example=['admin', 'client-1-group'],
+ )
+ api_servers: list[str] = Field(
+ default_factory=list,
+ description='List of backend servers for the API',
+ example=['http://localhost:8080', 'http://localhost:8081'],
+ )
api_type: str = Field(None, description="Type of the API. Valid values: 'REST'", example='REST')
- api_allowed_retry_count: int = Field(0, description='Number of allowed retries for the API', example=0)
- api_grpc_package: Optional[str] = Field(None, description='Optional gRPC Python package to use for this API (e.g., "my.pkg"). When set, overrides request package and default.', example='my.pkg')
- api_grpc_allowed_packages: Optional[List[str]] = Field(None, description='Allow-list of gRPC package/module base names (no dots). If set, requests must match one of these.', example=['customer_v1'])
- api_grpc_allowed_services: Optional[List[str]] = Field(None, description='Allow-list of gRPC service names (e.g., Greeter). If set, only these services are permitted.', example=['Greeter'])
- api_grpc_allowed_methods: Optional[List[str]] = Field(None, description='Allow-list of gRPC methods as Service.Method strings. If set, only these methods are permitted.', example=['Greeter.SayHello'])
+ api_allowed_retry_count: int = Field(
+ 0, description='Number of allowed retries for the API', example=0
+ )
+ api_grpc_package: str | None = Field(
+ None,
+ description='Optional gRPC Python package to use for this API (e.g., "my.pkg"). When set, overrides request package and default.',
+ example='my.pkg',
+ )
+ api_grpc_allowed_packages: list[str] | None = Field(
+ None,
+ description='Allow-list of gRPC package/module base names (no dots). If set, requests must match one of these.',
+ example=['customer_v1'],
+ )
+ api_grpc_allowed_services: list[str] | None = Field(
+ None,
+ description='Allow-list of gRPC service names (e.g., Greeter). If set, only these services are permitted.',
+ example=['Greeter'],
+ )
+ api_grpc_allowed_methods: list[str] | None = Field(
+ None,
+ description='Allow-list of gRPC methods as Service.Method strings. If set, only these methods are permitted.',
+ example=['Greeter.SayHello'],
+ )
- api_authorization_field_swap: Optional[str] = Field(None, description='Header to swap for backend authorization header', example='backend-auth-header')
- api_allowed_headers: Optional[List[str]] = Field(None, description='Allowed headers for the API', example=['Content-Type', 'Authorization'])
- api_credits_enabled: Optional[bool] = Field(False, description='Enable credit-based authentication for the API', example=True)
- api_credit_group: Optional[str] = Field(None, description='API credit group for the API credits', example='ai-group-1')
- active: Optional[bool] = Field(True, description='Whether the API is active (enabled)', example=True)
+ api_authorization_field_swap: str | None = Field(
+ None,
+ description='Header to swap for backend authorization header',
+ example='backend-auth-header',
+ )
+ api_allowed_headers: list[str] | None = Field(
+ None, description='Allowed headers for the API', example=['Content-Type', 'Authorization']
+ )
+ api_credits_enabled: bool | None = Field(
+ False, description='Enable credit-based authentication for the API', example=True
+ )
+ api_credit_group: str | None = Field(
+ None, description='API credit group for the API credits', example='ai-group-1'
+ )
+ active: bool | None = Field(
+ True, description='Whether the API is active (enabled)', example=True
+ )
- api_cors_allow_origins: Optional[List[str]] = Field(None, description="Allowed origins for CORS (e.g., ['http://localhost:3000']). Use ['*'] to allow all.")
- api_cors_allow_methods: Optional[List[str]] = Field(None, description="Allowed methods for CORS preflight (e.g., ['GET','POST','PUT','DELETE','OPTIONS'])")
- api_cors_allow_headers: Optional[List[str]] = Field(None, description="Allowed request headers for CORS preflight (e.g., ['Content-Type','Authorization'])")
- api_cors_allow_credentials: Optional[bool] = Field(False, description='Whether to include Access-Control-Allow-Credentials=true in responses')
- api_cors_expose_headers: Optional[List[str]] = Field(None, description='Response headers to expose to the browser via Access-Control-Expose-Headers')
+ api_cors_allow_origins: list[str] | None = Field(
+ None,
+ description="Allowed origins for CORS (e.g., ['http://localhost:3000']). Use ['*'] to allow all.",
+ )
+ api_cors_allow_methods: list[str] | None = Field(
+ None,
+ description="Allowed methods for CORS preflight (e.g., ['GET','POST','PUT','DELETE','OPTIONS'])",
+ )
+ api_cors_allow_headers: list[str] | None = Field(
+ None,
+ description="Allowed request headers for CORS preflight (e.g., ['Content-Type','Authorization'])",
+ )
+ api_cors_allow_credentials: bool | None = Field(
+ False, description='Whether to include Access-Control-Allow-Credentials=true in responses'
+ )
+ api_cors_expose_headers: list[str] | None = Field(
+ None,
+ description='Response headers to expose to the browser via Access-Control-Expose-Headers',
+ )
- api_public: Optional[bool] = Field(False, description='If true, this API can be called without authentication or subscription')
+ api_public: bool | None = Field(
+ False, description='If true, this API can be called without authentication or subscription'
+ )
- api_auth_required: Optional[bool] = Field(True, description='If true (default), JWT auth is required for this API when not public. If false, requests may be unauthenticated but must meet other checks as configured.')
+ api_auth_required: bool | None = Field(
+ True,
+ description='If true (default), JWT auth is required for this API when not public. If false, requests may be unauthenticated but must meet other checks as configured.',
+ )
- api_id: Optional[str] = Field(None, description='Unique identifier for the API, auto-generated', example=None)
- api_path: Optional[str] = Field(None, description='Unique path for the API, auto-generated', example=None)
+ api_id: str | None = Field(
+ None, description='Unique identifier for the API, auto-generated', example=None
+ )
+ api_path: str | None = Field(
+ None, description='Unique path for the API, auto-generated', example=None
+ )
- api_ip_mode: Optional[str] = Field('allow_all', description="IP policy mode: 'allow_all' or 'whitelist'")
- api_ip_whitelist: Optional[List[str]] = Field(None, description='Allowed IPs/CIDRs when api_ip_mode=whitelist')
- api_ip_blacklist: Optional[List[str]] = Field(None, description='IPs/CIDRs denied regardless of mode')
- api_trust_x_forwarded_for: Optional[bool] = Field(None, description='Override: trust X-Forwarded-For for this API')
+ api_ip_mode: str | None = Field(
+ 'allow_all', description="IP policy mode: 'allow_all' or 'whitelist'"
+ )
+ api_ip_whitelist: list[str] | None = Field(
+ None, description='Allowed IPs/CIDRs when api_ip_mode=whitelist'
+ )
+ api_ip_blacklist: list[str] | None = Field(
+ None, description='IPs/CIDRs denied regardless of mode'
+ )
+ api_trust_x_forwarded_for: bool | None = Field(
+ None, description='Override: trust X-Forwarded-For for this API'
+ )
class Config:
arbitrary_types_allowed = True
diff --git a/backend-services/models/create_endpoint_model.py b/backend-services/models/create_endpoint_model.py
index eb1db6c..0b4ccd0 100644
--- a/backend-services/models/create_endpoint_model.py
+++ b/backend-services/models/create_endpoint_model.py
@@ -5,19 +5,40 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import Optional, List
+
class CreateEndpointModel(BaseModel):
+ api_name: str = Field(
+ ..., min_length=1, max_length=50, description='Name of the API', example='customer'
+ )
+ api_version: str = Field(
+ ..., min_length=1, max_length=10, description='Version of the API', example='v1'
+ )
+ endpoint_method: str = Field(
+ ..., min_length=1, max_length=10, description='HTTP method for the endpoint', example='GET'
+ )
+ endpoint_uri: str = Field(
+ ..., min_length=1, max_length=255, description='URI for the endpoint', example='/customer'
+ )
+ endpoint_description: str = Field(
+ ...,
+ min_length=1,
+ max_length=255,
+ description='Description of the endpoint',
+ example='Get customer details',
+ )
+ endpoint_servers: list[str] | None = Field(
+ None,
+ description='Optional list of backend servers for this endpoint (overrides API servers)',
+ example=['http://localhost:8082', 'http://localhost:8083'],
+ )
- api_name: str = Field(..., min_length=1, max_length=50, description='Name of the API', example='customer')
- api_version: str = Field(..., min_length=1, max_length=10, description='Version of the API', example='v1')
- endpoint_method: str = Field(..., min_length=1, max_length=10, description='HTTP method for the endpoint', example='GET')
- endpoint_uri: str = Field(..., min_length=1, max_length=255, description='URI for the endpoint', example='/customer')
- endpoint_description: str = Field(..., min_length=1, max_length=255, description='Description of the endpoint', example='Get customer details')
- endpoint_servers: Optional[List[str]] = Field(None, description='Optional list of backend servers for this endpoint (overrides API servers)', example=['http://localhost:8082', 'http://localhost:8083'])
-
- api_id: Optional[str] = Field(None, description='Unique identifier for the API, auto-generated', example=None)
- endpoint_id: Optional[str] = Field(None, description='Unique identifier for the endpoint, auto-generated', example=None)
+ api_id: str | None = Field(
+ None, description='Unique identifier for the API, auto-generated', example=None
+ )
+ endpoint_id: str | None = Field(
+ None, description='Unique identifier for the endpoint, auto-generated', example=None
+ )
class Config:
arbitrary_types_allowed = True
diff --git a/backend-services/models/create_endpoint_validation_model.py b/backend-services/models/create_endpoint_validation_model.py
index 2b91574..59d9f07 100644
--- a/backend-services/models/create_endpoint_validation_model.py
+++ b/backend-services/models/create_endpoint_validation_model.py
@@ -8,11 +8,19 @@ from pydantic import BaseModel, Field
from models.validation_schema_model import ValidationSchema
-class CreateEndpointValidationModel(BaseModel):
- endpoint_id: str = Field(..., description='Unique identifier for the endpoint, auto-generated', example='1299f720-e619-4628-b584-48a6570026cf')
- validation_enabled: bool = Field(..., description='Whether the validation is enabled', example=True)
- validation_schema: ValidationSchema = Field(..., description='The schema to validate the endpoint against', example={})
+class CreateEndpointValidationModel(BaseModel):
+ endpoint_id: str = Field(
+ ...,
+ description='Unique identifier for the endpoint, auto-generated',
+ example='1299f720-e619-4628-b584-48a6570026cf',
+ )
+ validation_enabled: bool = Field(
+ ..., description='Whether the validation is enabled', example=True
+ )
+ validation_schema: ValidationSchema = Field(
+ ..., description='The schema to validate the endpoint against', example={}
+ )
class Config:
- arbitrary_types_allowed = True
\ No newline at end of file
+ arbitrary_types_allowed = True
diff --git a/backend-services/models/create_group_model.py b/backend-services/models/create_group_model.py
index d5ae97d..fff28e5 100644
--- a/backend-services/models/create_group_model.py
+++ b/backend-services/models/create_group_model.py
@@ -5,14 +5,21 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import List, Optional
+
class CreateGroupModel(BaseModel):
+ group_name: str = Field(
+ ..., min_length=1, max_length=50, description='Name of the group', example='client-1-group'
+ )
- group_name: str = Field(..., min_length=1, max_length=50, description='Name of the group', example='client-1-group')
-
- group_description: Optional[str] = Field(None, max_length=255, description='Description of the group', example='Group for client 1')
- api_access: Optional[List[str]] = Field(default_factory=list, description='List of APIs the group can access', example=['customer/v1'])
+ group_description: str | None = Field(
+ None, max_length=255, description='Description of the group', example='Group for client 1'
+ )
+ api_access: list[str] | None = Field(
+ default_factory=list,
+ description='List of APIs the group can access',
+ example=['customer/v1'],
+ )
class Config:
arbitrary_types_allowed = True
diff --git a/backend-services/models/create_role_model.py b/backend-services/models/create_role_model.py
index 239f04a..2bca258 100644
--- a/backend-services/models/create_role_model.py
+++ b/backend-services/models/create_role_model.py
@@ -5,26 +5,46 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import Optional
+
class CreateRoleModel(BaseModel):
-
- role_name: str = Field(..., min_length=1, max_length=50, description='Name of the role', example='admin')
- role_description: Optional[str] = Field(None, max_length=255, description='Description of the role', example='Administrator role with full access')
+ role_name: str = Field(
+ ..., min_length=1, max_length=50, description='Name of the role', example='admin'
+ )
+ role_description: str | None = Field(
+ None,
+ max_length=255,
+ description='Description of the role',
+ example='Administrator role with full access',
+ )
manage_users: bool = Field(False, description='Permission to manage users', example=True)
manage_apis: bool = Field(False, description='Permission to manage APIs', example=True)
- manage_endpoints: bool = Field(False, description='Permission to manage endpoints', example=True)
+ manage_endpoints: bool = Field(
+ False, description='Permission to manage endpoints', example=True
+ )
manage_groups: bool = Field(False, description='Permission to manage groups', example=True)
manage_roles: bool = Field(False, description='Permission to manage roles', example=True)
manage_routings: bool = Field(False, description='Permission to manage routings', example=True)
manage_gateway: bool = Field(False, description='Permission to manage gateway', example=True)
- manage_subscriptions: bool = Field(False, description='Permission to manage subscriptions', example=True)
- manage_security: bool = Field(False, description='Permission to manage security settings', example=True)
- manage_tiers: bool = Field(False, description='Permission to manage pricing tiers', example=True)
- manage_rate_limits: bool = Field(False, description='Permission to manage rate limiting rules', example=True)
+ manage_subscriptions: bool = Field(
+ False, description='Permission to manage subscriptions', example=True
+ )
+ manage_security: bool = Field(
+ False, description='Permission to manage security settings', example=True
+ )
+ manage_tiers: bool = Field(
+ False, description='Permission to manage pricing tiers', example=True
+ )
+ manage_rate_limits: bool = Field(
+ False, description='Permission to manage rate limiting rules', example=True
+ )
manage_credits: bool = Field(False, description='Permission to manage credits', example=True)
- manage_auth: bool = Field(False, description='Permission to manage auth (revoke tokens/disable users)', example=True)
- view_analytics: bool = Field(False, description='Permission to view analytics dashboard', example=True)
+ manage_auth: bool = Field(
+ False, description='Permission to manage auth (revoke tokens/disable users)', example=True
+ )
+ view_analytics: bool = Field(
+ False, description='Permission to view analytics dashboard', example=True
+ )
view_logs: bool = Field(False, description='Permission to view logs', example=True)
export_logs: bool = Field(False, description='Permission to export logs', example=True)
diff --git a/backend-services/models/create_routing_model.py b/backend-services/models/create_routing_model.py
index de2badf..01378f1 100644
--- a/backend-services/models/create_routing_model.py
+++ b/backend-services/models/create_routing_model.py
@@ -5,16 +5,40 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import Optional
+
class CreateRoutingModel(BaseModel):
+ routing_name: str = Field(
+ ...,
+ min_length=1,
+ max_length=50,
+ description='Name of the routing',
+ example='customer-routing',
+ )
+ routing_servers: list[str] = Field(
+ ...,
+ min_items=1,
+ description='List of backend servers for the routing',
+ example=['http://localhost:8080', 'http://localhost:8081'],
+ )
+ routing_description: str = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Description of the routing',
+ example='Routing for customer API',
+ )
- routing_name: str = Field(..., min_length=1, max_length=50, description='Name of the routing', example='customer-routing')
- routing_servers : list[str] = Field(..., min_items=1, description='List of backend servers for the routing', example=['http://localhost:8080', 'http://localhost:8081'])
- routing_description: str = Field(None, min_length=1, max_length=255, description='Description of the routing', example='Routing for customer API')
-
- client_key: Optional[str] = Field(None, min_length=1, max_length=50, description='Client key for the routing', example='client-1')
- server_index: Optional[int] = Field(0, ge=0, description='Index of the server to route to', example=0)
+ client_key: str | None = Field(
+ None,
+ min_length=1,
+ max_length=50,
+ description='Client key for the routing',
+ example='client-1',
+ )
+ server_index: int | None = Field(
+ 0, ge=0, description='Index of the server to route to', example=0
+ )
class Config:
- arbitrary_types_allowed = True
\ No newline at end of file
+ arbitrary_types_allowed = True
diff --git a/backend-services/models/create_user_model.py b/backend-services/models/create_user_model.py
index f28a9aa..8a0a74f 100644
--- a/backend-services/models/create_user_model.py
+++ b/backend-services/models/create_user_model.py
@@ -5,31 +5,93 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import List, Optional
+
class CreateUserModel(BaseModel):
+ username: str = Field(
+ ..., min_length=3, max_length=50, description='Username of the user', example='john_doe'
+ )
+ email: str = Field(
+ ...,
+ min_length=3,
+ max_length=127,
+ description='Email of the user (no strict format validation)',
+ example='john@mail.com',
+ )
+ password: str = Field(
+ ...,
+ min_length=16,
+ max_length=50,
+ description='Password of the user',
+ example='SecurePassword@123',
+ )
+ role: str = Field(
+ ..., min_length=2, max_length=50, description='Role of the user', example='admin'
+ )
+ groups: list[str] = Field(
+ default_factory=list,
+ description='List of groups the user belongs to',
+ example=['client-1-group'],
+ )
- username: str = Field(..., min_length=3, max_length=50, description='Username of the user', example='john_doe')
- email: str = Field(..., min_length=3, max_length=127, description='Email of the user (no strict format validation)', example='john@mail.com')
- password: str = Field(..., min_length=16, max_length=50, description='Password of the user', example='SecurePassword@123')
- role: str = Field(..., min_length=2, max_length=50, description='Role of the user', example='admin')
- groups: List[str] = Field(default_factory=list, description='List of groups the user belongs to', example=['client-1-group'])
-
- rate_limit_duration: Optional[int] = Field(None, ge=0, description='Rate limit for the user', example=100)
- rate_limit_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Duration for the rate limit', example='hour')
- rate_limit_enabled: Optional[bool] = Field(None, description='Whether rate limiting is enabled for this user', example=True)
- throttle_duration: Optional[int] = Field(None, ge=0, description='Throttle limit for the user', example=10)
- throttle_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Duration for the throttle limit', example='second')
- throttle_wait_duration: Optional[int] = Field(None, ge=0, description='Wait time for the throttle limit', example=5)
- throttle_wait_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Wait duration for the throttle limit', example='seconds')
- throttle_queue_limit: Optional[int] = Field(None, ge=0, description='Throttle queue limit for the user', example=10)
- throttle_enabled: Optional[bool] = Field(None, description='Whether throttling is enabled for this user', example=True)
- custom_attributes: Optional[dict] = Field(None, description='Custom attributes for the user', example={'custom_key': 'custom_value'})
- bandwidth_limit_bytes: Optional[int] = Field(None, ge=0, description='Maximum bandwidth allowed within the window (bytes)', example=1073741824)
- bandwidth_limit_window: Optional[str] = Field('day', min_length=1, max_length=10, description='Bandwidth window unit (second/minute/hour/day/month)', example='day')
- bandwidth_limit_enabled: Optional[bool] = Field(None, description='Whether bandwidth limit enforcement is enabled for this user', example=True)
- active: Optional[bool] = Field(True, description='Active status of the user', example=True)
- ui_access: Optional[bool] = Field(False, description='UI access for the user', example=False)
+ rate_limit_duration: int | None = Field(
+ None, ge=0, description='Rate limit for the user', example=100
+ )
+ rate_limit_duration_type: str | None = Field(
+ None, min_length=1, max_length=7, description='Duration for the rate limit', example='hour'
+ )
+ rate_limit_enabled: bool | None = Field(
+ None, description='Whether rate limiting is enabled for this user', example=True
+ )
+ throttle_duration: int | None = Field(
+ None, ge=0, description='Throttle limit for the user', example=10
+ )
+ throttle_duration_type: str | None = Field(
+ None,
+ min_length=1,
+ max_length=7,
+ description='Duration for the throttle limit',
+ example='second',
+ )
+ throttle_wait_duration: int | None = Field(
+ None, ge=0, description='Wait time for the throttle limit', example=5
+ )
+ throttle_wait_duration_type: str | None = Field(
+ None,
+ min_length=1,
+ max_length=7,
+ description='Wait duration for the throttle limit',
+ example='seconds',
+ )
+ throttle_queue_limit: int | None = Field(
+ None, ge=0, description='Throttle queue limit for the user', example=10
+ )
+ throttle_enabled: bool | None = Field(
+ None, description='Whether throttling is enabled for this user', example=True
+ )
+ custom_attributes: dict | None = Field(
+ None, description='Custom attributes for the user', example={'custom_key': 'custom_value'}
+ )
+ bandwidth_limit_bytes: int | None = Field(
+ None,
+ ge=0,
+ description='Maximum bandwidth allowed within the window (bytes)',
+ example=1073741824,
+ )
+ bandwidth_limit_window: str | None = Field(
+ 'day',
+ min_length=1,
+ max_length=10,
+ description='Bandwidth window unit (second/minute/hour/day/month)',
+ example='day',
+ )
+ bandwidth_limit_enabled: bool | None = Field(
+ None,
+ description='Whether bandwidth limit enforcement is enabled for this user',
+ example=True,
+ )
+ active: bool | None = Field(True, description='Active status of the user', example=True)
+ ui_access: bool | None = Field(False, description='UI access for the user', example=False)
class Config:
arbitrary_types_allowed = True
diff --git a/backend-services/models/create_vault_entry_model.py b/backend-services/models/create_vault_entry_model.py
index 10b0fbd..881184f 100644
--- a/backend-services/models/create_vault_entry_model.py
+++ b/backend-services/models/create_vault_entry_model.py
@@ -5,32 +5,31 @@ See https://github.com/apidoorman/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import Optional
class CreateVaultEntryModel(BaseModel):
"""Model for creating a new vault entry."""
-
+
key_name: str = Field(
- ...,
- min_length=1,
- max_length=255,
+ ...,
+ min_length=1,
+ max_length=255,
description='Unique name for the vault key',
- example='api_key_production'
+ example='api_key_production',
)
-
+
value: str = Field(
- ...,
- min_length=1,
+ ...,
+ min_length=1,
description='The secret value to encrypt and store',
- example='sk_live_abc123xyz789'
+ example='sk_live_abc123xyz789',
)
-
- description: Optional[str] = Field(
+
+ description: str | None = Field(
None,
max_length=500,
description='Optional description of what this key is used for',
- example='Production API key for payment gateway'
+ example='Production API key for payment gateway',
)
class Config:
diff --git a/backend-services/models/credit_model.py b/backend-services/models/credit_model.py
index 16b26a7..31e7274 100644
--- a/backend-services/models/credit_model.py
+++ b/backend-services/models/credit_model.py
@@ -4,30 +4,58 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import List, Optional
-from pydantic import BaseModel, Field
from datetime import datetime
+from pydantic import BaseModel, Field
+
+
class CreditTierModel(BaseModel):
- tier_name: str = Field(..., min_length=1, max_length=50, description='Name of the credit tier', example='basic')
+ tier_name: str = Field(
+ ..., min_length=1, max_length=50, description='Name of the credit tier', example='basic'
+ )
credits: int = Field(..., description='Number of credits per reset', example=50)
- input_limit: int = Field(..., description='Input limit for paid credits (text or context)', example=150)
- output_limit: int = Field(..., description='Output limit for paid credits (text or context)', example=150)
- reset_frequency: str = Field(..., description='Frequency of paid credit reset', example='monthly')
+ input_limit: int = Field(
+ ..., description='Input limit for paid credits (text or context)', example=150
+ )
+ output_limit: int = Field(
+ ..., description='Output limit for paid credits (text or context)', example=150
+ )
+ reset_frequency: str = Field(
+ ..., description='Frequency of paid credit reset', example='monthly'
+ )
class Config:
arbitrary_types_allowed = True
+
class CreditModel(BaseModel):
+ api_credit_group: str = Field(
+ ...,
+ min_length=1,
+ max_length=50,
+ description='API group for the credits',
+ example='ai-group-1',
+ )
+ api_key: str = Field(
+ ..., description='API key for the credit tier', example='xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
+ )
+ api_key_header: str = Field(
+ ..., description='Header the API key should be sent in', example='x-api-key'
+ )
+ credit_tiers: list[CreditTierModel] = Field(
+ ..., min_items=1, description='Credit tiers information'
+ )
- api_credit_group: str = Field(..., min_length=1, max_length=50, description='API group for the credits', example='ai-group-1')
- api_key: str = Field(..., description='API key for the credit tier', example='xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
- api_key_header: str = Field(..., description='Header the API key should be sent in', example='x-api-key')
- credit_tiers: List[CreditTierModel] = Field(..., min_items=1, description='Credit tiers information')
-
- api_key_new: Optional[str] = Field(None, description='New API key during rotation period', example='yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy')
- api_key_rotation_expires: Optional[datetime] = Field(None, description='Expiration time for old API key during rotation', example='2025-01-15T10:00:00Z')
+ api_key_new: str | None = Field(
+ None,
+ description='New API key during rotation period',
+ example='yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy',
+ )
+ api_key_rotation_expires: datetime | None = Field(
+ None,
+ description='Expiration time for old API key during rotation',
+ example='2025-01-15T10:00:00Z',
+ )
class Config:
arbitrary_types_allowed = True
-
diff --git a/backend-services/models/delete_successfully.py b/backend-services/models/delete_successfully.py
index 544520e..ef6da42 100644
--- a/backend-services/models/delete_successfully.py
+++ b/backend-services/models/delete_successfully.py
@@ -6,9 +6,11 @@ See https://github.com/pypeople-dev/doorman for more information
from pydantic import BaseModel, Field
-class ResponseMessage(BaseModel):
- message: str = Field(None, description='The response message', example='API Deleted Successfully')
+class ResponseMessage(BaseModel):
+ message: str = Field(
+ None, description='The response message', example='API Deleted Successfully'
+ )
class Config:
- arbitrary_types_allowed = True
\ No newline at end of file
+ arbitrary_types_allowed = True
diff --git a/backend-services/models/endpoint_model_response.py b/backend-services/models/endpoint_model_response.py
index abf2138..765ae56 100644
--- a/backend-services/models/endpoint_model_response.py
+++ b/backend-services/models/endpoint_model_response.py
@@ -5,18 +5,47 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import Optional, List
+
class EndpointModelResponse(BaseModel):
-
- api_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the API', example='customer')
- api_version: Optional[str] = Field(None, min_length=1, max_length=10, description='Version of the API', example='v1')
- endpoint_method: Optional[str] = Field(None, min_length=1, max_length=10, description='HTTP method for the endpoint', example='GET')
- endpoint_uri: Optional[str] = Field(None, min_length=1, max_length=255, description='URI for the endpoint', example='/customer')
- endpoint_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the endpoint', example='Get customer details')
- endpoint_servers: Optional[List[str]] = Field(None, description='Optional list of backend servers for this endpoint (overrides API servers)', example=['http://localhost:8082', 'http://localhost:8083'])
- api_id: Optional[str] = Field(None, min_length=1, max_length=255, description='Unique identifier for the API, auto-generated', example=None)
- endpoint_id: Optional[str] = Field(None, min_length=1, max_length=255, description='Unique identifier for the endpoint, auto-generated', example=None)
+ api_name: str | None = Field(
+ None, min_length=1, max_length=50, description='Name of the API', example='customer'
+ )
+ api_version: str | None = Field(
+ None, min_length=1, max_length=10, description='Version of the API', example='v1'
+ )
+ endpoint_method: str | None = Field(
+ None, min_length=1, max_length=10, description='HTTP method for the endpoint', example='GET'
+ )
+ endpoint_uri: str | None = Field(
+ None, min_length=1, max_length=255, description='URI for the endpoint', example='/customer'
+ )
+ endpoint_description: str | None = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Description of the endpoint',
+ example='Get customer details',
+ )
+ endpoint_servers: list[str] | None = Field(
+ None,
+ description='Optional list of backend servers for this endpoint (overrides API servers)',
+ example=['http://localhost:8082', 'http://localhost:8083'],
+ )
+ api_id: str | None = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Unique identifier for the API, auto-generated',
+ example=None,
+ )
+ endpoint_id: str | None = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Unique identifier for the endpoint, auto-generated',
+ example=None,
+ )
class Config:
arbitrary_types_allowed = True
diff --git a/backend-services/models/endpoint_validation_model_response.py b/backend-services/models/endpoint_validation_model_response.py
index 8edc594..1bef4f6 100644
--- a/backend-services/models/endpoint_validation_model_response.py
+++ b/backend-services/models/endpoint_validation_model_response.py
@@ -8,11 +8,19 @@ from pydantic import BaseModel, Field
from models.validation_schema_model import ValidationSchema
-class EndpointValidationModelResponse(BaseModel):
- endpoint_id: str = Field(..., description='Unique identifier for the endpoint, auto-generated', example='1299f720-e619-4628-b584-48a6570026cf')
- validation_enabled: bool = Field(..., description='Whether the validation is enabled', example=True)
- validation_schema: ValidationSchema = Field(..., description='The schema to validate the endpoint against', example={})
+class EndpointValidationModelResponse(BaseModel):
+ endpoint_id: str = Field(
+ ...,
+ description='Unique identifier for the endpoint, auto-generated',
+ example='1299f720-e619-4628-b584-48a6570026cf',
+ )
+ validation_enabled: bool = Field(
+ ..., description='Whether the validation is enabled', example=True
+ )
+ validation_schema: ValidationSchema = Field(
+ ..., description='The schema to validate the endpoint against', example={}
+ )
class Config:
- arbitrary_types_allowed = True
\ No newline at end of file
+ arbitrary_types_allowed = True
diff --git a/backend-services/models/field_validation_model.py b/backend-services/models/field_validation_model.py
index 7f7803a..9b35a1e 100644
--- a/backend-services/models/field_validation_model.py
+++ b/backend-services/models/field_validation_model.py
@@ -4,17 +4,31 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/doorman for more information
"""
-from typing import List, Union, Optional, Dict, Any
+from typing import Any, Optional
+
from pydantic import BaseModel, Field
+
class FieldValidation(BaseModel):
required: bool = Field(..., description='Whether the field is required')
- type: str = Field(..., description='Expected data type (string, number, boolean, array, object)')
- min: Optional[Union[int, float]] = Field(None, description='Minimum value for numbers or minimum length for strings/arrays')
- max: Optional[Union[int, float]] = Field(None, description='Maximum value for numbers or maximum length for strings/arrays')
- pattern: Optional[str] = Field(None, description='Regex pattern for string validation')
- enum: Optional[List[Any]] = Field(None, description='List of allowed values')
- format: Optional[str] = Field(None, description='Format validation (email, url, date, datetime, uuid, etc.)')
- custom_validator: Optional[str] = Field(None, description='Custom validation function name')
- nested_schema: Optional[Dict[str, 'FieldValidation']] = Field(None, description='Validation schema for nested objects')
- array_items: Optional['FieldValidation'] = Field(None, description='Validation schema for array items')
\ No newline at end of file
+ type: str = Field(
+ ..., description='Expected data type (string, number, boolean, array, object)'
+ )
+ min: int | float | None = Field(
+ None, description='Minimum value for numbers or minimum length for strings/arrays'
+ )
+ max: int | float | None = Field(
+ None, description='Maximum value for numbers or maximum length for strings/arrays'
+ )
+ pattern: str | None = Field(None, description='Regex pattern for string validation')
+ enum: list[Any] | None = Field(None, description='List of allowed values')
+ format: str | None = Field(
+ None, description='Format validation (email, url, date, datetime, uuid, etc.)'
+ )
+ custom_validator: str | None = Field(None, description='Custom validation function name')
+ nested_schema: dict[str, 'FieldValidation'] | None = Field(
+ None, description='Validation schema for nested objects'
+ )
+ array_items: Optional['FieldValidation'] = Field(
+ None, description='Validation schema for array items'
+ )
diff --git a/backend-services/models/group_model_response.py b/backend-services/models/group_model_response.py
index e040d3f..6618a08 100644
--- a/backend-services/models/group_model_response.py
+++ b/backend-services/models/group_model_response.py
@@ -5,13 +5,22 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import List, Optional
+
class GroupModelResponse(BaseModel):
-
- group_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the group', example='client-1-group')
- group_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the group', example='Group for client 1')
- api_access: Optional[List[str]] = Field(None, description='List of APIs the group can access', example=['customer/v1'])
+ group_name: str | None = Field(
+ None, min_length=1, max_length=50, description='Name of the group', example='client-1-group'
+ )
+ group_description: str | None = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Description of the group',
+ example='Group for client 1',
+ )
+ api_access: list[str] | None = Field(
+ None, description='List of APIs the group can access', example=['customer/v1']
+ )
class Config:
- arbitrary_types_allowed = True
\ No newline at end of file
+ arbitrary_types_allowed = True
diff --git a/backend-services/models/rate_limit_models.py b/backend-services/models/rate_limit_models.py
index bd7b4c4..0f92e60 100644
--- a/backend-services/models/rate_limit_models.py
+++ b/backend-services/models/rate_limit_models.py
@@ -9,59 +9,64 @@ This module defines the data structures for the rate limiting system including:
"""
from dataclasses import dataclass, field
-from typing import Optional, Dict, List, Any
-from enum import Enum
from datetime import datetime
-
+from enum import Enum
+from typing import Any
# ============================================================================
# ENUMS
# ============================================================================
+
class RuleType(Enum):
"""Types of rate limit rules"""
- PER_USER = "per_user"
- PER_API = "per_api"
- PER_ENDPOINT = "per_endpoint"
- PER_IP = "per_ip"
- PER_USER_API = "per_user_api" # Combined: specific user on specific API
- PER_USER_ENDPOINT = "per_user_endpoint" # Combined: specific user on specific endpoint
- GLOBAL = "global" # Global rate limit for all requests
+
+ PER_USER = 'per_user'
+ PER_API = 'per_api'
+ PER_ENDPOINT = 'per_endpoint'
+ PER_IP = 'per_ip'
+ PER_USER_API = 'per_user_api' # Combined: specific user on specific API
+ PER_USER_ENDPOINT = 'per_user_endpoint' # Combined: specific user on specific endpoint
+ GLOBAL = 'global' # Global rate limit for all requests
class TimeWindow(Enum):
"""Time windows for rate limiting"""
- SECOND = "second"
- MINUTE = "minute"
- HOUR = "hour"
- DAY = "day"
- MONTH = "month"
+
+ SECOND = 'second'
+ MINUTE = 'minute'
+ HOUR = 'hour'
+ DAY = 'day'
+ MONTH = 'month'
class TierName(Enum):
"""Predefined tier names"""
- FREE = "free"
- PRO = "pro"
- ENTERPRISE = "enterprise"
- CUSTOM = "custom"
+
+ FREE = 'free'
+ PRO = 'pro'
+ ENTERPRISE = 'enterprise'
+ CUSTOM = 'custom'
class QuotaType(Enum):
"""Types of quotas"""
- REQUESTS = "requests"
- BANDWIDTH = "bandwidth"
- COMPUTE_TIME = "compute_time"
+
+ REQUESTS = 'requests'
+ BANDWIDTH = 'bandwidth'
+ COMPUTE_TIME = 'compute_time'
# ============================================================================
# RATE LIMIT RULE MODELS
# ============================================================================
+
@dataclass
class RateLimitRule:
"""
Defines a rate limiting rule
-
+
Examples:
# Per-user rule: 100 requests per minute
RateLimitRule(
@@ -71,7 +76,7 @@ class RateLimitRule:
limit=100,
burst_allowance=20
)
-
+
# Per-API rule: 1000 requests per hour
RateLimitRule(
rule_id="rule_002",
@@ -81,24 +86,25 @@ class RateLimitRule:
limit=1000
)
"""
+
rule_id: str
rule_type: RuleType
time_window: TimeWindow
limit: int # Maximum requests allowed in time window
-
+
# Optional fields
- target_identifier: Optional[str] = None # User ID, API name, endpoint URI, or IP
+ target_identifier: str | None = None # User ID, API name, endpoint URI, or IP
burst_allowance: int = 0 # Additional requests allowed for bursts
priority: int = 0 # Higher priority rules are checked first
enabled: bool = True
-
+
# Metadata
- created_at: Optional[datetime] = None
- updated_at: Optional[datetime] = None
- created_by: Optional[str] = None
- description: Optional[str] = None
-
- def to_dict(self) -> Dict[str, Any]:
+ created_at: datetime | None = None
+ updated_at: datetime | None = None
+ created_by: str | None = None
+ description: str | None = None
+
+ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for storage"""
return {
'rule_id': self.rule_id,
@@ -112,11 +118,11 @@ class RateLimitRule:
'created_at': self.created_at.isoformat() if self.created_at else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
'created_by': self.created_by,
- 'description': self.description
+ 'description': self.description,
}
-
+
@classmethod
- def from_dict(cls, data: Dict[str, Any]) -> 'RateLimitRule':
+ def from_dict(cls, data: dict[str, Any]) -> 'RateLimitRule':
"""Create from dictionary"""
return cls(
rule_id=data['rule_id'],
@@ -127,10 +133,14 @@ class RateLimitRule:
burst_allowance=data.get('burst_allowance', 0),
priority=data.get('priority', 0),
enabled=data.get('enabled', True),
- created_at=datetime.fromisoformat(data['created_at']) if data.get('created_at') else None,
- updated_at=datetime.fromisoformat(data['updated_at']) if data.get('updated_at') else None,
+ created_at=datetime.fromisoformat(data['created_at'])
+ if data.get('created_at')
+ else None,
+ updated_at=datetime.fromisoformat(data['updated_at'])
+ if data.get('updated_at')
+ else None,
created_by=data.get('created_by'),
- description=data.get('description')
+ description=data.get('description'),
)
@@ -138,31 +148,33 @@ class RateLimitRule:
# TIER/PLAN MODELS
# ============================================================================
+
@dataclass
class TierLimits:
"""Rate limits and quotas for a specific tier"""
+
# Rate limits (requests per time window)
- requests_per_second: Optional[int] = None
- requests_per_minute: Optional[int] = None
- requests_per_hour: Optional[int] = None
- requests_per_day: Optional[int] = None
- requests_per_month: Optional[int] = None
-
+ requests_per_second: int | None = None
+ requests_per_minute: int | None = None
+ requests_per_hour: int | None = None
+ requests_per_day: int | None = None
+ requests_per_month: int | None = None
+
# Burst allowances
burst_per_second: int = 0
burst_per_minute: int = 0
burst_per_hour: int = 0
-
+
# Quotas
- monthly_request_quota: Optional[int] = None
- daily_request_quota: Optional[int] = None
- monthly_bandwidth_quota: Optional[int] = None # In bytes
-
+ monthly_request_quota: int | None = None
+ daily_request_quota: int | None = None
+ monthly_bandwidth_quota: int | None = None # In bytes
+
# Throttling configuration
enable_throttling: bool = False # If true, queue/delay requests; if false, hard reject (429)
max_queue_time_ms: int = 5000 # Maximum time to queue a request before rejecting (milliseconds)
-
- def to_dict(self) -> Dict[str, Any]:
+
+ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary"""
return {
'requests_per_second': self.requests_per_second,
@@ -177,11 +189,11 @@ class TierLimits:
'daily_request_quota': self.daily_request_quota,
'monthly_bandwidth_quota': self.monthly_bandwidth_quota,
'enable_throttling': self.enable_throttling,
- 'max_queue_time_ms': self.max_queue_time_ms
+ 'max_queue_time_ms': self.max_queue_time_ms,
}
-
+
@classmethod
- def from_dict(cls, data: Dict[str, Any]) -> 'TierLimits':
+ def from_dict(cls, data: dict[str, Any]) -> 'TierLimits':
"""Create from dictionary"""
return cls(**data)
@@ -190,7 +202,7 @@ class TierLimits:
class Tier:
"""
Defines a tier/plan with associated rate limits and quotas
-
+
Examples:
# Free tier
Tier(
@@ -203,7 +215,7 @@ class Tier:
daily_request_quota=10000
)
)
-
+
# Pro tier
Tier(
tier_id="tier_pro",
@@ -218,24 +230,25 @@ class Tier:
price_monthly=49.99
)
"""
+
tier_id: str
name: TierName
display_name: str
limits: TierLimits
-
+
# Optional fields
- description: Optional[str] = None
- price_monthly: Optional[float] = None
- price_yearly: Optional[float] = None
- features: List[str] = field(default_factory=list)
+ description: str | None = None
+ price_monthly: float | None = None
+ price_yearly: float | None = None
+ features: list[str] = field(default_factory=list)
is_default: bool = False
enabled: bool = True
-
+
# Metadata
- created_at: Optional[datetime] = None
- updated_at: Optional[datetime] = None
-
- def to_dict(self) -> Dict[str, Any]:
+ created_at: datetime | None = None
+ updated_at: datetime | None = None
+
+ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for storage"""
return {
'tier_id': self.tier_id,
@@ -249,11 +262,11 @@ class Tier:
'is_default': self.is_default,
'enabled': self.enabled,
'created_at': self.created_at.isoformat() if self.created_at else None,
- 'updated_at': self.updated_at.isoformat() if self.updated_at else None
+ 'updated_at': self.updated_at.isoformat() if self.updated_at else None,
}
-
+
@classmethod
- def from_dict(cls, data: Dict[str, Any]) -> 'Tier':
+ def from_dict(cls, data: dict[str, Any]) -> 'Tier':
"""Create from dictionary"""
return cls(
tier_id=data['tier_id'],
@@ -266,30 +279,35 @@ class Tier:
features=data.get('features', []),
is_default=data.get('is_default', False),
enabled=data.get('enabled', True),
- created_at=datetime.fromisoformat(data['created_at']) if data.get('created_at') else None,
- updated_at=datetime.fromisoformat(data['updated_at']) if data.get('updated_at') else None
+ created_at=datetime.fromisoformat(data['created_at'])
+ if data.get('created_at')
+ else None,
+ updated_at=datetime.fromisoformat(data['updated_at'])
+ if data.get('updated_at')
+ else None,
)
@dataclass
class UserTierAssignment:
"""Assigns a user to a tier with optional overrides"""
+
user_id: str
tier_id: str
-
+
# Optional overrides (override tier defaults for this specific user)
- override_limits: Optional[TierLimits] = None
-
+ override_limits: TierLimits | None = None
+
# Scheduling
- effective_from: Optional[datetime] = None
- effective_until: Optional[datetime] = None
-
+ effective_from: datetime | None = None
+ effective_until: datetime | None = None
+
# Metadata
- assigned_at: Optional[datetime] = None
- assigned_by: Optional[str] = None
- notes: Optional[str] = None
-
- def to_dict(self) -> Dict[str, Any]:
+ assigned_at: datetime | None = None
+ assigned_by: str | None = None
+ notes: str | None = None
+
+ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary"""
return {
'user_id': self.user_id,
@@ -299,7 +317,7 @@ class UserTierAssignment:
'effective_until': self.effective_until.isoformat() if self.effective_until else None,
'assigned_at': self.assigned_at.isoformat() if self.assigned_at else None,
'assigned_by': self.assigned_by,
- 'notes': self.notes
+ 'notes': self.notes,
}
@@ -307,41 +325,43 @@ class UserTierAssignment:
# QUOTA TRACKING MODELS
# ============================================================================
+
@dataclass
class QuotaUsage:
"""
Tracks current quota usage for a user/API/endpoint
-
+
This is stored in Redis for real-time tracking
"""
+
key: str # Redis key (e.g., "quota:user:john_doe:month:2025-12")
quota_type: QuotaType
current_usage: int
limit: int
reset_at: datetime # When the quota resets
-
+
# Optional fields
burst_usage: int = 0 # Burst tokens used
burst_limit: int = 0
-
+
@property
def remaining(self) -> int:
"""Calculate remaining quota"""
return max(0, self.limit - self.current_usage)
-
+
@property
def percentage_used(self) -> float:
"""Calculate percentage of quota used"""
if self.limit == 0:
return 0.0
return (self.current_usage / self.limit) * 100
-
+
@property
def is_exhausted(self) -> bool:
"""Check if quota is exhausted"""
return self.current_usage >= self.limit
-
- def to_dict(self) -> Dict[str, Any]:
+
+ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary"""
return {
'key': self.key,
@@ -353,7 +373,7 @@ class QuotaUsage:
'reset_at': self.reset_at.isoformat(),
'burst_usage': self.burst_usage,
'burst_limit': self.burst_limit,
- 'is_exhausted': self.is_exhausted
+ 'is_exhausted': self.is_exhausted,
}
@@ -361,9 +381,10 @@ class QuotaUsage:
class RateLimitCounter:
"""
Real-time counter for rate limiting (stored in Redis)
-
+
Uses sliding window counter algorithm
"""
+
key: str # Redis key (e.g., "ratelimit:user:john_doe:minute:1701504000")
window_start: int # Unix timestamp
window_size: int # Window size in seconds
@@ -371,23 +392,23 @@ class RateLimitCounter:
limit: int # Maximum allowed requests
burst_count: int = 0 # Burst tokens used
burst_limit: int = 0
-
+
@property
def remaining(self) -> int:
"""Calculate remaining requests"""
return max(0, self.limit - self.count)
-
+
@property
def is_limited(self) -> bool:
"""Check if rate limit is exceeded"""
return self.count >= self.limit
-
+
@property
def reset_at(self) -> int:
"""Calculate when the window resets (Unix timestamp)"""
return self.window_start + self.window_size
-
- def to_dict(self) -> Dict[str, Any]:
+
+ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary"""
return {
'key': self.key,
@@ -399,7 +420,7 @@ class RateLimitCounter:
'reset_at': self.reset_at,
'burst_count': self.burst_count,
'burst_limit': self.burst_limit,
- 'is_limited': self.is_limited
+ 'is_limited': self.is_limited,
}
@@ -407,28 +428,30 @@ class RateLimitCounter:
# HISTORICAL TRACKING MODELS
# ============================================================================
+
@dataclass
class UsageHistoryRecord:
"""
Historical usage record for analytics
-
+
Stored in time-series database or MongoDB
"""
+
timestamp: datetime
- user_id: Optional[str] = None
- api_name: Optional[str] = None
- endpoint_uri: Optional[str] = None
- ip_address: Optional[str] = None
-
+ user_id: str | None = None
+ api_name: str | None = None
+ endpoint_uri: str | None = None
+ ip_address: str | None = None
+
# Metrics
request_count: int = 0
blocked_count: int = 0 # Requests blocked by rate limit
burst_used: int = 0
-
+
# Aggregation period
- period: str = "minute" # minute, hour, day
-
- def to_dict(self) -> Dict[str, Any]:
+ period: str = 'minute' # minute, hour, day
+
+ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary"""
return {
'timestamp': self.timestamp.isoformat(),
@@ -439,7 +462,7 @@ class UsageHistoryRecord:
'request_count': self.request_count,
'blocked_count': self.blocked_count,
'burst_used': self.burst_used,
- 'period': self.period
+ 'period': self.period,
}
@@ -447,42 +470,44 @@ class UsageHistoryRecord:
# RESPONSE MODELS
# ============================================================================
+
@dataclass
class RateLimitInfo:
"""
Information about current rate limit status
-
+
Returned in API responses and headers
"""
+
limit: int
remaining: int
reset_at: int # Unix timestamp
- retry_after: Optional[int] = None # Seconds until retry (when limited)
-
+ retry_after: int | None = None # Seconds until retry (when limited)
+
# Additional info
burst_limit: int = 0
burst_remaining: int = 0
- tier: Optional[str] = None
-
- def to_headers(self) -> Dict[str, str]:
+ tier: str | None = None
+
+ def to_headers(self) -> dict[str, str]:
"""Convert to HTTP headers"""
headers = {
'X-RateLimit-Limit': str(self.limit),
'X-RateLimit-Remaining': str(self.remaining),
- 'X-RateLimit-Reset': str(self.reset_at)
+ 'X-RateLimit-Reset': str(self.reset_at),
}
-
+
if self.retry_after is not None:
headers['X-RateLimit-Retry-After'] = str(self.retry_after)
headers['Retry-After'] = str(self.retry_after)
-
+
if self.burst_limit > 0:
headers['X-RateLimit-Burst-Limit'] = str(self.burst_limit)
headers['X-RateLimit-Burst-Remaining'] = str(self.burst_remaining)
-
+
return headers
-
- def to_dict(self) -> Dict[str, Any]:
+
+ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary"""
return {
'limit': self.limit,
@@ -491,7 +516,7 @@ class RateLimitInfo:
'retry_after': self.retry_after,
'burst_limit': self.burst_limit,
'burst_remaining': self.burst_remaining,
- 'tier': self.tier
+ 'tier': self.tier,
}
@@ -499,6 +524,7 @@ class RateLimitInfo:
# HELPER FUNCTIONS
# ============================================================================
+
def get_time_window_seconds(window: TimeWindow) -> int:
"""Convert time window enum to seconds"""
mapping = {
@@ -506,39 +532,32 @@ def get_time_window_seconds(window: TimeWindow) -> int:
TimeWindow.MINUTE: 60,
TimeWindow.HOUR: 3600,
TimeWindow.DAY: 86400,
- TimeWindow.MONTH: 2592000 # 30 days
+ TimeWindow.MONTH: 2592000, # 30 days
}
return mapping[window]
def generate_redis_key(
- rule_type: RuleType,
- identifier: str,
- window: TimeWindow,
- window_start: int
+ rule_type: RuleType, identifier: str, window: TimeWindow, window_start: int
) -> str:
"""
Generate Redis key for rate limit counter
-
+
Examples:
generate_redis_key(RuleType.PER_USER, "john_doe", TimeWindow.MINUTE, 1701504000)
# Returns: "ratelimit:user:john_doe:minute:1701504000"
"""
type_prefix = rule_type.value.replace('per_', '')
window_name = window.value
- return f"ratelimit:{type_prefix}:{identifier}:{window_name}:{window_start}"
+ return f'ratelimit:{type_prefix}:{identifier}:{window_name}:{window_start}'
-def generate_quota_key(
- user_id: str,
- quota_type: QuotaType,
- period: str
-) -> str:
+def generate_quota_key(user_id: str, quota_type: QuotaType, period: str) -> str:
"""
Generate Redis key for quota tracking
-
+
Examples:
generate_quota_key("john_doe", QuotaType.REQUESTS, "2025-12")
# Returns: "quota:user:john_doe:requests:month:2025-12"
"""
- return f"quota:user:{user_id}:{quota_type.value}:month:{period}"
+ return f'quota:user:{user_id}:{quota_type.value}:month:{period}'
diff --git a/backend-services/models/request_model.py b/backend-services/models/request_model.py
index 917d6eb..35cf0d2 100644
--- a/backend-services/models/request_model.py
+++ b/backend-services/models/request_model.py
@@ -1,10 +1,10 @@
from pydantic import BaseModel
-from typing import Dict, Optional
+
class RequestModel(BaseModel):
method: str
path: str
- headers: Dict[str, str]
- query_params: Dict[str, str]
- identity: Optional[str] = None
- body: Optional[str] = None
\ No newline at end of file
+ headers: dict[str, str]
+ query_params: dict[str, str]
+ identity: str | None = None
+ body: str | None = None
diff --git a/backend-services/models/response_model.py b/backend-services/models/response_model.py
index 8cfeb1c..8761b24 100644
--- a/backend-services/models/response_model.py
+++ b/backend-services/models/response_model.py
@@ -1,13 +1,13 @@
from pydantic import BaseModel, Field
-from typing import Optional, Union
+
class ResponseModel(BaseModel):
status_code: int = Field(None)
- response_headers: Optional[dict] = Field(None)
+ response_headers: dict | None = Field(None)
- response: Optional[Union[dict, list, str]] = Field(None)
- message: Optional[str] = Field(None, min_length=1, max_length=255)
+ response: dict | list | str | None = Field(None)
+ message: str | None = Field(None, min_length=1, max_length=255)
- error_code: Optional[str] = Field(None, min_length=1, max_length=255)
- error_message: Optional[str] = Field(None, min_length=1, max_length=255)
\ No newline at end of file
+ error_code: str | None = Field(None, min_length=1, max_length=255)
+ error_message: str | None = Field(None, min_length=1, max_length=255)
diff --git a/backend-services/models/role_model_response.py b/backend-services/models/role_model_response.py
index 1af719a..e01e976 100644
--- a/backend-services/models/role_model_response.py
+++ b/backend-services/models/role_model_response.py
@@ -5,28 +5,57 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import Optional
+
class RoleModelResponse(BaseModel):
-
- role_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the role', example='admin')
- role_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the role', example='Administrator role with full access')
- manage_users: Optional[bool] = Field(None, description='Permission to manage users', example=True)
- manage_apis: Optional[bool] = Field(None, description='Permission to manage APIs', example=True)
- manage_endpoints: Optional[bool] = Field(None, description='Permission to manage endpoints', example=True)
- manage_groups: Optional[bool] = Field(None, description='Permission to manage groups', example=True)
- manage_roles: Optional[bool] = Field(None, description='Permission to manage roles', example=True)
- manage_routings: Optional[bool] = Field(None, description='Permission to manage routings', example=True)
- manage_gateway: Optional[bool] = Field(None, description='Permission to manage gateway', example=True)
- manage_subscriptions: Optional[bool] = Field(None, description='Permission to manage subscriptions', example=True)
- manage_security: Optional[bool] = Field(None, description='Permission to manage security settings', example=True)
- manage_tiers: Optional[bool] = Field(None, description='Permission to manage pricing tiers', example=True)
- manage_rate_limits: Optional[bool] = Field(None, description='Permission to manage rate limiting rules', example=True)
- manage_credits: Optional[bool] = Field(None, description='Permission to manage API credits', example=True)
- manage_auth: Optional[bool] = Field(None, description='Permission to manage auth (revoke tokens/disable users)', example=True)
- view_analytics: Optional[bool] = Field(None, description='Permission to view analytics dashboard', example=True)
- view_logs: Optional[bool] = Field(None, description='Permission to view logs', example=True)
- export_logs: Optional[bool] = Field(None, description='Permission to export logs', example=True)
+ role_name: str | None = Field(
+ None, min_length=1, max_length=50, description='Name of the role', example='admin'
+ )
+ role_description: str | None = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Description of the role',
+ example='Administrator role with full access',
+ )
+ manage_users: bool | None = Field(None, description='Permission to manage users', example=True)
+ manage_apis: bool | None = Field(None, description='Permission to manage APIs', example=True)
+ manage_endpoints: bool | None = Field(
+ None, description='Permission to manage endpoints', example=True
+ )
+ manage_groups: bool | None = Field(
+ None, description='Permission to manage groups', example=True
+ )
+ manage_roles: bool | None = Field(None, description='Permission to manage roles', example=True)
+ manage_routings: bool | None = Field(
+ None, description='Permission to manage routings', example=True
+ )
+ manage_gateway: bool | None = Field(
+ None, description='Permission to manage gateway', example=True
+ )
+ manage_subscriptions: bool | None = Field(
+ None, description='Permission to manage subscriptions', example=True
+ )
+ manage_security: bool | None = Field(
+ None, description='Permission to manage security settings', example=True
+ )
+ manage_tiers: bool | None = Field(
+ None, description='Permission to manage pricing tiers', example=True
+ )
+ manage_rate_limits: bool | None = Field(
+ None, description='Permission to manage rate limiting rules', example=True
+ )
+ manage_credits: bool | None = Field(
+ None, description='Permission to manage API credits', example=True
+ )
+ manage_auth: bool | None = Field(
+ None, description='Permission to manage auth (revoke tokens/disable users)', example=True
+ )
+ view_analytics: bool | None = Field(
+ None, description='Permission to view analytics dashboard', example=True
+ )
+ view_logs: bool | None = Field(None, description='Permission to view logs', example=True)
+ export_logs: bool | None = Field(None, description='Permission to export logs', example=True)
class Config:
arbitrary_types_allowed = True
diff --git a/backend-services/models/routing_model_response.py b/backend-services/models/routing_model_response.py
index 55a8cb8..e3f6c8e 100644
--- a/backend-services/models/routing_model_response.py
+++ b/backend-services/models/routing_model_response.py
@@ -5,16 +5,40 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import Optional
+
class RoutingModelResponse(BaseModel):
+ routing_name: str | None = Field(
+ None,
+ min_length=1,
+ max_length=50,
+ description='Name of the routing',
+ example='customer-routing',
+ )
+ routing_servers: list[str] | None = Field(
+ None,
+ min_items=1,
+ description='List of backend servers for the routing',
+ example=['http://localhost:8080', 'http://localhost:8081'],
+ )
+ routing_description: str | None = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Description of the routing',
+ example='Routing for customer API',
+ )
- routing_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the routing', example='customer-routing')
- routing_servers : Optional[list[str]] = Field(None, min_items=1, description='List of backend servers for the routing', example=['http://localhost:8080', 'http://localhost:8081'])
- routing_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the routing', example='Routing for customer API')
-
- client_key: Optional[str] = Field(None, min_length=1, max_length=50, description='Client key for the routing', example='client-1')
- server_index: Optional[int] = Field(None, exclude=True, ge=0, description='Index of the server to route to', example=0)
+ client_key: str | None = Field(
+ None,
+ min_length=1,
+ max_length=50,
+ description='Client key for the routing',
+ example='client-1',
+ )
+ server_index: int | None = Field(
+ None, exclude=True, ge=0, description='Index of the server to route to', example=0
+ )
class Config:
- arbitrary_types_allowed = True
\ No newline at end of file
+ arbitrary_types_allowed = True
diff --git a/backend-services/models/security_settings_model.py b/backend-services/models/security_settings_model.py
index 358eaca..8a30d54 100644
--- a/backend-services/models/security_settings_model.py
+++ b/backend-services/models/security_settings_model.py
@@ -1,12 +1,24 @@
from pydantic import BaseModel, Field
-from typing import Optional, List
+
class SecuritySettingsModel(BaseModel):
- enable_auto_save: Optional[bool] = Field(default=None)
- auto_save_frequency_seconds: Optional[int] = Field(default=None, ge=60, description='How often to auto-save memory dump (seconds)')
- dump_path: Optional[str] = Field(default=None, description='Path to write encrypted memory dumps')
- ip_whitelist: Optional[List[str]] = Field(default=None, description='List of allowed IPs/CIDRs. If non-empty, only these are allowed.')
- ip_blacklist: Optional[List[str]] = Field(default=None, description='List of blocked IPs/CIDRs')
- trust_x_forwarded_for: Optional[bool] = Field(default=None, description='If true, use X-Forwarded-For header for client IP')
- xff_trusted_proxies: Optional[List[str]] = Field(default=None, description='IPs/CIDRs of proxies allowed to set client IP headers (XFF/X-Real-IP). Empty means trust all when enabled.')
- allow_localhost_bypass: Optional[bool] = Field(default=None, description='Allow direct localhost (::1/127.0.0.1) to bypass IP allow/deny lists when no forwarding headers are present')
+ enable_auto_save: bool | None = Field(default=None)
+ auto_save_frequency_seconds: int | None = Field(
+ default=None, ge=60, description='How often to auto-save memory dump (seconds)'
+ )
+ dump_path: str | None = Field(default=None, description='Path to write encrypted memory dumps')
+ ip_whitelist: list[str] | None = Field(
+ default=None, description='List of allowed IPs/CIDRs. If non-empty, only these are allowed.'
+ )
+ ip_blacklist: list[str] | None = Field(default=None, description='List of blocked IPs/CIDRs')
+ trust_x_forwarded_for: bool | None = Field(
+ default=None, description='If true, use X-Forwarded-For header for client IP'
+ )
+ xff_trusted_proxies: list[str] | None = Field(
+ default=None,
+ description='IPs/CIDRs of proxies allowed to set client IP headers (XFF/X-Real-IP). Empty means trust all when enabled.',
+ )
+ allow_localhost_bypass: bool | None = Field(
+ default=None,
+ description='Allow direct localhost (::1/127.0.0.1) to bypass IP allow/deny lists when no forwarding headers are present',
+ )
diff --git a/backend-services/models/subscribe_model.py b/backend-services/models/subscribe_model.py
index f110262..3247077 100644
--- a/backend-services/models/subscribe_model.py
+++ b/backend-services/models/subscribe_model.py
@@ -6,11 +6,29 @@ See https://github.com/pypeople-dev/doorman for more information
from pydantic import BaseModel, Field
-class SubscribeModel(BaseModel):
- username: str = Field(..., min_length=3, max_length=50, description='Username of the subscriber', example='client-1')
- api_name: str = Field(..., min_length=3, max_length=50, description='Name of the API to subscribe to', example='customer')
- api_version: str = Field(..., min_length=1, max_length=5, description='Version of the API to subscribe to', example='v1')
+class SubscribeModel(BaseModel):
+ username: str = Field(
+ ...,
+ min_length=3,
+ max_length=50,
+ description='Username of the subscriber',
+ example='client-1',
+ )
+ api_name: str = Field(
+ ...,
+ min_length=3,
+ max_length=50,
+ description='Name of the API to subscribe to',
+ example='customer',
+ )
+ api_version: str = Field(
+ ...,
+ min_length=1,
+ max_length=5,
+ description='Version of the API to subscribe to',
+ example='v1',
+ )
class Config:
- arbitrary_types_allowed = True
\ No newline at end of file
+ arbitrary_types_allowed = True
diff --git a/backend-services/models/update_api_model.py b/backend-services/models/update_api_model.py
index 1a64859..8d182c6 100644
--- a/backend-services/models/update_api_model.py
+++ b/backend-services/models/update_api_model.py
@@ -5,44 +5,120 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import List, Optional
+
class UpdateApiModel(BaseModel):
+ api_name: str | None = Field(
+ None, min_length=1, max_length=25, description='Name of the API', example='customer'
+ )
+ api_version: str | None = Field(
+ None, min_length=1, max_length=8, description='Version of the API', example='v1'
+ )
+ api_description: str | None = Field(
+ None,
+ min_length=1,
+ max_length=127,
+ description='Description of the API',
+ example='New customer onboarding API',
+ )
+ api_allowed_roles: list[str] | None = Field(
+ None, description='Allowed user roles for the API', example=['admin', 'user']
+ )
+ api_allowed_groups: list[str] | None = Field(
+ None, description='Allowed user groups for the API', example=['admin', 'client-1-group']
+ )
+ api_servers: list[str] | None = Field(
+ None,
+ description='List of backend servers for the API',
+ example=['http://localhost:8080', 'http://localhost:8081'],
+ )
+ api_type: str | None = Field(
+ None, description="Type of the API. Valid values: 'REST'", example='REST'
+ )
+ api_authorization_field_swap: str | None = Field(
+ None,
+ description='Header to swap for backend authorization header',
+ example='backend-auth-header',
+ )
+ api_allowed_headers: list[str] | None = Field(
+ None, description='Allowed headers for the API', example=['Content-Type', 'Authorization']
+ )
+ api_allowed_retry_count: int | None = Field(
+ None, description='Number of allowed retries for the API', example=0
+ )
+ api_grpc_package: str | None = Field(
+ None,
+ description='Optional gRPC Python package to use for this API (e.g., "my.pkg"). When set, overrides request package and default.',
+ example='my.pkg',
+ )
+ api_grpc_allowed_packages: list[str] | None = Field(
+ None,
+ description='Allow-list of gRPC package/module base names (no dots). If set, requests must match one of these.',
+ example=['customer_v1'],
+ )
+ api_grpc_allowed_services: list[str] | None = Field(
+ None,
+ description='Allow-list of gRPC service names (e.g., Greeter). If set, only these services are permitted.',
+ example=['Greeter'],
+ )
+ api_grpc_allowed_methods: list[str] | None = Field(
+ None,
+ description='Allow-list of gRPC methods as Service.Method strings. If set, only these methods are permitted.',
+ example=['Greeter.SayHello'],
+ )
+ api_credits_enabled: bool | None = Field(
+ False, description='Enable credit-based authentication for the API', example=True
+ )
+ api_credit_group: str | None = Field(
+ None, description='API credit group for the API credits', example='ai-group-1'
+ )
+ active: bool | None = Field(None, description='Whether the API is active (enabled)')
+ api_id: str | None = Field(
+ None, description='Unique identifier for the API, auto-generated', example=None
+ )
+ api_path: str | None = Field(
+ None, description='Unqiue path for the API, auto-generated', example=None
+ )
- api_name: Optional[str] = Field(None, min_length=1, max_length=25, description='Name of the API', example='customer')
- api_version: Optional[str] = Field(None, min_length=1, max_length=8, description='Version of the API', example='v1')
- api_description: Optional[str] = Field(None, min_length=1, max_length=127, description='Description of the API', example='New customer onboarding API')
- api_allowed_roles: Optional[List[str]] = Field(None, description='Allowed user roles for the API', example=['admin', 'user'])
- api_allowed_groups: Optional[List[str]] = Field(None, description='Allowed user groups for the API' , example=['admin', 'client-1-group'])
- api_servers: Optional[List[str]] = Field(None, description='List of backend servers for the API', example=['http://localhost:8080', 'http://localhost:8081'])
- api_type: Optional[str] = Field(None, description="Type of the API. Valid values: 'REST'", example='REST')
- api_authorization_field_swap: Optional[str] = Field(None, description='Header to swap for backend authorization header', example='backend-auth-header')
- api_allowed_headers: Optional[List[str]] = Field(None, description='Allowed headers for the API', example=['Content-Type', 'Authorization'])
- api_allowed_retry_count: Optional[int] = Field(None, description='Number of allowed retries for the API', example=0)
- api_grpc_package: Optional[str] = Field(None, description='Optional gRPC Python package to use for this API (e.g., "my.pkg"). When set, overrides request package and default.', example='my.pkg')
- api_grpc_allowed_packages: Optional[List[str]] = Field(None, description='Allow-list of gRPC package/module base names (no dots). If set, requests must match one of these.', example=['customer_v1'])
- api_grpc_allowed_services: Optional[List[str]] = Field(None, description='Allow-list of gRPC service names (e.g., Greeter). If set, only these services are permitted.', example=['Greeter'])
- api_grpc_allowed_methods: Optional[List[str]] = Field(None, description='Allow-list of gRPC methods as Service.Method strings. If set, only these methods are permitted.', example=['Greeter.SayHello'])
- api_credits_enabled: Optional[bool] = Field(False, description='Enable credit-based authentication for the API', example=True)
- api_credit_group: Optional[str] = Field(None, description='API credit group for the API credits', example='ai-group-1')
- active: Optional[bool] = Field(None, description='Whether the API is active (enabled)')
- api_id: Optional[str] = Field(None, description='Unique identifier for the API, auto-generated', example=None)
- api_path: Optional[str] = Field(None, description='Unqiue path for the API, auto-generated', example=None)
+ api_cors_allow_origins: list[str] | None = Field(
+ None,
+ description="Allowed origins for CORS (e.g., ['http://localhost:3000']). Use ['*'] to allow all.",
+ )
+ api_cors_allow_methods: list[str] | None = Field(
+ None,
+ description="Allowed methods for CORS preflight (e.g., ['GET','POST','PUT','DELETE','OPTIONS'])",
+ )
+ api_cors_allow_headers: list[str] | None = Field(
+ None,
+ description="Allowed request headers for CORS preflight (e.g., ['Content-Type','Authorization'])",
+ )
+ api_cors_allow_credentials: bool | None = Field(
+ None, description='Whether to include Access-Control-Allow-Credentials=true in responses'
+ )
+ api_cors_expose_headers: list[str] | None = Field(
+ None,
+ description='Response headers to expose to the browser via Access-Control-Expose-Headers',
+ )
- api_cors_allow_origins: Optional[List[str]] = Field(None, description="Allowed origins for CORS (e.g., ['http://localhost:3000']). Use ['*'] to allow all.")
- api_cors_allow_methods: Optional[List[str]] = Field(None, description="Allowed methods for CORS preflight (e.g., ['GET','POST','PUT','DELETE','OPTIONS'])")
- api_cors_allow_headers: Optional[List[str]] = Field(None, description="Allowed request headers for CORS preflight (e.g., ['Content-Type','Authorization'])")
- api_cors_allow_credentials: Optional[bool] = Field(None, description='Whether to include Access-Control-Allow-Credentials=true in responses')
- api_cors_expose_headers: Optional[List[str]] = Field(None, description='Response headers to expose to the browser via Access-Control-Expose-Headers')
+ api_public: bool | None = Field(
+ None, description='If true, this API can be called without authentication or subscription'
+ )
- api_public: Optional[bool] = Field(None, description='If true, this API can be called without authentication or subscription')
+ api_auth_required: bool | None = Field(
+ None,
+ description='If true (default), JWT auth is required for this API when not public. If false, requests may be unauthenticated but must meet other checks as configured.',
+ )
- api_auth_required: Optional[bool] = Field(None, description='If true (default), JWT auth is required for this API when not public. If false, requests may be unauthenticated but must meet other checks as configured.')
-
- api_ip_mode: Optional[str] = Field(None, description="IP policy mode: 'allow_all' or 'whitelist'")
- api_ip_whitelist: Optional[List[str]] = Field(None, description='Allowed IPs/CIDRs when api_ip_mode=whitelist')
- api_ip_blacklist: Optional[List[str]] = Field(None, description='IPs/CIDRs denied regardless of mode')
- api_trust_x_forwarded_for: Optional[bool] = Field(None, description='Override: trust X-Forwarded-For for this API')
+ api_ip_mode: str | None = Field(None, description="IP policy mode: 'allow_all' or 'whitelist'")
+ api_ip_whitelist: list[str] | None = Field(
+ None, description='Allowed IPs/CIDRs when api_ip_mode=whitelist'
+ )
+ api_ip_blacklist: list[str] | None = Field(
+ None, description='IPs/CIDRs denied regardless of mode'
+ )
+ api_trust_x_forwarded_for: bool | None = Field(
+ None, description='Override: trust X-Forwarded-For for this API'
+ )
class Config:
arbitrary_types_allowed = True
diff --git a/backend-services/models/update_endpoint_model.py b/backend-services/models/update_endpoint_model.py
index ce48867..c559e2b 100644
--- a/backend-services/models/update_endpoint_model.py
+++ b/backend-services/models/update_endpoint_model.py
@@ -5,18 +5,47 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import Optional, List
+
class UpdateEndpointModel(BaseModel):
-
- api_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the API', example='customer')
- api_version: Optional[str] = Field(None, min_length=1, max_length=10, description='Version of the API', example='v1')
- endpoint_method: Optional[str] = Field(None, min_length=1, max_length=10, description='HTTP method for the endpoint', example='GET')
- endpoint_uri: Optional[str] = Field(None, min_length=1, max_length=255, description='URI for the endpoint', example='/customer')
- endpoint_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the endpoint', example='Get customer details')
- endpoint_servers: Optional[List[str]] = Field(None, description='Optional list of backend servers for this endpoint (overrides API servers)', example=['http://localhost:8082', 'http://localhost:8083'])
- api_id: Optional[str] = Field(None, min_length=1, max_length=255, description='Unique identifier for the API, auto-generated', example=None)
- endpoint_id: Optional[str] = Field(None, min_length=1, max_length=255, description='Unique identifier for the endpoint, auto-generated', example=None)
+ api_name: str | None = Field(
+ None, min_length=1, max_length=50, description='Name of the API', example='customer'
+ )
+ api_version: str | None = Field(
+ None, min_length=1, max_length=10, description='Version of the API', example='v1'
+ )
+ endpoint_method: str | None = Field(
+ None, min_length=1, max_length=10, description='HTTP method for the endpoint', example='GET'
+ )
+ endpoint_uri: str | None = Field(
+ None, min_length=1, max_length=255, description='URI for the endpoint', example='/customer'
+ )
+ endpoint_description: str | None = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Description of the endpoint',
+ example='Get customer details',
+ )
+ endpoint_servers: list[str] | None = Field(
+ None,
+ description='Optional list of backend servers for this endpoint (overrides API servers)',
+ example=['http://localhost:8082', 'http://localhost:8083'],
+ )
+ api_id: str | None = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Unique identifier for the API, auto-generated',
+ example=None,
+ )
+ endpoint_id: str | None = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Unique identifier for the endpoint, auto-generated',
+ example=None,
+ )
class Config:
arbitrary_types_allowed = True
diff --git a/backend-services/models/update_endpoint_validation_model.py b/backend-services/models/update_endpoint_validation_model.py
index def3c12..08d97a1 100644
--- a/backend-services/models/update_endpoint_validation_model.py
+++ b/backend-services/models/update_endpoint_validation_model.py
@@ -8,10 +8,14 @@ from pydantic import BaseModel, Field
from models.validation_schema_model import ValidationSchema
-class UpdateEndpointValidationModel(BaseModel):
- validation_enabled: bool = Field(..., description='Whether the validation is enabled', example=True)
- validation_schema: ValidationSchema = Field(..., description='The schema to validate the endpoint against', example={})
+class UpdateEndpointValidationModel(BaseModel):
+ validation_enabled: bool = Field(
+ ..., description='Whether the validation is enabled', example=True
+ )
+ validation_schema: ValidationSchema = Field(
+ ..., description='The schema to validate the endpoint against', example={}
+ )
class Config:
- arbitrary_types_allowed = True
\ No newline at end of file
+ arbitrary_types_allowed = True
diff --git a/backend-services/models/update_group_model.py b/backend-services/models/update_group_model.py
index 7552361..1b5d8c2 100644
--- a/backend-services/models/update_group_model.py
+++ b/backend-services/models/update_group_model.py
@@ -5,13 +5,22 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import List, Optional
+
class UpdateGroupModel(BaseModel):
-
- group_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the group', example='client-1-group')
- group_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the group', example='Group for client 1')
- api_access: Optional[List[str]] = Field(None, description='List of APIs the group can access', example=['customer/v1'])
+ group_name: str | None = Field(
+ None, min_length=1, max_length=50, description='Name of the group', example='client-1-group'
+ )
+ group_description: str | None = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Description of the group',
+ example='Group for client 1',
+ )
+ api_access: list[str] | None = Field(
+ None, description='List of APIs the group can access', example=['customer/v1']
+ )
class Config:
- arbitrary_types_allowed = True
\ No newline at end of file
+ arbitrary_types_allowed = True
diff --git a/backend-services/models/update_password_model.py b/backend-services/models/update_password_model.py
index 12245c8..c280486 100644
--- a/backend-services/models/update_password_model.py
+++ b/backend-services/models/update_password_model.py
@@ -6,9 +6,15 @@ See https://github.com/pypeople-dev/doorman for more information
from pydantic import BaseModel, Field
-class UpdatePasswordModel(BaseModel):
- new_password: str = Field(..., min_length=6, max_length=36, description='New password of the user', example='NewPassword456!')
+class UpdatePasswordModel(BaseModel):
+ new_password: str = Field(
+ ...,
+ min_length=6,
+ max_length=36,
+ description='New password of the user',
+ example='NewPassword456!',
+ )
class Config:
- arbitrary_types_allowed = True
\ No newline at end of file
+ arbitrary_types_allowed = True
diff --git a/backend-services/models/update_role_model.py b/backend-services/models/update_role_model.py
index 4d869ea..84ef80a 100644
--- a/backend-services/models/update_role_model.py
+++ b/backend-services/models/update_role_model.py
@@ -5,28 +5,57 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import Optional
+
class UpdateRoleModel(BaseModel):
-
- role_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the role', example='admin')
- role_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the role', example='Administrator role with full access')
- manage_users: Optional[bool] = Field(None, description='Permission to manage users', example=True)
- manage_apis: Optional[bool] = Field(None, description='Permission to manage APIs', example=True)
- manage_endpoints: Optional[bool] = Field(None, description='Permission to manage endpoints', example=True)
- manage_groups: Optional[bool] = Field(None, description='Permission to manage groups', example=True)
- manage_roles: Optional[bool] = Field(None, description='Permission to manage roles', example=True)
- manage_routings: Optional[bool] = Field(None, description='Permission to manage routings', example=True)
- manage_gateway: Optional[bool] = Field(None, description='Permission to manage gateway', example=True)
- manage_subscriptions: Optional[bool] = Field(None, description='Permission to manage subscriptions', example=True)
- manage_security: Optional[bool] = Field(None, description='Permission to manage security settings', example=True)
- manage_tiers: Optional[bool] = Field(None, description='Permission to manage pricing tiers', example=True)
- manage_rate_limits: Optional[bool] = Field(None, description='Permission to manage rate limiting rules', example=True)
- manage_credits: Optional[bool] = Field(None, description='Permission to manage credits', example=True)
- manage_auth: Optional[bool] = Field(None, description='Permission to manage auth (revoke tokens/disable users)', example=True)
- view_analytics: Optional[bool] = Field(None, description='Permission to view analytics dashboard', example=True)
- view_logs: Optional[bool] = Field(None, description='Permission to view logs', example=True)
- export_logs: Optional[bool] = Field(None, description='Permission to export logs', example=True)
+ role_name: str | None = Field(
+ None, min_length=1, max_length=50, description='Name of the role', example='admin'
+ )
+ role_description: str | None = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Description of the role',
+ example='Administrator role with full access',
+ )
+ manage_users: bool | None = Field(None, description='Permission to manage users', example=True)
+ manage_apis: bool | None = Field(None, description='Permission to manage APIs', example=True)
+ manage_endpoints: bool | None = Field(
+ None, description='Permission to manage endpoints', example=True
+ )
+ manage_groups: bool | None = Field(
+ None, description='Permission to manage groups', example=True
+ )
+ manage_roles: bool | None = Field(None, description='Permission to manage roles', example=True)
+ manage_routings: bool | None = Field(
+ None, description='Permission to manage routings', example=True
+ )
+ manage_gateway: bool | None = Field(
+ None, description='Permission to manage gateway', example=True
+ )
+ manage_subscriptions: bool | None = Field(
+ None, description='Permission to manage subscriptions', example=True
+ )
+ manage_security: bool | None = Field(
+ None, description='Permission to manage security settings', example=True
+ )
+ manage_tiers: bool | None = Field(
+ None, description='Permission to manage pricing tiers', example=True
+ )
+ manage_rate_limits: bool | None = Field(
+ None, description='Permission to manage rate limiting rules', example=True
+ )
+ manage_credits: bool | None = Field(
+ None, description='Permission to manage credits', example=True
+ )
+ manage_auth: bool | None = Field(
+ None, description='Permission to manage auth (revoke tokens/disable users)', example=True
+ )
+ view_analytics: bool | None = Field(
+ None, description='Permission to view analytics dashboard', example=True
+ )
+ view_logs: bool | None = Field(None, description='Permission to view logs', example=True)
+ export_logs: bool | None = Field(None, description='Permission to export logs', example=True)
class Config:
arbitrary_types_allowed = True
diff --git a/backend-services/models/update_routing_model.py b/backend-services/models/update_routing_model.py
index 6e8824e..0042d1b 100644
--- a/backend-services/models/update_routing_model.py
+++ b/backend-services/models/update_routing_model.py
@@ -5,16 +5,40 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import Optional
+
class UpdateRoutingModel(BaseModel):
+ routing_name: str | None = Field(
+ None,
+ min_length=1,
+ max_length=50,
+ description='Name of the routing',
+ example='customer-routing',
+ )
+ routing_servers: list[str] | None = Field(
+ None,
+ min_items=1,
+ description='List of backend servers for the routing',
+ example=['http://localhost:8080', 'http://localhost:8081'],
+ )
+ routing_description: str | None = Field(
+ None,
+ min_length=1,
+ max_length=255,
+ description='Description of the routing',
+ example='Routing for customer API',
+ )
- routing_name: Optional[str] = Field(None, min_length=1, max_length=50, description='Name of the routing', example='customer-routing')
- routing_servers : Optional[list[str]] = Field(None, min_items=1, description='List of backend servers for the routing', example=['http://localhost:8080', 'http://localhost:8081'])
- routing_description: Optional[str] = Field(None, min_length=1, max_length=255, description='Description of the routing', example='Routing for customer API')
-
- client_key: Optional[str] = Field(None, min_length=1, max_length=50, description='Client key for the routing', example='client-1')
- server_index: Optional[int] = Field(None, exclude=True, ge=0, description='Index of the server to route to', example=0)
+ client_key: str | None = Field(
+ None,
+ min_length=1,
+ max_length=50,
+ description='Client key for the routing',
+ example='client-1',
+ )
+ server_index: int | None = Field(
+ None, exclude=True, ge=0, description='Index of the server to route to', example=0
+ )
class Config:
- arbitrary_types_allowed = True
\ No newline at end of file
+ arbitrary_types_allowed = True
diff --git a/backend-services/models/update_user_model.py b/backend-services/models/update_user_model.py
index 6ab0e2b..c292da8 100644
--- a/backend-services/models/update_user_model.py
+++ b/backend-services/models/update_user_model.py
@@ -5,29 +5,90 @@ See https://github.com/pypeople-dev/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import List, Optional
+
class UpdateUserModel(BaseModel):
+ username: str | None = Field(
+ None, min_length=3, max_length=50, description='Username of the user', example='john_doe'
+ )
+ email: str | None = Field(
+ None,
+ min_length=3,
+ max_length=127,
+ description='Email of the user (no strict format validation)',
+ example='john@mail.com',
+ )
+ password: str | None = Field(
+ None,
+ min_length=6,
+ max_length=50,
+ description='Password of the user',
+ example='SecurePassword@123',
+ )
+ role: str | None = Field(
+ None, min_length=2, max_length=50, description='Role of the user', example='admin'
+ )
+ groups: list[str] | None = Field(
+ None, description='List of groups the user belongs to', example=['client-1-group']
+ )
+ rate_limit_duration: int | None = Field(
+ None, ge=0, description='Rate limit for the user', example=100
+ )
+ rate_limit_duration_type: str | None = Field(
+ None, min_length=1, max_length=7, description='Duration for the rate limit', example='hour'
+ )
+ rate_limit_enabled: bool | None = Field(
+ None, description='Whether rate limiting is enabled for this user', example=True
+ )
+ throttle_duration: int | None = Field(
+ None, ge=0, description='Throttle limit for the user', example=10
+ )
+ throttle_duration_type: str | None = Field(
+ None,
+ min_length=1,
+ max_length=7,
+ description='Duration for the throttle limit',
+ example='second',
+ )
+ throttle_wait_duration: int | None = Field(
+ None, ge=0, description='Wait time for the throttle limit', example=5
+ )
+ throttle_wait_duration_type: str | None = Field(
+ None,
+ min_length=1,
+ max_length=7,
+ description='Wait duration for the throttle limit',
+ example='seconds',
+ )
+ throttle_queue_limit: int | None = Field(
+ None, ge=0, description='Throttle queue limit for the user', example=10
+ )
+ throttle_enabled: bool | None = Field(
+ None, description='Whether throttling is enabled for this user', example=True
+ )
+ custom_attributes: dict | None = Field(
+ None, description='Custom attributes for the user', example={'custom_key': 'custom_value'}
+ )
+ bandwidth_limit_bytes: int | None = Field(
+ None,
+ ge=0,
+ description='Maximum bandwidth allowed within the window (bytes)',
+ example=1073741824,
+ )
+ bandwidth_limit_window: str | None = Field(
+ None,
+ min_length=1,
+ max_length=10,
+ description='Bandwidth window unit (second/minute/hour/day/month)',
+ example='day',
+ )
+ bandwidth_limit_enabled: bool | None = Field(
+ None,
+ description='Whether bandwidth limit enforcement is enabled for this user',
+ example=True,
+ )
+ active: bool | None = Field(None, description='Active status of the user', example=True)
+ ui_access: bool | None = Field(None, description='UI access for the user', example=False)
- username: Optional[str] = Field(None, min_length=3, max_length=50, description='Username of the user', example='john_doe')
- email: Optional[str] = Field(None, min_length=3, max_length=127, description='Email of the user (no strict format validation)', example='john@mail.com')
- password: Optional[str] = Field(None, min_length=6, max_length=50, description='Password of the user', example='SecurePassword@123')
- role: Optional[str] = Field(None, min_length=2, max_length=50, description='Role of the user', example='admin')
- groups: Optional[List[str]] = Field(None, description='List of groups the user belongs to', example=['client-1-group'])
- rate_limit_duration: Optional[int] = Field(None, ge=0, description='Rate limit for the user', example=100)
- rate_limit_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Duration for the rate limit', example='hour')
- rate_limit_enabled: Optional[bool] = Field(None, description='Whether rate limiting is enabled for this user', example=True)
- throttle_duration: Optional[int] = Field(None, ge=0, description='Throttle limit for the user', example=10)
- throttle_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Duration for the throttle limit', example='second')
- throttle_wait_duration: Optional[int] = Field(None, ge=0, description='Wait time for the throttle limit', example=5)
- throttle_wait_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Wait duration for the throttle limit', example='seconds')
- throttle_queue_limit: Optional[int] = Field(None, ge=0, description='Throttle queue limit for the user', example=10)
- throttle_enabled: Optional[bool] = Field(None, description='Whether throttling is enabled for this user', example=True)
- custom_attributes: Optional[dict] = Field(None, description='Custom attributes for the user', example={'custom_key': 'custom_value'})
- bandwidth_limit_bytes: Optional[int] = Field(None, ge=0, description='Maximum bandwidth allowed within the window (bytes)', example=1073741824)
- bandwidth_limit_window: Optional[str] = Field(None, min_length=1, max_length=10, description='Bandwidth window unit (second/minute/hour/day/month)', example='day')
- bandwidth_limit_enabled: Optional[bool] = Field(None, description='Whether bandwidth limit enforcement is enabled for this user', example=True)
- active: Optional[bool] = Field(None, description='Active status of the user', example=True)
- ui_access: Optional[bool] = Field(None, description='UI access for the user', example=False)
class Config:
arbitrary_types_allowed = True
diff --git a/backend-services/models/update_vault_entry_model.py b/backend-services/models/update_vault_entry_model.py
index 355097a..698ad97 100644
--- a/backend-services/models/update_vault_entry_model.py
+++ b/backend-services/models/update_vault_entry_model.py
@@ -5,17 +5,16 @@ See https://github.com/apidoorman/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import Optional
class UpdateVaultEntryModel(BaseModel):
"""Model for updating a vault entry. Only description can be updated, not the value."""
-
- description: Optional[str] = Field(
+
+ description: str | None = Field(
None,
max_length=500,
description='Updated description of what this key is used for',
- example='Production API key for payment gateway - updated'
+ example='Production API key for payment gateway - updated',
)
class Config:
diff --git a/backend-services/models/user_credits_model.py b/backend-services/models/user_credits_model.py
index 2b18353..7110a24 100644
--- a/backend-services/models/user_credits_model.py
+++ b/backend-services/models/user_credits_model.py
@@ -4,23 +4,39 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import Optional, Dict
from pydantic import BaseModel, Field
+
class UserCreditInformationModel(BaseModel):
- tier_name: str = Field(..., min_length=1, max_length=50, description='Name of the credit tier', example='basic')
+ tier_name: str = Field(
+ ..., min_length=1, max_length=50, description='Name of the credit tier', example='basic'
+ )
available_credits: int = Field(..., description='Number of available credits', example=50)
- reset_date: Optional[str] = Field(None, description='Date when paid credits are reset', example='2023-10-01')
- user_api_key: Optional[str] = Field(None, description='User specific API key for the credit tier', example='xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
+ reset_date: str | None = Field(
+ None, description='Date when paid credits are reset', example='2023-10-01'
+ )
+ user_api_key: str | None = Field(
+ None,
+ description='User specific API key for the credit tier',
+ example='xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx',
+ )
class Config:
arbitrary_types_allowed = True
+
class UserCreditModel(BaseModel):
- username: str = Field(..., min_length=3, max_length=50, description='Username of credits owner', example='client-1')
- users_credits: Dict[str, UserCreditInformationModel] = Field(..., description='Credits information. Key is the credit group name')
+ username: str = Field(
+ ...,
+ min_length=3,
+ max_length=50,
+ description='Username of credits owner',
+ example='client-1',
+ )
+ users_credits: dict[str, UserCreditInformationModel] = Field(
+ ..., description='Credits information. Key is the credit group name'
+ )
class Config:
arbitrary_types_allowed = True
-
diff --git a/backend-services/models/user_model_response.py b/backend-services/models/user_model_response.py
index 2018303..b438c3b 100644
--- a/backend-services/models/user_model_response.py
+++ b/backend-services/models/user_model_response.py
@@ -4,33 +4,96 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/doorman for more information
"""
-from pydantic import BaseModel, Field, EmailStr
-from typing import List, Optional
+from pydantic import BaseModel, EmailStr, Field
+
class UserModelResponse(BaseModel):
-
- username: Optional[str] = Field(None, min_length=3, max_length=50, description='Username of the user', example='john_doe')
- email: Optional[EmailStr] = Field(None, description='Email of the user', example='john@mail.com')
- password: Optional[str] = Field(None, min_length=6, max_length=50, description='Password of the user', example='SecurePassword@123')
- role: Optional[str] = Field(None, min_length=2, max_length=50, description='Role of the user', example='admin')
- groups: Optional[List[str]] = Field(None, description='List of groups the user belongs to', example=['client-1-group'])
- rate_limit_duration: Optional[int] = Field(None, ge=0, description='Rate limit for the user', example=100)
- rate_limit_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Duration for the rate limit', example='hour')
- rate_limit_enabled: Optional[bool] = Field(None, description='Whether rate limiting is enabled for this user', example=True)
- throttle_duration: Optional[int] = Field(None, ge=0, description='Throttle limit for the user', example=10)
- throttle_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Duration for the throttle limit', example='second')
- throttle_wait_duration: Optional[int] = Field(None, ge=0, description='Wait time for the throttle limit', example=5)
- throttle_wait_duration_type: Optional[str] = Field(None, min_length=1, max_length=7, description='Wait duration for the throttle limit', example='seconds')
- throttle_queue_limit: Optional[int] = Field(None, ge=0, description='Throttle queue limit for the user', example=10)
- throttle_enabled: Optional[bool] = Field(None, description='Whether throttling is enabled for this user', example=True)
- custom_attributes: Optional[dict] = Field(None, description='Custom attributes for the user', example={'custom_key': 'custom_value'})
- bandwidth_limit_bytes: Optional[int] = Field(None, ge=0, description='Maximum bandwidth allowed within the window (bytes)', example=1073741824)
- bandwidth_limit_window: Optional[str] = Field(None, min_length=1, max_length=10, description='Bandwidth window unit (second/minute/hour/day/month)', example='day')
- bandwidth_usage_bytes: Optional[int] = Field(None, ge=0, description='Current bandwidth usage in the active window (bytes)', example=123456)
- bandwidth_resets_at: Optional[int] = Field(None, description='UTC epoch seconds when the current bandwidth window resets', example=1727481600)
- bandwidth_limit_enabled: Optional[bool] = Field(None, description='Whether bandwidth limit enforcement is enabled for this user', example=True)
- active: Optional[bool] = Field(None, description='Active status of the user', example=True)
- ui_access: Optional[bool] = Field(None, description='UI access for the user', example=False)
+ username: str | None = Field(
+ None, min_length=3, max_length=50, description='Username of the user', example='john_doe'
+ )
+ email: EmailStr | None = Field(None, description='Email of the user', example='john@mail.com')
+ password: str | None = Field(
+ None,
+ min_length=6,
+ max_length=50,
+ description='Password of the user',
+ example='SecurePassword@123',
+ )
+ role: str | None = Field(
+ None, min_length=2, max_length=50, description='Role of the user', example='admin'
+ )
+ groups: list[str] | None = Field(
+ None, description='List of groups the user belongs to', example=['client-1-group']
+ )
+ rate_limit_duration: int | None = Field(
+ None, ge=0, description='Rate limit for the user', example=100
+ )
+ rate_limit_duration_type: str | None = Field(
+ None, min_length=1, max_length=7, description='Duration for the rate limit', example='hour'
+ )
+ rate_limit_enabled: bool | None = Field(
+ None, description='Whether rate limiting is enabled for this user', example=True
+ )
+ throttle_duration: int | None = Field(
+ None, ge=0, description='Throttle limit for the user', example=10
+ )
+ throttle_duration_type: str | None = Field(
+ None,
+ min_length=1,
+ max_length=7,
+ description='Duration for the throttle limit',
+ example='second',
+ )
+ throttle_wait_duration: int | None = Field(
+ None, ge=0, description='Wait time for the throttle limit', example=5
+ )
+ throttle_wait_duration_type: str | None = Field(
+ None,
+ min_length=1,
+ max_length=7,
+ description='Wait duration for the throttle limit',
+ example='seconds',
+ )
+ throttle_queue_limit: int | None = Field(
+ None, ge=0, description='Throttle queue limit for the user', example=10
+ )
+ throttle_enabled: bool | None = Field(
+ None, description='Whether throttling is enabled for this user', example=True
+ )
+ custom_attributes: dict | None = Field(
+ None, description='Custom attributes for the user', example={'custom_key': 'custom_value'}
+ )
+ bandwidth_limit_bytes: int | None = Field(
+ None,
+ ge=0,
+ description='Maximum bandwidth allowed within the window (bytes)',
+ example=1073741824,
+ )
+ bandwidth_limit_window: str | None = Field(
+ None,
+ min_length=1,
+ max_length=10,
+ description='Bandwidth window unit (second/minute/hour/day/month)',
+ example='day',
+ )
+ bandwidth_usage_bytes: int | None = Field(
+ None,
+ ge=0,
+ description='Current bandwidth usage in the active window (bytes)',
+ example=123456,
+ )
+ bandwidth_resets_at: int | None = Field(
+ None,
+ description='UTC epoch seconds when the current bandwidth window resets',
+ example=1727481600,
+ )
+ bandwidth_limit_enabled: bool | None = Field(
+ None,
+ description='Whether bandwidth limit enforcement is enabled for this user',
+ example=True,
+ )
+ active: bool | None = Field(None, description='Active status of the user', example=True)
+ ui_access: bool | None = Field(None, description='UI access for the user', example=False)
class Config:
arbitrary_types_allowed = True
diff --git a/backend-services/models/validation_schema_model.py b/backend-services/models/validation_schema_model.py
index 9ee60e7..a75c3b8 100644
--- a/backend-services/models/validation_schema_model.py
+++ b/backend-services/models/validation_schema_model.py
@@ -4,11 +4,11 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/doorman for more information
"""
-from typing import Dict
from pydantic import BaseModel, Field
from models.field_validation_model import FieldValidation
+
class ValidationSchema(BaseModel):
"""Validation schema for endpoint request/response validation.
@@ -57,21 +57,12 @@ class ValidationSchema(BaseModel):
}
}
"""
- validation_schema: Dict[str, FieldValidation] = Field(
+
+ validation_schema: dict[str, FieldValidation] = Field(
...,
description='The schema to validate the endpoint against',
example={
- 'user.name': {
- 'required': True,
- 'type': 'string',
- 'min': 2,
- 'max': 50
- },
- 'user.age': {
- 'required': True,
- 'type': 'number',
- 'min': 0,
- 'max': 120
- }
- }
- )
\ No newline at end of file
+ 'user.name': {'required': True, 'type': 'string', 'min': 2, 'max': 50},
+ 'user.age': {'required': True, 'type': 'number', 'min': 0, 'max': 120},
+ },
+ )
diff --git a/backend-services/models/vault_entry_model_response.py b/backend-services/models/vault_entry_model_response.py
index de6a96f..ae19313 100644
--- a/backend-services/models/vault_entry_model_response.py
+++ b/backend-services/models/vault_entry_model_response.py
@@ -5,40 +5,29 @@ See https://github.com/apidoorman/doorman for more information
"""
from pydantic import BaseModel, Field
-from typing import Optional
class VaultEntryModelResponse(BaseModel):
"""Response model for vault entry. Value is never returned."""
-
- key_name: str = Field(
- ...,
- description='Name of the vault key',
- example='api_key_production'
- )
-
- username: str = Field(
- ...,
- description='Username of the vault entry owner',
- example='john_doe'
- )
-
- description: Optional[str] = Field(
+
+ key_name: str = Field(..., description='Name of the vault key', example='api_key_production')
+
+ username: str = Field(..., description='Username of the vault entry owner', example='john_doe')
+
+ description: str | None = Field(
None,
description='Description of what this key is used for',
- example='Production API key for payment gateway'
+ example='Production API key for payment gateway',
)
-
- created_at: Optional[str] = Field(
- None,
- description='Timestamp when the entry was created',
- example='2024-11-22T10:15:30Z'
+
+ created_at: str | None = Field(
+ None, description='Timestamp when the entry was created', example='2024-11-22T10:15:30Z'
)
-
- updated_at: Optional[str] = Field(
+
+ updated_at: str | None = Field(
None,
description='Timestamp when the entry was last updated',
- example='2024-11-22T10:15:30Z'
+ example='2024-11-22T10:15:30Z',
)
class Config:
diff --git a/backend-services/routes/analytics_routes.py b/backend-services/routes/analytics_routes.py
index fcd91c3..d6bad7b 100644
--- a/backend-services/routes/analytics_routes.py
+++ b/backend-services/routes/analytics_routes.py
@@ -9,20 +9,17 @@ Provides comprehensive analytics endpoints for:
- Endpoint performance analysis
"""
-from fastapi import APIRouter, Request, Query, HTTPException
-from pydantic import BaseModel
-from typing import Optional, List
-import uuid
-import time
import logging
-from datetime import datetime, timedelta
+import time
+import uuid
+
+from fastapi import APIRouter, Query, Request
from models.response_model import ResponseModel
-from utils.response_util import respond_rest, process_response
from utils.auth_util import auth_required
-from utils.role_util import platform_role_required_bool
from utils.enhanced_metrics_util import enhanced_metrics_store
-from utils.analytics_aggregator import analytics_aggregator
+from utils.response_util import respond_rest
+from utils.role_util import platform_role_required_bool
analytics_router = APIRouter()
logger = logging.getLogger('doorman.analytics')
@@ -32,7 +29,9 @@ logger = logging.getLogger('doorman.analytics')
# ENDPOINT 1: Dashboard Overview
# ============================================================================
-@analytics_router.get('/analytics/overview',
+
+@analytics_router.get(
+ '/analytics/overview',
description='Get dashboard overview statistics',
response_model=ResponseModel,
responses={
@@ -50,33 +49,33 @@ logger = logging.getLogger('doorman.analytics')
'p75': 180.0,
'p90': 250.0,
'p95': 300.0,
- 'p99': 450.0
+ 'p99': 450.0,
},
'unique_users': 150,
'total_bandwidth': 1073741824,
'top_apis': [
{'api': 'rest:customer', 'count': 5000},
- {'api': 'rest:orders', 'count': 3000}
+ {'api': 'rest:orders', 'count': 3000},
],
'top_users': [
{'user': 'john_doe', 'count': 500},
- {'user': 'jane_smith', 'count': 300}
- ]
+ {'user': 'jane_smith', 'count': 300},
+ ],
}
}
- }
+ },
}
- }
+ },
)
async def get_analytics_overview(
request: Request,
- start_ts: Optional[int] = Query(None, description='Start timestamp (Unix seconds)'),
- end_ts: Optional[int] = Query(None, description='End timestamp (Unix seconds)'),
- range: Optional[str] = Query('24h', description='Time range (1h, 24h, 7d, 30d)')
+ start_ts: int | None = Query(None, description='Start timestamp (Unix seconds)'),
+ end_ts: int | None = Query(None, description='End timestamp (Unix seconds)'),
+ range: str | None = Query('24h', description='Time range (1h, 24h, 7d, 30d)'),
):
"""
Get dashboard overview statistics.
-
+
Returns summary metrics including:
- Total requests and errors
- Error rate
@@ -89,23 +88,27 @@ async def get_analytics_overview(
"""
request_id = str(uuid.uuid4())
start_time = time.time()
-
+
try:
# Authentication and authorization
payload = await auth_required(request)
username = payload.get('sub')
-
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
-
+
if not await platform_role_required_bool(username, 'view_analytics'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS001',
- error_message='You do not have permission to view analytics'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS001',
+ error_message='You do not have permission to view analytics',
+ )
+ )
+
# Determine time range
if start_ts and end_ts:
# Use provided timestamps
@@ -113,24 +116,19 @@ async def get_analytics_overview(
else:
# Use range parameter
end_ts = int(time.time())
- range_map = {
- '1h': 3600,
- '24h': 86400,
- '7d': 604800,
- '30d': 2592000
- }
+ range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000}
seconds = range_map.get(range, 86400)
start_ts = end_ts - seconds
-
+
# Get analytics snapshot
snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts)
-
+
# Build response
overview = {
'time_range': {
'start_ts': start_ts,
'end_ts': end_ts,
- 'duration_seconds': end_ts - start_ts
+ 'duration_seconds': end_ts - start_ts,
},
'summary': {
'total_requests': snapshot.total_requests,
@@ -140,29 +138,31 @@ async def get_analytics_overview(
'unique_users': snapshot.unique_users,
'total_bandwidth': snapshot.total_bytes_in + snapshot.total_bytes_out,
'bandwidth_in': snapshot.total_bytes_in,
- 'bandwidth_out': snapshot.total_bytes_out
+ 'bandwidth_out': snapshot.total_bytes_out,
},
'percentiles': snapshot.percentiles.to_dict(),
'top_apis': snapshot.top_apis,
'top_users': snapshot.top_users,
- 'status_distribution': snapshot.status_distribution
+ 'status_distribution': snapshot.status_distribution,
}
-
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response=overview
- ))
-
+
+ return respond_rest(
+ ResponseModel(
+ status_code=200, response_headers={'request_id': request_id}, response=overview
+ )
+ )
+
except Exception as e:
logger.error(f'{request_id} | Error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS999',
- error_message='An unexpected error occurred'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS999',
+ error_message='An unexpected error occurred',
+ )
+ )
+
finally:
end_time = time.time()
logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms')
@@ -172,21 +172,28 @@ async def get_analytics_overview(
# ENDPOINT 2: Time-Series Data
# ============================================================================
-@analytics_router.get('/analytics/timeseries',
+
+@analytics_router.get(
+ '/analytics/timeseries',
description='Get time-series analytics data with filtering',
- response_model=ResponseModel
+ response_model=ResponseModel,
)
async def get_analytics_timeseries(
request: Request,
- start_ts: Optional[int] = Query(None, description='Start timestamp (Unix seconds)'),
- end_ts: Optional[int] = Query(None, description='End timestamp (Unix seconds)'),
- range: Optional[str] = Query('24h', description='Time range (1h, 24h, 7d, 30d)'),
- granularity: Optional[str] = Query('auto', description='Data granularity (auto, minute, 5minute, hour, day)'),
- metric_type: Optional[str] = Query(None, description='Specific metric to return (request_count, error_rate, latency, bandwidth)')
+ start_ts: int | None = Query(None, description='Start timestamp (Unix seconds)'),
+ end_ts: int | None = Query(None, description='End timestamp (Unix seconds)'),
+ range: str | None = Query('24h', description='Time range (1h, 24h, 7d, 30d)'),
+ granularity: str | None = Query(
+ 'auto', description='Data granularity (auto, minute, 5minute, hour, day)'
+ ),
+ metric_type: str | None = Query(
+ None,
+ description='Specific metric to return (request_count, error_rate, latency, bandwidth)',
+ ),
):
"""
Get time-series analytics data.
-
+
Returns series of data points over time with:
- Timestamp
- Request count
@@ -198,19 +205,21 @@ async def get_analytics_timeseries(
"""
request_id = str(uuid.uuid4())
start_time = time.time()
-
+
try:
payload = await auth_required(request)
username = payload.get('sub')
-
+
if not await platform_role_required_bool(username, 'view_analytics'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS001',
- error_message='You do not have permission to view analytics'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS001',
+ error_message='You do not have permission to view analytics',
+ )
+ )
+
# Determine time range
if start_ts and end_ts:
pass
@@ -219,10 +228,10 @@ async def get_analytics_timeseries(
range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000}
seconds = range_map.get(range, 86400)
start_ts = end_ts - seconds
-
+
# Get snapshot with time-series data
snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts, granularity)
-
+
# Filter by metric type if specified
series = snapshot.series
if metric_type:
@@ -243,27 +252,31 @@ async def get_analytics_timeseries(
filtered_point['bytes_out'] = point['bytes_out']
filtered_series.append(filtered_point)
series = filtered_series
-
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response={
- 'time_range': {'start_ts': start_ts, 'end_ts': end_ts},
- 'granularity': granularity,
- 'series': series,
- 'data_points': len(series)
- }
- ))
-
+
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ response={
+ 'time_range': {'start_ts': start_ts, 'end_ts': end_ts},
+ 'granularity': granularity,
+ 'series': series,
+ 'data_points': len(series),
+ },
+ )
+ )
+
except Exception as e:
logger.error(f'{request_id} | Error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS999',
- error_message='An unexpected error occurred'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS999',
+ error_message='An unexpected error occurred',
+ )
+ )
+
finally:
end_time = time.time()
logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms')
@@ -273,20 +286,20 @@ async def get_analytics_timeseries(
# ENDPOINT 3: Top APIs
# ============================================================================
-@analytics_router.get('/analytics/top-apis',
- description='Get most used APIs',
- response_model=ResponseModel
+
+@analytics_router.get(
+ '/analytics/top-apis', description='Get most used APIs', response_model=ResponseModel
)
async def get_top_apis(
request: Request,
- start_ts: Optional[int] = Query(None),
- end_ts: Optional[int] = Query(None),
- range: Optional[str] = Query('24h'),
- limit: int = Query(10, ge=1, le=100, description='Number of APIs to return')
+ start_ts: int | None = Query(None),
+ end_ts: int | None = Query(None),
+ range: str | None = Query('24h'),
+ limit: int = Query(10, ge=1, le=100, description='Number of APIs to return'),
):
"""
Get top N most used APIs.
-
+
Returns list of APIs sorted by request count with:
- API name
- Total requests
@@ -296,19 +309,21 @@ async def get_top_apis(
"""
request_id = str(uuid.uuid4())
start_time = time.time()
-
+
try:
payload = await auth_required(request)
username = payload.get('sub')
-
+
if not await platform_role_required_bool(username, 'view_analytics'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS001',
- error_message='You do not have permission to view analytics'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS001',
+ error_message='You do not have permission to view analytics',
+ )
+ )
+
# Determine time range
if start_ts and end_ts:
pass
@@ -317,32 +332,36 @@ async def get_top_apis(
range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000}
seconds = range_map.get(range, 86400)
start_ts = end_ts - seconds
-
+
# Get snapshot
snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts)
-
+
# Get top APIs (already sorted by count)
top_apis = snapshot.top_apis[:limit]
-
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response={
- 'time_range': {'start_ts': start_ts, 'end_ts': end_ts},
- 'top_apis': top_apis,
- 'total_apis': len(snapshot.top_apis)
- }
- ))
-
+
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ response={
+ 'time_range': {'start_ts': start_ts, 'end_ts': end_ts},
+ 'top_apis': top_apis,
+ 'total_apis': len(snapshot.top_apis),
+ },
+ )
+ )
+
except Exception as e:
logger.error(f'{request_id} | Error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS999',
- error_message='An unexpected error occurred'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS999',
+ error_message='An unexpected error occurred',
+ )
+ )
+
finally:
end_time = time.time()
logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms')
@@ -352,37 +371,39 @@ async def get_top_apis(
# ENDPOINT 4: Top Users
# ============================================================================
-@analytics_router.get('/analytics/top-users',
- description='Get highest consuming users',
- response_model=ResponseModel
+
+@analytics_router.get(
+ '/analytics/top-users', description='Get highest consuming users', response_model=ResponseModel
)
async def get_top_users(
request: Request,
- start_ts: Optional[int] = Query(None),
- end_ts: Optional[int] = Query(None),
- range: Optional[str] = Query('24h'),
- limit: int = Query(10, ge=1, le=100, description='Number of users to return')
+ start_ts: int | None = Query(None),
+ end_ts: int | None = Query(None),
+ range: str | None = Query('24h'),
+ limit: int = Query(10, ge=1, le=100, description='Number of users to return'),
):
"""
Get top N highest consuming users.
-
+
Returns list of users sorted by request count.
"""
request_id = str(uuid.uuid4())
start_time = time.time()
-
+
try:
payload = await auth_required(request)
username = payload.get('sub')
-
+
if not await platform_role_required_bool(username, 'view_analytics'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS001',
- error_message='You do not have permission to view analytics'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS001',
+ error_message='You do not have permission to view analytics',
+ )
+ )
+
# Determine time range
if start_ts and end_ts:
pass
@@ -391,32 +412,36 @@ async def get_top_users(
range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000}
seconds = range_map.get(range, 86400)
start_ts = end_ts - seconds
-
+
# Get snapshot
snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts)
-
+
# Get top users (already sorted by count)
top_users = snapshot.top_users[:limit]
-
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response={
- 'time_range': {'start_ts': start_ts, 'end_ts': end_ts},
- 'top_users': top_users,
- 'total_users': len(snapshot.top_users)
- }
- ))
-
+
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ response={
+ 'time_range': {'start_ts': start_ts, 'end_ts': end_ts},
+ 'top_users': top_users,
+ 'total_users': len(snapshot.top_users),
+ },
+ )
+ )
+
except Exception as e:
logger.error(f'{request_id} | Error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS999',
- error_message='An unexpected error occurred'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS999',
+ error_message='An unexpected error occurred',
+ )
+ )
+
finally:
end_time = time.time()
logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms')
@@ -426,21 +451,23 @@ async def get_top_users(
# ENDPOINT 5: Top Endpoints
# ============================================================================
-@analytics_router.get('/analytics/top-endpoints',
+
+@analytics_router.get(
+ '/analytics/top-endpoints',
description='Get slowest/most-used endpoints',
- response_model=ResponseModel
+ response_model=ResponseModel,
)
async def get_top_endpoints(
request: Request,
- start_ts: Optional[int] = Query(None),
- end_ts: Optional[int] = Query(None),
- range: Optional[str] = Query('24h'),
+ start_ts: int | None = Query(None),
+ end_ts: int | None = Query(None),
+ range: str | None = Query('24h'),
sort_by: str = Query('count', description='Sort by: count, avg_ms, error_rate'),
- limit: int = Query(10, ge=1, le=100)
+ limit: int = Query(10, ge=1, le=100),
):
"""
Get top endpoints sorted by usage or performance.
-
+
Returns detailed per-endpoint metrics including:
- Request count
- Error count and rate
@@ -449,19 +476,21 @@ async def get_top_endpoints(
"""
request_id = str(uuid.uuid4())
start_time = time.time()
-
+
try:
payload = await auth_required(request)
username = payload.get('sub')
-
+
if not await platform_role_required_bool(username, 'view_analytics'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS001',
- error_message='You do not have permission to view analytics'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS001',
+ error_message='You do not have permission to view analytics',
+ )
+ )
+
# Determine time range
if start_ts and end_ts:
pass
@@ -470,41 +499,45 @@ async def get_top_endpoints(
range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000}
seconds = range_map.get(range, 86400)
start_ts = end_ts - seconds
-
+
# Get snapshot
snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts)
-
+
# Get and sort endpoints
endpoints = snapshot.top_endpoints
-
+
if sort_by == 'avg_ms':
endpoints.sort(key=lambda x: x['avg_ms'], reverse=True)
elif sort_by == 'error_rate':
endpoints.sort(key=lambda x: x['error_rate'], reverse=True)
# Default is already sorted by count
-
+
top_endpoints = endpoints[:limit]
-
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response={
- 'time_range': {'start_ts': start_ts, 'end_ts': end_ts},
- 'sort_by': sort_by,
- 'top_endpoints': top_endpoints,
- 'total_endpoints': len(endpoints)
- }
- ))
-
+
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ response={
+ 'time_range': {'start_ts': start_ts, 'end_ts': end_ts},
+ 'sort_by': sort_by,
+ 'top_endpoints': top_endpoints,
+ 'total_endpoints': len(endpoints),
+ },
+ )
+ )
+
except Exception as e:
logger.error(f'{request_id} | Error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS999',
- error_message='An unexpected error occurred'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS999',
+ error_message='An unexpected error occurred',
+ )
+ )
+
finally:
end_time = time.time()
logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms')
@@ -514,21 +547,23 @@ async def get_top_endpoints(
# ENDPOINT 6: Per-API Breakdown
# ============================================================================
-@analytics_router.get('/analytics/api/{api_name}/{version}',
+
+@analytics_router.get(
+ '/analytics/api/{api_name}/{version}',
description='Get detailed analytics for a specific API',
- response_model=ResponseModel
+ response_model=ResponseModel,
)
async def get_api_analytics(
request: Request,
api_name: str,
version: str,
- start_ts: Optional[int] = Query(None),
- end_ts: Optional[int] = Query(None),
- range: Optional[str] = Query('24h')
+ start_ts: int | None = Query(None),
+ end_ts: int | None = Query(None),
+ range: str | None = Query('24h'),
):
"""
Get detailed analytics for a specific API.
-
+
Returns:
- Total requests for this API
- Error count and rate
@@ -538,19 +573,21 @@ async def get_api_analytics(
"""
request_id = str(uuid.uuid4())
start_time = time.time()
-
+
try:
payload = await auth_required(request)
username = payload.get('sub')
-
+
if not await platform_role_required_bool(username, 'view_analytics'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS001',
- error_message='You do not have permission to view analytics'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS001',
+ error_message='You do not have permission to view analytics',
+ )
+ )
+
# Determine time range
if start_ts and end_ts:
pass
@@ -559,55 +596,62 @@ async def get_api_analytics(
range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000}
seconds = range_map.get(range, 86400)
start_ts = end_ts - seconds
-
+
# Get full snapshot
snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts)
-
+
# Filter for this API
- api_key = f"rest:{api_name}" # Assuming REST API
-
+ api_key = f'rest:{api_name}' # Assuming REST API
+
# Find API in top_apis
api_data = None
for api, count in snapshot.top_apis:
if api == api_key:
api_data = {'api': api, 'count': count}
break
-
+
if not api_data:
- return respond_rest(ResponseModel(
- status_code=404,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS404',
- error_message=f'No data found for API: {api_name}/{version}'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=404,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS404',
+ error_message=f'No data found for API: {api_name}/{version}',
+ )
+ )
+
# Filter endpoints for this API
api_endpoints = [
- ep for ep in snapshot.top_endpoints
+ ep
+ for ep in snapshot.top_endpoints
if ep['endpoint_uri'].startswith(f'/{api_name}/{version}')
]
-
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response={
- 'api_name': api_name,
- 'version': version,
- 'time_range': {'start_ts': start_ts, 'end_ts': end_ts},
- 'summary': api_data,
- 'endpoints': api_endpoints
- }
- ))
-
+
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ response={
+ 'api_name': api_name,
+ 'version': version,
+ 'time_range': {'start_ts': start_ts, 'end_ts': end_ts},
+ 'summary': api_data,
+ 'endpoints': api_endpoints,
+ },
+ )
+ )
+
except Exception as e:
logger.error(f'{request_id} | Error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS999',
- error_message='An unexpected error occurred'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS999',
+ error_message='An unexpected error occurred',
+ )
+ )
+
finally:
end_time = time.time()
logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms')
@@ -617,20 +661,22 @@ async def get_api_analytics(
# ENDPOINT 7: Per-User Breakdown
# ============================================================================
-@analytics_router.get('/analytics/user/{username}',
+
+@analytics_router.get(
+ '/analytics/user/{username}',
description='Get detailed analytics for a specific user',
- response_model=ResponseModel
+ response_model=ResponseModel,
)
async def get_user_analytics(
request: Request,
username: str,
- start_ts: Optional[int] = Query(None),
- end_ts: Optional[int] = Query(None),
- range: Optional[str] = Query('24h')
+ start_ts: int | None = Query(None),
+ end_ts: int | None = Query(None),
+ range: str | None = Query('24h'),
):
"""
Get detailed analytics for a specific user.
-
+
Returns:
- Total requests by this user
- APIs accessed
@@ -638,19 +684,21 @@ async def get_user_analytics(
"""
request_id = str(uuid.uuid4())
start_time = time.time()
-
+
try:
payload = await auth_required(request)
requesting_username = payload.get('sub')
-
+
if not await platform_role_required_bool(requesting_username, 'view_analytics'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS001',
- error_message='You do not have permission to view analytics'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS001',
+ error_message='You do not have permission to view analytics',
+ )
+ )
+
# Determine time range
if start_ts and end_ts:
pass
@@ -659,44 +707,50 @@ async def get_user_analytics(
range_map = {'1h': 3600, '24h': 86400, '7d': 604800, '30d': 2592000}
seconds = range_map.get(range, 86400)
start_ts = end_ts - seconds
-
+
# Get full snapshot
snapshot = enhanced_metrics_store.get_snapshot(start_ts, end_ts)
-
+
# Find user in top_users
user_data = None
for user, count in snapshot.top_users:
if user == username:
user_data = {'user': user, 'count': count}
break
-
+
if not user_data:
- return respond_rest(ResponseModel(
- status_code=404,
+ return respond_rest(
+ ResponseModel(
+ status_code=404,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS404',
+ error_message=f'No data found for user: {username}',
+ )
+ )
+
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
response_headers={'request_id': request_id},
- error_code='ANALYTICS404',
- error_message=f'No data found for user: {username}'
- ))
-
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response={
- 'username': username,
- 'time_range': {'start_ts': start_ts, 'end_ts': end_ts},
- 'summary': user_data
- }
- ))
-
+ response={
+ 'username': username,
+ 'time_range': {'start_ts': start_ts, 'end_ts': end_ts},
+ 'summary': user_data,
+ },
+ )
+ )
+
except Exception as e:
logger.error(f'{request_id} | Error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='ANALYTICS999',
- error_message='An unexpected error occurred'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='ANALYTICS999',
+ error_message='An unexpected error occurred',
+ )
+ )
+
finally:
end_time = time.time()
logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms')
diff --git a/backend-services/routes/api_routes.py b/backend-services/routes/api_routes.py
index c3e4a53..4114336 100644
--- a/backend-services/routes/api_routes.py
+++ b/backend-services/routes/api_routes.py
@@ -4,22 +4,22 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from fastapi import APIRouter, Depends, Request, HTTPException
-from typing import List
import logging
-import uuid
import time
+import uuid
+
+from fastapi import APIRouter, HTTPException, Request, Response
-from models.response_model import ResponseModel
-from services.api_service import ApiService
-from utils.auth_util import auth_required
-from models.create_api_model import CreateApiModel
-from models.update_api_model import UpdateApiModel
from models.api_model_response import ApiModelResponse
-from utils.response_util import respond_rest, process_response
-from utils.constants import ErrorCodes, Messages, Defaults, Roles, Headers
-from utils.role_util import platform_role_required_bool
+from models.create_api_model import CreateApiModel
+from models.response_model import ResponseModel
+from models.update_api_model import UpdateApiModel
+from services.api_service import ApiService
from utils.audit_util import audit
+from utils.auth_util import auth_required
+from utils.constants import ErrorCodes, Headers, Messages, Roles
+from utils.response_util import process_response, respond_rest
+from utils.role_util import platform_role_required_bool
api_router = APIRouter()
logger = logging.getLogger('doorman.gateway')
@@ -33,58 +33,65 @@ Response:
{}
"""
-@api_router.post('',
+
+@api_router.post(
+ '',
description='Add API',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'API created successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'API created successfully'}}},
}
- }
+ },
)
-
-async def create_api(request: Request, api_data: CreateApiModel):
+async def create_api(request: Request, api_data: CreateApiModel) -> Response:
payload = await auth_required(request)
username = payload.get('sub')
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
try:
if not await platform_role_required_bool(username, Roles.MANAGE_APIS):
logger.warning(f'{request_id} | Permission denied for user: {username}')
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='API007',
- error_message='You do not have permission to create APIs'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='API007',
+ error_message='You do not have permission to create APIs',
+ )
+ )
result = await ApiService.create_api(api_data, request_id)
- audit(request, actor=username, action='api.create', target=f'{api_data.api_name}/{api_data.api_version}', status=result.get('status_code'), details={'message': result.get('message')}, request_id=request_id)
+ audit(
+ request,
+ actor=username,
+ action='api.create',
+ target=f'{api_data.api_name}/{api_data.api_version}',
+ status=result.get('status_code'),
+ details={'message': result.get('message')},
+ request_id=request_id,
+ )
return respond_rest(result)
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Update API
@@ -94,57 +101,66 @@ Response:
{}
"""
-@api_router.put('/{api_name}/{api_version}',
+
+@api_router.put(
+ '/{api_name}/{api_version}',
description='Update API',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'API updated successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'API updated successfully'}}},
}
- }
+ },
)
-
-async def update_api(api_name: str, api_version: str, request: Request, api_data: UpdateApiModel):
+async def update_api(
+ api_name: str, api_version: str, request: Request, api_data: UpdateApiModel
+) -> Response:
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_APIS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='API008',
- error_message='You do not have permission to update APIs'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='API008',
+ error_message='You do not have permission to update APIs',
+ )
+ )
result = await ApiService.update_api(api_name, api_version, api_data, request_id)
- audit(request, actor=username, action='api.update', target=f'{api_name}/{api_version}', status=result.get('status_code'), details={'message': result.get('message')}, request_id=request_id)
+ audit(
+ request,
+ actor=username,
+ action='api.update',
+ target=f'{api_name}/{api_version}',
+ status=result.get('status_code'),
+ details={'message': result.get('message')},
+ request_id=request_id,
+ )
return respond_rest(result)
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Get API
@@ -154,46 +170,47 @@ Response:
{}
"""
-@api_router.get('/{api_name}/{api_version}',
+
+@api_router.get(
+ '/{api_name}/{api_version}',
description='Get API',
response_model=ApiModelResponse,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'API retrieved successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'API retrieved successfully'}}},
}
- }
+ },
)
-
-async def get_api_by_name_version(api_name: str, api_version: str, request: Request):
+async def get_api_by_name_version(api_name: str, api_version: str, request: Request) -> Response:
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
- return respond_rest(await ApiService.get_api_by_name_version(api_name, api_version, request_id))
+ return respond_rest(
+ await ApiService.get_api_by_name_version(api_name, api_version, request_id)
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Delete API
@@ -203,48 +220,55 @@ Response:
{}
"""
-@api_router.delete('/{api_name}/{api_version}',
+
+@api_router.delete(
+ '/{api_name}/{api_version}',
description='Delete API',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'API deleted successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'API deleted successfully'}}},
}
- }
+ },
)
-
-async def delete_api(api_name: str, api_version: str, request: Request):
+async def delete_api(api_name: str, api_version: str, request: Request) -> Response:
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
result = await ApiService.delete_api(api_name, api_version, request_id)
- audit(request, actor=username, action='api.delete', target=f'{api_name}/{api_version}', status=result.get('status_code'), details={'message': result.get('message')}, request_id=request_id)
+ audit(
+ request,
+ actor=username,
+ action='api.delete',
+ target=f'{api_name}/{api_version}',
+ status=result.get('status_code'),
+ details={'message': result.get('message')},
+ request_id=request_id,
+ )
return respond_rest(result)
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -254,46 +278,47 @@ Response:
{}
"""
-@api_router.get('/all',
- description='Get all APIs',
- response_model=List[ApiModelResponse]
-)
-async def get_all_apis(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE):
+@api_router.get('/all', description='Get all APIs', response_model=list[ApiModelResponse])
+async def get_all_apis(page: int, page_size: int, request: Request) -> Response:
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
return respond_rest(await ApiService.get_apis(page, page_size, request_id))
except HTTPException as e:
# Surface 401/403 properly for tests that probe unauthorized access
- return respond_rest(ResponseModel(
- status_code=e.status_code,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code='API_AUTH',
- error_message=e.detail
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='API_AUTH',
+ error_message=e.detail,
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
-@api_router.get('',
- description='Get all APIs (base path)',
- response_model=List[ApiModelResponse]
-)
-async def get_all_apis_base(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE):
+
+
+@api_router.get('', description='Get all APIs (base path)', response_model=list[ApiModelResponse])
+async def get_all_apis_base(page: int, page_size: int, request: Request) -> Response:
"""Convenience alias for GET /platform/api/all to support tests and clients
that expect listing at the base collection path.
"""
@@ -302,21 +327,28 @@ async def get_all_apis_base(request: Request, page: int = Defaults.PAGE, page_si
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
return respond_rest(await ApiService.get_apis(page, page_size, request_id))
except HTTPException as e:
- return respond_rest(ResponseModel(
- status_code=e.status_code,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code='API_AUTH',
- error_message=e.detail
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='API_AUTH',
+ error_message=e.detail,
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
diff --git a/backend-services/routes/authorization_routes.py b/backend-services/routes/authorization_routes.py
index 6faa9cf..63b8a59 100644
--- a/backend-services/routes/authorization_routes.py
+++ b/backend-services/routes/authorization_routes.py
@@ -4,23 +4,30 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from fastapi import APIRouter, Request, Depends, HTTPException, Response
-from jose import JWTError
-import uuid
-import time
import logging
import os
+import time
+import uuid
+
+from fastapi import APIRouter, HTTPException, Request, Response
+from jose import JWTError
-from models.response_model import ResponseModel
-from services.user_service import UserService
-from utils.response_util import respond_rest
-from utils.auth_util import auth_required, create_access_token
-from utils.auth_blacklist import TimedHeap, jwt_blacklist, revoke_all_for_user, unrevoke_all_for_user, is_user_revoked, add_revoked_jti
-from utils.role_util import platform_role_required_bool
-from utils.role_util import is_admin_user
-from models.update_user_model import UpdateUserModel
from models.create_user_model import CreateUserModel
+from models.response_model import ResponseModel
+from models.update_user_model import UpdateUserModel
+from services.user_service import UserService
+from utils.auth_blacklist import (
+ TimedHeap,
+ add_revoked_jti,
+ is_user_revoked,
+ jwt_blacklist,
+ revoke_all_for_user,
+ unrevoke_all_for_user,
+)
+from utils.auth_util import auth_required, create_access_token
from utils.limit_throttle_util import limit_by_ip
+from utils.response_util import respond_rest
+from utils.role_util import is_admin_user, platform_role_required_bool
authorization_router = APIRouter()
@@ -35,23 +42,18 @@ Response:
{}
"""
-@authorization_router.post('/authorization',
+
+@authorization_router.post(
+ '/authorization',
description='Create authorization token',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'access_token': '******************'
- }
- }
- }
+ 'content': {'application/json': {'example': {'access_token': '******************'}}},
}
- }
+ },
)
-
async def authorization(request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
@@ -65,53 +67,55 @@ async def authorization(request: Request):
try:
data = await request.json()
except Exception:
- return respond_rest(ResponseModel(
- status_code=400,
- response_headers={'request_id': request_id},
- error_code='AUTH004',
- error_message='Invalid JSON payload'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='AUTH004',
+ error_message='Invalid JSON payload',
+ )
+ )
email = data.get('email')
password = data.get('password')
if not email or not password:
- return respond_rest(ResponseModel(
- status_code=400,
- response_headers={
- 'request_id': request_id
- },
- error_code='AUTH001',
- error_message='Missing email or password'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='AUTH001',
+ error_message='Missing email or password',
+ )
+ )
user = await UserService.check_password_return_user(email, password)
if not user:
- return respond_rest(ResponseModel(
- status_code=400,
- response_headers={
- 'request_id': request_id
- },
- error_code='AUTH002',
- error_message='Invalid email or password'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='AUTH002',
+ error_message='Invalid email or password',
+ )
+ )
if not user['active']:
- return respond_rest(ResponseModel(
- status_code=400,
- response_headers={
- 'request_id': request_id
- },
- error_code='AUTH007',
- error_message='User is not active'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='AUTH007',
+ error_message='User is not active',
+ )
+ )
access_token = create_access_token({'sub': user['username'], 'role': user['role']}, False)
- logger.info(f"Login successful for user: {user['username']}")
+ logger.info(f'Login successful for user: {user["username"]}')
- response = respond_rest(ResponseModel(
- status_code=200,
- response_headers={
- 'request_id': request_id
- },
- response={'access_token': access_token}
- ))
+ response = respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ response={'access_token': access_token},
+ )
+ )
if rate_limit_info:
response.headers['X-RateLimit-Limit'] = str(rate_limit_info['limit'])
@@ -121,6 +125,7 @@ async def authorization(request: Request):
response.delete_cookie('access_token_cookie')
import uuid as _uuid
+
csrf_token = str(_uuid.uuid4())
_secure_env = os.getenv('COOKIE_SECURE')
@@ -131,7 +136,9 @@ async def authorization(request: Request):
# Security warning: cookies should be secure in production
if not _secure and os.getenv('ENV', '').lower() in ('production', 'prod'):
- logger.warning(f'{request_id} | SECURITY WARNING: Secure cookies disabled in production environment')
+ logger.warning(
+ f'{request_id} | SECURITY WARNING: Secure cookies disabled in production environment'
+ )
_domain = os.getenv('COOKIE_DOMAIN', None)
_samesite = (os.getenv('COOKIE_SAMESITE', 'Strict') or 'Strict').strip().lower()
@@ -152,7 +159,7 @@ async def authorization(request: Request):
samesite=_samesite,
path='/',
domain=safe_domain,
- max_age=1800
+ max_age=1800,
)
response.set_cookie(
@@ -162,7 +169,7 @@ async def authorization(request: Request):
secure=_secure,
samesite=_samesite,
path='/',
- max_age=1800
+ max_age=1800,
)
response.set_cookie(
@@ -173,7 +180,7 @@ async def authorization(request: Request):
samesite=_samesite,
path='/',
domain=safe_domain,
- max_age=1800
+ max_age=1800,
)
response.set_cookie(
@@ -183,44 +190,44 @@ async def authorization(request: Request):
secure=_secure,
samesite=_samesite,
path='/',
- max_age=1800
+ max_age=1800,
)
return response
except HTTPException as e:
if getattr(e, 'status_code', None) == 429:
headers = getattr(e, 'headers', {}) or {}
detail = e.detail if isinstance(e.detail, dict) else {}
- return respond_rest(ResponseModel(
- status_code=429,
- response_headers={
- 'request_id': request_id,
- **headers
- },
- error_code=str(detail.get('error_code') or 'IP_RATE_LIMIT'),
- error_message=str(detail.get('message') or 'Too many requests')
- ))
- return respond_rest(ResponseModel(
- status_code=401,
- response_headers={
- 'request_id': request_id
- },
- error_code='AUTH003',
- error_message='Unable to validate credentials'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=429,
+ response_headers={'request_id': request_id, **headers},
+ error_code=str(detail.get('error_code') or 'IP_RATE_LIMIT'),
+ error_message=str(detail.get('message') or 'Too many requests'),
+ )
+ )
+ return respond_rest(
+ ResponseModel(
+ status_code=401,
+ response_headers={'request_id': request_id},
+ error_code='AUTH003',
+ error_message='Unable to validate credentials',
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Register new user
@@ -230,23 +237,18 @@ Response:
{}
"""
-@authorization_router.post('/authorization/register',
+
+@authorization_router.post(
+ '/authorization/register',
description='Register new user',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'User created successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'User created successfully'}}},
}
- }
+ },
)
-
async def register(request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
@@ -254,59 +256,68 @@ async def register(request: Request):
# Rate limit registration to prevent abuse
reg_limit = int(os.getenv('REGISTER_IP_RATE_LIMIT', '5'))
reg_window = int(os.getenv('REGISTER_IP_RATE_WINDOW', '3600'))
- rate_limit_info = await limit_by_ip(request, limit=reg_limit, window=reg_window)
+ await limit_by_ip(request, limit=reg_limit, window=reg_window)
+
+ logger.info(
+ f'{request_id} | Register request from: {request.client.host}:{request.client.port}'
+ )
- logger.info(f'{request_id} | Register request from: {request.client.host}:{request.client.port}')
-
try:
data = await request.json()
except Exception:
- return respond_rest(ResponseModel(
- status_code=400,
- response_headers={'request_id': request_id},
- error_code='AUTH004',
- error_message='Invalid JSON payload'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='AUTH004',
+ error_message='Invalid JSON payload',
+ )
+ )
+
# Validate required fields
if not data.get('email') or not data.get('password'):
- return respond_rest(ResponseModel(
- status_code=400,
- response_headers={'request_id': request_id},
- error_code='AUTH001',
- error_message='Missing email or password'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='AUTH001',
+ error_message='Missing email or password',
+ )
+ )
# Create user model
# Default to 'user' role and active=True
user_data = CreateUserModel(
- username=data.get('email').split('@')[0], # Simple username derivation
+ username=data.get('email').split('@')[0], # Simple username derivation
email=data.get('email'),
password=data.get('password'),
role='user',
- active=True
+ active=True,
)
# Check if user exists (UserService.create_user handles this but we want clean error)
# Actually UserService.create_user will return error if exists.
-
+
result = await UserService.create_user(user_data, request_id)
-
+
# If successful, we could auto-login, but for now just return success
return respond_rest(result)
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -316,53 +327,66 @@ Response:
{}
"""
-@authorization_router.post('/authorization/admin/revoke/{username}',
- description='Revoke all active tokens for a user (admin)',
- response_model=ResponseModel)
+@authorization_router.post(
+ '/authorization/admin/revoke/{username}',
+ description='Revoke all active tokens for a user (admin)',
+ response_model=ResponseModel,
+)
async def admin_revoke_user_tokens(username: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
admin_user = payload.get('sub')
- logger.info(f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(admin_user, 'manage_auth'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='AUTH900',
- error_message='You do not have permission to manage auth'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='AUTH900',
+ error_message='You do not have permission to manage auth',
+ )
+ )
try:
if await is_admin_user(username) and not await is_admin_user(admin_user):
- return respond_rest(ResponseModel(
- status_code=404,
- response_headers={'request_id': request_id},
- error_message='User not found'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=404,
+ response_headers={'request_id': request_id},
+ error_message='User not found',
+ )
+ )
except Exception as e:
logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True)
revoke_all_for_user(username)
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- message=f'All tokens revoked for {username}'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ message=f'All tokens revoked for {username}',
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -372,53 +396,66 @@ Response:
{}
"""
-@authorization_router.post('/authorization/admin/unrevoke/{username}',
- description='Clear token revocation for a user (admin)',
- response_model=ResponseModel)
+@authorization_router.post(
+ '/authorization/admin/unrevoke/{username}',
+ description='Clear token revocation for a user (admin)',
+ response_model=ResponseModel,
+)
async def admin_unrevoke_user_tokens(username: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
admin_user = payload.get('sub')
- logger.info(f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(admin_user, 'manage_auth'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='AUTH900',
- error_message='You do not have permission to manage auth'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='AUTH900',
+ error_message='You do not have permission to manage auth',
+ )
+ )
try:
if await is_admin_user(username) and not await is_admin_user(admin_user):
- return respond_rest(ResponseModel(
- status_code=404,
- response_headers={'request_id': request_id},
- error_message='User not found'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=404,
+ response_headers={'request_id': request_id},
+ error_message='User not found',
+ )
+ )
except Exception as e:
logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True)
unrevoke_all_for_user(username)
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- message=f'Token revocation cleared for {username}'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ message=f'Token revocation cleared for {username}',
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -428,56 +465,69 @@ Response:
{}
"""
-@authorization_router.post('/authorization/admin/disable/{username}',
- description='Disable a user (admin)',
- response_model=ResponseModel)
+@authorization_router.post(
+ '/authorization/admin/disable/{username}',
+ description='Disable a user (admin)',
+ response_model=ResponseModel,
+)
async def admin_disable_user(username: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
admin_user = payload.get('sub')
- logger.info(f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(admin_user, 'manage_auth'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='AUTH900',
- error_message='You do not have permission to manage auth'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='AUTH900',
+ error_message='You do not have permission to manage auth',
+ )
+ )
try:
if await is_admin_user(username) and not await is_admin_user(admin_user):
- return respond_rest(ResponseModel(
- status_code=404,
- response_headers={'request_id': request_id},
- error_message='User not found'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=404,
+ response_headers={'request_id': request_id},
+ error_message='User not found',
+ )
+ )
except Exception as e:
logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True)
await UserService.update_user(username, UpdateUserModel(active=False), request_id)
revoke_all_for_user(username)
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- message=f'User {username} disabled and tokens revoked'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ message=f'User {username} disabled and tokens revoked',
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -487,54 +537,67 @@ Response:
{}
"""
-@authorization_router.post('/authorization/admin/enable/{username}',
- description='Enable a user (admin)',
- response_model=ResponseModel)
+@authorization_router.post(
+ '/authorization/admin/enable/{username}',
+ description='Enable a user (admin)',
+ response_model=ResponseModel,
+)
async def admin_enable_user(username: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
admin_user = payload.get('sub')
- logger.info(f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(admin_user, 'manage_auth'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='AUTH900',
- error_message='You do not have permission to manage auth'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='AUTH900',
+ error_message='You do not have permission to manage auth',
+ )
+ )
try:
if await is_admin_user(username) and not await is_admin_user(admin_user):
- return respond_rest(ResponseModel(
- status_code=404,
- response_headers={'request_id': request_id},
- error_message='User not found'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=404,
+ response_headers={'request_id': request_id},
+ error_message='User not found',
+ )
+ )
except Exception as e:
logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True)
await UserService.update_user(username, UpdateUserModel(active=True), request_id)
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- message=f'User {username} enabled'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ message=f'User {username} enabled',
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -544,57 +607,65 @@ Response:
{}
"""
-@authorization_router.get('/authorization/admin/status/{username}',
- description='Get auth status for a user (admin)',
- response_model=ResponseModel)
+@authorization_router.get(
+ '/authorization/admin/status/{username}',
+ description='Get auth status for a user (admin)',
+ response_model=ResponseModel,
+)
async def admin_user_status(username: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
admin_user = payload.get('sub')
- logger.info(f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {admin_user} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(admin_user, 'manage_auth'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='AUTH900',
- error_message='You do not have permission to manage auth'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='AUTH900',
+ error_message='You do not have permission to manage auth',
+ )
+ )
try:
if await is_admin_user(username) and not await is_admin_user(admin_user):
- return respond_rest(ResponseModel(
- status_code=404,
- response_headers={'request_id': request_id},
- error_message='User not found'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=404,
+ response_headers={'request_id': request_id},
+ error_message='User not found',
+ )
+ )
except Exception as e:
logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True)
user = await UserService.get_user_by_username_helper(username)
- status = {
- 'active': bool(user.get('active', False)),
- 'revoked': is_user_revoked(username)
- }
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response=status
- ))
+ status = {'active': bool(user.get('active', False)), 'revoked': is_user_revoked(username)}
+ return respond_rest(
+ ResponseModel(
+ status_code=200, response_headers={'request_id': request_id}, response=status
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Create authorization refresh token
@@ -604,51 +675,49 @@ Response:
{}
"""
-@authorization_router.post('/authorization/refresh',
+
+@authorization_router.post(
+ '/authorization/refresh',
description='Create authorization refresh token',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'refresh_token': '******************'
- }
- }
- }
+ 'content': {'application/json': {'example': {'refresh_token': '******************'}}},
}
- }
+ },
)
-
async def extended_authorization(request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
user = await UserService.get_user_by_username_helper(username)
if not user['active']:
- return respond_rest(ResponseModel(
- status_code=400,
- response_headers={
- 'request_id': request_id
- },
- error_code='AUTH007',
- error_message='User is not active'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='AUTH007',
+ error_message='User is not active',
+ )
+ )
refresh_token = create_access_token({'sub': username, 'role': user['role']}, True)
- response = respond_rest(ResponseModel(
- status_code=200,
- response_headers={
- 'request_id': request_id
- },
- response={'refresh_token': refresh_token}
- ))
+ response = respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ response={'refresh_token': refresh_token},
+ )
+ )
import uuid as _uuid
+
csrf_token = str(_uuid.uuid4())
_secure_env = os.getenv('COOKIE_SECURE')
@@ -659,7 +728,9 @@ async def extended_authorization(request: Request):
# Security warning: cookies should be secure in production
if not _secure and os.getenv('ENV', '').lower() in ('production', 'prod'):
- logger.warning(f'{request_id} | SECURITY WARNING: Secure cookies disabled in production environment')
+ logger.warning(
+ f'{request_id} | SECURITY WARNING: Secure cookies disabled in production environment'
+ )
_domain = os.getenv('COOKIE_DOMAIN', None)
_samesite = (os.getenv('COOKIE_SAMESITE', 'Strict') or 'Strict').strip().lower()
@@ -680,7 +751,7 @@ async def extended_authorization(request: Request):
samesite=_samesite,
path='/',
domain=safe_domain,
- max_age=604800
+ max_age=604800,
)
response.set_cookie(
@@ -690,7 +761,7 @@ async def extended_authorization(request: Request):
secure=_secure,
samesite=_samesite,
path='/',
- max_age=604800
+ max_age=604800,
)
response.set_cookie(
@@ -701,7 +772,7 @@ async def extended_authorization(request: Request):
samesite=_samesite,
path='/',
domain=safe_domain,
- max_age=604800
+ max_age=604800,
)
response.set_cookie(
@@ -711,42 +782,43 @@ async def extended_authorization(request: Request):
secure=_secure,
samesite=_samesite,
path='/',
- max_age=604800
+ max_age=604800,
)
return response
- except HTTPException as e:
- return respond_rest(ResponseModel(
- status_code=401,
- response_headers={
- 'request_id': request_id
- },
- error_code='AUTH003',
- error_message='Unable to validate credentials'
- ))
+ except HTTPException:
+ return respond_rest(
+ ResponseModel(
+ status_code=401,
+ response_headers={'request_id': request_id},
+ error_code='AUTH003',
+ error_message='Unable to validate credentials',
+ )
+ )
except JWTError as e:
logging.error(f'Token refresh failed: {str(e)}')
- return respond_rest(ResponseModel(
- status_code=401,
- response_headers={
- 'request_id': request_id
- },
- error_code='AUTH004',
- error_message='Token refresh failed'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=401,
+ response_headers={'request_id': request_id},
+ error_code='AUTH004',
+ error_message='Token refresh failed',
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Get authorization token status
@@ -756,61 +828,59 @@ Response:
{}
"""
-@authorization_router.get('/authorization/status',
+
+@authorization_router.get(
+ '/authorization/status',
description='Get authorization token status',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'status': 'authorized'
- }
- }
- }
+ 'content': {'application/json': {'example': {'status': 'authorized'}}},
}
- }
+ },
)
-
async def authorization_status(request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={
- 'request_id': request_id
- },
- message='Token is valid'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ message='Token is valid',
+ )
+ )
except JWTError:
- return respond_rest(ResponseModel(
- status_code=401,
- response_headers={
- 'request_id': request_id
- },
- error_code='AUTH005',
- error_message='Token is invalid'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=401,
+ response_headers={'request_id': request_id},
+ error_code='AUTH005',
+ error_message='Token is invalid',
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Invalidate authorization token
@@ -820,33 +890,33 @@ Response:
{}
"""
-@authorization_router.post('/authorization/invalidate',
+
+@authorization_router.post(
+ '/authorization/invalidate',
description='Invalidate authorization token',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Your token has been invalidated'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Your token has been invalidated'}}
+ },
}
- }
+ },
)
-
async def authorization_invalidate(response: Response, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
try:
import time as _t
+
exp = payload.get('exp')
ttl = None
if isinstance(exp, (int, float)):
@@ -857,29 +927,31 @@ async def authorization_invalidate(response: Response, request: Request):
if username not in jwt_blacklist:
jwt_blacklist[username] = TimedHeap()
jwt_blacklist[username].push(payload.get('jti'))
- response = respond_rest(ResponseModel(
- status_code=200,
- response_headers={
- 'request_id': request_id
- },
- message='Your token has been invalidated'
- ))
+ response = respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ message='Your token has been invalidated',
+ )
+ )
_domain = os.getenv('COOKIE_DOMAIN', None)
host = request.url.hostname or (request.client.host if request.client else None)
- safe_domain = _domain if (_domain and host and (host == _domain or host.endswith(_domain))) else None
+ safe_domain = (
+ _domain if (_domain and host and (host == _domain or host.endswith(_domain))) else None
+ )
response.delete_cookie('access_token_cookie', domain=safe_domain, path='/')
return response
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/config_hot_reload_routes.py b/backend-services/routes/config_hot_reload_routes.py
index a380734..5a1f183 100644
--- a/backend-services/routes/config_hot_reload_routes.py
+++ b/backend-services/routes/config_hot_reload_routes.py
@@ -4,37 +4,33 @@ Configuration Hot Reload Routes
API endpoints for configuration management and hot reload.
"""
-from fastapi import APIRouter, Depends, HTTPException
-from typing import Dict, Any
import logging
+from typing import Any
+from fastapi import APIRouter, Depends, HTTPException
+
+from models.response_model import ResponseModel
from utils.auth_util import auth_required
from utils.hot_reload_config import hot_config
-from models.response_model import ResponseModel
logger = logging.getLogger('doorman.gateway')
-config_hot_reload_router = APIRouter(
- prefix='/config',
- tags=['Configuration Hot Reload']
-)
+config_hot_reload_router = APIRouter(prefix='/config', tags=['Configuration Hot Reload'])
+
@config_hot_reload_router.get(
'/current',
summary='Get Current Configuration',
description='Retrieve current hot-reloadable configuration values',
- response_model=Dict[str, Any],
+ response_model=dict[str, Any],
)
-async def get_current_config(
- payload: dict = Depends(auth_required)
-):
+async def get_current_config(payload: dict = Depends(auth_required)):
"""Get current configuration (admin only)"""
try:
accesses = payload.get('accesses', {})
if not accesses.get('manage_gateway'):
raise HTTPException(
- status_code=403,
- detail='Insufficient permissions: manage_gateway required'
+ status_code=403, detail='Insufficient permissions: manage_gateway required'
)
config = hot_config.dump()
@@ -44,101 +40,126 @@ async def get_current_config(
data={
'config': config,
'source': 'Environment variables override config file values',
- 'reload_command': 'kill -HUP $(cat doorman.pid)'
+ 'reload_command': 'kill -HUP $(cat doorman.pid)',
},
error_code=None,
- error_message=None
+ error_message=None,
).dict()
except HTTPException:
raise
except Exception as e:
logger.error(f'Failed to retrieve configuration: {e}', exc_info=True)
- raise HTTPException(
- status_code=500,
- detail='Failed to retrieve configuration'
- )
+ raise HTTPException(status_code=500, detail='Failed to retrieve configuration')
+
@config_hot_reload_router.post(
'/reload',
summary='Trigger Configuration Reload',
description='Manually trigger configuration reload (same as SIGHUP)',
- response_model=Dict[str, Any],
+ response_model=dict[str, Any],
)
-async def trigger_config_reload(
- payload: dict = Depends(auth_required)
-):
+async def trigger_config_reload(payload: dict = Depends(auth_required)):
"""Trigger configuration reload (admin only)"""
try:
accesses = payload.get('accesses', {})
if not accesses.get('manage_gateway'):
raise HTTPException(
- status_code=403,
- detail='Insufficient permissions: manage_gateway required'
+ status_code=403, detail='Insufficient permissions: manage_gateway required'
)
hot_config.reload()
return ResponseModel(
status_code=200,
- data={
- 'message': 'Configuration reloaded successfully',
- 'config': hot_config.dump()
- },
+ data={'message': 'Configuration reloaded successfully', 'config': hot_config.dump()},
error_code=None,
- error_message=None
+ error_message=None,
).dict()
except HTTPException:
raise
except Exception as e:
logger.error(f'Failed to reload configuration: {e}', exc_info=True)
- raise HTTPException(
- status_code=500,
- detail='Failed to reload configuration'
- )
+ raise HTTPException(status_code=500, detail='Failed to reload configuration')
+
@config_hot_reload_router.get(
'/reloadable-keys',
summary='List Reloadable Configuration Keys',
description='Get list of configuration keys that support hot reload',
- response_model=Dict[str, Any],
+ response_model=dict[str, Any],
)
-async def get_reloadable_keys(
- payload: dict = Depends(auth_required)
-):
+async def get_reloadable_keys(payload: dict = Depends(auth_required)):
"""Get list of reloadable configuration keys"""
try:
reloadable_keys = [
- {'key': 'LOG_LEVEL', 'description': 'Log level (DEBUG, INFO, WARNING, ERROR)', 'example': 'INFO'},
+ {
+ 'key': 'LOG_LEVEL',
+ 'description': 'Log level (DEBUG, INFO, WARNING, ERROR)',
+ 'example': 'INFO',
+ },
{'key': 'LOG_FORMAT', 'description': 'Log format (json, text)', 'example': 'json'},
{'key': 'LOG_FILE', 'description': 'Log file path', 'example': 'logs/doorman.log'},
-
- {'key': 'GATEWAY_TIMEOUT', 'description': 'Gateway timeout in seconds', 'example': '30'},
- {'key': 'UPSTREAM_TIMEOUT', 'description': 'Upstream timeout in seconds', 'example': '30'},
- {'key': 'CONNECTION_TIMEOUT', 'description': 'Connection timeout in seconds', 'example': '10'},
-
+ {
+ 'key': 'GATEWAY_TIMEOUT',
+ 'description': 'Gateway timeout in seconds',
+ 'example': '30',
+ },
+ {
+ 'key': 'UPSTREAM_TIMEOUT',
+ 'description': 'Upstream timeout in seconds',
+ 'example': '30',
+ },
+ {
+ 'key': 'CONNECTION_TIMEOUT',
+ 'description': 'Connection timeout in seconds',
+ 'example': '10',
+ },
{'key': 'RATE_LIMIT_ENABLED', 'description': 'Enable rate limiting', 'example': 'true'},
{'key': 'RATE_LIMIT_REQUESTS', 'description': 'Requests per window', 'example': '100'},
{'key': 'RATE_LIMIT_WINDOW', 'description': 'Window size in seconds', 'example': '60'},
-
{'key': 'CACHE_TTL', 'description': 'Cache TTL in seconds', 'example': '300'},
{'key': 'CACHE_MAX_SIZE', 'description': 'Maximum cache entries', 'example': '1000'},
-
- {'key': 'CIRCUIT_BREAKER_ENABLED', 'description': 'Enable circuit breaker', 'example': 'true'},
- {'key': 'CIRCUIT_BREAKER_THRESHOLD', 'description': 'Failures before opening', 'example': '5'},
- {'key': 'CIRCUIT_BREAKER_TIMEOUT', 'description': 'Timeout before retry (seconds)', 'example': '60'},
-
+ {
+ 'key': 'CIRCUIT_BREAKER_ENABLED',
+ 'description': 'Enable circuit breaker',
+ 'example': 'true',
+ },
+ {
+ 'key': 'CIRCUIT_BREAKER_THRESHOLD',
+ 'description': 'Failures before opening',
+ 'example': '5',
+ },
+ {
+ 'key': 'CIRCUIT_BREAKER_TIMEOUT',
+ 'description': 'Timeout before retry (seconds)',
+ 'example': '60',
+ },
{'key': 'RETRY_ENABLED', 'description': 'Enable retry logic', 'example': 'true'},
{'key': 'RETRY_MAX_ATTEMPTS', 'description': 'Maximum retry attempts', 'example': '3'},
{'key': 'RETRY_BACKOFF', 'description': 'Backoff multiplier', 'example': '1.0'},
-
- {'key': 'METRICS_ENABLED', 'description': 'Enable metrics collection', 'example': 'true'},
- {'key': 'METRICS_INTERVAL', 'description': 'Metrics interval (seconds)', 'example': '60'},
-
- {'key': 'FEATURE_REQUEST_REPLAY', 'description': 'Enable request replay', 'example': 'false'},
+ {
+ 'key': 'METRICS_ENABLED',
+ 'description': 'Enable metrics collection',
+ 'example': 'true',
+ },
+ {
+ 'key': 'METRICS_INTERVAL',
+ 'description': 'Metrics interval (seconds)',
+ 'example': '60',
+ },
+ {
+ 'key': 'FEATURE_REQUEST_REPLAY',
+ 'description': 'Enable request replay',
+ 'example': 'false',
+ },
{'key': 'FEATURE_AB_TESTING', 'description': 'Enable A/B testing', 'example': 'false'},
- {'key': 'FEATURE_COST_ANALYTICS', 'description': 'Enable cost analytics', 'example': 'false'},
+ {
+ 'key': 'FEATURE_COST_ANALYTICS',
+ 'description': 'Enable cost analytics',
+ 'example': 'false',
+ },
]
return ResponseModel(
@@ -150,16 +171,13 @@ async def get_reloadable_keys(
'Environment variables always override config file values',
'Changes take effect immediately after reload',
'Reload via: kill -HUP $(cat doorman.pid)',
- 'Or use: POST /config/reload'
- ]
+ 'Or use: POST /config/reload',
+ ],
},
error_code=None,
- error_message=None
+ error_message=None,
).dict()
except Exception as e:
logger.error(f'Failed to retrieve reloadable keys: {e}', exc_info=True)
- raise HTTPException(
- status_code=500,
- detail='Failed to retrieve reloadable keys'
- )
+ raise HTTPException(status_code=500, detail='Failed to retrieve reloadable keys')
diff --git a/backend-services/routes/config_routes.py b/backend-services/routes/config_routes.py
index 76322f0..82046b5 100644
--- a/backend-services/routes/config_routes.py
+++ b/backend-services/routes/config_routes.py
@@ -2,37 +2,39 @@
Routes to export and import platform configuration (APIs, Endpoints, Roles, Groups, Routings).
"""
-from fastapi import APIRouter, Request
-from typing import Any, Dict, List, Optional
-import uuid
-import time
-import logging
import copy
+import logging
+import time
+import uuid
+from typing import Any
+
+from fastapi import APIRouter, Request
from models.response_model import ResponseModel
-from utils.response_util import process_response
-from utils.auth_util import auth_required
-from utils.role_util import platform_role_required_bool
-from utils.doorman_cache_util import doorman_cache
from utils.audit_util import audit
+from utils.auth_util import auth_required
from utils.database import (
-
api_collection,
endpoint_collection,
group_collection,
role_collection,
routing_collection,
)
+from utils.doorman_cache_util import doorman_cache
+from utils.response_util import process_response
+from utils.role_util import platform_role_required_bool
config_router = APIRouter()
logger = logging.getLogger('doorman.gateway')
-def _strip_id(doc: Dict[str, Any]) -> Dict[str, Any]:
+
+def _strip_id(doc: dict[str, Any]) -> dict[str, Any]:
d = dict(doc)
d.pop('_id', None)
return d
-def _export_all() -> Dict[str, Any]:
+
+def _export_all() -> dict[str, Any]:
apis = [_strip_id(a) for a in api_collection.find().to_list(length=None)]
endpoints = [_strip_id(e) for e in endpoint_collection.find().to_list(length=None)]
roles = [_strip_id(r) for r in role_collection.find().to_list(length=None)]
@@ -46,6 +48,7 @@ def _export_all() -> Dict[str, Any]:
'routings': routings,
}
+
"""
Endpoint
@@ -55,11 +58,12 @@ Response:
{}
"""
-@config_router.get('/config/export/all',
+
+@config_router.get(
+ '/config/export/all',
description='Export all platform configuration (APIs, Endpoints, Roles, Groups, Routings)',
response_model=ResponseModel,
)
-
async def export_all(request: Request):
request_id = str(uuid.uuid4())
start = time.time() * 1000
@@ -67,15 +71,39 @@ async def export_all(request: Request):
payload = await auth_required(request)
username = payload.get('sub')
if not await platform_role_required_bool(username, 'manage_gateway'):
- return process_response(ResponseModel(status_code=403, error_code='CFG001', error_message='Insufficient permissions').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403, error_code='CFG001', error_message='Insufficient permissions'
+ ).dict(),
+ 'rest',
+ )
data = _export_all()
- audit(request, actor=username, action='config.export_all', target='all', status='success', details={'counts': {k: len(v) for k,v in data.items()}}, request_id=request_id)
- return process_response(ResponseModel(status_code=200, response_headers={'request_id': request_id}, response=data).dict(), 'rest')
+ audit(
+ request,
+ actor=username,
+ action='config.export_all',
+ target='all',
+ status='success',
+ details={'counts': {k: len(v) for k, v in data.items()}},
+ request_id=request_id,
+ )
+ return process_response(
+ ResponseModel(
+ status_code=200, response_headers={'request_id': request_id}, response=data
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.error(f'{request_id} | export_all error: {e}')
- return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500, error_code='GTW999', error_message='An unexpected error occurred'
+ ).dict(),
+ 'rest',
+ )
finally:
- logger.info(f'{request_id} | export_all took {time.time()*1000 - start:.2f}ms')
+ logger.info(f'{request_id} | export_all took {time.time() * 1000 - start:.2f}ms')
+
"""
Endpoint
@@ -86,37 +114,80 @@ Response:
{}
"""
-@config_router.get('/config/export/apis',
- description='Export APIs (optionally a single API with its endpoints)',
- response_model=ResponseModel)
-async def export_apis(request: Request, api_name: Optional[str] = None, api_version: Optional[str] = None):
+@config_router.get(
+ '/config/export/apis',
+ description='Export APIs (optionally a single API with its endpoints)',
+ response_model=ResponseModel,
+)
+async def export_apis(
+ request: Request, api_name: str | None = None, api_version: str | None = None
+):
request_id = str(uuid.uuid4())
start = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
if not await platform_role_required_bool(username, 'manage_apis'):
- return process_response(ResponseModel(status_code=403, error_code='CFG002', error_message='Insufficient permissions').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403, error_code='CFG002', error_message='Insufficient permissions'
+ ).dict(),
+ 'rest',
+ )
if api_name and api_version:
api = api_collection.find_one({'api_name': api_name, 'api_version': api_version})
if not api:
- return process_response(ResponseModel(status_code=404, error_code='CFG404', error_message='API not found').dict(), 'rest')
- aid = api.get('api_id')
- eps = endpoint_collection.find({'api_name': api_name, 'api_version': api_version}).to_list(length=None)
- audit(request, actor=username, action='config.export_api', target=f'{api_name}/{api_version}', status='success', details={'endpoints': len(eps)}, request_id=request_id)
- return process_response(ResponseModel(status_code=200, response={
- 'api': _strip_id(api),
- 'endpoints': [_strip_id(e) for e in eps]
- }).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=404, error_code='CFG404', error_message='API not found'
+ ).dict(),
+ 'rest',
+ )
+ api.get('api_id')
+ eps = endpoint_collection.find(
+ {'api_name': api_name, 'api_version': api_version}
+ ).to_list(length=None)
+ audit(
+ request,
+ actor=username,
+ action='config.export_api',
+ target=f'{api_name}/{api_version}',
+ status='success',
+ details={'endpoints': len(eps)},
+ request_id=request_id,
+ )
+ return process_response(
+ ResponseModel(
+ status_code=200,
+ response={'api': _strip_id(api), 'endpoints': [_strip_id(e) for e in eps]},
+ ).dict(),
+ 'rest',
+ )
apis = [_strip_id(a) for a in api_collection.find().to_list(length=None)]
- audit(request, actor=username, action='config.export_apis', target='list', status='success', details={'count': len(apis)}, request_id=request_id)
- return process_response(ResponseModel(status_code=200, response={'apis': apis}).dict(), 'rest')
+ audit(
+ request,
+ actor=username,
+ action='config.export_apis',
+ target='list',
+ status='success',
+ details={'count': len(apis)},
+ request_id=request_id,
+ )
+ return process_response(
+ ResponseModel(status_code=200, response={'apis': apis}).dict(), 'rest'
+ )
except Exception as e:
logger.error(f'{request_id} | export_apis error: {e}')
- return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500, error_code='GTW999', error_message='An unexpected error occurred'
+ ).dict(),
+ 'rest',
+ )
finally:
- logger.info(f'{request_id} | export_apis took {time.time()*1000 - start:.2f}ms')
+ logger.info(f'{request_id} | export_apis took {time.time() * 1000 - start:.2f}ms')
+
"""
Endpoint
@@ -127,27 +198,63 @@ Response:
{}
"""
-@config_router.get('/config/export/roles', description='Export Roles', response_model=ResponseModel)
-async def export_roles(request: Request, role_name: Optional[str] = None):
+@config_router.get('/config/export/roles', description='Export Roles', response_model=ResponseModel)
+async def export_roles(request: Request, role_name: str | None = None):
request_id = str(uuid.uuid4())
try:
payload = await auth_required(request)
username = payload.get('sub')
if not await platform_role_required_bool(username, 'manage_roles'):
- return process_response(ResponseModel(status_code=403, error_code='CFG003', error_message='Insufficient permissions').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403, error_code='CFG003', error_message='Insufficient permissions'
+ ).dict(),
+ 'rest',
+ )
if role_name:
role = role_collection.find_one({'role_name': role_name})
if not role:
- return process_response(ResponseModel(status_code=404, error_code='CFG404', error_message='Role not found').dict(), 'rest')
- audit(request, actor=username, action='config.export_role', target=role_name, status='success', details=None, request_id=request_id)
- return process_response(ResponseModel(status_code=200, response={'role': _strip_id(role)}).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=404, error_code='CFG404', error_message='Role not found'
+ ).dict(),
+ 'rest',
+ )
+ audit(
+ request,
+ actor=username,
+ action='config.export_role',
+ target=role_name,
+ status='success',
+ details=None,
+ request_id=request_id,
+ )
+ return process_response(
+ ResponseModel(status_code=200, response={'role': _strip_id(role)}).dict(), 'rest'
+ )
roles = [_strip_id(r) for r in role_collection.find().to_list(length=None)]
- audit(request, actor=username, action='config.export_roles', target='list', status='success', details={'count': len(roles)}, request_id=request_id)
- return process_response(ResponseModel(status_code=200, response={'roles': roles}).dict(), 'rest')
+ audit(
+ request,
+ actor=username,
+ action='config.export_roles',
+ target='list',
+ status='success',
+ details={'count': len(roles)},
+ request_id=request_id,
+ )
+ return process_response(
+ ResponseModel(status_code=200, response={'roles': roles}).dict(), 'rest'
+ )
except Exception as e:
logger.error(f'{request_id} | export_roles error: {e}')
- return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500, error_code='GTW999', error_message='An unexpected error occurred'
+ ).dict(),
+ 'rest',
+ )
+
"""
Endpoint
@@ -158,27 +265,65 @@ Response:
{}
"""
-@config_router.get('/config/export/groups', description='Export Groups', response_model=ResponseModel)
-async def export_groups(request: Request, group_name: Optional[str] = None):
+@config_router.get(
+ '/config/export/groups', description='Export Groups', response_model=ResponseModel
+)
+async def export_groups(request: Request, group_name: str | None = None):
request_id = str(uuid.uuid4())
try:
payload = await auth_required(request)
username = payload.get('sub')
if not await platform_role_required_bool(username, 'manage_groups'):
- return process_response(ResponseModel(status_code=403, error_code='CFG004', error_message='Insufficient permissions').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403, error_code='CFG004', error_message='Insufficient permissions'
+ ).dict(),
+ 'rest',
+ )
if group_name:
group = group_collection.find_one({'group_name': group_name})
if not group:
- return process_response(ResponseModel(status_code=404, error_code='CFG404', error_message='Group not found').dict(), 'rest')
- audit(request, actor=username, action='config.export_group', target=group_name, status='success', details=None, request_id=request_id)
- return process_response(ResponseModel(status_code=200, response={'group': _strip_id(group)}).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=404, error_code='CFG404', error_message='Group not found'
+ ).dict(),
+ 'rest',
+ )
+ audit(
+ request,
+ actor=username,
+ action='config.export_group',
+ target=group_name,
+ status='success',
+ details=None,
+ request_id=request_id,
+ )
+ return process_response(
+ ResponseModel(status_code=200, response={'group': _strip_id(group)}).dict(), 'rest'
+ )
groups = [_strip_id(g) for g in group_collection.find().to_list(length=None)]
- audit(request, actor=username, action='config.export_groups', target='list', status='success', details={'count': len(groups)}, request_id=request_id)
- return process_response(ResponseModel(status_code=200, response={'groups': groups}).dict(), 'rest')
+ audit(
+ request,
+ actor=username,
+ action='config.export_groups',
+ target='list',
+ status='success',
+ details={'count': len(groups)},
+ request_id=request_id,
+ )
+ return process_response(
+ ResponseModel(status_code=200, response={'groups': groups}).dict(), 'rest'
+ )
except Exception as e:
logger.error(f'{request_id} | export_groups error: {e}')
- return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500, error_code='GTW999', error_message='An unexpected error occurred'
+ ).dict(),
+ 'rest',
+ )
+
"""
Endpoint
@@ -189,27 +334,66 @@ Response:
{}
"""
-@config_router.get('/config/export/routings', description='Export Routings', response_model=ResponseModel)
-async def export_routings(request: Request, client_key: Optional[str] = None):
+@config_router.get(
+ '/config/export/routings', description='Export Routings', response_model=ResponseModel
+)
+async def export_routings(request: Request, client_key: str | None = None):
request_id = str(uuid.uuid4())
try:
payload = await auth_required(request)
username = payload.get('sub')
if not await platform_role_required_bool(username, 'manage_routings'):
- return process_response(ResponseModel(status_code=403, error_code='CFG005', error_message='Insufficient permissions').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403, error_code='CFG005', error_message='Insufficient permissions'
+ ).dict(),
+ 'rest',
+ )
if client_key:
routing = routing_collection.find_one({'client_key': client_key})
if not routing:
- return process_response(ResponseModel(status_code=404, error_code='CFG404', error_message='Routing not found').dict(), 'rest')
- audit(request, actor=username, action='config.export_routing', target=client_key, status='success', details=None, request_id=request_id)
- return process_response(ResponseModel(status_code=200, response={'routing': _strip_id(routing)}).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=404, error_code='CFG404', error_message='Routing not found'
+ ).dict(),
+ 'rest',
+ )
+ audit(
+ request,
+ actor=username,
+ action='config.export_routing',
+ target=client_key,
+ status='success',
+ details=None,
+ request_id=request_id,
+ )
+ return process_response(
+ ResponseModel(status_code=200, response={'routing': _strip_id(routing)}).dict(),
+ 'rest',
+ )
routings = [_strip_id(r) for r in routing_collection.find().to_list(length=None)]
- audit(request, actor=username, action='config.export_routings', target='list', status='success', details={'count': len(routings)}, request_id=request_id)
- return process_response(ResponseModel(status_code=200, response={'routings': routings}).dict(), 'rest')
+ audit(
+ request,
+ actor=username,
+ action='config.export_routings',
+ target='list',
+ status='success',
+ details={'count': len(routings)},
+ request_id=request_id,
+ )
+ return process_response(
+ ResponseModel(status_code=200, response={'routings': routings}).dict(), 'rest'
+ )
except Exception as e:
logger.error(f'{request_id} | export_routings error: {e}')
- return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500, error_code='GTW999', error_message='An unexpected error occurred'
+ ).dict(),
+ 'rest',
+ )
+
"""
Endpoint
@@ -220,30 +404,47 @@ Response:
{}
"""
-@config_router.get('/config/export/endpoints',
- description='Export endpoints (optionally filter by api_name/api_version)',
- response_model=ResponseModel)
-async def export_endpoints(request: Request, api_name: Optional[str] = None, api_version: Optional[str] = None):
+@config_router.get(
+ '/config/export/endpoints',
+ description='Export endpoints (optionally filter by api_name/api_version)',
+ response_model=ResponseModel,
+)
+async def export_endpoints(
+ request: Request, api_name: str | None = None, api_version: str | None = None
+):
request_id = str(uuid.uuid4())
try:
payload = await auth_required(request)
username = payload.get('sub')
if not await platform_role_required_bool(username, 'manage_endpoints'):
- return process_response(ResponseModel(status_code=403, error_code='CFG007', error_message='Insufficient permissions').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403, error_code='CFG007', error_message='Insufficient permissions'
+ ).dict(),
+ 'rest',
+ )
query = {}
if api_name:
query['api_name'] = api_name
if api_version:
query['api_version'] = api_version
eps = [_strip_id(e) for e in endpoint_collection.find(query).to_list(length=None)]
- return process_response(ResponseModel(status_code=200, response={'endpoints': eps}).dict(), 'rest')
+ return process_response(
+ ResponseModel(status_code=200, response={'endpoints': eps}).dict(), 'rest'
+ )
except Exception as e:
logger.error(f'{request_id} | export_endpoints error: {e}')
- return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500, error_code='GTW999', error_message='An unexpected error occurred'
+ ).dict(),
+ 'rest',
+ )
-def _upsert_api(doc: Dict[str, Any]) -> None:
+
+def _upsert_api(doc: dict[str, Any]) -> None:
api_name = doc.get('api_name')
api_version = doc.get('api_version')
if not api_name or not api_version:
@@ -258,11 +459,14 @@ def _upsert_api(doc: Dict[str, Any]) -> None:
to_set.setdefault('api_id', str(uuid.uuid4()))
to_set.setdefault('api_path', f'/{api_name}/{api_version}')
if existing:
- api_collection.update_one({'api_name': api_name, 'api_version': api_version}, {'$set': to_set})
+ api_collection.update_one(
+ {'api_name': api_name, 'api_version': api_version}, {'$set': to_set}
+ )
else:
api_collection.insert_one(to_set)
-def _upsert_endpoint(doc: Dict[str, Any]) -> None:
+
+def _upsert_endpoint(doc: dict[str, Any]) -> None:
api_name = doc.get('api_name')
api_version = doc.get('api_version')
method = doc.get('endpoint_method')
@@ -275,23 +479,29 @@ def _upsert_endpoint(doc: Dict[str, Any]) -> None:
if api_doc:
to_set['api_id'] = api_doc.get('api_id')
to_set.setdefault('endpoint_id', str(uuid.uuid4()))
- existing = endpoint_collection.find_one({
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': method,
- 'endpoint_uri': uri,
- })
- if existing:
- endpoint_collection.update_one({
+ existing = endpoint_collection.find_one(
+ {
'api_name': api_name,
'api_version': api_version,
'endpoint_method': method,
'endpoint_uri': uri,
- }, {'$set': to_set})
+ }
+ )
+ if existing:
+ endpoint_collection.update_one(
+ {
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': method,
+ 'endpoint_uri': uri,
+ },
+ {'$set': to_set},
+ )
else:
endpoint_collection.insert_one(to_set)
-def _upsert_role(doc: Dict[str, Any]) -> None:
+
+def _upsert_role(doc: dict[str, Any]) -> None:
name = doc.get('role_name')
if not name:
return
@@ -302,7 +512,8 @@ def _upsert_role(doc: Dict[str, Any]) -> None:
else:
role_collection.insert_one(to_set)
-def _upsert_group(doc: Dict[str, Any]) -> None:
+
+def _upsert_group(doc: dict[str, Any]) -> None:
name = doc.get('group_name')
if not name:
return
@@ -313,7 +524,8 @@ def _upsert_group(doc: Dict[str, Any]) -> None:
else:
group_collection.insert_one(to_set)
-def _upsert_routing(doc: Dict[str, Any]) -> None:
+
+def _upsert_routing(doc: dict[str, Any]) -> None:
key = doc.get('client_key')
if not key:
return
@@ -324,6 +536,7 @@ def _upsert_routing(doc: Dict[str, Any]) -> None:
else:
routing_collection.insert_one(to_set)
+
"""
Endpoint
@@ -333,11 +546,13 @@ Response:
{}
"""
-@config_router.post('/config/import',
- description='Import platform configuration (any subset of apis, endpoints, roles, groups, routings)',
- response_model=ResponseModel)
-async def import_all(request: Request, body: Dict[str, Any]):
+@config_router.post(
+ '/config/import',
+ description='Import platform configuration (any subset of apis, endpoints, roles, groups, routings)',
+ response_model=ResponseModel,
+)
+async def import_all(request: Request, body: dict[str, Any]):
request_id = str(uuid.uuid4())
start = time.time() * 1000
try:
@@ -345,27 +560,52 @@ async def import_all(request: Request, body: Dict[str, Any]):
username = payload.get('sub')
if not await platform_role_required_bool(username, 'manage_gateway'):
- return process_response(ResponseModel(status_code=403, error_code='CFG006', error_message='Insufficient permissions').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403, error_code='CFG006', error_message='Insufficient permissions'
+ ).dict(),
+ 'rest',
+ )
counts = {'apis': 0, 'endpoints': 0, 'roles': 0, 'groups': 0, 'routings': 0}
for api in body.get('apis', []) or []:
- _upsert_api(api); counts['apis'] += 1
+ _upsert_api(api)
+ counts['apis'] += 1
for ep in body.get('endpoints', []) or []:
- _upsert_endpoint(ep); counts['endpoints'] += 1
+ _upsert_endpoint(ep)
+ counts['endpoints'] += 1
for r in body.get('roles', []) or []:
- _upsert_role(r); counts['roles'] += 1
+ _upsert_role(r)
+ counts['roles'] += 1
for g in body.get('groups', []) or []:
- _upsert_group(g); counts['groups'] += 1
+ _upsert_group(g)
+ counts['groups'] += 1
for rt in body.get('routings', []) or []:
- _upsert_routing(rt); counts['routings'] += 1
+ _upsert_routing(rt)
+ counts['routings'] += 1
try:
doorman_cache.clear_all_caches()
except Exception:
pass
- audit(request, actor=username, action='config.import', target='bulk', status='success', details={'imported': counts}, request_id=request_id)
- return process_response(ResponseModel(status_code=200, response={'imported': counts}).dict(), 'rest')
+ audit(
+ request,
+ actor=username,
+ action='config.import',
+ target='bulk',
+ status='success',
+ details={'imported': counts},
+ request_id=request_id,
+ )
+ return process_response(
+ ResponseModel(status_code=200, response={'imported': counts}).dict(), 'rest'
+ )
except Exception as e:
logger.error(f'{request_id} | import_all error: {e}')
- return process_response(ResponseModel(status_code=500, error_code='GTW999', error_message='An unexpected error occurred').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500, error_code='GTW999', error_message='An unexpected error occurred'
+ ).dict(),
+ 'rest',
+ )
finally:
- logger.info(f'{request_id} | import_all took {time.time()*1000 - start:.2f}ms')
+ logger.info(f'{request_id} | import_all took {time.time() * 1000 - start:.2f}ms')
diff --git a/backend-services/routes/credit_routes.py b/backend-services/routes/credit_routes.py
index 854612a..ae341b7 100644
--- a/backend-services/routes/credit_routes.py
+++ b/backend-services/routes/credit_routes.py
@@ -4,20 +4,20 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import List
-from fastapi import APIRouter, Depends, Request
-import uuid
-import time
import logging
+import time
+import uuid
+from fastapi import APIRouter, Request
+
+from models.credit_model import CreditModel
from models.response_model import ResponseModel
from models.user_credits_model import UserCreditModel
-from models.credit_model import CreditModel
from services.credit_service import CreditService
-from utils.auth_util import auth_required
-from utils.response_util import respond_rest, process_response
-from utils.role_util import platform_role_required_bool
from utils.audit_util import audit
+from utils.auth_util import auth_required
+from utils.response_util import process_response, respond_rest
+from utils.role_util import platform_role_required_bool
credit_router = APIRouter()
@@ -32,41 +32,46 @@ Response:
{}
"""
-@credit_router.get('/defs',
+
+@credit_router.get(
+ '/defs',
description='List credit definitions',
response_model=ResponseModel,
- responses={
- 200: {'description': 'Successful Response'}
- }
+ responses={200: {'description': 'Successful Response'}},
)
-
async def list_credit_definitions(request: Request, page: int = 1, page_size: int = 50):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_credits'):
- return respond_rest(ResponseModel(
- status_code=403,
- error_code='CRD002',
- error_message='Unable to retrieve credits'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403, error_code='CRD002', error_message='Unable to retrieve credits'
+ )
+ )
return respond_rest(await CreditService.list_credit_defs(page, page_size, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -76,38 +81,43 @@ Response:
{}
"""
-@credit_router.get('/defs/{api_credit_group}',
- description='Get a credit definition',
- response_model=ResponseModel,
-)
+@credit_router.get(
+ '/defs/{api_credit_group}', description='Get a credit definition', response_model=ResponseModel
+)
async def get_credit_definition(api_credit_group: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_credits'):
- return respond_rest(ResponseModel(
- status_code=403,
- error_code='CRD002',
- error_message='Unable to retrieve credits'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403, error_code='CRD002', error_message='Unable to retrieve credits'
+ )
+ )
return respond_rest(await CreditService.get_credit_def(api_credit_group, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Create a credit definition
@@ -117,7 +127,9 @@ Response:
{}
"""
-@credit_router.post('',
+
+@credit_router.post(
+ '',
description='Create a credit definition',
response_model=ResponseModel,
responses={
@@ -125,22 +137,21 @@ Response:
'description': 'Successful Response',
'content': {
'application/json': {
- 'example': {
- 'message': 'Credit definition created successfully'
- }
+ 'example': {'message': 'Credit definition created successfully'}
}
- }
+ },
}
- }
+ },
)
-
async def create_credit(credit_data: CreditModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_credits'):
return respond_rest(
@@ -148,24 +159,35 @@ async def create_credit(credit_data: CreditModel, request: Request):
status_code=403,
error_code='CRD001',
error_message='You do not have permission to manage credits',
- ))
+ )
+ )
result = await CreditService.create_credit(credit_data, request_id)
- audit(request, actor=username, action='credit_def.create', target=credit_data.api_credit_group, status=result.get('status_code'), details=None, request_id=request_id)
+ audit(
+ request,
+ actor=username,
+ action='credit_def.create',
+ target=credit_data.api_credit_group,
+ status=result.get('status_code'),
+ details=None,
+ request_id=request_id,
+ )
return respond_rest(result)
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Update a credit definition
@@ -175,7 +197,9 @@ Response:
{}
"""
-@credit_router.put('/{api_credit_group}',
+
+@credit_router.put(
+ '/{api_credit_group}',
description='Update a credit definition',
response_model=ResponseModel,
responses={
@@ -183,22 +207,21 @@ Response:
'description': 'Successful Response',
'content': {
'application/json': {
- 'example': {
- 'message': 'Credit definition updated successfully'
- }
+ 'example': {'message': 'Credit definition updated successfully'}
}
- }
+ },
}
- }
+ },
)
-
-async def update_credit(api_credit_group:str, credit_data: CreditModel, request: Request):
+async def update_credit(api_credit_group: str, credit_data: CreditModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_credits'):
return respond_rest(
@@ -206,24 +229,35 @@ async def update_credit(api_credit_group:str, credit_data: CreditModel, request:
status_code=403,
error_code='CRD001',
error_message='You do not have permission to manage credits',
- ))
+ )
+ )
result = await CreditService.update_credit(api_credit_group, credit_data, request_id)
- audit(request, actor=username, action='credit_def.update', target=api_credit_group, status=result.get('status_code'), details=None, request_id=request_id)
+ audit(
+ request,
+ actor=username,
+ action='credit_def.update',
+ target=api_credit_group,
+ status=result.get('status_code'),
+ details=None,
+ request_id=request_id,
+ )
return respond_rest(result)
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Delete a credit definition
@@ -233,7 +267,9 @@ Response:
{}
"""
-@credit_router.delete('/{api_credit_group}',
+
+@credit_router.delete(
+ '/{api_credit_group}',
description='Delete a credit definition',
response_model=ResponseModel,
responses={
@@ -241,22 +277,21 @@ Response:
'description': 'Successful Response',
'content': {
'application/json': {
- 'example': {
- 'message': 'Credit definition deleted successfully'
- }
+ 'example': {'message': 'Credit definition deleted successfully'}
}
- }
+ },
}
- }
+ },
)
-
-async def delete_credit(api_credit_group:str, request: Request):
+async def delete_credit(api_credit_group: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_credits'):
return respond_rest(
@@ -264,24 +299,35 @@ async def delete_credit(api_credit_group:str, request: Request):
status_code=403,
error_code='CRD001',
error_message='You do not have permission to manage credits',
- ))
+ )
+ )
result = await CreditService.delete_credit(api_credit_group, request_id)
- audit(request, actor=username, action='credit_def.delete', target=api_credit_group, status=result.get('status_code'), details=None, request_id=request_id)
+ audit(
+ request,
+ actor=username,
+ action='credit_def.delete',
+ target=api_credit_group,
+ status=result.get('status_code'),
+ details=None,
+ request_id=request_id,
+ )
return respond_rest(result)
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Add credits for a user
@@ -291,30 +337,27 @@ Response:
{}
"""
-@credit_router.post('/{username}',
+
+@credit_router.post(
+ '/{username}',
description='Add credits for a user',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'Credits saved successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'Credits saved successfully'}}},
}
- }
+ },
)
-
async def add_user_credits(username: str, credit_data: UserCreditModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_credits'):
return respond_rest(
@@ -322,24 +365,35 @@ async def add_user_credits(username: str, credit_data: UserCreditModel, request:
status_code=403,
error_code='CRD001',
error_message='You do not have permission to manage credits',
- ))
+ )
+ )
result = await CreditService.add_credits(username, credit_data, request_id)
- audit(request, actor=username, action='user_credits.save', target=username, status=result.get('status_code'), details={'groups': list((credit_data.users_credits or {}).keys())}, request_id=request_id)
+ audit(
+ request,
+ actor=username,
+ action='user_credits.save',
+ target=username,
+ status=result.get('status_code'),
+ details={'groups': list((credit_data.users_credits or {}).keys())},
+ request_id=request_id,
+ )
return respond_rest(result)
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -349,18 +403,19 @@ Response:
{}
"""
-@credit_router.get('/all',
- description='Get all user credits',
- response_model=List[UserCreditModel]
-)
-async def get_all_users_credits(request: Request, page: int = 1, page_size: int = 10, search: str = ''):
+@credit_router.get('/all', description='Get all user credits', response_model=list[UserCreditModel])
+async def get_all_users_credits(
+ request: Request, page: int = 1, page_size: int = 10, search: str = ''
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_credits'):
return respond_rest(
@@ -368,22 +423,27 @@ async def get_all_users_credits(request: Request, page: int = 1, page_size: int
status_code=403,
error_code='CRD002',
error_message='Unable to retrieve credits for all users',
- ))
- return respond_rest(await CreditService.get_all_credits(page, page_size, request_id, search=search))
+ )
+ )
+ return respond_rest(
+ await CreditService.get_all_credits(page, page_size, request_id, search=search)
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -393,38 +453,41 @@ Response:
{}
"""
-@credit_router.get('/{username}',
- description='Get credits for a user',
- response_model=UserCreditModel
-)
+@credit_router.get(
+ '/{username}', description='Get credits for a user', response_model=UserCreditModel
+)
async def get_credits(username: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
- if not payload.get('sub') == username and not await platform_role_required_bool(payload.get('sub'), 'manage_credits'):
+ if not payload.get('sub') == username and not await platform_role_required_bool(
+ payload.get('sub'), 'manage_credits'
+ ):
return respond_rest(
ResponseModel(
status_code=403,
error_code='CRD003',
error_message='Unable to retrieve credits for user',
- ))
+ )
+ )
return respond_rest(await CreditService.get_user_credits(username, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Rotate API key for a user
@@ -434,30 +497,25 @@ Response:
{}
"""
-@credit_router.post('/rotate-key',
+
+@credit_router.post(
+ '/rotate-key',
description='Rotate API key for the authenticated user',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'api_key': '******************'
- }
- }
- }
+ 'content': {'application/json': {'example': {'api_key': '******************'}}},
}
- }
+ },
)
-
async def rotate_key(request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
-
+
# Get group from body
try:
body = await request.json()
@@ -466,29 +524,33 @@ async def rotate_key(request: Request):
group = None
if not group:
- return respond_rest(ResponseModel(
- status_code=400,
- error_code='CRD020',
- error_message='api_credit_group is required'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=400,
+ error_code='CRD020',
+ error_message='api_credit_group is required',
+ )
+ )
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
-
+
# No special role required, just authentication
-
+
return respond_rest(await CreditService.rotate_api_key(username, group, request_id))
-
+
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/dashboard_routes.py b/backend-services/routes/dashboard_routes.py
index 2ec96d8..82b2228 100644
--- a/backend-services/routes/dashboard_routes.py
+++ b/backend-services/routes/dashboard_routes.py
@@ -4,18 +4,18 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from fastapi import APIRouter, Request
-from typing import Dict, List
-import uuid
-import time
import logging
-from datetime import datetime, timedelta
+import time
+import uuid
+from datetime import datetime
+
+from fastapi import APIRouter, Request
from models.response_model import ResponseModel
from utils.auth_util import auth_required
-from utils.response_util import respond_rest
-from utils.database import user_collection, api_collection, subscriptions_collection
+from utils.database import api_collection, subscriptions_collection, user_collection
from utils.metrics_util import metrics_store
+from utils.response_util import respond_rest
dashboard_router = APIRouter()
logger = logging.getLogger('doorman.gateway')
@@ -29,11 +29,8 @@ Response:
{}
"""
-@dashboard_router.get('',
- description='Get dashboard data',
- response_model=ResponseModel
-)
+@dashboard_router.get('', description='Get dashboard data', response_model=ResponseModel)
async def get_dashboard_data(request: Request):
"""Get dashboard statistics and data"""
request_id = str(uuid.uuid4())
@@ -41,14 +38,16 @@ async def get_dashboard_data(request: Request):
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
total_users = user_collection.count_documents({'active': True})
total_apis = api_collection.count_documents({})
snap = metrics_store.snapshot('30d')
- monthly_usage: Dict[str, int] = {}
+ monthly_usage: dict[str, int] = {}
for pt in snap.get('series', []):
try:
ts = datetime.fromtimestamp(pt['timestamp'])
@@ -61,15 +60,12 @@ async def get_dashboard_data(request: Request):
for username, reqs in snap.get('top_users', [])[:5]:
subs = subscriptions_collection.find_one({'username': username}) or {}
subscribers = len(subs.get('apis', [])) if isinstance(subs.get('apis'), list) else 0
- active_users_list.append({
- 'username': username,
- 'requests': f'{int(reqs):,}',
- 'subscribers': subscribers
- })
+ active_users_list.append(
+ {'username': username, 'requests': f'{int(reqs):,}', 'subscribers': subscribers}
+ )
popular_apis = []
for api_key, reqs in snap.get('top_apis', [])[:10]:
-
try:
name = api_key
@@ -81,11 +77,9 @@ async def get_dashboard_data(request: Request):
count += 1
except Exception:
count = 0
- popular_apis.append({
- 'name': name,
- 'requests': f'{int(reqs):,}',
- 'subscribers': count
- })
+ popular_apis.append(
+ {'name': name, 'requests': f'{int(reqs):,}', 'subscribers': count}
+ )
except Exception:
continue
@@ -95,23 +89,27 @@ async def get_dashboard_data(request: Request):
'newApis': total_apis,
'monthlyUsage': monthly_usage,
'activeUsersList': active_users_list,
- 'popularApis': popular_apis
+ 'popularApis': popular_apis,
}
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response=dashboard_data
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ response=dashboard_data,
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/demo_routes.py b/backend-services/routes/demo_routes.py
index 0ee59f8..1ca80f5 100644
--- a/backend-services/routes/demo_routes.py
+++ b/backend-services/routes/demo_routes.py
@@ -3,17 +3,17 @@ Protected demo seeding routes for populating the running server with dummy data.
Only available to users with 'manage_gateway' OR 'manage_credits'.
"""
-from fastapi import APIRouter, Request
-from typing import Optional
-import uuid
-import time
import logging
+import time
+import uuid
+
+from fastapi import APIRouter, Request
from models.response_model import ResponseModel
-from utils.response_util import respond_rest
-from utils.role_util import platform_role_required_bool, is_admin_user
from utils.auth_util import auth_required
from utils.demo_seed_util import run_seed
+from utils.response_util import respond_rest
+from utils.role_util import is_admin_user
demo_router = APIRouter()
logger = logging.getLogger('doorman.gateway')
@@ -27,38 +27,55 @@ Response:
{}
"""
-@demo_router.post('/seed',
- description='Seed the running server with demo data',
- response_model=ResponseModel
-)
-async def demo_seed(request: Request,
- users: int = 40,
- apis: int = 15,
- endpoints: int = 6,
- groups: int = 8,
- protos: int = 6,
- logs: int = 1500,
- seed: Optional[int] = None):
+@demo_router.post(
+ '/seed', description='Seed the running server with demo data', response_model=ResponseModel
+)
+async def demo_seed(
+ request: Request,
+ users: int = 40,
+ apis: int = 15,
+ endpoints: int = 6,
+ groups: int = 8,
+ protos: int = 6,
+ logs: int = 1500,
+ seed: int | None = None,
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await is_admin_user(username):
- return respond_rest(ResponseModel(
- status_code=403,
- error_code='DEMO001',
- error_message='Permission denied to run seeder'
- ))
- res = run_seed(users=users, apis=apis, endpoints=endpoints, groups=groups, protos=protos, logs=logs, seed=seed)
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ error_code='DEMO001',
+ error_message='Permission denied to run seeder',
+ )
+ )
+ res = run_seed(
+ users=users,
+ apis=apis,
+ endpoints=endpoints,
+ groups=groups,
+ protos=protos,
+ logs=logs,
+ seed=seed,
+ )
return respond_rest(ResponseModel(status_code=200, response=res, message='Seed completed'))
except Exception as e:
logger.error(f'{request_id} | Demo seed error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(status_code=500, error_code='DEMO999', error_message='Failed to seed demo data'))
+ return respond_rest(
+ ResponseModel(
+ status_code=500, error_code='DEMO999', error_message='Failed to seed demo data'
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/endpoint_routes.py b/backend-services/routes/endpoint_routes.py
index 4f22f6c..7cfe436 100644
--- a/backend-services/routes/endpoint_routes.py
+++ b/backend-services/routes/endpoint_routes.py
@@ -4,12 +4,13 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import List
-from fastapi import APIRouter, Depends, Request, HTTPException
-import uuid
-import time
import logging
+import time
+import uuid
+from fastapi import APIRouter, HTTPException, Request
+
+from models.create_endpoint_model import CreateEndpointModel
from models.create_endpoint_validation_model import CreateEndpointValidationModel
from models.endpoint_model_response import EndpointModelResponse
from models.endpoint_validation_model_response import EndpointValidationModelResponse
@@ -18,9 +19,8 @@ from models.update_endpoint_model import UpdateEndpointModel
from models.update_endpoint_validation_model import UpdateEndpointValidationModel
from services.endpoint_service import EndpointService
from utils.auth_util import auth_required
-from models.create_endpoint_model import CreateEndpointModel
-from utils.response_util import respond_rest, process_response
-from utils.constants import Headers, Roles, ErrorCodes, Messages
+from utils.constants import ErrorCodes, Headers, Messages, Roles
+from utils.response_util import process_response, respond_rest
from utils.role_util import platform_role_required_bool
endpoint_router = APIRouter()
@@ -36,57 +36,58 @@ Response:
{}
"""
-@endpoint_router.post('',
+
+@endpoint_router.post(
+ '',
description='Add endpoint',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Endpoint created successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Endpoint created successfully'}}
+ },
}
- }
+ },
)
-
async def create_endpoint(endpoint_data: CreateEndpointModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_ENDPOINTS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='END010',
- error_message='You do not have permission to create endpoints'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='END010',
+ error_message='You do not have permission to create endpoints',
+ )
+ )
return respond_rest(await EndpointService.create_endpoint(endpoint_data, request_id))
except HTTPException as he:
raise he
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Update endpoint
@@ -96,57 +97,74 @@ Response:
{}
"""
-@endpoint_router.put('/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}',
+
+@endpoint_router.put(
+ '/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}',
description='Update endpoint',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Endpoint updated successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Endpoint updated successfully'}}
+ },
}
- }
+ },
)
-
-async def update_endpoint(endpoint_method: str, api_name: str, api_version: str, endpoint_uri: str, endpoint_data: UpdateEndpointModel, request: Request):
+async def update_endpoint(
+ endpoint_method: str,
+ api_name: str,
+ api_version: str,
+ endpoint_uri: str,
+ endpoint_data: UpdateEndpointModel,
+ request: Request,
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_ENDPOINTS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='END011',
- error_message='You do not have permission to update endpoints'
- ))
- return respond_rest(await EndpointService.update_endpoint(endpoint_method, api_name, api_version, '/' + endpoint_uri, endpoint_data, request_id))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='END011',
+ error_message='You do not have permission to update endpoints',
+ )
+ )
+ return respond_rest(
+ await EndpointService.update_endpoint(
+ endpoint_method,
+ api_name,
+ api_version,
+ '/' + endpoint_uri,
+ endpoint_data,
+ request_id,
+ )
+ )
except HTTPException as he:
raise he
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Delete endpoint
@@ -156,57 +174,64 @@ Response:
{}
"""
-@endpoint_router.delete('/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}',
+
+@endpoint_router.delete(
+ '/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}',
description='Delete endpoint',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Endpoint deleted successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Endpoint deleted successfully'}}
+ },
}
- }
+ },
)
-
-async def delete_endpoint(endpoint_method: str, api_name: str, api_version: str, endpoint_uri: str, request: Request):
+async def delete_endpoint(
+ endpoint_method: str, api_name: str, api_version: str, endpoint_uri: str, request: Request
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_ENDPOINTS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='END012',
- error_message='You do not have permission to delete endpoints'
- ))
- return respond_rest(await EndpointService.delete_endpoint(endpoint_method, api_name, api_version, '/' + endpoint_uri, request_id))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='END012',
+ error_message='You do not have permission to delete endpoints',
+ )
+ )
+ return respond_rest(
+ await EndpointService.delete_endpoint(
+ endpoint_method, api_name, api_version, '/' + endpoint_uri, request_id
+ )
+ )
except HTTPException as he:
raise he
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -216,36 +241,47 @@ Response:
{}
"""
-@endpoint_router.get('/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}',
- description='Get endpoint by API name, API version and endpoint uri',
- response_model=EndpointModelResponse
-)
-async def get_endpoint(endpoint_method: str, api_name: str, api_version: str, endpoint_uri: str, request: Request):
+@endpoint_router.get(
+ '/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}',
+ description='Get endpoint by API name, API version and endpoint uri',
+ response_model=EndpointModelResponse,
+)
+async def get_endpoint(
+ endpoint_method: str, api_name: str, api_version: str, endpoint_uri: str, request: Request
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
- return respond_rest(await EndpointService.get_endpoint(endpoint_method, api_name, api_version, '/' + endpoint_uri, request_id))
+ return respond_rest(
+ await EndpointService.get_endpoint(
+ endpoint_method, api_name, api_version, '/' + endpoint_uri, request_id
+ )
+ )
except HTTPException as he:
raise he
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -255,36 +291,43 @@ Response:
{}
"""
-@endpoint_router.get('/{api_name}/{api_version}',
- description='Get all endpoints for an API',
- response_model=List[EndpointModelResponse]
-)
+@endpoint_router.get(
+ '/{api_name}/{api_version}',
+ description='Get all endpoints for an API',
+ response_model=list[EndpointModelResponse],
+)
async def get_endpoints_by_name_version(api_name: str, api_version: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
- return respond_rest(await EndpointService.get_endpoints_by_name_version(api_name, api_version, request_id))
+ return respond_rest(
+ await EndpointService.get_endpoints_by_name_version(api_name, api_version, request_id)
+ )
except HTTPException as he:
raise he
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Create a new endpoint validation
@@ -294,7 +337,9 @@ Response:
{}
"""
-@endpoint_router.post('/endpoint/validation',
+
+@endpoint_router.post(
+ '/endpoint/validation',
description='Create a new endpoint validation',
response_model=ResponseModel,
responses={
@@ -302,49 +347,54 @@ Response:
'description': 'Successful Response',
'content': {
'application/json': {
- 'example': {
- 'message': 'Endpoint validation created successfully'
- }
+ 'example': {'message': 'Endpoint validation created successfully'}
}
- }
+ },
}
- }
+ },
)
-
-async def create_endpoint_validation(endpoint_validation_data: CreateEndpointValidationModel, request: Request):
+async def create_endpoint_validation(
+ endpoint_validation_data: CreateEndpointValidationModel, request: Request
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_ENDPOINTS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='END013',
- error_message='You do not have permission to create endpoint validations'
- ))
- return respond_rest(await EndpointService.create_endpoint_validation(endpoint_validation_data, request_id))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='END013',
+ error_message='You do not have permission to create endpoint validations',
+ )
+ )
+ return respond_rest(
+ await EndpointService.create_endpoint_validation(endpoint_validation_data, request_id)
+ )
except HTTPException as he:
raise he
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -354,45 +404,56 @@ Response:
{}
"""
-@endpoint_router.put('/endpoint/validation/{endpoint_id}',
+
+@endpoint_router.put(
+ '/endpoint/validation/{endpoint_id}',
description='Update an endpoint validation by endpoint ID',
- response_model=ResponseModel
+ response_model=ResponseModel,
)
-
-async def update_endpoint_validation(endpoint_id: str, endpoint_validation_data: UpdateEndpointValidationModel, request: Request):
+async def update_endpoint_validation(
+ endpoint_id: str, endpoint_validation_data: UpdateEndpointValidationModel, request: Request
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_ENDPOINTS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='END014',
- error_message='You do not have permission to update endpoint validations'
- ))
- return respond_rest(await EndpointService.update_endpoint_validation(endpoint_id, endpoint_validation_data, request_id))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='END014',
+ error_message='You do not have permission to update endpoint validations',
+ )
+ )
+ return respond_rest(
+ await EndpointService.update_endpoint_validation(
+ endpoint_id, endpoint_validation_data, request_id
+ )
+ )
except HTTPException as he:
raise he
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -402,45 +463,52 @@ Response:
{}
"""
-@endpoint_router.delete('/endpoint/validation/{endpoint_id}',
- description='Delete an endpoint validation by endpoint ID',
- response_model=ResponseModel
-)
+@endpoint_router.delete(
+ '/endpoint/validation/{endpoint_id}',
+ description='Delete an endpoint validation by endpoint ID',
+ response_model=ResponseModel,
+)
async def delete_endpoint_validation(endpoint_id: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_ENDPOINTS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='END015',
- error_message='You do not have permission to delete endpoint validations'
- ))
- return respond_rest(await EndpointService.delete_endpoint_validation(endpoint_id, request_id))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='END015',
+ error_message='You do not have permission to delete endpoint validations',
+ )
+ )
+ return respond_rest(
+ await EndpointService.delete_endpoint_validation(endpoint_id, request_id)
+ )
except HTTPException as he:
raise he
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -450,36 +518,41 @@ Response:
{}
"""
-@endpoint_router.get('/endpoint/validation/{endpoint_id}',
- description='Get an endpoint validation by endpoint ID',
- response_model=EndpointValidationModelResponse
-)
+@endpoint_router.get(
+ '/endpoint/validation/{endpoint_id}',
+ description='Get an endpoint validation by endpoint ID',
+ response_model=EndpointValidationModelResponse,
+)
async def get_endpoint_validation(endpoint_id: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
return respond_rest(await EndpointService.get_endpoint_validation(endpoint_id, request_id))
except HTTPException as he:
raise he
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -489,13 +562,18 @@ Response:
{}
"""
-@endpoint_router.post('/validation',
- description='Create a new endpoint validation (alias)',
- response_model=ResponseModel)
-async def create_endpoint_validation_alias(endpoint_validation_data: CreateEndpointValidationModel, request: Request):
+@endpoint_router.post(
+ '/validation',
+ description='Create a new endpoint validation (alias)',
+ response_model=ResponseModel,
+)
+async def create_endpoint_validation_alias(
+ endpoint_validation_data: CreateEndpointValidationModel, request: Request
+):
return await create_endpoint_validation(endpoint_validation_data, request)
+
"""
Endpoint
@@ -505,13 +583,18 @@ Response:
{}
"""
-@endpoint_router.put('/validation/{endpoint_id}',
- description='Update endpoint validation by endpoint ID (alias)',
- response_model=ResponseModel)
-async def update_endpoint_validation_alias(endpoint_id: str, endpoint_validation_data: UpdateEndpointValidationModel, request: Request):
+@endpoint_router.put(
+ '/validation/{endpoint_id}',
+ description='Update endpoint validation by endpoint ID (alias)',
+ response_model=ResponseModel,
+)
+async def update_endpoint_validation_alias(
+ endpoint_id: str, endpoint_validation_data: UpdateEndpointValidationModel, request: Request
+):
return await update_endpoint_validation(endpoint_id, endpoint_validation_data, request)
+
"""
Endpoint
@@ -521,13 +604,16 @@ Response:
{}
"""
-@endpoint_router.delete('/validation/{endpoint_id}',
- description='Delete endpoint validation by endpoint ID (alias)',
- response_model=ResponseModel)
+@endpoint_router.delete(
+ '/validation/{endpoint_id}',
+ description='Delete endpoint validation by endpoint ID (alias)',
+ response_model=ResponseModel,
+)
async def delete_endpoint_validation_alias(endpoint_id: str, request: Request):
return await delete_endpoint_validation(endpoint_id, request)
+
"""
Endpoint
@@ -537,9 +623,11 @@ Response:
{}
"""
-@endpoint_router.get('/validation/{endpoint_id}',
- description='Get endpoint validation by endpoint ID (alias)',
- response_model=EndpointValidationModelResponse)
+@endpoint_router.get(
+ '/validation/{endpoint_id}',
+ description='Get endpoint validation by endpoint ID (alias)',
+ response_model=EndpointValidationModelResponse,
+)
async def get_endpoint_validation_alias(endpoint_id: str, request: Request):
return await get_endpoint_validation(endpoint_id, request)
diff --git a/backend-services/routes/gateway_routes.py b/backend-services/routes/gateway_routes.py
index 5d987f3..4840b2b 100644
--- a/backend-services/routes/gateway_routes.py
+++ b/backend-services/routes/gateway_routes.py
@@ -4,30 +4,36 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from fastapi import APIRouter, HTTPException, Request, Depends
-import os
-import uuid
-import time
-import logging
import json
+import logging
import re
+import time
+import uuid
from datetime import datetime
+from fastapi import APIRouter, Depends, HTTPException, Request
+
from models.response_model import ResponseModel
+from services.gateway_service import GatewayService
from utils import api_util
-from utils.doorman_cache_util import doorman_cache
-from utils.limit_throttle_util import limit_and_throttle
-from utils.bandwidth_util import enforce_pre_request_limit
+from utils.audit_util import audit
from utils.auth_util import auth_required
+from utils.bandwidth_util import enforce_pre_request_limit
+from utils.doorman_cache_util import doorman_cache
from utils.group_util import group_required
+from utils.health_check_util import (
+ check_mongodb,
+ check_redis,
+ get_active_connections,
+ get_memory_usage,
+ get_uptime,
+)
+from utils.ip_policy_util import enforce_api_ip_policy
+from utils.limit_throttle_util import limit_and_throttle
from utils.response_util import process_response
from utils.role_util import platform_role_required_bool
from utils.subscription_util import subscription_required
-from utils.health_check_util import check_mongodb, check_redis, get_memory_usage, get_active_connections, get_uptime
-from services.gateway_service import GatewayService
from utils.validation_util import validation_util
-from utils.audit_util import audit
-from utils.ip_policy_util import enforce_api_ip_policy
gateway_router = APIRouter()
@@ -42,10 +48,13 @@ Response:
{}
"""
-@gateway_router.api_route('/status', methods=['GET'],
- description='Gateway status (requires manage_gateway)',
- response_model=ResponseModel)
+@gateway_router.api_route(
+ '/status',
+ methods=['GET'],
+ description='Gateway status (requires manage_gateway)',
+ response_model=ResponseModel,
+)
async def status(request: Request):
"""Restricted status endpoint.
@@ -57,53 +66,67 @@ async def status(request: Request):
payload = await auth_required(request)
username = payload.get('sub')
if not await platform_role_required_bool(username, 'manage_gateway'):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='GTW013',
- error_message='Forbidden'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='GTW013',
+ error_message='Forbidden',
+ ).dict(),
+ 'rest',
+ )
mongodb_status = await check_mongodb()
redis_status = await check_redis()
memory_usage = get_memory_usage()
active_connections = get_active_connections()
uptime = get_uptime()
- return process_response(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response={
- 'status': 'online',
- 'mongodb': mongodb_status,
- 'redis': redis_status,
- 'memory_usage': memory_usage,
- 'active_connections': active_connections,
- 'uptime': uptime
- }
- ).dict(), 'rest')
- except Exception as e:
- if hasattr(e, 'status_code') and getattr(e, 'status_code') == 401:
- return process_response(ResponseModel(
- status_code=401,
+ return process_response(
+ ResponseModel(
+ status_code=200,
response_headers={'request_id': request_id},
- error_code='GTW401',
- error_message='Unauthorized'
- ).dict(), 'rest')
+ response={
+ 'status': 'online',
+ 'mongodb': mongodb_status,
+ 'redis': redis_status,
+ 'memory_usage': memory_usage,
+ 'active_connections': active_connections,
+ 'uptime': uptime,
+ },
+ ).dict(),
+ 'rest',
+ )
+ except Exception as e:
+ if hasattr(e, 'status_code') and e.status_code == 401:
+ return process_response(
+ ResponseModel(
+ status_code=401,
+ response_headers={'request_id': request_id},
+ error_code='GTW401',
+ error_message='Unauthorized',
+ ).dict(),
+ 'rest',
+ )
logger.error(f'{request_id} | Status check failed: {str(e)}')
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW006',
- error_message='Internal server error'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW006',
+ error_message='Internal server error',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Status check time {end_time - start_time}ms')
+
@gateway_router.get('/health', description='Public health probe', include_in_schema=False)
async def health():
return {'status': 'online'}
+
"""
Clear all caches
@@ -122,26 +145,20 @@ Response:
{}
"""
-@gateway_router.api_route('/caches', methods=['DELETE', 'OPTIONS'],
+
+@gateway_router.api_route(
+ '/caches',
+ methods=['DELETE', 'OPTIONS'],
description='Clear all caches',
response_model=ResponseModel,
- dependencies=[
- Depends(auth_required)
- ],
+ dependencies=[Depends(auth_required)],
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'All caches cleared'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'All caches cleared'}}},
}
- }
+ },
)
-
async def clear_all_caches(request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
@@ -149,35 +166,53 @@ async def clear_all_caches(request: Request):
payload = await auth_required(request)
username = payload.get('sub')
if not await platform_role_required_bool(username, 'manage_gateway'):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='GTW008',
- error_message='You do not have permission to clear caches'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='GTW008',
+ error_message='You do not have permission to clear caches',
+ ).dict(),
+ 'rest',
+ )
doorman_cache.clear_all_caches()
try:
from utils.limit_throttle_util import reset_counters as _reset_rate
+
_reset_rate()
except Exception:
pass
- audit(request, actor=username, action='gateway.clear_caches', target='all', status='success', details=None)
- return process_response(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- message='All caches cleared'
- ).dict(), 'rest')
- except Exception as e:
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ audit(
+ request,
+ actor=username,
+ action='gateway.clear_caches',
+ target='all',
+ status='success',
+ details=None,
+ )
+ return process_response(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ message='All caches cleared',
+ ).dict(),
+ 'rest',
+ )
+ except Exception:
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Clear caches took {end_time - start_time:.2f}ms')
+
"""
Endpoint
@@ -196,15 +231,22 @@ Response:
{}
"""
-@gateway_router.api_route('/rest/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'],
+
+@gateway_router.api_route(
+ '/rest/{path:path}',
+ methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'],
description='REST gateway endpoint',
response_model=ResponseModel,
- include_in_schema=False)
+ include_in_schema=False,
+)
async def gateway(request: Request, path: str):
- request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4())
+ request_id = (
+ getattr(request.state, 'request_id', None)
+ or request.headers.get('X-Request-ID')
+ or str(uuid.uuid4())
+ )
start_time = time.time() * 1000
try:
-
parts = [p for p in (path or '').split('/') if p]
api_public = False
api_auth_required = True
@@ -216,25 +258,45 @@ async def gateway(request: Request, path: str):
try:
enforce_api_ip_policy(request, resolved_api)
except HTTPException as e:
- return process_response(ResponseModel(status_code=e.status_code, error_code=e.detail, error_message='IP restricted').dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=e.status_code,
+ error_code=e.detail,
+ error_message='IP restricted',
+ ).dict(),
+ 'rest',
+ )
endpoint_uri = '/' + '/'.join(parts[2:]) if len(parts) > 2 else '/'
try:
endpoints = await api_util.get_api_endpoints(resolved_api.get('api_id'))
import re as _re
+
regex_pattern = _re.compile(r'\{[^/]+\}')
- method_to_match = 'GET' if str(request.method).upper() == 'HEAD' else request.method
+ method_to_match = (
+ 'GET' if str(request.method).upper() == 'HEAD' else request.method
+ )
composite = method_to_match + endpoint_uri
- if not any(_re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in (endpoints or [])):
- return process_response(ResponseModel(
- status_code=404,
- response_headers={'request_id': request_id},
- error_code='GTW003',
- error_message='Endpoint does not exist for the requested API'
- ).dict(), 'rest')
+ if not any(
+ _re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite)
+ for ep in (endpoints or [])
+ ):
+ return process_response(
+ ResponseModel(
+ status_code=404,
+ response_headers={'request_id': request_id},
+ error_code='GTW003',
+ error_message='Endpoint does not exist for the requested API',
+ ).dict(),
+ 'rest',
+ )
except Exception:
pass
api_public = bool(resolved_api.get('api_public')) if resolved_api else False
- api_auth_required = bool(resolved_api.get('api_auth_required')) if resolved_api and resolved_api.get('api_auth_required') is not None else True
+ api_auth_required = (
+ bool(resolved_api.get('api_auth_required'))
+ if resolved_api and resolved_api.get('api_auth_required') is not None
+ else True
+ )
username = None
if not api_public:
if api_auth_required:
@@ -245,59 +307,104 @@ async def gateway(request: Request, path: str):
username = payload.get('sub')
await enforce_pre_request_limit(request, username)
else:
-
pass
- logger.info(f"{request_id} | Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')[:-3]}ms")
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")[:-3]}ms'
+ )
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
- return process_response(await GatewayService.rest_gateway(username, request, request_id, start_time, path), 'rest')
+ return process_response(
+ await GatewayService.rest_gateway(username, request, request_id, start_time, path),
+ 'rest',
+ )
except HTTPException as e:
- return process_response(ResponseModel(
- status_code=e.status_code,
- response_headers={
- 'request_id': request_id
- },
- error_code=e.detail,
- error_message=e.detail
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={'request_id': request_id},
+ error_code=e.detail,
+ error_message=e.detail,
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
-@gateway_router.get('/rest/{path:path}', description='REST gateway endpoint (GET)', response_model=ResponseModel, operation_id='rest_get')
+
+@gateway_router.get(
+ '/rest/{path:path}',
+ description='REST gateway endpoint (GET)',
+ response_model=ResponseModel,
+ operation_id='rest_get',
+)
async def rest_get(request: Request, path: str):
return await gateway(request, path)
-@gateway_router.post('/rest/{path:path}', description='REST gateway endpoint (POST)', response_model=ResponseModel, operation_id='rest_post')
+
+@gateway_router.post(
+ '/rest/{path:path}',
+ description='REST gateway endpoint (POST)',
+ response_model=ResponseModel,
+ operation_id='rest_post',
+)
async def rest_post(request: Request, path: str):
return await gateway(request, path)
-@gateway_router.put('/rest/{path:path}', description='REST gateway endpoint (PUT)', response_model=ResponseModel, operation_id='rest_put')
+
+@gateway_router.put(
+ '/rest/{path:path}',
+ description='REST gateway endpoint (PUT)',
+ response_model=ResponseModel,
+ operation_id='rest_put',
+)
async def rest_put(request: Request, path: str):
return await gateway(request, path)
-@gateway_router.patch('/rest/{path:path}', description='REST gateway endpoint (PATCH)', response_model=ResponseModel, operation_id='rest_patch')
+
+@gateway_router.patch(
+ '/rest/{path:path}',
+ description='REST gateway endpoint (PATCH)',
+ response_model=ResponseModel,
+ operation_id='rest_patch',
+)
async def rest_patch(request: Request, path: str):
return await gateway(request, path)
-@gateway_router.delete('/rest/{path:path}', description='REST gateway endpoint (DELETE)', response_model=ResponseModel, operation_id='rest_delete')
+
+@gateway_router.delete(
+ '/rest/{path:path}',
+ description='REST gateway endpoint (DELETE)',
+ response_model=ResponseModel,
+ operation_id='rest_delete',
+)
async def rest_delete(request: Request, path: str):
return await gateway(request, path)
-@gateway_router.head('/rest/{path:path}', description='REST gateway endpoint (HEAD)', response_model=ResponseModel, operation_id='rest_head')
+
+@gateway_router.head(
+ '/rest/{path:path}',
+ description='REST gateway endpoint (HEAD)',
+ response_model=ResponseModel,
+ operation_id='rest_head',
+)
async def rest_head(request: Request, path: str):
return await gateway(request, path)
+
"""
Endpoint
@@ -316,16 +423,22 @@ Response:
{}
"""
-@gateway_router.api_route('/rest/{path:path}', methods=['OPTIONS'],
- description='REST gateway CORS preflight', include_in_schema=False)
+@gateway_router.api_route(
+ '/rest/{path:path}',
+ methods=['OPTIONS'],
+ description='REST gateway CORS preflight',
+ include_in_schema=False,
+)
async def rest_preflight(request: Request, path: str):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
+ import os as _os
+
from utils import api_util as _api_util
from utils.doorman_cache_util import doorman_cache as _cache
- import os as _os
+
parts = [p for p in (path or '').split('/') if p]
name_ver = ''
if len(parts) >= 2 and parts[1].startswith('v') and parts[1][1:].isdigit():
@@ -334,15 +447,17 @@ async def rest_preflight(request: Request, path: str):
api = await _api_util.get_api(api_key, name_ver)
if not api:
from fastapi.responses import Response as StarletteResponse
+
return StarletteResponse(status_code=204, headers={'request_id': request_id})
# Optionally enforce 405 for unregistered endpoints when requested
try:
- if _os.getenv('STRICT_OPTIONS_405', 'false').lower() in ('1','true','yes','on'):
+ if _os.getenv('STRICT_OPTIONS_405', 'false').lower() in ('1', 'true', 'yes', 'on'):
endpoint_uri = '/' + '/'.join(parts[2:]) if len(parts) > 2 else '/'
try:
endpoints = await _api_util.get_api_endpoints(api.get('api_id'))
import re as _re
+
regex_pattern = _re.compile(r'\{[^/]+\}')
# For preflight, only care that the endpoint exists for any method
exists = any(
@@ -356,15 +471,22 @@ async def rest_preflight(request: Request, path: str):
)
if not exists:
from fastapi.responses import Response as StarletteResponse
- return StarletteResponse(status_code=405, headers={'request_id': request_id})
+
+ return StarletteResponse(
+ status_code=405, headers={'request_id': request_id}
+ )
except Exception:
pass
except Exception:
pass
origin = request.headers.get('origin') or request.headers.get('Origin')
- req_method = request.headers.get('access-control-request-method') or request.headers.get('Access-Control-Request-Method')
- req_headers = request.headers.get('access-control-request-headers') or request.headers.get('Access-Control-Request-Headers')
+ req_method = request.headers.get('access-control-request-method') or request.headers.get(
+ 'Access-Control-Request-Method'
+ )
+ req_headers = request.headers.get('access-control-request-headers') or request.headers.get(
+ 'Access-Control-Request-Headers'
+ )
ok, headers = GatewayService._compute_api_cors_headers(api, origin, req_method, req_headers)
# Deterministic: always decide ACAO here from API config, regardless of computation above.
# 1) Remove any existing ACAO/Vary from computed headers
@@ -385,14 +507,17 @@ async def rest_preflight(request: Request, path: str):
pass
headers = {**(headers or {}), 'request_id': request_id}
from fastapi.responses import Response as StarletteResponse
+
return StarletteResponse(status_code=204, headers=headers)
except Exception:
from fastapi.responses import Response as StarletteResponse
+
return StarletteResponse(status_code=204, headers={'request_id': request_id})
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -411,12 +536,19 @@ Response:
{}
"""
-@gateway_router.api_route('/soap/{path:path}', methods=['POST'],
- description='SOAP gateway endpoint',
- response_model=ResponseModel)
+@gateway_router.api_route(
+ '/soap/{path:path}',
+ methods=['POST'],
+ description='SOAP gateway endpoint',
+ response_model=ResponseModel,
+)
async def soap_gateway(request: Request, path: str):
- request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4())
+ request_id = (
+ getattr(request.state, 'request_id', None)
+ or request.headers.get('X-Request-ID')
+ or str(uuid.uuid4())
+ )
start_time = time.time() * 1000
try:
parts = [p for p in (path or '').split('/') if p]
@@ -426,12 +558,23 @@ async def soap_gateway(request: Request, path: str):
api_key = doorman_cache.get_cache('api_id_cache', f'/{parts[0]}/{parts[1]}')
api = await api_util.get_api(api_key, f'/{parts[0]}/{parts[1]}')
api_public = bool(api.get('api_public')) if api else False
- api_auth_required = bool(api.get('api_auth_required')) if api and api.get('api_auth_required') is not None else True
+ api_auth_required = (
+ bool(api.get('api_auth_required'))
+ if api and api.get('api_auth_required') is not None
+ else True
+ )
if api:
try:
enforce_api_ip_policy(request, api)
except HTTPException as e:
- return process_response(ResponseModel(status_code=e.status_code, error_code=e.detail, error_message='IP restricted').dict(), 'soap')
+ return process_response(
+ ResponseModel(
+ status_code=e.status_code,
+ error_code=e.detail,
+ error_message='IP restricted',
+ ).dict(),
+ 'soap',
+ )
username = None
if not api_public:
if api_auth_required:
@@ -443,33 +586,43 @@ async def soap_gateway(request: Request, path: str):
await enforce_pre_request_limit(request, username)
else:
pass
- logger.info(f"{request_id} | Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')[:-3]}ms")
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")[:-3]}ms'
+ )
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
- return process_response(await GatewayService.soap_gateway(username, request, request_id, start_time, path), 'soap')
+ return process_response(
+ await GatewayService.soap_gateway(username, request, request_id, start_time, path),
+ 'soap',
+ )
except HTTPException as e:
- return process_response(ResponseModel(
- status_code=e.status_code,
- response_headers={
- 'request_id': request_id
- },
- error_code=e.detail,
- error_message=e.detail
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={'request_id': request_id},
+ error_code=e.detail,
+ error_message=e.detail,
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'soap')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'soap',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -488,15 +641,20 @@ Response:
{}
"""
-@gateway_router.api_route('/soap/{path:path}', methods=['OPTIONS'],
- description='SOAP gateway CORS preflight', include_in_schema=False)
+@gateway_router.api_route(
+ '/soap/{path:path}',
+ methods=['OPTIONS'],
+ description='SOAP gateway CORS preflight',
+ include_in_schema=False,
+)
async def soap_preflight(request: Request, path: str):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
from utils import api_util as _api_util
from utils.doorman_cache_util import doorman_cache as _cache
+
parts = [p for p in (path or '').split('/') if p]
name_ver = ''
if len(parts) >= 2 and parts[1].startswith('v') and parts[1][1:].isdigit():
@@ -505,10 +663,15 @@ async def soap_preflight(request: Request, path: str):
api = await _api_util.get_api(api_key, name_ver)
if not api:
from fastapi.responses import Response as StarletteResponse
+
return StarletteResponse(status_code=204, headers={'request_id': request_id})
origin = request.headers.get('origin') or request.headers.get('Origin')
- req_method = request.headers.get('access-control-request-method') or request.headers.get('Access-Control-Request-Method')
- req_headers = request.headers.get('access-control-request-headers') or request.headers.get('Access-Control-Request-Headers')
+ req_method = request.headers.get('access-control-request-method') or request.headers.get(
+ 'Access-Control-Request-Method'
+ )
+ req_headers = request.headers.get('access-control-request-headers') or request.headers.get(
+ 'Access-Control-Request-Headers'
+ )
ok, headers = GatewayService._compute_api_cors_headers(api, origin, req_method, req_headers)
if not ok and headers:
try:
@@ -518,14 +681,17 @@ async def soap_preflight(request: Request, path: str):
pass
headers = {**(headers or {}), 'request_id': request_id}
from fastapi.responses import Response as StarletteResponse
+
return StarletteResponse(status_code=204, headers=headers)
except Exception:
from fastapi.responses import Response as StarletteResponse
+
return StarletteResponse(status_code=204, headers={'request_id': request_id})
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -544,27 +710,49 @@ Response:
{}
"""
-@gateway_router.api_route('/graphql/{path:path}', methods=['POST'],
- description='GraphQL gateway endpoint',
- response_model=ResponseModel)
+@gateway_router.api_route(
+ '/graphql/{path:path}',
+ methods=['POST'],
+ description='GraphQL gateway endpoint',
+ response_model=ResponseModel,
+)
async def graphql_gateway(request: Request, path: str):
- request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4())
+ request_id = (
+ getattr(request.state, 'request_id', None)
+ or request.headers.get('X-Request-ID')
+ or str(uuid.uuid4())
+ )
start_time = time.time() * 1000
try:
if not request.headers.get('X-API-Version'):
raise HTTPException(status_code=400, detail='X-API-Version header is required')
- api_name = re.sub(r'^.*/', '',request.url.path)
- api_key = doorman_cache.get_cache('api_id_cache', api_name + '/' + request.headers.get('X-API-Version', 'v0'))
- api = await api_util.get_api(api_key, api_name + '/' + request.headers.get('X-API-Version', 'v0'))
+ api_name = re.sub(r'^.*/', '', request.url.path)
+ api_key = doorman_cache.get_cache(
+ 'api_id_cache', api_name + '/' + request.headers.get('X-API-Version', 'v0')
+ )
+ api = await api_util.get_api(
+ api_key, api_name + '/' + request.headers.get('X-API-Version', 'v0')
+ )
if api:
try:
enforce_api_ip_policy(request, api)
except HTTPException as e:
- return process_response(ResponseModel(status_code=e.status_code, error_code=e.detail, error_message='IP restricted').dict(), 'graphql')
+ return process_response(
+ ResponseModel(
+ status_code=e.status_code,
+ error_code=e.detail,
+ error_message='IP restricted',
+ ).dict(),
+ 'graphql',
+ )
api_public = bool(api.get('api_public')) if api else False
- api_auth_required = bool(api.get('api_auth_required')) if api and api.get('api_auth_required') is not None else True
+ api_auth_required = (
+ bool(api.get('api_auth_required'))
+ if api and api.get('api_auth_required') is not None
+ else True
+ )
username = None
if not api_public:
if api_auth_required:
@@ -576,8 +764,12 @@ async def graphql_gateway(request: Request, path: str):
await enforce_pre_request_limit(request, username)
else:
pass
- logger.info(f"{request_id} | Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')[:-3]}ms")
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")[:-3]}ms'
+ )
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if api and api.get('validation_enabled'):
body = await request.json()
@@ -586,36 +778,45 @@ async def graphql_gateway(request: Request, path: str):
try:
await validation_util.validate_graphql_request(api.get('api_id'), query, variables)
except Exception as e:
- return process_response(ResponseModel(
- status_code=400,
- response_headers={'request_id': request_id},
- error_code='GTW011',
- error_message=str(e)
- ).dict(), 'graphql')
- return process_response(await GatewayService.graphql_gateway(username, request, request_id, start_time, path), 'graphql')
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='GTW011',
+ error_message=str(e),
+ ).dict(),
+ 'graphql',
+ )
+ return process_response(
+ await GatewayService.graphql_gateway(username, request, request_id, start_time, path),
+ 'graphql',
+ )
except HTTPException as e:
- return process_response(ResponseModel(
- status_code=e.status_code,
- response_headers={
- 'request_id': request_id
- },
- error_code=e.detail,
- error_message=e.detail
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={'request_id': request_id},
+ error_code=e.detail,
+ error_message=e.detail,
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'graphql')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'graphql',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -634,15 +835,24 @@ Response:
{}
"""
-@gateway_router.api_route('/graphql/{path:path}', methods=['OPTIONS'],
- description='GraphQL gateway CORS preflight', include_in_schema=False)
+@gateway_router.api_route(
+ '/graphql/{path:path}',
+ methods=['OPTIONS'],
+ description='GraphQL gateway CORS preflight',
+ include_in_schema=False,
+)
async def graphql_preflight(request: Request, path: str):
- request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4())
+ request_id = (
+ getattr(request.state, 'request_id', None)
+ or request.headers.get('X-Request-ID')
+ or str(uuid.uuid4())
+ )
start_time = time.time() * 1000
try:
from utils import api_util as _api_util
from utils.doorman_cache_util import doorman_cache as _cache
+
api_name = path.replace('graphql/', '')
api_version = request.headers.get('X-API-Version', 'v1')
api_path = f'/{api_name}/{api_version}'
@@ -650,10 +860,15 @@ async def graphql_preflight(request: Request, path: str):
api = await _api_util.get_api(api_key, f'{api_name}/{api_version}')
if not api:
from fastapi.responses import Response as StarletteResponse
+
return StarletteResponse(status_code=204, headers={'request_id': request_id})
origin = request.headers.get('origin') or request.headers.get('Origin')
- req_method = request.headers.get('access-control-request-method') or request.headers.get('Access-Control-Request-Method')
- req_headers = request.headers.get('access-control-request-headers') or request.headers.get('Access-Control-Request-Headers')
+ req_method = request.headers.get('access-control-request-method') or request.headers.get(
+ 'Access-Control-Request-Method'
+ )
+ req_headers = request.headers.get('access-control-request-headers') or request.headers.get(
+ 'Access-Control-Request-Headers'
+ )
ok, headers = GatewayService._compute_api_cors_headers(api, origin, req_method, req_headers)
if not ok and headers:
try:
@@ -663,14 +878,17 @@ async def graphql_preflight(request: Request, path: str):
pass
headers = {**(headers or {}), 'request_id': request_id}
from fastapi.responses import Response as StarletteResponse
+
return StarletteResponse(status_code=204, headers=headers)
except Exception:
from fastapi.responses import Response as StarletteResponse
+
return StarletteResponse(status_code=204, headers={'request_id': request_id})
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -689,26 +907,48 @@ Response:
{}
"""
-@gateway_router.api_route('/grpc/{path:path}', methods=['POST'],
- description='gRPC gateway endpoint',
- response_model=ResponseModel)
+@gateway_router.api_route(
+ '/grpc/{path:path}',
+ methods=['POST'],
+ description='gRPC gateway endpoint',
+ response_model=ResponseModel,
+)
async def grpc_gateway(request: Request, path: str):
- request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4())
+ request_id = (
+ getattr(request.state, 'request_id', None)
+ or request.headers.get('X-Request-ID')
+ or str(uuid.uuid4())
+ )
start_time = time.time() * 1000
try:
if not request.headers.get('X-API-Version'):
raise HTTPException(status_code=400, detail='X-API-Version header is required')
api_name = re.sub(r'^.*/', '', request.url.path)
- api_key = doorman_cache.get_cache('api_id_cache', api_name + '/' + request.headers.get('X-API-Version', 'v0'))
- api = await api_util.get_api(api_key, api_name + '/' + request.headers.get('X-API-Version', 'v0'))
+ api_key = doorman_cache.get_cache(
+ 'api_id_cache', api_name + '/' + request.headers.get('X-API-Version', 'v0')
+ )
+ api = await api_util.get_api(
+ api_key, api_name + '/' + request.headers.get('X-API-Version', 'v0')
+ )
if api:
try:
enforce_api_ip_policy(request, api)
except HTTPException as e:
- return process_response(ResponseModel(status_code=e.status_code, error_code=e.detail, error_message='IP restricted').dict(), 'grpc')
+ return process_response(
+ ResponseModel(
+ status_code=e.status_code,
+ error_code=e.detail,
+ error_message='IP restricted',
+ ).dict(),
+ 'grpc',
+ )
api_public = bool(api.get('api_public')) if api else False
- api_auth_required = bool(api.get('api_auth_required')) if api and api.get('api_auth_required') is not None else True
+ api_auth_required = (
+ bool(api.get('api_auth_required'))
+ if api and api.get('api_auth_required') is not None
+ else True
+ )
username = None
if not api_public:
if api_auth_required:
@@ -720,8 +960,12 @@ async def grpc_gateway(request: Request, path: str):
await enforce_pre_request_limit(request, username)
else:
pass
- logger.info(f"{request_id} | Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')[:-3]}ms")
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")[:-3]}ms'
+ )
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if api and api.get('validation_enabled'):
body = await request.json()
@@ -729,40 +973,47 @@ async def grpc_gateway(request: Request, path: str):
try:
await validation_util.validate_grpc_request(api.get('api_id'), request_data)
except Exception as e:
- return process_response(ResponseModel(
- status_code=400,
- response_headers={'request_id': request_id},
- error_code='GTW011',
- error_message=str(e)
- ).dict(), 'grpc')
- svc_resp = await GatewayService.grpc_gateway(username, request, request_id, start_time, path)
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='GTW011',
+ error_message=str(e),
+ ).dict(),
+ 'grpc',
+ )
+ svc_resp = await GatewayService.grpc_gateway(
+ username, request, request_id, start_time, path
+ )
if not isinstance(svc_resp, dict):
svc_resp = ResponseModel(
status_code=500,
response_headers={'request_id': request_id},
error_code='GTW006',
- error_message='Internal server error'
+ error_message='Internal server error',
).dict()
return process_response(svc_resp, 'grpc')
except HTTPException as e:
- return process_response(ResponseModel(
- status_code=e.status_code,
- response_headers={
- 'request_id': request_id
- },
- error_code=e.detail,
- error_message=e.detail
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={'request_id': request_id},
+ error_code=e.detail,
+ error_message=e.detail,
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'grpc')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'grpc',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/group_routes.py b/backend-services/routes/group_routes.py
index 25ac0b6..dce49e3 100644
--- a/backend-services/routes/group_routes.py
+++ b/backend-services/routes/group_routes.py
@@ -4,20 +4,20 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import List
-from fastapi import APIRouter, Depends, Request
-import uuid
-import time
import logging
+import time
+import uuid
+from fastapi import APIRouter, Request
+
+from models.create_group_model import CreateGroupModel
from models.group_model_response import GroupModelResponse
from models.response_model import ResponseModel
from models.update_group_model import UpdateGroupModel
from services.group_service import GroupService
from utils.auth_util import auth_required
-from models.create_group_model import CreateGroupModel
-from utils.response_util import respond_rest, process_response
-from utils.constants import Headers, Roles, ErrorCodes, Messages, Defaults
+from utils.constants import Defaults, ErrorCodes, Headers, Messages, Roles
+from utils.response_util import process_response, respond_rest
from utils.role_util import platform_role_required_bool
group_router = APIRouter()
@@ -33,55 +33,53 @@ Response:
{}
"""
-@group_router.post('',
+
+@group_router.post(
+ '',
description='Add group',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'Group created successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'Group created successfully'}}},
}
- }
+ },
)
-
async def create_group(api_data: CreateGroupModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_GROUPS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='GRP008',
- error_message='You do not have permission to create groups'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='GRP008',
+ error_message='You do not have permission to create groups',
+ )
+ )
return respond_rest(await GroupService.create_group(api_data, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Update group
@@ -91,55 +89,53 @@ Response:
{}
"""
-@group_router.put('/{group_name}',
+
+@group_router.put(
+ '/{group_name}',
description='Update group',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'Group updated successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'Group updated successfully'}}},
}
- }
+ },
)
-
async def update_group(group_name: str, api_data: UpdateGroupModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_GROUPS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='GRP009',
- error_message='You do not have permission to update groups'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='GRP009',
+ error_message='You do not have permission to update groups',
+ )
+ )
return respond_rest(await GroupService.update_group(group_name, api_data, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Delete group
@@ -149,55 +145,53 @@ Response:
{}
"""
-@group_router.delete('/{group_name}',
+
+@group_router.delete(
+ '/{group_name}',
description='Delete group',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'Group deleted successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'Group deleted successfully'}}},
}
- }
+ },
)
-
async def delete_group(group_name: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_GROUPS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='GRP010',
- error_message='You do not have permission to delete groups'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='GRP010',
+ error_message='You do not have permission to delete groups',
+ )
+ )
return respond_rest(await GroupService.delete_group(group_name, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -207,34 +201,37 @@ Response:
{}
"""
-@group_router.get('/all',
- description='Get all groups',
- response_model=List[GroupModelResponse]
-)
-async def get_groups(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE):
+@group_router.get('/all', description='Get all groups', response_model=list[GroupModelResponse])
+async def get_groups(
+ request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
return respond_rest(await GroupService.get_groups(page, page_size, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -244,30 +241,30 @@ Response:
{}
"""
-@group_router.get('/{group_name}',
- description='Get group',
- response_model=GroupModelResponse
-)
+@group_router.get('/{group_name}', description='Get group', response_model=GroupModelResponse)
async def get_group(group_name: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
return respond_rest(await GroupService.get_group(group_name, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/logging_routes.py b/backend-services/routes/logging_routes.py
index c73a985..9c65222 100644
--- a/backend-services/routes/logging_routes.py
+++ b/backend-services/routes/logging_routes.py
@@ -4,19 +4,19 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import List, Optional
-from fastapi import APIRouter, Depends, Request, Query, HTTPException
-from fastapi.responses import StreamingResponse
-import uuid
-import time
-import logging
import io
+import logging
+import time
+import uuid
+
+from fastapi import APIRouter, HTTPException, Query, Request
+from fastapi.responses import StreamingResponse
from models.response_model import ResponseModel
from services.logging_service import LoggingService
from utils.auth_util import auth_required
-from utils.response_util import respond_rest, process_response
-from utils.constants import Headers, Roles, ErrorCodes, Messages
+from utils.constants import ErrorCodes, Headers, Messages, Roles
+from utils.response_util import process_response, respond_rest
from utils.role_util import platform_role_required_bool
logging_router = APIRouter()
@@ -32,7 +32,9 @@ Response:
{}
"""
-@logging_router.get('/logs',
+
+@logging_router.get(
+ '/logs',
description='Get logs with filtering',
response_model=ResponseModel,
responses={
@@ -55,34 +57,33 @@ Response:
'response_time': '150.5',
'ip_address': '192.168.1.1',
'protocol': 'HTTP/1.1',
- 'request_id': '123e4567-e89b-12d3-a456-426614174000'
+ 'request_id': '123e4567-e89b-12d3-a456-426614174000',
}
],
'total': 100,
- 'has_more': False
+ 'has_more': False,
}
}
- }
+ },
}
- }
+ },
)
-
async def get_logs(
request: Request,
- start_date: Optional[str] = Query(None, description='Start date (YYYY-MM-DD)'),
- end_date: Optional[str] = Query(None, description='End date (YYYY-MM-DD)'),
- start_time: Optional[str] = Query(None, description='Start time (HH:MM)'),
- end_time: Optional[str] = Query(None, description='End time (HH:MM)'),
- user: Optional[str] = Query(None, description='Filter by user'),
- endpoint: Optional[str] = Query(None, description='Filter by endpoint'),
- request_id: Optional[str] = Query(None, description='Filter by request ID'),
- method: Optional[str] = Query(None, description='Filter by HTTP method'),
- ip_address: Optional[str] = Query(None, description='Filter by IP address'),
- min_response_time: Optional[str] = Query(None, description='Minimum response time (ms)'),
- max_response_time: Optional[str] = Query(None, description='Maximum response time (ms)'),
- level: Optional[str] = Query(None, description='Filter by log level'),
+ start_date: str | None = Query(None, description='Start date (YYYY-MM-DD)'),
+ end_date: str | None = Query(None, description='End date (YYYY-MM-DD)'),
+ start_time: str | None = Query(None, description='Start time (HH:MM)'),
+ end_time: str | None = Query(None, description='End time (HH:MM)'),
+ user: str | None = Query(None, description='Filter by user'),
+ endpoint: str | None = Query(None, description='Filter by endpoint'),
+ request_id: str | None = Query(None, description='Filter by request ID'),
+ method: str | None = Query(None, description='Filter by HTTP method'),
+ ip_address: str | None = Query(None, description='Filter by IP address'),
+ min_response_time: str | None = Query(None, description='Minimum response time (ms)'),
+ max_response_time: str | None = Query(None, description='Maximum response time (ms)'),
+ level: str | None = Query(None, description='Filter by log level'),
limit: int = Query(100, description='Number of logs to return', ge=1, le=1000),
- offset: int = Query(0, description='Number of logs to skip', ge=0)
+ offset: int = Query(0, description='Number of logs to skip', ge=0),
):
request_id_param = str(uuid.uuid4())
start_time_param = time.time() * 1000
@@ -90,18 +91,20 @@ async def get_logs(
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id_param} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id_param} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id_param} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.VIEW_LOGS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id_param
- },
- error_code='LOG001',
- error_message='You do not have permission to view logs'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id_param},
+ error_code='LOG001',
+ error_message='You do not have permission to view logs',
+ )
+ )
logging_service = LoggingService()
result = await logging_service.get_logs(
@@ -119,33 +122,32 @@ async def get_logs(
level=level,
limit=limit,
offset=offset,
- request_id_param=request_id_param
+ request_id_param=request_id_param,
)
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={
- 'request_id': request_id_param
- },
- response=result
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=200, response_headers={'request_id': request_id_param}, response=result
+ )
+ )
except HTTPException as e:
raise e
except Exception as e:
logger.critical(f'{request_id_param} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id_param
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id_param},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
end_time_param = time.time() * 1000
logger.info(f'{request_id_param} | Total time: {str(end_time_param - start_time_param)}ms')
+
"""
Endpoint
@@ -155,60 +157,59 @@ Response:
{}
"""
-@logging_router.get('/logs/files',
- description='Get list of available log files',
- response_model=ResponseModel
-)
+@logging_router.get(
+ '/logs/files', description='Get list of available log files', response_model=ResponseModel
+)
async def get_log_files(request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.VIEW_LOGS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='LOG005',
- error_message='You do not have permission to view log files'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='LOG005',
+ error_message='You do not have permission to view log files',
+ )
+ )
logging_service = LoggingService()
log_files = logging_service.get_available_log_files()
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={
- 'request_id': request_id
- },
- response={
- 'log_files': log_files,
- 'count': len(log_files)
- }
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ response={'log_files': log_files, 'count': len(log_files)},
+ )
+ )
except HTTPException as e:
raise e
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Get log statistics for dashboard
@@ -218,7 +219,9 @@ Response:
{}
"""
-@logging_router.get('/logs/statistics',
+
+@logging_router.get(
+ '/logs/statistics',
description='Get log statistics for dashboard',
response_model=ResponseModel,
responses={
@@ -235,69 +238,70 @@ Response:
'avg_response_time': 150.5,
'top_apis': [
{'name': 'customer', 'count': 500},
- {'name': 'orders', 'count': 300}
+ {'name': 'orders', 'count': 300},
],
'top_users': [
{'name': 'john_doe', 'count': 200},
- {'name': 'jane_smith', 'count': 150}
+ {'name': 'jane_smith', 'count': 150},
],
'top_endpoints': [
{'name': '/api/customer/v1/users', 'count': 100},
- {'name': '/api/orders/v1/orders', 'count': 80}
- ]
+ {'name': '/api/orders/v1/orders', 'count': 80},
+ ],
}
}
- }
+ },
}
- }
+ },
)
-
async def get_log_statistics(request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.VIEW_LOGS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='LOG002',
- error_message='You do not have permission to view log statistics'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='LOG002',
+ error_message='You do not have permission to view log statistics',
+ )
+ )
logging_service = LoggingService()
statistics = await logging_service.get_log_statistics(request_id)
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={
- 'request_id': request_id
- },
- response=statistics
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=200, response_headers={'request_id': request_id}, response=statistics
+ )
+ )
except HTTPException as e:
raise e
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Export logs in various formats
@@ -307,7 +311,9 @@ Response:
{}
"""
-@logging_router.get('/logs/export',
+
+@logging_router.get(
+ '/logs/export',
description='Export logs in various formats',
response_model=ResponseModel,
responses={
@@ -317,42 +323,44 @@ Response:
'application/json': {
'example': {
'format': 'json',
- 'data': '[{\"timestamp\": \"2024-01-01T12:00:00\", \"level\": \"INFO\"}]',
- 'filename': 'logs_export_20240101_120000.json'
+ 'data': '[{"timestamp": "2024-01-01T12:00:00", "level": "INFO"}]',
+ 'filename': 'logs_export_20240101_120000.json',
}
}
- }
+ },
}
- }
+ },
)
-
async def export_logs(
request: Request,
format: str = Query('json', description='Export format (json, csv)'),
- start_date: Optional[str] = Query(None, description='Start date (YYYY-MM-DD)'),
- end_date: Optional[str] = Query(None, description='End date (YYYY-MM-DD)'),
- user: Optional[str] = Query(None, description='Filter by user'),
- api: Optional[str] = Query(None, description='Filter by API'),
- endpoint: Optional[str] = Query(None, description='Filter by endpoint'),
- level: Optional[str] = Query(None, description='Filter by log level')
+ start_date: str | None = Query(None, description='Start date (YYYY-MM-DD)'),
+ end_date: str | None = Query(None, description='End date (YYYY-MM-DD)'),
+ user: str | None = Query(None, description='Filter by user'),
+ api: str | None = Query(None, description='Filter by API'),
+ endpoint: str | None = Query(None, description='Filter by endpoint'),
+ level: str | None = Query(None, description='Filter by log level'),
):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.EXPORT_LOGS):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='LOG003',
- error_message='You do not have permission to export logs'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='LOG003',
+ error_message='You do not have permission to export logs',
+ ).dict(),
+ 'rest',
+ )
logging_service = LoggingService()
@@ -371,31 +379,33 @@ async def export_logs(
start_date=start_date,
end_date=end_date,
filters=filters,
- request_id=request_id
+ request_id=request_id,
)
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- response=export_result
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=200,
+ response_headers={Headers.REQUEST_ID: request_id},
+ response=export_result,
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Download logs as file
@@ -405,47 +415,45 @@ Response:
{}
"""
-@logging_router.get('/logs/download',
+
+@logging_router.get(
+ '/logs/download',
description='Download logs as file',
include_in_schema=False,
responses={
- 200: {
- 'description': 'File download',
- 'content': {
- 'application/json': {},
- 'text/csv': {}
- }
- }
- }
+ 200: {'description': 'File download', 'content': {'application/json': {}, 'text/csv': {}}}
+ },
)
-
async def download_logs(
request: Request,
format: str = Query('json', description='Export format (json, csv)'),
- start_date: Optional[str] = Query(None, description='Start date (YYYY-MM-DD)'),
- end_date: Optional[str] = Query(None, description='End date (YYYY-MM-DD)'),
- user: Optional[str] = Query(None, description='Filter by user'),
- api: Optional[str] = Query(None, description='Filter by API'),
- endpoint: Optional[str] = Query(None, description='Filter by endpoint'),
- level: Optional[str] = Query(None, description='Filter by log level')
+ start_date: str | None = Query(None, description='Start date (YYYY-MM-DD)'),
+ end_date: str | None = Query(None, description='End date (YYYY-MM-DD)'),
+ user: str | None = Query(None, description='Filter by user'),
+ api: str | None = Query(None, description='Filter by API'),
+ endpoint: str | None = Query(None, description='Filter by endpoint'),
+ level: str | None = Query(None, description='Filter by log level'),
):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.EXPORT_LOGS):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='LOG004',
- error_message='You do not have permission to download logs'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='LOG004',
+ error_message='You do not have permission to download logs',
+ ).dict(),
+ 'rest',
+ )
logging_service = LoggingService()
@@ -464,11 +472,11 @@ async def download_logs(
start_date=start_date,
end_date=end_date,
filters=filters,
- request_id=request_id
+ request_id=request_id,
)
file_data = export_result['data'].encode('utf-8')
- file_obj = io.BytesIO(file_data)
+ io.BytesIO(file_data)
content_type = 'application/json' if format.lower() == 'json' else 'text/csv'
@@ -476,21 +484,22 @@ async def download_logs(
io.BytesIO(file_data),
media_type=content_type,
headers={
- 'Content-Disposition': f"attachment; filename={export_result['filename']}",
- Headers.REQUEST_ID: request_id
- }
+ 'Content-Disposition': f'attachment; filename={export_result["filename"]}',
+ Headers.REQUEST_ID: request_id,
+ },
)
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/memory_routes.py b/backend-services/routes/memory_routes.py
index 1fbedd5..01e46ab 100644
--- a/backend-services/routes/memory_routes.py
+++ b/backend-services/routes/memory_routes.py
@@ -2,26 +2,28 @@
Routes for dumping and restoring in-memory database state.
"""
-from fastapi import APIRouter, Request
-from typing import Optional
-from pydantic import BaseModel
-import os
-import uuid
-import time
import logging
+import os
+import time
+import uuid
+
+from fastapi import APIRouter, Request
+from pydantic import BaseModel
-from utils.response_util import process_response
from models.response_model import ResponseModel
from utils.auth_util import auth_required
-from utils.role_util import platform_role_required_bool
from utils.database import database
from utils.memory_dump_util import dump_memory_to_file, restore_memory_from_file
+from utils.response_util import process_response
+from utils.role_util import platform_role_required_bool
memory_router = APIRouter()
logger = logging.getLogger('doorman.gateway')
+
class DumpRequest(BaseModel):
- path: Optional[str] = None
+ path: str | None = None
+
"""
Endpoint
@@ -32,68 +34,88 @@ Response:
{}
"""
-@memory_router.post('/memory/dump',
+
+@memory_router.post(
+ '/memory/dump',
description='Dump in-memory database to an encrypted file',
response_model=ResponseModel,
)
-
-async def memory_dump(request: Request, body: Optional[DumpRequest] = None):
+async def memory_dump(request: Request, body: DumpRequest | None = None):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_security'):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='SEC003',
- error_message='You do not have permission to perform memory dump'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='SEC003',
+ error_message='You do not have permission to perform memory dump',
+ ).dict(),
+ 'rest',
+ )
if not database.memory_only:
- return process_response(ResponseModel(
- status_code=400,
- response_headers={'request_id': request_id},
- error_code='MEM001',
- error_message='Memory dump available only in memory-only mode'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='MEM001',
+ error_message='Memory dump available only in memory-only mode',
+ ).dict(),
+ 'rest',
+ )
if not os.getenv('MEM_ENCRYPTION_KEY'):
- return process_response(ResponseModel(
- status_code=400,
- response_headers={'request_id': request_id},
- error_code='MEM002',
- error_message='MEM_ENCRYPTION_KEY is not configured'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='MEM002',
+ error_message='MEM_ENCRYPTION_KEY is not configured',
+ ).dict(),
+ 'rest',
+ )
path = None
if body and body.path:
path = body.path
dump_path = dump_memory_to_file(path)
- return process_response(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- message='Memory dump created successfully',
- response={'response': {'path': dump_path}}
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ message='Memory dump created successfully',
+ response={'response': {'path': dump_path}},
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
class RestoreRequest(BaseModel):
- path: Optional[str] = None
+ path: str | None = None
+
"""
Endpoint
@@ -104,69 +126,90 @@ Response:
{}
"""
-@memory_router.post('/memory/restore',
+
+@memory_router.post(
+ '/memory/restore',
description='Restore in-memory database from an encrypted file',
response_model=ResponseModel,
)
-
-async def memory_restore(request: Request, body: Optional[RestoreRequest] = None):
+async def memory_restore(request: Request, body: RestoreRequest | None = None):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_security'):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='SEC004',
- error_message='You do not have permission to perform memory restore'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='SEC004',
+ error_message='You do not have permission to perform memory restore',
+ ).dict(),
+ 'rest',
+ )
if not database.memory_only:
- return process_response(ResponseModel(
- status_code=400,
- response_headers={'request_id': request_id},
- error_code='MEM001',
- error_message='Memory restore available only in memory-only mode'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='MEM001',
+ error_message='Memory restore available only in memory-only mode',
+ ).dict(),
+ 'rest',
+ )
if not os.getenv('MEM_ENCRYPTION_KEY'):
- return process_response(ResponseModel(
- status_code=400,
- response_headers={'request_id': request_id},
- error_code='MEM002',
- error_message='MEM_ENCRYPTION_KEY is not configured'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='MEM002',
+ error_message='MEM_ENCRYPTION_KEY is not configured',
+ ).dict(),
+ 'rest',
+ )
path = None
if body and body.path:
path = body.path
info = restore_memory_from_file(path)
- return process_response(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- message='Memory restore completed',
- response={'response': info}
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ message='Memory restore completed',
+ response={'response': info},
+ ).dict(),
+ 'rest',
+ )
except FileNotFoundError as e:
- return process_response(ResponseModel(
- status_code=404,
- response_headers={'request_id': request_id},
- error_code='MEM003',
- error_message=str(e)
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=404,
+ response_headers={'request_id': request_id},
+ error_code='MEM003',
+ error_message=str(e),
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/monitor_routes.py b/backend-services/routes/monitor_routes.py
index 5d01460..46f046e 100644
--- a/backend-services/routes/monitor_routes.py
+++ b/backend-services/routes/monitor_routes.py
@@ -2,28 +2,31 @@
Routes to expose gateway metrics to the web client.
"""
-from fastapi import APIRouter, Request
-from pydantic import BaseModel
-import uuid
-import time
-import logging
-import io
import csv
+import io
+import logging
+import time
+import uuid
+
+from fastapi import APIRouter, Request
from fastapi.responses import Response as FastAPIResponse
+from pydantic import BaseModel
from models.response_model import ResponseModel
-from utils.response_util import process_response
-from utils.metrics_util import metrics_store
from services.logging_service import LoggingService
from utils.auth_util import auth_required
-from utils.role_util import platform_role_required_bool
-from utils.health_check_util import check_mongodb, check_redis
-from utils.doorman_cache_util import doorman_cache
from utils.database import database
+from utils.doorman_cache_util import doorman_cache
+from utils.health_check_util import check_mongodb, check_redis
+from utils.metrics_util import metrics_store
+from utils.response_util import process_response
+from utils.role_util import platform_role_required_bool
+
class LivenessResponse(BaseModel):
status: str
+
class ReadinessResponse(BaseModel):
status: str
mongodb: bool | None = None
@@ -31,6 +34,7 @@ class ReadinessResponse(BaseModel):
mode: str | None = None
cache_backend: str | None = None
+
monitor_router = APIRouter()
logger = logging.getLogger('doorman.gateway')
@@ -43,26 +47,32 @@ Response:
{}
"""
-@monitor_router.get('/monitor/metrics',
- description='Get aggregated gateway metrics',
- response_model=ResponseModel,
-)
-async def get_metrics(request: Request, range: str = '24h', group: str = 'minute', sort: str = 'asc'):
+@monitor_router.get(
+ '/monitor/metrics', description='Get aggregated gateway metrics', response_model=ResponseModel
+)
+async def get_metrics(
+ request: Request, range: str = '24h', group: str = 'minute', sort: str = 'asc'
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_gateway'):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='MON001',
- error_message='You do not have permission to view monitor metrics'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='MON001',
+ error_message='You do not have permission to view monitor metrics',
+ ).dict(),
+ 'rest',
+ )
grp = (group or 'minute').lower()
if grp not in ('minute', 'day'):
grp = 'minute'
@@ -70,23 +80,28 @@ async def get_metrics(request: Request, range: str = '24h', group: str = 'minute
if srt not in ('asc', 'desc'):
srt = 'asc'
snap = metrics_store.snapshot(range, group=grp, sort=srt)
- return process_response(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response=snap
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=200, response_headers={'request_id': request_id}, response=snap
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -96,12 +111,16 @@ Response:
{}
"""
-@monitor_router.get('/monitor/liveness',
+
+@monitor_router.get(
+ '/monitor/liveness',
description='Kubernetes liveness probe endpoint (no auth)',
- response_model=LivenessResponse)
+ response_model=LivenessResponse,
+)
async def liveness(request: Request):
return {'status': 'alive'}
+
"""
Endpoint
@@ -111,9 +130,12 @@ Response:
{}
"""
-@monitor_router.get('/monitor/readiness',
+
+@monitor_router.get(
+ '/monitor/readiness',
description='Kubernetes readiness probe endpoint. Detailed status requires manage_gateway permission.',
- response_model=ReadinessResponse)
+ response_model=ReadinessResponse,
+)
async def readiness(request: Request):
"""Readiness probe endpoint.
@@ -128,7 +150,9 @@ async def readiness(request: Request):
try:
payload = await auth_required(request)
username = payload.get('sub')
- authorized = await platform_role_required_bool(username, 'manage_gateway') if username else False
+ authorized = (
+ await platform_role_required_bool(username, 'manage_gateway') if username else False
+ )
except Exception:
authorized = False
@@ -150,6 +174,7 @@ async def readiness(request: Request):
except Exception:
return {'status': 'degraded'}
+
"""
Endpoint
@@ -159,10 +184,12 @@ Response:
{}
"""
-@monitor_router.get('/monitor/report',
- description='Generate a CSV report for a date range (requires manage_gateway)',
- include_in_schema=False)
+@monitor_router.get(
+ '/monitor/report',
+ description='Generate a CSV report for a date range (requires manage_gateway)',
+ include_in_schema=False,
+)
async def generate_report(request: Request, start: str, end: str):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
@@ -170,15 +197,19 @@ async def generate_report(request: Request, start: str, end: str):
payload = await auth_required(request)
username = payload.get('sub')
if not await platform_role_required_bool(username, 'manage_gateway'):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='MON002',
- error_message='You do not have permission to generate reports'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='MON002',
+ error_message='You do not have permission to generate reports',
+ ).dict(),
+ 'rest',
+ )
def _parse_ts(s: str) -> int:
from datetime import datetime
+
fmt_variants = ['%Y-%m-%dT%H:%M', '%Y-%m-%d']
for fmt in fmt_variants:
try:
@@ -199,9 +230,11 @@ async def generate_report(request: Request, start: str, end: str):
ls = LoggingService()
import datetime as _dt
+
def _to_date_time(ts: int):
dt = _dt.datetime.utcfromtimestamp(ts)
return dt.strftime('%Y-%m-%d'), dt.strftime('%H:%M')
+
start_date, start_time_str = _to_date_time(start_ts)
end_date, end_time_str = _to_date_time(end_ts)
@@ -211,13 +244,17 @@ async def generate_report(request: Request, start: str, end: str):
max_pages = 100
for _ in range(max_pages):
batch = await ls.get_logs(
- start_date=start_date, end_date=end_date,
- start_time=start_time_str, end_time=end_time_str,
- limit=page_limit, offset=offset, request_id_param=request_id
+ start_date=start_date,
+ end_date=end_date,
+ start_time=start_time_str,
+ end_time=end_time_str,
+ limit=page_limit,
+ offset=offset,
+ request_id_param=request_id,
)
chunk = batch.get('logs', [])
logs.extend(chunk)
- total_batch = batch.get('total', 0)
+ batch.get('total', 0)
offset += page_limit
if not batch.get('has_more') or not chunk:
break
@@ -228,12 +265,13 @@ async def generate_report(request: Request, start: str, end: str):
parts = ep.split('/')
return f'rest:{parts[3]}' if len(parts) > 3 else 'rest:unknown'
if ep.startswith('/api/graphql/'):
- return f"graphql:{ep.split('/')[-1] or 'unknown'}"
+ return f'graphql:{ep.split("/")[-1] or "unknown"}'
if ep.startswith('/api/soap/'):
- return f"soap:{ep.split('/')[-1] or 'unknown'}"
+ return f'soap:{ep.split("/")[-1] or "unknown"}'
return 'platform'
except Exception:
return 'unknown'
+
total = 0
errors = 0
total_ms = 0.0
@@ -244,17 +282,20 @@ async def generate_report(request: Request, start: str, end: str):
for e in logs:
ep = str(e.get('endpoint') or '')
if not ep:
-
continue
total += 1
status_code = None
try:
- status_code = int(e.get('status_code')) if e.get('status_code') is not None else None
+ status_code = (
+ int(e.get('status_code')) if e.get('status_code') is not None else None
+ )
except Exception:
status_code = None
level = (e.get('level') or '').upper()
- is_error = (status_code is not None and status_code >= 400) or (level not in ('INFO', 'DEBUG'))
+ is_error = (status_code is not None and status_code >= 400) or (
+ level not in ('INFO', 'DEBUG')
+ )
if is_error:
errors += 1
if status_code is not None:
@@ -292,6 +333,7 @@ async def generate_report(request: Request, start: str, end: str):
total_bytes_in = sum(getattr(b, 'bytes_in', 0) for b in sel)
total_bytes_out = sum(getattr(b, 'bytes_out', 0) for b in sel)
from collections import defaultdict
+
daily_bw = defaultdict(lambda: {'in': 0, 'out': 0})
for b in sel:
day_ts = int((b.start_ts // 86400) * 86400)
@@ -308,7 +350,9 @@ async def generate_report(request: Request, start: str, end: str):
w.writerow(['total_requests', total])
w.writerow(['total_errors', errors])
w.writerow(['successes', max(total - errors, 0)])
- w.writerow(['success_rate', f'{(0 if total == 0 else (100.0 * (total - errors) / total)):.2f}%'])
+ w.writerow(
+ ['success_rate', f'{(0 if total == 0 else (100.0 * (total - errors) / total)):.2f}%']
+ )
w.writerow(['avg_response_ms', f'{avg_ms:.2f}'])
w.writerow([])
w.writerow(['Bandwidth Overview'])
@@ -341,7 +385,6 @@ async def generate_report(request: Request, start: str, end: str):
w.writerow(['Bandwidth (per day, UTC)'])
w.writerow(['date', 'bytes_in', 'bytes_out', 'total'])
for day_ts in sorted(daily_bw.keys()):
- import datetime as _dt
date_str = _dt.datetime.utcfromtimestamp(day_ts).strftime('%Y-%m-%d')
bi = int(daily_bw[day_ts]['in'])
bo = int(daily_bw[day_ts]['out'])
@@ -349,24 +392,32 @@ async def generate_report(request: Request, start: str, end: str):
csv_bytes = buf.getvalue().encode('utf-8')
filename = f'doorman_report_{start}_to_{end}.csv'
- return FastAPIResponse(content=csv_bytes, media_type='text/csv', headers={
- 'Content-Disposition': f'attachment; filename={filename}'
- })
+ return FastAPIResponse(
+ content=csv_bytes,
+ media_type='text/csv',
+ headers={'Content-Disposition': f'attachment; filename={filename}'},
+ )
except ValueError as ve:
- return process_response(ResponseModel(
- status_code=400,
- response_headers={'request_id': request_id},
- error_code='MON003',
- error_message=str(ve)
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='MON003',
+ error_message=str(ve),
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error in report: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/proto_routes.py b/backend-services/routes/proto_routes.py
index cfa1946..cc481be 100644
--- a/backend-services/routes/proto_routes.py
+++ b/backend-services/routes/proto_routes.py
@@ -4,21 +4,21 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from fastapi import APIRouter, Depends, Request, UploadFile, File, HTTPException
-from pathlib import Path
+import logging
import os
import re
-import logging
-import uuid
-import time
-from datetime import datetime
-import sys
import subprocess
+import sys
+import time
+import uuid
+from pathlib import Path
+
+from fastapi import APIRouter, File, HTTPException, Request, UploadFile
from models.response_model import ResponseModel
from utils.auth_util import auth_required
+from utils.constants import Defaults, ErrorCodes, Headers, Messages, Roles
from utils.response_util import process_response
-from utils.constants import Headers, Defaults, Roles, ErrorCodes, Messages
from utils.role_util import platform_role_required_bool
proto_router = APIRouter()
@@ -26,6 +26,7 @@ logger = logging.getLogger('doorman.gateway')
PROJECT_ROOT = Path(__file__).parent.resolve()
+
def sanitize_filename(filename: str):
"""Sanitize and validate filename with comprehensive security checks"""
if not filename:
@@ -55,10 +56,13 @@ def sanitize_filename(filename: str):
safe_pattern = re.compile(r'^[a-zA-Z0-9_\-\.]+$')
if not safe_pattern.match(sanitized):
- raise ValueError('Filename contains invalid characters (use only letters, numbers, underscore, dash, dot)')
+ raise ValueError(
+ 'Filename contains invalid characters (use only letters, numbers, underscore, dash, dot)'
+ )
return sanitized
+
def validate_path(base_path: Path, target_path: Path):
try:
base_path = Path(os.path.realpath(base_path))
@@ -71,6 +75,7 @@ def validate_path(base_path: Path, target_path: Path):
logger.error(f'Path validation error: {str(e)}')
return False
+
def validate_proto_content(content: bytes, max_size: int = 1024 * 1024) -> str:
"""Validate proto file content for security and correctness"""
if len(content) > max_size:
@@ -84,20 +89,21 @@ def validate_proto_content(content: bytes, max_size: int = 1024 * 1024) -> str:
except UnicodeDecodeError:
raise ValueError('Invalid proto file: not valid UTF-8')
- if 'syntax' not in content_str and 'message' not in content_str and 'service' not in content_str:
+ if (
+ 'syntax' not in content_str
+ and 'message' not in content_str
+ and 'service' not in content_str
+ ):
raise ValueError('Invalid proto file: missing proto syntax (syntax/message/service)')
- suspicious_patterns = [
- r'`',
- r'\$\(',
- r';\s*(?:rm|mv|cp|chmod|cat|wget|curl)',
- ]
+ suspicious_patterns = [r'`', r'\$\(', r';\s*(?:rm|mv|cp|chmod|cat|wget|curl)']
for pattern in suspicious_patterns:
if re.search(pattern, content_str):
raise ValueError('Invalid proto file: suspicious content detected')
return content_str
+
def get_safe_proto_path(api_name: str, api_version: str):
try:
safe_api_name = sanitize_filename(api_name)
@@ -108,19 +114,16 @@ def get_safe_proto_path(api_name: str, api_version: str):
proto_dir.mkdir(exist_ok=True)
generated_dir.mkdir(exist_ok=True)
proto_path = (proto_dir / f'{key}.proto').resolve()
- if not validate_path(PROJECT_ROOT, proto_path) or not validate_path(PROJECT_ROOT, generated_dir):
+ if not validate_path(PROJECT_ROOT, proto_path) or not validate_path(
+ PROJECT_ROOT, generated_dir
+ ):
raise ValueError('Invalid path detected')
return proto_path, generated_dir
except ValueError as e:
- raise HTTPException(
- status_code=400,
- detail=f'Path validation error: {str(e)}'
- )
+ raise HTTPException(status_code=400, detail=f'Path validation error: {str(e)}')
except Exception as e:
- raise HTTPException(
- status_code=500,
- detail=f'Failed to create safe paths: {str(e)}'
- )
+ raise HTTPException(status_code=500, detail=f'Failed to create safe paths: {str(e)}')
+
"""
Upload proto file
@@ -131,36 +134,43 @@ Response:
{}
"""
-@proto_router.post('/{api_name}/{api_version}',
+
+@proto_router.post(
+ '/{api_name}/{api_version}',
description='Upload proto file',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Proto file uploaded successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Proto file uploaded successfully'}}
+ },
}
- })
-
-async def upload_proto_file(api_name: str, api_version: str, file: UploadFile = File(...), request: Request = None):
+ },
+)
+async def upload_proto_file(
+ api_name: str, api_version: str, file: UploadFile = File(...), request: Request = None
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
- max_size = int(os.getenv(Defaults.MAX_MULTIPART_SIZE_BYTES_ENV, Defaults.MAX_MULTIPART_SIZE_BYTES_DEFAULT))
+ max_size = int(
+ os.getenv(
+ Defaults.MAX_MULTIPART_SIZE_BYTES_ENV, Defaults.MAX_MULTIPART_SIZE_BYTES_DEFAULT
+ )
+ )
cl = request.headers.get('content-length') if request else None
try:
if cl and int(cl) > max_size:
- return process_response(ResponseModel(
- status_code=413,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.REQUEST_TOO_LARGE,
- error_message=Messages.FILE_TOO_LARGE
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=413,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.REQUEST_TOO_LARGE,
+ error_message=Messages.FILE_TOO_LARGE,
+ ).dict(),
+ 'rest',
+ )
except Exception:
pass
payload = await auth_required(request)
@@ -168,20 +178,26 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile =
logger.info(f'{request_id} | Username: {username}')
logger.info(f'{request_id} | Endpoint: POST /proto/{api_name}/{api_version}')
if not await platform_role_required_bool(username, Roles.MANAGE_APIS):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.AUTH_REQUIRED,
- error_message=Messages.PERMISSION_MANAGE_APIS
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.AUTH_REQUIRED,
+ error_message=Messages.PERMISSION_MANAGE_APIS,
+ ).dict(),
+ 'rest',
+ )
original_name = file.filename or ''
if not original_name.lower().endswith('.proto'):
- return process_response(ResponseModel(
- status_code=400,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.REQUEST_FILE_TYPE,
- error_message=Messages.ONLY_PROTO_ALLOWED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.REQUEST_FILE_TYPE,
+ error_message=Messages.ONLY_PROTO_ALLOWED,
+ ).dict(),
+ 'rest',
+ )
proto_path, generated_dir = get_safe_proto_path(api_name, api_version)
content = await file.read()
@@ -189,40 +205,58 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile =
max_proto_size = int(os.getenv('MAX_PROTO_SIZE_BYTES', 1024 * 1024))
proto_content = validate_proto_content(content, max_size=max_proto_size)
except ValueError as e:
- return process_response(ResponseModel(
- status_code=400,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.REQUEST_FILE_TYPE,
- error_message=f'Invalid proto file: {str(e)}'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.REQUEST_FILE_TYPE,
+ error_message=f'Invalid proto file: {str(e)}',
+ ).dict(),
+ 'rest',
+ )
safe_api_name = sanitize_filename(api_name)
safe_api_version = sanitize_filename(api_version)
if 'package' in proto_content:
- proto_content = re.sub(r'package\s+[^;]+;', f'package {safe_api_name}_{safe_api_version};', proto_content)
+ proto_content = re.sub(
+ r'package\s+[^;]+;', f'package {safe_api_name}_{safe_api_version};', proto_content
+ )
else:
- proto_content = re.sub(r'syntax\s*=\s*"proto3";', f'syntax = "proto3";\n\npackage {safe_api_name}_{safe_api_version};', proto_content)
+ proto_content = re.sub(
+ r'syntax\s*=\s*"proto3";',
+ f'syntax = "proto3";\n\npackage {safe_api_name}_{safe_api_version};',
+ proto_content,
+ )
proto_path.write_text(proto_content)
try:
- subprocess.run([
- sys.executable, '-m', 'grpc_tools.protoc',
- f'--proto_path={proto_path.parent}',
- f'--python_out={generated_dir}',
- f'--grpc_python_out={generated_dir}',
- str(proto_path)
- ], check=True)
- logger.info(f"{request_id} | Proto compiled: src={proto_path} out={generated_dir}")
+ subprocess.run(
+ [
+ sys.executable,
+ '-m',
+ 'grpc_tools.protoc',
+ f'--proto_path={proto_path.parent}',
+ f'--python_out={generated_dir}',
+ f'--grpc_python_out={generated_dir}',
+ str(proto_path),
+ ],
+ check=True,
+ )
+ logger.info(f'{request_id} | Proto compiled: src={proto_path} out={generated_dir}')
init_path = (generated_dir / '__init__.py').resolve()
if not validate_path(generated_dir, init_path):
raise ValueError('Invalid init path')
if not init_path.exists():
init_path.write_text('"""Generated gRPC code."""\n')
- pb2_grpc_file = (generated_dir / f'{safe_api_name}_{safe_api_version}_pb2_grpc.py').resolve()
+ pb2_grpc_file = (
+ generated_dir / f'{safe_api_name}_{safe_api_version}_pb2_grpc.py'
+ ).resolve()
if not validate_path(generated_dir, pb2_grpc_file):
raise ValueError('Invalid grpc file path')
if pb2_grpc_file.exists():
content = pb2_grpc_file.read_text()
# Double-check sanitized values contain only safe characters before using in regex
- if not re.match(r'^[a-zA-Z0-9_\-\.]+$', safe_api_name) or not re.match(r'^[a-zA-Z0-9_\-\.]+$', safe_api_version):
+ if not re.match(r'^[a-zA-Z0-9_\-\.]+$', safe_api_name) or not re.match(
+ r'^[a-zA-Z0-9_\-\.]+$', safe_api_version
+ ):
raise ValueError('Invalid characters in sanitized API name or version')
escaped_mod = re.escape(f'{safe_api_name}_{safe_api_version}_pb2')
import_pattern = rf'^import {escaped_mod} as (.+)$'
@@ -231,22 +265,34 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile =
for i, line in enumerate(lines, 1):
if 'import' in line and 'pb2' in line:
logger.info(f'{request_id} | Line {i}: {repr(line)}')
- new_content = re.sub(import_pattern, rf'from generated import {safe_api_name}_{safe_api_version}_pb2 as \1', content, flags=re.MULTILINE)
+ new_content = re.sub(
+ import_pattern,
+ rf'from generated import {safe_api_name}_{safe_api_version}_pb2 as \1',
+ content,
+ flags=re.MULTILINE,
+ )
if new_content != content:
logger.info(f'{request_id} | Import fix applied successfully')
pb2_grpc_file.write_text(new_content)
- logger.info(f"{request_id} | Wrote fixed pb2_grpc at {pb2_grpc_file}")
+ logger.info(f'{request_id} | Wrote fixed pb2_grpc at {pb2_grpc_file}')
pycache_dir = (generated_dir / '__pycache__').resolve()
if not validate_path(generated_dir, pycache_dir):
- logger.warning(f'{request_id} | Unsafe pycache path detected. Skipping cache cleanup.')
+ logger.warning(
+ f'{request_id} | Unsafe pycache path detected. Skipping cache cleanup.'
+ )
elif pycache_dir.exists():
- for pyc_file in pycache_dir.glob(f'{safe_api_name}_{safe_api_version}*.pyc'):
+ for pyc_file in pycache_dir.glob(
+ f'{safe_api_name}_{safe_api_version}*.pyc'
+ ):
try:
pyc_file.unlink()
logger.info(f'{request_id} | Deleted cache file: {pyc_file.name}')
except Exception as e:
- logger.warning(f'{request_id} | Failed to delete cache file {pyc_file.name}: {e}')
+ logger.warning(
+ f'{request_id} | Failed to delete cache file {pyc_file.name}: {e}'
+ )
import sys as sys_import
+
pb2_module_name = f'{safe_api_name}_{safe_api_version}_pb2'
pb2_grpc_module_name = f'{safe_api_name}_{safe_api_version}_pb2_grpc'
if pb2_module_name in sys_import.modules:
@@ -254,51 +300,78 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile =
logger.info(f'{request_id} | Cleared {pb2_module_name} from sys.modules')
if pb2_grpc_module_name in sys_import.modules:
del sys_import.modules[pb2_grpc_module_name]
- logger.info(f'{request_id} | Cleared {pb2_grpc_module_name} from sys.modules')
+ logger.info(
+ f'{request_id} | Cleared {pb2_grpc_module_name} from sys.modules'
+ )
else:
- logger.warning(f'{request_id} | Import fix pattern did not match - no changes made')
+ logger.warning(
+ f'{request_id} | Import fix pattern did not match - no changes made'
+ )
try:
# Reuse escaped_mod which was already validated above
rel_pattern = rf'^from \\. import {escaped_mod} as (.+)$'
content2 = pb2_grpc_file.read_text()
- new2 = re.sub(rel_pattern, rf'from generated import {safe_api_name}_{safe_api_version}_pb2 as \\1', content2, flags=re.MULTILINE)
+ new2 = re.sub(
+ rel_pattern,
+ rf'from generated import {safe_api_name}_{safe_api_version}_pb2 as \\1',
+ content2,
+ flags=re.MULTILINE,
+ )
if new2 != content2:
pb2_grpc_file.write_text(new2)
- logger.info(f"{request_id} | Applied relative import rewrite for module {safe_api_name}_{safe_api_version}_pb2")
+ logger.info(
+ f'{request_id} | Applied relative import rewrite for module {safe_api_name}_{safe_api_version}_pb2'
+ )
except Exception as e:
- logger.warning(f"{request_id} | Failed relative import rewrite: {e}")
- return process_response(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- message='Proto file uploaded and gRPC code generated successfully'
- ).dict(), 'rest')
+ logger.warning(f'{request_id} | Failed relative import rewrite: {e}')
+ return process_response(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ message='Proto file uploaded and gRPC code generated successfully',
+ ).dict(),
+ 'rest',
+ )
except subprocess.CalledProcessError as e:
logger.error(f'{request_id} | Failed to generate gRPC code: {str(e)}')
- return process_response(ResponseModel(
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.GRPC_GENERATION_FAILED,
+ error_message=f'{Messages.GRPC_GEN_FAILED}: {str(e)}',
+ ).dict(),
+ 'rest',
+ )
+ except HTTPException as e:
+ logger.error(f'{request_id} | Path validation error: {str(e)}')
+ return process_response(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.PATH_VALIDATION,
+ error_message=str(e.detail),
+ ).dict(),
+ 'rest',
+ )
+ except Exception as e:
+ logger.error(
+ f'{request_id} | Error uploading proto file: {type(e).__name__}: {str(e)}',
+ exc_info=True,
+ )
+ return process_response(
+ ResponseModel(
status_code=500,
response_headers={Headers.REQUEST_ID: request_id},
error_code=ErrorCodes.GRPC_GENERATION_FAILED,
- error_message=f'{Messages.GRPC_GEN_FAILED}: {str(e)}'
- ).dict(), 'rest')
- except HTTPException as e:
- logger.error(f'{request_id} | Path validation error: {str(e)}')
- return process_response(ResponseModel(
- status_code=e.status_code,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.PATH_VALIDATION,
- error_message=str(e.detail)
- ).dict(), 'rest')
- except Exception as e:
- logger.error(f'{request_id} | Error uploading proto file: {type(e).__name__}: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.GRPC_GENERATION_FAILED,
- error_message=f'Failed to upload proto file: {str(e)}'
- ).dict(), 'rest')
+ error_message=f'Failed to upload proto file: {str(e)}',
+ ).dict(),
+ 'rest',
+ )
finally:
logger.info(f'{request_id} | Total time: {time.time() * 1000 - start_time}ms')
+
"""
Get proto file
@@ -308,23 +381,20 @@ Response:
{}
"""
-@proto_router.get('/{api_name}/{api_version}',
+
+@proto_router.get(
+ '/{api_name}/{api_version}',
description='Get proto file',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Proto file retrieved successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Proto file retrieved successfully'}}
+ },
}
- }
+ },
)
-
async def get_proto_file(api_name: str, api_version: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
@@ -334,46 +404,62 @@ async def get_proto_file(api_name: str, api_version: str, request: Request):
logger.info(f'{request_id} | Endpoint: {request.method} {request.url.path}')
try:
if not await platform_role_required_bool(username, Roles.MANAGE_APIS):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.AUTH_REQUIRED,
- error_message=Messages.PERMISSION_MANAGE_APIS
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.AUTH_REQUIRED,
+ error_message=Messages.PERMISSION_MANAGE_APIS,
+ ).dict(),
+ 'rest',
+ )
proto_path, _ = get_safe_proto_path(api_name, api_version)
if not proto_path.exists():
- return process_response(ResponseModel(
- status_code=404,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.API_NOT_FOUND,
- error_message=f'Proto file not found for API {api_name}/{api_version}'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=404,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.API_NOT_FOUND,
+ error_message=f'Proto file not found for API {api_name}/{api_version}',
+ ).dict(),
+ 'rest',
+ )
proto_content = proto_path.read_text()
- return process_response(ResponseModel(
- status_code=200,
- response_headers={Headers.REQUEST_ID: request_id},
- message='Proto file retrieved successfully',
- response={'content': proto_content}
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=200,
+ response_headers={Headers.REQUEST_ID: request_id},
+ message='Proto file retrieved successfully',
+ response={'content': proto_content},
+ ).dict(),
+ 'rest',
+ )
except HTTPException as e:
logger.error(f'{request_id} | Path validation error: {str(e)}')
- return process_response(ResponseModel(
- status_code=e.status_code,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.PATH_VALIDATION,
- error_message=str(e.detail)
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.PATH_VALIDATION,
+ error_message=str(e.detail),
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.error(f'{request_id} | Failed to get proto file: {str(e)}')
- return process_response(ResponseModel(
- status_code=500,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.API_NOT_FOUND,
- error_message=f'Failed to get proto file: {str(e)}'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.API_NOT_FOUND,
+ error_message=f'Failed to get proto file: {str(e)}',
+ ).dict(),
+ 'rest',
+ )
finally:
logger.info(f'{request_id} | Total time: {time.time() * 1000 - start_time}ms')
+
"""
Update proto file
@@ -383,46 +469,53 @@ Response:
{}
"""
-@proto_router.put('/{api_name}/{api_version}',
+
+@proto_router.put(
+ '/{api_name}/{api_version}',
description='Update proto file',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Proto file updated successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Proto file updated successfully'}}
+ },
}
- }
+ },
)
-
-async def update_proto_file(api_name: str, api_version: str, request: Request, proto_file: UploadFile = File(...)):
+async def update_proto_file(
+ api_name: str, api_version: str, request: Request, proto_file: UploadFile = File(...)
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_APIS):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code='API008',
- error_message='You do not have permission to update proto files'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='API008',
+ error_message='You do not have permission to update proto files',
+ ).dict(),
+ 'rest',
+ )
original_name = proto_file.filename or ''
if not original_name.lower().endswith('.proto'):
- return process_response(ResponseModel(
- status_code=400,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.REQUEST_FILE_TYPE,
- error_message=Messages.ONLY_PROTO_ALLOWED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.REQUEST_FILE_TYPE,
+ error_message=Messages.ONLY_PROTO_ALLOWED,
+ ).dict(),
+ 'rest',
+ )
proto_path, generated_dir = get_safe_proto_path(api_name, api_version)
content = await proto_file.read()
@@ -430,55 +523,76 @@ async def update_proto_file(api_name: str, api_version: str, request: Request, p
max_proto_size = int(os.getenv('MAX_PROTO_SIZE_BYTES', 1024 * 1024))
proto_content = validate_proto_content(content, max_size=max_proto_size)
except ValueError as e:
- return process_response(ResponseModel(
- status_code=400,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.REQUEST_FILE_TYPE,
- error_message=f'Invalid proto file: {str(e)}'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.REQUEST_FILE_TYPE,
+ error_message=f'Invalid proto file: {str(e)}',
+ ).dict(),
+ 'rest',
+ )
proto_path.write_text(proto_content)
try:
- subprocess.run([
- sys.executable, '-m', 'grpc_tools.protoc',
- f'--proto_path={proto_path.parent}',
- f'--python_out={generated_dir}',
- f'--grpc_python_out={generated_dir}',
- str(proto_path)
- ], check=True)
+ subprocess.run(
+ [
+ sys.executable,
+ '-m',
+ 'grpc_tools.protoc',
+ f'--proto_path={proto_path.parent}',
+ f'--python_out={generated_dir}',
+ f'--grpc_python_out={generated_dir}',
+ str(proto_path),
+ ],
+ check=True,
+ )
except subprocess.CalledProcessError as e:
logger.error(f'{request_id} | Failed to generate gRPC code: {str(e)}')
- return process_response(ResponseModel(
- status_code=500,
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='API009',
+ error_message='Failed to generate gRPC code from proto file',
+ ).dict(),
+ 'rest',
+ )
+ return process_response(
+ ResponseModel(
+ status_code=200,
response_headers={Headers.REQUEST_ID: request_id},
- error_code='API009',
- error_message='Failed to generate gRPC code from proto file'
- ).dict(), 'rest')
- return process_response(ResponseModel(
- status_code=200,
- response_headers={Headers.REQUEST_ID: request_id},
- message='Proto file updated successfully'
- ).dict(), 'rest')
+ message='Proto file updated successfully',
+ ).dict(),
+ 'rest',
+ )
except HTTPException as e:
logger.error(f'{request_id} | Path validation error: {str(e)}')
- return process_response(ResponseModel(
- status_code=e.status_code,
- response_headers={'request_id': request_id},
- error_code='GTW013',
- error_message=str(e.detail)
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={'request_id': request_id},
+ error_code='GTW013',
+ error_message=str(e.detail),
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Delete proto file
@@ -488,38 +602,40 @@ Response:
{}
"""
-@proto_router.delete('/{api_name}/{api_version}',
+
+@proto_router.delete(
+ '/{api_name}/{api_version}',
description='Delete proto file',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Proto file deleted successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Proto file deleted successfully'}}
+ },
}
- }
+ },
)
-
async def delete_proto_file(api_name: str, api_version: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_APIS):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code='API008',
- error_message='You do not have permission to delete proto files'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='API008',
+ error_message='You do not have permission to delete proto files',
+ ).dict(),
+ 'rest',
+ )
proto_path, generated_dir = get_safe_proto_path(api_name, api_version)
safe_api_name = sanitize_filename(api_name)
safe_api_version = sanitize_filename(api_version)
@@ -529,36 +645,52 @@ async def delete_proto_file(api_name: str, api_version: str, request: Request):
raise ValueError('Unsafe proto file path detected')
proto_path.unlink()
logger.info(f'{request_id} | Deleted proto file: {proto_path}')
- generated_files = [f'{key}_pb2.py', f'{key}_pb2.pyc', f'{key}_pb2_grpc.py', f'{key}_pb2_grpc.pyc']
+ generated_files = [
+ f'{key}_pb2.py',
+ f'{key}_pb2.pyc',
+ f'{key}_pb2_grpc.py',
+ f'{key}_pb2_grpc.pyc',
+ ]
for file in generated_files:
file_path = (generated_dir / file).resolve()
if not validate_path(generated_dir, file_path):
- logger.warning(f'{request_id} | Unsafe file path detected: {file_path}. Skipping deletion.')
+ logger.warning(
+ f'{request_id} | Unsafe file path detected: {file_path}. Skipping deletion.'
+ )
continue
if file_path.exists():
file_path.unlink()
logger.info(f'{request_id} | Deleted generated file: {file_path}')
- return process_response(ResponseModel(
- status_code=200,
- response_headers={Headers.REQUEST_ID: request_id},
- message='Proto file and generated files deleted successfully'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=200,
+ response_headers={Headers.REQUEST_ID: request_id},
+ message='Proto file and generated files deleted successfully',
+ ).dict(),
+ 'rest',
+ )
except ValueError as e:
logger.error(f'{request_id} | Path validation error: {str(e)}')
- return process_response(ResponseModel(
- status_code=400,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.PATH_VALIDATION,
- error_message=str(e)
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.PATH_VALIDATION,
+ error_message=str(e),
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/quota_routes.py b/backend-services/routes/quota_routes.py
index f552cad..3cad36c 100644
--- a/backend-services/routes/quota_routes.py
+++ b/backend-services/routes/quota_routes.py
@@ -6,15 +6,15 @@ User-facing endpoints for checking current usage and limits.
"""
import logging
-from typing import Optional
from datetime import datetime
-from fastapi import APIRouter, HTTPException, Depends, status
+
+from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
-from models.rate_limit_models import QuotaType, TierLimits
+from models.rate_limit_models import QuotaType
from services.tier_service import TierService, get_tier_service
-from utils.quota_tracker import QuotaTracker, get_quota_tracker
from utils.database_async import async_database
+from utils.quota_tracker import QuotaTracker, get_quota_tracker
logger = logging.getLogger(__name__)
@@ -25,8 +25,10 @@ quota_router = APIRouter()
# RESPONSE MODELS
# ============================================================================
+
class QuotaStatusResponse(BaseModel):
"""Response model for quota status"""
+
quota_type: str
current_usage: int
limit: int
@@ -43,16 +45,18 @@ class QuotaStatusResponse(BaseModel):
class TierInfoResponse(BaseModel):
"""Response model for tier information"""
+
tier_id: str
tier_name: str
display_name: str
limits: dict
- price_monthly: Optional[float]
+ price_monthly: float | None
features: list
class QuotaDashboardResponse(BaseModel):
"""Complete quota dashboard response"""
+
user_id: str
tier_info: TierInfoResponse
quotas: list
@@ -63,6 +67,7 @@ class QuotaDashboardResponse(BaseModel):
# DEPENDENCY INJECTION
# ============================================================================
+
async def get_quota_tracker_dep() -> QuotaTracker:
"""Dependency to get quota tracker"""
return get_quota_tracker()
@@ -77,27 +82,28 @@ async def get_tier_service_dep() -> TierService:
def get_current_user_id() -> str:
"""
Get current user ID from request context
-
+
In production, this should extract from JWT token or session.
For now, returns a placeholder.
"""
# TODO: Extract from auth middleware
- return "current_user"
+ return 'current_user'
# ============================================================================
# QUOTA STATUS ENDPOINTS
# ============================================================================
-@quota_router.get("/status", response_model=QuotaDashboardResponse)
+
+@quota_router.get('/status', response_model=QuotaDashboardResponse)
async def get_quota_status(
user_id: str = Depends(get_current_user_id),
quota_tracker: QuotaTracker = Depends(get_quota_tracker_dep),
- tier_service: TierService = Depends(get_tier_service_dep)
+ tier_service: TierService = Depends(get_tier_service_dep),
):
"""
Get complete quota status for current user
-
+
Returns:
- Current tier information
- All quota usages
@@ -106,82 +112,78 @@ async def get_quota_status(
try:
# Get user's tier
tier = await tier_service.get_user_tier(user_id)
-
+
if not tier:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="No tier assigned to user"
+ status_code=status.HTTP_404_NOT_FOUND, detail='No tier assigned to user'
)
-
+
# Get user's effective limits (including overrides)
limits = await tier_service.get_user_limits(user_id)
-
+
if not limits:
limits = tier.limits
-
+
# Check all quotas
quotas = []
-
+
# Monthly request quota
if limits.monthly_request_quota:
result = quota_tracker.check_quota(
- user_id,
- QuotaType.REQUESTS,
- limits.monthly_request_quota,
- 'month'
+ user_id, QuotaType.REQUESTS, limits.monthly_request_quota, 'month'
)
- quotas.append(QuotaStatusResponse(
- quota_type='monthly_requests',
- current_usage=result.current_usage,
- limit=result.limit,
- remaining=result.remaining,
- percentage_used=result.percentage_used,
- reset_at=result.reset_at.isoformat(),
- is_warning=result.is_warning,
- is_critical=result.is_critical,
- is_exhausted=result.is_exhausted
- ))
-
+ quotas.append(
+ QuotaStatusResponse(
+ quota_type='monthly_requests',
+ current_usage=result.current_usage,
+ limit=result.limit,
+ remaining=result.remaining,
+ percentage_used=result.percentage_used,
+ reset_at=result.reset_at.isoformat(),
+ is_warning=result.is_warning,
+ is_critical=result.is_critical,
+ is_exhausted=result.is_exhausted,
+ )
+ )
+
# Daily request quota
if limits.daily_request_quota:
result = quota_tracker.check_quota(
- user_id,
- QuotaType.REQUESTS,
- limits.daily_request_quota,
- 'day'
+ user_id, QuotaType.REQUESTS, limits.daily_request_quota, 'day'
)
- quotas.append(QuotaStatusResponse(
- quota_type='daily_requests',
- current_usage=result.current_usage,
- limit=result.limit,
- remaining=result.remaining,
- percentage_used=result.percentage_used,
- reset_at=result.reset_at.isoformat(),
- is_warning=result.is_warning,
- is_critical=result.is_critical,
- is_exhausted=result.is_exhausted
- ))
-
+ quotas.append(
+ QuotaStatusResponse(
+ quota_type='daily_requests',
+ current_usage=result.current_usage,
+ limit=result.limit,
+ remaining=result.remaining,
+ percentage_used=result.percentage_used,
+ reset_at=result.reset_at.isoformat(),
+ is_warning=result.is_warning,
+ is_critical=result.is_critical,
+ is_exhausted=result.is_exhausted,
+ )
+ )
+
# Monthly bandwidth quota
if limits.monthly_bandwidth_quota:
result = quota_tracker.check_quota(
- user_id,
- QuotaType.BANDWIDTH,
- limits.monthly_bandwidth_quota,
- 'month'
+ user_id, QuotaType.BANDWIDTH, limits.monthly_bandwidth_quota, 'month'
)
- quotas.append(QuotaStatusResponse(
- quota_type='monthly_bandwidth',
- current_usage=result.current_usage,
- limit=result.limit,
- remaining=result.remaining,
- percentage_used=result.percentage_used,
- reset_at=result.reset_at.isoformat(),
- is_warning=result.is_warning,
- is_critical=result.is_critical,
- is_exhausted=result.is_exhausted
- ))
-
+ quotas.append(
+ QuotaStatusResponse(
+ quota_type='monthly_bandwidth',
+ current_usage=result.current_usage,
+ limit=result.limit,
+ remaining=result.remaining,
+ percentage_used=result.percentage_used,
+ reset_at=result.reset_at.isoformat(),
+ is_warning=result.is_warning,
+ is_critical=result.is_critical,
+ is_exhausted=result.is_exhausted,
+ )
+ )
+
# Build tier info
tier_info = TierInfoResponse(
tier_id=tier.tier_id,
@@ -189,85 +191,79 @@ async def get_quota_status(
display_name=tier.display_name,
limits=limits.to_dict(),
price_monthly=tier.price_monthly,
- features=tier.features
+ features=tier.features,
)
-
+
# Build usage summary
total_usage = sum(q.current_usage for q in quotas if 'requests' in q.quota_type)
total_limit = sum(q.limit for q in quotas if 'requests' in q.quota_type)
-
+
usage_summary = {
'total_requests_used': total_usage,
'total_requests_limit': total_limit,
'has_warnings': any(q.is_warning for q in quotas),
'has_critical': any(q.is_critical for q in quotas),
- 'has_exhausted': any(q.is_exhausted for q in quotas)
+ 'has_exhausted': any(q.is_exhausted for q in quotas),
}
-
+
return QuotaDashboardResponse(
- user_id=user_id,
- tier_info=tier_info,
- quotas=quotas,
- usage_summary=usage_summary
+ user_id=user_id, tier_info=tier_info, quotas=quotas, usage_summary=usage_summary
)
-
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error getting quota status: {e}")
+ logger.error(f'Error getting quota status: {e}')
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to get quota status"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get quota status'
)
-@quota_router.get("/status/{quota_type}")
+@quota_router.get('/status/{quota_type}')
async def get_specific_quota_status(
quota_type: str,
user_id: str = Depends(get_current_user_id),
quota_tracker: QuotaTracker = Depends(get_quota_tracker_dep),
- tier_service: TierService = Depends(get_tier_service_dep)
+ tier_service: TierService = Depends(get_tier_service_dep),
):
"""
Get status for a specific quota type
-
+
Args:
quota_type: Type of quota (monthly_requests, daily_requests, monthly_bandwidth)
"""
try:
# Get user's limits
limits = await tier_service.get_user_limits(user_id)
-
+
if not limits:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="No limits found for user"
+ status_code=status.HTTP_404_NOT_FOUND, detail='No limits found for user'
)
-
+
# Map quota type to limit and period
quota_mapping = {
'monthly_requests': (QuotaType.REQUESTS, limits.monthly_request_quota, 'month'),
'daily_requests': (QuotaType.REQUESTS, limits.daily_request_quota, 'day'),
- 'monthly_bandwidth': (QuotaType.BANDWIDTH, limits.monthly_bandwidth_quota, 'month')
+ 'monthly_bandwidth': (QuotaType.BANDWIDTH, limits.monthly_bandwidth_quota, 'month'),
}
-
+
if quota_type not in quota_mapping:
raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=f"Invalid quota type: {quota_type}"
+ status_code=status.HTTP_400_BAD_REQUEST, detail=f'Invalid quota type: {quota_type}'
)
-
+
q_type, limit, period = quota_mapping[quota_type]
-
+
if not limit:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
- detail=f"Quota {quota_type} not configured for user"
+ detail=f'Quota {quota_type} not configured for user',
)
-
+
# Check quota
result = quota_tracker.check_quota(user_id, q_type, limit, period)
-
+
return QuotaStatusResponse(
quota_type=quota_type,
current_usage=result.current_usage,
@@ -277,180 +273,159 @@ async def get_specific_quota_status(
reset_at=result.reset_at.isoformat(),
is_warning=result.is_warning,
is_critical=result.is_critical,
- is_exhausted=result.is_exhausted
+ is_exhausted=result.is_exhausted,
)
-
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error getting quota status: {e}")
+ logger.error(f'Error getting quota status: {e}')
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to get quota status"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get quota status'
)
-@quota_router.get("/usage/history")
+@quota_router.get('/usage/history')
async def get_usage_history(
user_id: str = Depends(get_current_user_id),
- quota_tracker: QuotaTracker = Depends(get_quota_tracker_dep)
+ quota_tracker: QuotaTracker = Depends(get_quota_tracker_dep),
):
"""
Get historical usage data
-
+
Returns usage history for the past 6 months.
"""
try:
# Get history from quota tracker
- history = quota_tracker.get_quota_history(
- user_id,
- QuotaType.REQUESTS,
- months=6
- )
-
+ history = quota_tracker.get_quota_history(user_id, QuotaType.REQUESTS, months=6)
+
return {
'user_id': user_id,
'history': history,
- 'note': 'Historical tracking not yet fully implemented'
+ 'note': 'Historical tracking not yet fully implemented',
}
-
+
except Exception as e:
- logger.error(f"Error getting usage history: {e}")
+ logger.error(f'Error getting usage history: {e}')
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to get usage history"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get usage history'
)
-@quota_router.post("/usage/export")
+@quota_router.post('/usage/export')
async def export_usage_data(
format: str = 'json',
user_id: str = Depends(get_current_user_id),
quota_tracker: QuotaTracker = Depends(get_quota_tracker_dep),
- tier_service: TierService = Depends(get_tier_service_dep)
+ tier_service: TierService = Depends(get_tier_service_dep),
):
"""
Export usage data in JSON or CSV format
-
+
Args:
format: Export format ('json' or 'csv')
"""
try:
# Get current quota status
limits = await tier_service.get_user_limits(user_id)
-
+
if not limits:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="No limits found for user"
+ status_code=status.HTTP_404_NOT_FOUND, detail='No limits found for user'
)
-
+
# Collect all quota data
- export_data = {
- 'user_id': user_id,
- 'export_date': datetime.now().isoformat(),
- 'quotas': []
- }
-
+ export_data = {'user_id': user_id, 'export_date': datetime.now().isoformat(), 'quotas': []}
+
# Add monthly requests
if limits.monthly_request_quota:
result = quota_tracker.check_quota(
- user_id,
- QuotaType.REQUESTS,
- limits.monthly_request_quota,
- 'month'
+ user_id, QuotaType.REQUESTS, limits.monthly_request_quota, 'month'
)
- export_data['quotas'].append({
- 'type': 'monthly_requests',
- 'current_usage': result.current_usage,
- 'limit': result.limit,
- 'remaining': result.remaining,
- 'percentage_used': result.percentage_used,
- 'reset_at': result.reset_at.isoformat()
- })
-
+ export_data['quotas'].append(
+ {
+ 'type': 'monthly_requests',
+ 'current_usage': result.current_usage,
+ 'limit': result.limit,
+ 'remaining': result.remaining,
+ 'percentage_used': result.percentage_used,
+ 'reset_at': result.reset_at.isoformat(),
+ }
+ )
+
# Add daily requests
if limits.daily_request_quota:
result = quota_tracker.check_quota(
- user_id,
- QuotaType.REQUESTS,
- limits.daily_request_quota,
- 'day'
+ user_id, QuotaType.REQUESTS, limits.daily_request_quota, 'day'
)
- export_data['quotas'].append({
- 'type': 'daily_requests',
- 'current_usage': result.current_usage,
- 'limit': result.limit,
- 'remaining': result.remaining,
- 'percentage_used': result.percentage_used,
- 'reset_at': result.reset_at.isoformat()
- })
-
+ export_data['quotas'].append(
+ {
+ 'type': 'daily_requests',
+ 'current_usage': result.current_usage,
+ 'limit': result.limit,
+ 'remaining': result.remaining,
+ 'percentage_used': result.percentage_used,
+ 'reset_at': result.reset_at.isoformat(),
+ }
+ )
+
if format == 'csv':
# Convert to CSV format
csv_lines = ['Type,Current Usage,Limit,Remaining,Percentage Used,Reset At']
for quota in export_data['quotas']:
csv_lines.append(
- f"{quota['type']},{quota['current_usage']},{quota['limit']},"
- f"{quota['remaining']},{quota['percentage_used']:.2f},{quota['reset_at']}"
+ f'{quota["type"]},{quota["current_usage"]},{quota["limit"]},'
+ f'{quota["remaining"]},{quota["percentage_used"]:.2f},{quota["reset_at"]}'
)
-
- return {
- 'format': 'csv',
- 'data': '\n'.join(csv_lines)
- }
+
+ return {'format': 'csv', 'data': '\n'.join(csv_lines)}
else:
# Return JSON
- return {
- 'format': 'json',
- 'data': export_data
- }
-
+ return {'format': 'json', 'data': export_data}
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error exporting usage data: {e}")
+ logger.error(f'Error exporting usage data: {e}')
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to export usage data"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to export usage data'
)
-@quota_router.get("/tier/info")
+@quota_router.get('/tier/info')
async def get_tier_info(
user_id: str = Depends(get_current_user_id),
- tier_service: TierService = Depends(get_tier_service_dep)
+ tier_service: TierService = Depends(get_tier_service_dep),
):
"""
Get current tier information for user
-
+
Returns tier details, benefits, and upgrade options.
"""
try:
# Get user's tier
tier = await tier_service.get_user_tier(user_id)
-
+
if not tier:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="No tier assigned to user"
+ status_code=status.HTTP_404_NOT_FOUND, detail='No tier assigned to user'
)
-
+
# Get all available tiers for upgrade options
all_tiers = await tier_service.list_tiers(enabled_only=True)
-
+
# Filter tiers higher than current (for upgrades)
upgrade_options = [
{
'tier_id': t.tier_id,
'display_name': t.display_name,
'price_monthly': t.price_monthly,
- 'features': t.features
+ 'features': t.features,
}
for t in all_tiers
if t.tier_id != tier.tier_id
]
-
+
return {
'current_tier': TierInfoResponse(
tier_id=tier.tier_id,
@@ -458,45 +433,43 @@ async def get_tier_info(
display_name=tier.display_name,
limits=tier.limits.to_dict(),
price_monthly=tier.price_monthly,
- features=tier.features
+ features=tier.features,
),
- 'upgrade_options': upgrade_options
+ 'upgrade_options': upgrade_options,
}
-
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error getting tier info: {e}")
+ logger.error(f'Error getting tier info: {e}')
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to get tier info"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get tier info'
)
-@quota_router.get("/burst/status")
+@quota_router.get('/burst/status')
async def get_burst_status(
user_id: str = Depends(get_current_user_id),
- tier_service: TierService = Depends(get_tier_service_dep)
+ tier_service: TierService = Depends(get_tier_service_dep),
):
"""
Get burst usage status for current user
-
+
Returns burst token consumption across different time windows.
"""
try:
# Get user's tier and limits
tier = await tier_service.get_user_tier(user_id)
-
+
if not tier:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="No tier assigned to user"
+ status_code=status.HTTP_404_NOT_FOUND, detail='No tier assigned to user'
)
-
+
limits = await tier_service.get_user_limits(user_id)
if not limits:
limits = tier.limits
-
+
# TODO: Get actual burst usage from Redis
# For now, return structure with placeholder data
burst_status = {
@@ -504,23 +477,22 @@ async def get_burst_status(
'burst_limits': {
'per_minute': limits.burst_per_minute,
'per_hour': limits.burst_per_hour,
- 'per_second': limits.burst_per_second
+ 'per_second': limits.burst_per_second,
},
'burst_usage': {
'per_minute': 0, # TODO: Get from Redis
'per_hour': 0,
- 'per_second': 0
+ 'per_second': 0,
},
- 'note': 'Burst tracking requires rate limiter integration'
+ 'note': 'Burst tracking requires rate limiter integration',
}
-
+
return burst_status
-
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error getting burst status: {e}")
+ logger.error(f'Error getting burst status: {e}')
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to get burst status"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get burst status'
)
diff --git a/backend-services/routes/rate_limit_rule_routes.py b/backend-services/routes/rate_limit_rule_routes.py
index 914fad6..b61f7a7 100644
--- a/backend-services/routes/rate_limit_rule_routes.py
+++ b/backend-services/routes/rate_limit_rule_routes.py
@@ -5,8 +5,8 @@ FastAPI routes for managing rate limit rules.
"""
import logging
-from typing import List, Optional
-from fastapi import APIRouter, HTTPException, Depends, status, Query
+
+from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, Field
from models.rate_limit_models import RateLimitRule, RuleType, TimeWindow
@@ -22,58 +22,69 @@ rate_limit_rule_router = APIRouter()
# REQUEST/RESPONSE MODELS
# ============================================================================
+
class RuleCreateRequest(BaseModel):
"""Request model for creating a rule"""
- rule_id: str = Field(..., description="Unique rule identifier")
- rule_type: str = Field(..., description="Rule type (per_user, per_api, per_endpoint, per_ip, global)")
- time_window: str = Field(..., description="Time window (second, minute, hour, day, month)")
- limit: int = Field(..., gt=0, description="Maximum requests allowed")
- target_identifier: Optional[str] = Field(None, description="Target (user ID, API name, endpoint, IP)")
- burst_allowance: int = Field(0, ge=0, description="Additional burst requests")
- priority: int = Field(0, description="Rule priority (higher = checked first)")
- enabled: bool = Field(True, description="Whether rule is enabled")
- description: Optional[str] = Field(None, description="Rule description")
+
+ rule_id: str = Field(..., description='Unique rule identifier')
+ rule_type: str = Field(
+ ..., description='Rule type (per_user, per_api, per_endpoint, per_ip, global)'
+ )
+ time_window: str = Field(..., description='Time window (second, minute, hour, day, month)')
+ limit: int = Field(..., gt=0, description='Maximum requests allowed')
+ target_identifier: str | None = Field(
+ None, description='Target (user ID, API name, endpoint, IP)'
+ )
+ burst_allowance: int = Field(0, ge=0, description='Additional burst requests')
+ priority: int = Field(0, description='Rule priority (higher = checked first)')
+ enabled: bool = Field(True, description='Whether rule is enabled')
+ description: str | None = Field(None, description='Rule description')
class RuleUpdateRequest(BaseModel):
"""Request model for updating a rule"""
- limit: Optional[int] = Field(None, gt=0)
- target_identifier: Optional[str] = None
- burst_allowance: Optional[int] = Field(None, ge=0)
- priority: Optional[int] = None
- enabled: Optional[bool] = None
- description: Optional[str] = None
+
+ limit: int | None = Field(None, gt=0)
+ target_identifier: str | None = None
+ burst_allowance: int | None = Field(None, ge=0)
+ priority: int | None = None
+ enabled: bool | None = None
+ description: str | None = None
class BulkRuleRequest(BaseModel):
"""Request model for bulk operations"""
- rule_ids: List[str]
+
+ rule_ids: list[str]
class RuleDuplicateRequest(BaseModel):
"""Request model for duplicating a rule"""
+
new_rule_id: str
class RuleResponse(BaseModel):
"""Response model for rule"""
+
rule_id: str
rule_type: str
time_window: str
limit: int
- target_identifier: Optional[str]
+ target_identifier: str | None
burst_allowance: int
priority: int
enabled: bool
- description: Optional[str]
- created_at: Optional[str]
- updated_at: Optional[str]
+ description: str | None
+ created_at: str | None
+ updated_at: str | None
# ============================================================================
# DEPENDENCY INJECTION
# ============================================================================
+
async def get_rule_service_dep() -> RateLimitRuleService:
"""Dependency to get rule service"""
return get_rate_limit_rule_service(async_database.db)
@@ -83,10 +94,10 @@ async def get_rule_service_dep() -> RateLimitRuleService:
# RULE CRUD ENDPOINTS
# ============================================================================
-@rate_limit_rule_router.post("/", response_model=RuleResponse, status_code=status.HTTP_201_CREATED)
+
+@rate_limit_rule_router.post('/', response_model=RuleResponse, status_code=status.HTTP_201_CREATED)
async def create_rule(
- request: RuleCreateRequest,
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
+ request: RuleCreateRequest, rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
):
"""Create a new rate limit rule"""
try:
@@ -99,281 +110,309 @@ async def create_rule(
burst_allowance=request.burst_allowance,
priority=request.priority,
enabled=request.enabled,
- description=request.description
+ description=request.description,
)
-
+
# Validate rule
errors = rule_service.validate_rule(rule)
if errors:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail={"errors": errors})
-
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail={'errors': errors})
+
created_rule = await rule_service.create_rule(rule)
-
+
return RuleResponse(**created_rule.to_dict())
-
+
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
- logger.error(f"Error creating rule: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create rule")
+ logger.error(f'Error creating rule: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to create rule'
+ )
-@rate_limit_rule_router.get("/", response_model=List[RuleResponse])
+@rate_limit_rule_router.get('/', response_model=list[RuleResponse])
async def list_rules(
- rule_type: Optional[str] = Query(None, description="Filter by rule type"),
- enabled_only: bool = Query(False, description="Only return enabled rules"),
+ rule_type: str | None = Query(None, description='Filter by rule type'),
+ enabled_only: bool = Query(False, description='Only return enabled rules'),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
+ rule_service: RateLimitRuleService = Depends(get_rule_service_dep),
):
"""List all rate limit rules"""
try:
rule_type_enum = RuleType(rule_type) if rule_type else None
rules = await rule_service.list_rules(
- rule_type=rule_type_enum,
- enabled_only=enabled_only,
- skip=skip,
- limit=limit
+ rule_type=rule_type_enum, enabled_only=enabled_only, skip=skip, limit=limit
)
-
+
return [RuleResponse(**rule.to_dict()) for rule in rules]
-
+
except Exception as e:
- logger.error(f"Error listing rules: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to list rules")
+ logger.error(f'Error listing rules: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to list rules'
+ )
-@rate_limit_rule_router.get("/search")
+@rate_limit_rule_router.get('/search')
async def search_rules(
- q: str = Query(..., description="Search term"),
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
+ q: str = Query(..., description='Search term'),
+ rule_service: RateLimitRuleService = Depends(get_rule_service_dep),
):
"""Search rules by ID, description, or target"""
try:
rules = await rule_service.search_rules(q)
return [RuleResponse(**rule.to_dict()) for rule in rules]
-
+
except Exception as e:
- logger.error(f"Error searching rules: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to search rules")
+ logger.error(f'Error searching rules: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to search rules'
+ )
-@rate_limit_rule_router.get("/{rule_id}", response_model=RuleResponse)
+@rate_limit_rule_router.get('/{rule_id}', response_model=RuleResponse)
async def get_rule(
- rule_id: str,
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
+ rule_id: str, rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
):
"""Get a specific rule by ID"""
try:
rule = await rule_service.get_rule(rule_id)
-
+
if not rule:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Rule {rule_id} not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail=f'Rule {rule_id} not found'
+ )
+
return RuleResponse(**rule.to_dict())
-
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error getting rule: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get rule")
+ logger.error(f'Error getting rule: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get rule'
+ )
-@rate_limit_rule_router.put("/{rule_id}", response_model=RuleResponse)
+@rate_limit_rule_router.put('/{rule_id}', response_model=RuleResponse)
async def update_rule(
rule_id: str,
request: RuleUpdateRequest,
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
+ rule_service: RateLimitRuleService = Depends(get_rule_service_dep),
):
"""Update a rate limit rule"""
try:
updates = {k: v for k, v in request.dict().items() if v is not None}
-
+
updated_rule = await rule_service.update_rule(rule_id, updates)
-
+
if not updated_rule:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Rule {rule_id} not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail=f'Rule {rule_id} not found'
+ )
+
return RuleResponse(**updated_rule.to_dict())
-
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error updating rule: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update rule")
+ logger.error(f'Error updating rule: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to update rule'
+ )
-@rate_limit_rule_router.delete("/{rule_id}", status_code=status.HTTP_204_NO_CONTENT)
+@rate_limit_rule_router.delete('/{rule_id}', status_code=status.HTTP_204_NO_CONTENT)
async def delete_rule(
- rule_id: str,
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
+ rule_id: str, rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
):
"""Delete a rate limit rule"""
try:
deleted = await rule_service.delete_rule(rule_id)
-
+
if not deleted:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Rule {rule_id} not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail=f'Rule {rule_id} not found'
+ )
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error deleting rule: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to delete rule")
+ logger.error(f'Error deleting rule: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to delete rule'
+ )
-@rate_limit_rule_router.post("/{rule_id}/enable", response_model=RuleResponse)
+@rate_limit_rule_router.post('/{rule_id}/enable', response_model=RuleResponse)
async def enable_rule(
- rule_id: str,
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
+ rule_id: str, rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
):
"""Enable a rule"""
try:
rule = await rule_service.enable_rule(rule_id)
-
+
if not rule:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Rule {rule_id} not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail=f'Rule {rule_id} not found'
+ )
+
return RuleResponse(**rule.to_dict())
-
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error enabling rule: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to enable rule")
+ logger.error(f'Error enabling rule: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to enable rule'
+ )
-@rate_limit_rule_router.post("/{rule_id}/disable", response_model=RuleResponse)
+@rate_limit_rule_router.post('/{rule_id}/disable', response_model=RuleResponse)
async def disable_rule(
- rule_id: str,
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
+ rule_id: str, rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
):
"""Disable a rule"""
try:
rule = await rule_service.disable_rule(rule_id)
-
+
if not rule:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Rule {rule_id} not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail=f'Rule {rule_id} not found'
+ )
+
return RuleResponse(**rule.to_dict())
-
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error disabling rule: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to disable rule")
+ logger.error(f'Error disabling rule: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to disable rule'
+ )
# ============================================================================
# BULK OPERATIONS
# ============================================================================
-@rate_limit_rule_router.post("/bulk/delete")
+
+@rate_limit_rule_router.post('/bulk/delete')
async def bulk_delete_rules(
- request: BulkRuleRequest,
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
+ request: BulkRuleRequest, rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
):
"""Delete multiple rules at once"""
try:
count = await rule_service.bulk_delete_rules(request.rule_ids)
- return {"deleted_count": count}
-
+ return {'deleted_count': count}
+
except Exception as e:
- logger.error(f"Error bulk deleting rules: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to delete rules")
+ logger.error(f'Error bulk deleting rules: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to delete rules'
+ )
-@rate_limit_rule_router.post("/bulk/enable")
+@rate_limit_rule_router.post('/bulk/enable')
async def bulk_enable_rules(
- request: BulkRuleRequest,
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
+ request: BulkRuleRequest, rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
):
"""Enable multiple rules at once"""
try:
count = await rule_service.bulk_enable_rules(request.rule_ids)
- return {"enabled_count": count}
-
+ return {'enabled_count': count}
+
except Exception as e:
- logger.error(f"Error bulk enabling rules: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to enable rules")
+ logger.error(f'Error bulk enabling rules: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to enable rules'
+ )
-@rate_limit_rule_router.post("/bulk/disable")
+@rate_limit_rule_router.post('/bulk/disable')
async def bulk_disable_rules(
- request: BulkRuleRequest,
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
+ request: BulkRuleRequest, rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
):
"""Disable multiple rules at once"""
try:
count = await rule_service.bulk_disable_rules(request.rule_ids)
- return {"disabled_count": count}
-
+ return {'disabled_count': count}
+
except Exception as e:
- logger.error(f"Error bulk disabling rules: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to disable rules")
+ logger.error(f'Error bulk disabling rules: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to disable rules'
+ )
# ============================================================================
# RULE DUPLICATION
# ============================================================================
-@rate_limit_rule_router.post("/{rule_id}/duplicate", response_model=RuleResponse, status_code=status.HTTP_201_CREATED)
+
+@rate_limit_rule_router.post(
+ '/{rule_id}/duplicate', response_model=RuleResponse, status_code=status.HTTP_201_CREATED
+)
async def duplicate_rule(
rule_id: str,
request: RuleDuplicateRequest,
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
+ rule_service: RateLimitRuleService = Depends(get_rule_service_dep),
):
"""Duplicate an existing rule"""
try:
new_rule = await rule_service.duplicate_rule(rule_id, request.new_rule_id)
return RuleResponse(**new_rule.to_dict())
-
+
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
- logger.error(f"Error duplicating rule: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to duplicate rule")
+ logger.error(f'Error duplicating rule: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to duplicate rule'
+ )
# ============================================================================
# STATISTICS
# ============================================================================
-@rate_limit_rule_router.get("/statistics/summary")
-async def get_rule_statistics(
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
-):
+
+@rate_limit_rule_router.get('/statistics/summary')
+async def get_rule_statistics(rule_service: RateLimitRuleService = Depends(get_rule_service_dep)):
"""Get statistics about rate limit rules"""
try:
stats = await rule_service.get_rule_statistics()
return stats
-
+
except Exception as e:
- logger.error(f"Error getting rule statistics: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get statistics")
+ logger.error(f'Error getting rule statistics: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get statistics'
+ )
# ============================================================================
# RATE LIMIT STATUS ENDPOINT (User-facing)
# ============================================================================
-@rate_limit_rule_router.get("/status")
-async def get_rate_limit_status(
- rule_service: RateLimitRuleService = Depends(get_rule_service_dep)
-):
+
+@rate_limit_rule_router.get('/status')
+async def get_rate_limit_status(rule_service: RateLimitRuleService = Depends(get_rule_service_dep)):
"""
Get current rate limit status for the authenticated user
-
+
Returns applicable rate limit rules and current usage.
This is a user-facing endpoint showing their current limits.
"""
try:
# TODO: Get user_id from auth middleware
- user_id = "current_user"
-
+ user_id = 'current_user'
+
# Get applicable rules for user
rules = await rule_service.get_applicable_rules(user_id=user_id)
-
+
# Format response
status_info = {
'user_id': user_id,
@@ -384,19 +423,18 @@ async def get_rate_limit_status(
'time_window': rule.time_window.value,
'limit': rule.limit,
'burst_allowance': rule.burst_allowance,
- 'description': rule.description
+ 'description': rule.description,
}
for rule in rules
],
- 'note': 'Use /platform/quota/status for detailed usage information'
+ 'note': 'Use /platform/quota/status for detailed usage information',
}
-
+
return status_info
-
+
except Exception as e:
- logger.error(f"Error getting rate limit status: {e}")
+ logger.error(f'Error getting rate limit status: {e}')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to get rate limit status"
+ detail='Failed to get rate limit status',
)
-
diff --git a/backend-services/routes/role_routes.py b/backend-services/routes/role_routes.py
index d633101..51ca0b5 100644
--- a/backend-services/routes/role_routes.py
+++ b/backend-services/routes/role_routes.py
@@ -4,21 +4,21 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import List
-from fastapi import APIRouter, Depends, Request
-import uuid
-import time
import logging
+import time
+import uuid
+from fastapi import APIRouter, Request
+
+from models.create_role_model import CreateRoleModel
from models.response_model import ResponseModel
from models.role_model_response import RoleModelResponse
from models.update_role_model import UpdateRoleModel
from services.role_service import RoleService
from utils.auth_util import auth_required
-from models.create_role_model import CreateRoleModel
+from utils.constants import Defaults, ErrorCodes, Headers, Messages, Roles
from utils.response_util import respond_rest
-from utils.constants import Headers, Roles, ErrorCodes, Messages, Defaults
-from utils.role_util import platform_role_required_bool, is_admin_role, is_admin_user
+from utils.role_util import is_admin_role, is_admin_user, platform_role_required_bool
role_router = APIRouter()
@@ -33,68 +33,70 @@ Response:
{}
"""
-@role_router.post('',
+
+@role_router.post(
+ '',
description='Add role',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'Role created successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'Role created successfully'}}},
}
- }
+ },
)
-
async def create_role(api_data: CreateRoleModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_ROLES):
logger.error(f'{request_id} | User does not have permission to create roles')
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='ROLE009',
- error_message='You do not have permission to create roles'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='ROLE009',
+ error_message='You do not have permission to create roles',
+ )
+ )
try:
- incoming_is_admin = bool(getattr(api_data, 'platform_admin', False)) or api_data.role_name.strip().lower() in ('admin', 'platform admin')
+ incoming_is_admin = bool(
+ getattr(api_data, 'platform_admin', False)
+ ) or api_data.role_name.strip().lower() in ('admin', 'platform admin')
except Exception:
incoming_is_admin = False
if incoming_is_admin:
if not await is_admin_user(username):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code='ROLE013',
- error_message='Only admin may create the admin role'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='ROLE013',
+ error_message='Only admin may create the admin role',
+ )
+ )
return respond_rest(await RoleService.create_role(api_data, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Update role
@@ -104,76 +106,80 @@ Response:
{}
"""
-@role_router.put('/{role_name}',
+
+@role_router.put(
+ '/{role_name}',
description='Update role',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'Role updated successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'Role updated successfully'}}},
}
- }
+ },
)
-
async def update_role(role_name: str, api_data: UpdateRoleModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
# DEBUG: Log the incoming data
logger.info(f'{request_id} | DEBUG: Received model data: {api_data.dict()}')
-
+
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_ROLES):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='ROLE010',
- error_message='You do not have permission to update roles'
- ))
- target_is_admin = await is_admin_role(role_name)
- if target_is_admin and not await is_admin_user(username):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code='ROLE014',
- error_message='Only admin may modify the admin role'
- ))
- try:
- if getattr(api_data, 'platform_admin', None) is not None and not await is_admin_user(username):
- return respond_rest(ResponseModel(
+ return respond_rest(
+ ResponseModel(
status_code=403,
response_headers={Headers.REQUEST_ID: request_id},
- error_code='ROLE015',
- error_message='Only admin may change admin designation'
- ))
+ error_code='ROLE010',
+ error_message='You do not have permission to update roles',
+ )
+ )
+ target_is_admin = await is_admin_role(role_name)
+ if target_is_admin and not await is_admin_user(username):
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='ROLE014',
+ error_message='Only admin may modify the admin role',
+ )
+ )
+ try:
+ if getattr(api_data, 'platform_admin', None) is not None and not await is_admin_user(
+ username
+ ):
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='ROLE015',
+ error_message='Only admin may change admin designation',
+ )
+ )
except Exception:
pass
return respond_rest(await RoleService.update_role(role_name, api_data, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Delete role
@@ -183,63 +189,63 @@ Response:
{}
"""
-@role_router.delete('/{role_name}',
+
+@role_router.delete(
+ '/{role_name}',
description='Delete role',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'Role deleted successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'Role deleted successfully'}}},
}
- }
+ },
)
-
async def delete_role(role_name: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_ROLES):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='ROLE011',
- error_message='You do not have permission to delete roles'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='ROLE011',
+ error_message='You do not have permission to delete roles',
+ )
+ )
target_is_admin = await is_admin_role(role_name)
if target_is_admin and not await is_admin_user(username):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={Headers.REQUEST_ID: request_id},
- error_code='ROLE016',
- error_message='Only admin may delete the admin role'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='ROLE016',
+ error_message='Only admin may delete the admin role',
+ )
+ )
return respond_rest(await RoleService.delete_role(role_name, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -249,18 +255,19 @@ Response:
{}
"""
-@role_router.get('/all',
- description='Get all roles',
- response_model=List[RoleModelResponse]
-)
-async def get_roles(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE):
+@role_router.get('/all', description='Get all roles', response_model=list[RoleModelResponse])
+async def get_roles(
+ request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
data = await RoleService.get_roles(page, page_size, request_id)
try:
@@ -283,18 +290,19 @@ async def get_roles(request: Request, page: int = Defaults.PAGE, page_size: int
return respond_rest(data)
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -304,25 +312,26 @@ Response:
{}
"""
-@role_router.get('/{role_name}',
- description='Get role',
- response_model=RoleModelResponse
-)
+@role_router.get('/{role_name}', description='Get role', response_model=RoleModelResponse)
async def get_role(role_name: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if await is_admin_role(role_name) and not await is_admin_user(username):
- return respond_rest(ResponseModel(
- status_code=404,
- response_headers={Headers.REQUEST_ID: request_id},
- error_message='Role not found'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=404,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_message='Role not found',
+ )
+ )
data = await RoleService.get_role(role_name, request_id)
try:
if data.get('status_code') == 200 and not await is_admin_user(username):
@@ -337,22 +346,25 @@ async def get_role(role_name: str, request: Request):
return respond_rest(data)
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
-@role_router.get('',
- description='Get all roles (base path)',
- response_model=List[RoleModelResponse]
+
+
+@role_router.get(
+ '', description='Get all roles (base path)', response_model=list[RoleModelResponse]
)
-async def get_roles_base(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE):
+async def get_roles_base(
+ request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE
+):
"""Convenience alias for GET /platform/role/all to support clients/tests
that expect listing at the base collection path.
"""
diff --git a/backend-services/routes/routing_routes.py b/backend-services/routes/routing_routes.py
index 940b23e..59ea70e 100644
--- a/backend-services/routes/routing_routes.py
+++ b/backend-services/routes/routing_routes.py
@@ -4,11 +4,11 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import List
-from fastapi import APIRouter, Depends, Request
-import uuid
-import time
import logging
+import time
+import uuid
+
+from fastapi import APIRouter, Request
from models.create_routing_model import CreateRoutingModel
from models.response_model import ResponseModel
@@ -16,7 +16,7 @@ from models.routing_model_response import RoutingModelResponse
from models.update_routing_model import UpdateRoutingModel
from services.routing_service import RoutingService
from utils.auth_util import auth_required
-from utils.response_util import respond_rest, process_response
+from utils.response_util import process_response, respond_rest
from utils.role_util import platform_role_required_bool
routing_router = APIRouter()
@@ -32,55 +32,56 @@ Response:
{}
"""
-@routing_router.post('',
+
+@routing_router.post(
+ '',
description='Add routing',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Routing created successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Routing created successfully'}}
+ },
}
- }
+ },
)
-
async def create_routing(api_data: CreateRoutingModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_routings'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- 'request_id': request_id
- },
- error_code='RTG009',
- error_message='You do not have permission to create routings'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='RTG009',
+ error_message='You do not have permission to create routings',
+ )
+ )
return respond_rest(await RoutingService.create_routing(api_data, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Update routing
@@ -90,55 +91,56 @@ Response:
{}
"""
-@routing_router.put('/{client_key}',
+
+@routing_router.put(
+ '/{client_key}',
description='Update routing',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Routing updated successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Routing updated successfully'}}
+ },
}
- }
+ },
)
-
async def update_routing(client_key: str, api_data: UpdateRoutingModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_routings'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- 'request_id': request_id
- },
- error_code='RTG010',
- error_message='You do not have permission to update routings'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='RTG010',
+ error_message='You do not have permission to update routings',
+ )
+ )
return respond_rest(await RoutingService.update_routing(client_key, api_data, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Delete routing
@@ -148,55 +150,56 @@ Response:
{}
"""
-@routing_router.delete('/{client_key}',
+
+@routing_router.delete(
+ '/{client_key}',
description='Delete routing',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Routing deleted successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Routing deleted successfully'}}
+ },
}
- }
+ },
)
-
async def delete_routing(client_key: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_routings'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- 'request_id': request_id
- },
- error_code='RTG011',
- error_message='You do not have permission to delete routings'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='RTG011',
+ error_message='You do not have permission to delete routings',
+ )
+ )
return respond_rest(await RoutingService.delete_routing(client_key, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -206,43 +209,46 @@ Response:
{}
"""
-@routing_router.get('/all',
- description='Get all routings',
- response_model=List[RoutingModelResponse]
-)
+@routing_router.get(
+ '/all', description='Get all routings', response_model=list[RoutingModelResponse]
+)
async def get_routings(request: Request, page: int = 1, page_size: int = 10):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_routings'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- 'request_id': request_id
- },
- error_code='RTG012',
- error_message='You do not have permission to get routings'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='RTG012',
+ error_message='You do not have permission to get routings',
+ )
+ )
return respond_rest(await RoutingService.get_routings(page, page_size, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -252,39 +258,39 @@ Response:
{}
"""
-@routing_router.get('/{client_key}',
- description='Get routing',
- response_model=RoutingModelResponse
-)
+@routing_router.get('/{client_key}', description='Get routing', response_model=RoutingModelResponse)
async def get_routing(client_key: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_routings'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- 'request_id': request_id
- },
- error_code='RTG013',
- error_message='You do not have permission to get routings'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='RTG013',
+ error_message='You do not have permission to get routings',
+ )
+ )
return respond_rest(await RoutingService.get_routing(client_key, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/security_routes.py b/backend-services/routes/security_routes.py
index c8e2c26..7026fee 100644
--- a/backend-services/routes/security_routes.py
+++ b/backend-services/routes/security_routes.py
@@ -2,22 +2,22 @@
Routes for managing security settings.
"""
-from fastapi import APIRouter, Request
-from typing import Optional
-import os
-import sys
-import subprocess
-import uuid
-import time
import logging
+import os
+import subprocess
+import sys
+import time
+import uuid
+
+from fastapi import APIRouter, Request
from models.response_model import ResponseModel
from models.security_settings_model import SecuritySettingsModel
-from utils.response_util import process_response
+from utils.audit_util import audit
from utils.auth_util import auth_required
+from utils.response_util import process_response
from utils.role_util import platform_role_required_bool
from utils.security_settings_util import load_settings, save_settings
-from utils.audit_util import audit
security_router = APIRouter()
logger = logging.getLogger('doorman.gateway')
@@ -31,26 +31,34 @@ Response:
{}
"""
-@security_router.get('/security/settings',
- description='Get security settings',
- response_model=ResponseModel,
-)
+@security_router.get(
+ '/security/settings', description='Get security settings', response_model=ResponseModel
+)
async def get_security_settings(request: Request):
- request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4())
+ request_id = (
+ getattr(request.state, 'request_id', None)
+ or request.headers.get('X-Request-ID')
+ or str(uuid.uuid4())
+ )
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_security'):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='SEC001',
- error_message='You do not have permission to view security settings'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='SEC001',
+ error_message='You do not have permission to view security settings',
+ ).dict(),
+ 'rest',
+ )
settings = await load_settings()
try:
client_ip = request.client.host if request.client else None
@@ -63,44 +71,57 @@ async def get_security_settings(request: Request):
settings_with_mode = dict(settings)
try:
from utils.database import database
+
settings_with_mode['memory_only'] = bool(database.memory_only)
except Exception:
settings_with_mode['memory_only'] = False
try:
import os
+
env_val = os.getenv('LOCAL_HOST_IP_BYPASS')
locked = isinstance(env_val, str) and env_val.strip() != ''
if locked:
- settings_with_mode['allow_localhost_bypass'] = (env_val.lower() == 'true')
+ settings_with_mode['allow_localhost_bypass'] = env_val.lower() == 'true'
settings_with_mode['allow_localhost_bypass_locked'] = locked
except Exception:
settings_with_mode['allow_localhost_bypass_locked'] = False
try:
warnings = []
- if settings_with_mode.get('trust_x_forwarded_for') and not (settings_with_mode.get('xff_trusted_proxies') or []):
- warnings.append('Trust X-Forwarded-For is enabled, but no trusted proxies are configured. Set xff_trusted_proxies to avoid header spoofing.')
+ if settings_with_mode.get('trust_x_forwarded_for') and not (
+ settings_with_mode.get('xff_trusted_proxies') or []
+ ):
+ warnings.append(
+ 'Trust X-Forwarded-For is enabled, but no trusted proxies are configured. Set xff_trusted_proxies to avoid header spoofing.'
+ )
settings_with_mode['security_warnings'] = warnings
except Exception:
pass
settings_with_mode['client_ip'] = client_ip
settings_with_mode['client_ip_xff'] = client_ip_xff
- return process_response(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response=settings_with_mode
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ response=settings_with_mode,
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -110,57 +131,85 @@ Response:
{}
"""
-@security_router.put('/security/settings',
- description='Update security settings',
- response_model=ResponseModel,
+
+@security_router.put(
+ '/security/settings', description='Update security settings', response_model=ResponseModel
)
-async def update_security_settings(request: Request, body: Optional[SecuritySettingsModel] = None):
- request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4())
+async def update_security_settings(request: Request, body: SecuritySettingsModel | None = None):
+ request_id = (
+ getattr(request.state, 'request_id', None)
+ or request.headers.get('X-Request-ID')
+ or str(uuid.uuid4())
+ )
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_security'):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='SEC002',
- error_message='You do not have permission to update security settings'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='SEC002',
+ error_message='You do not have permission to update security settings',
+ ).dict(),
+ 'rest',
+ )
payload_dict = {}
try:
- payload_dict = (body.dict(exclude_none=True) if body is not None else {})
+ payload_dict = body.dict(exclude_none=True) if body is not None else {}
except Exception:
payload_dict = {}
new_settings = await save_settings(payload_dict)
- audit(request, actor=username, action='security.update_settings', target='security_settings', status='success', details={k: new_settings.get(k) for k in ('enable_auto_save','auto_save_frequency_seconds','dump_path')}, request_id=request_id)
+ audit(
+ request,
+ actor=username,
+ action='security.update_settings',
+ target='security_settings',
+ status='success',
+ details={
+ k: new_settings.get(k)
+ for k in ('enable_auto_save', 'auto_save_frequency_seconds', 'dump_path')
+ },
+ request_id=request_id,
+ )
settings_with_mode = dict(new_settings)
try:
from utils.database import database
+
settings_with_mode['memory_only'] = bool(database.memory_only)
except Exception:
settings_with_mode['memory_only'] = False
- return process_response(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- message='Security settings updated',
- response=settings_with_mode
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ message='Security settings updated',
+ response=settings_with_mode,
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -170,69 +219,103 @@ Response:
{}
"""
-@security_router.post('/security/restart',
+
+@security_router.post(
+ '/security/restart',
description='Schedule a safe gateway restart (PID-based)',
response_model=ResponseModel,
)
-
async def restart_gateway(request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_security'):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='SEC003',
- error_message='You do not have permission to restart the gateway'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='SEC003',
+ error_message='You do not have permission to restart the gateway',
+ ).dict(),
+ 'rest',
+ )
pid_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'doorman.pid')
pid_file = os.path.abspath(pid_file)
if not os.path.exists(pid_file):
- return process_response(ResponseModel(
- status_code=409,
- response_headers={'request_id': request_id},
- error_code='SEC004',
- error_message="Restart not supported: no PID file found (run using 'doorman start' or contact your admin)"
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=409,
+ response_headers={'request_id': request_id},
+ error_code='SEC004',
+ error_message="Restart not supported: no PID file found (run using 'doorman start' or contact your admin)",
+ ).dict(),
+ 'rest',
+ )
- doorman_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'doorman.py'))
+ doorman_path = os.path.abspath(
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'doorman.py')
+ )
try:
if os.name == 'nt':
- subprocess.Popen([sys.executable, doorman_path, 'restart'],
- creationflags=subprocess.CREATE_NEW_PROCESS_GROUP | subprocess.DETACHED_PROCESS,
- stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+ subprocess.Popen(
+ [sys.executable, doorman_path, 'restart'],
+ creationflags=subprocess.CREATE_NEW_PROCESS_GROUP | subprocess.DETACHED_PROCESS,
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ )
else:
- subprocess.Popen([sys.executable, doorman_path, 'restart'],
- stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
- preexec_fn=os.setsid)
+ subprocess.Popen(
+ [sys.executable, doorman_path, 'restart'],
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ preexec_fn=os.setsid,
+ )
except Exception as e:
logger.error(f'{request_id} | Failed to spawn restarter: {e}')
- return process_response(ResponseModel(
- status_code=500,
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='SEC005',
+ error_message='Failed to schedule restart',
+ ).dict(),
+ 'rest',
+ )
+ audit(
+ request,
+ actor=username,
+ action='security.restart',
+ target='gateway',
+ status='scheduled',
+ details=None,
+ request_id=request_id,
+ )
+ return process_response(
+ ResponseModel(
+ status_code=202,
response_headers={'request_id': request_id},
- error_code='SEC005',
- error_message='Failed to schedule restart'
- ).dict(), 'rest')
- audit(request, actor=username, action='security.restart', target='gateway', status='scheduled', details=None, request_id=request_id)
- return process_response(ResponseModel(
- status_code=202,
- response_headers={'request_id': request_id},
- message='Restart scheduled'
- ).dict(), 'rest')
+ message='Restart scheduled',
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/subscription_routes.py b/backend-services/routes/subscription_routes.py
index 1f8398e..2d65699 100644
--- a/backend-services/routes/subscription_routes.py
+++ b/backend-services/routes/subscription_routes.py
@@ -4,20 +4,21 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from fastapi import APIRouter, HTTPException, Depends, Request
-import uuid
-import time
import logging
+import time
+import uuid
+
+from fastapi import APIRouter, HTTPException, Request
from models.response_model import ResponseModel
-from services.subscription_service import SubscriptionService
-from utils.auth_util import auth_required
from models.subscribe_model import SubscribeModel
-from utils.group_util import group_required
-from utils.role_util import platform_role_required_bool
-from utils.database import api_collection
-from utils.response_util import respond_rest
+from services.subscription_service import SubscriptionService
from utils.audit_util import audit
+from utils.auth_util import auth_required
+from utils.database import api_collection
+from utils.group_util import group_required
+from utils.response_util import respond_rest
+from utils.role_util import platform_role_required_bool
subscription_router = APIRouter()
@@ -32,72 +33,84 @@ Response:
{}
"""
-@subscription_router.post('/subscribe',
+
+@subscription_router.post(
+ '/subscribe',
description='Subscribe to API',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Subscription created successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Subscription created successfully'}}
+ },
}
- }
+ },
)
-
async def subscribe_api(api_data: SubscribeModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
- if not await group_required(request, api_data.api_name + '/' + api_data.api_version, api_data.username):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- 'request_id': request_id
- },
- error_code='SUB007',
- error_message='You do not have the correct group access'
- ))
+ if not await group_required(
+ request, api_data.api_name + '/' + api_data.api_version, api_data.username
+ ):
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='SUB007',
+ error_message='You do not have the correct group access',
+ )
+ )
target_user = api_data.username or username
- logger.info(f'{request_id} | Actor: {username} | Action: subscribe | Target: {target_user} | API: {api_data.api_name}/{api_data.api_version}')
+ logger.info(
+ f'{request_id} | Actor: {username} | Action: subscribe | Target: {target_user} | API: {api_data.api_name}/{api_data.api_version}'
+ )
result = await SubscriptionService.subscribe(api_data, request_id)
actor_user = username
target_user = api_data.username or username
- audit(request, actor=actor_user, action='subscription.subscribe', target=f'{target_user}:{api_data.api_name}/{api_data.api_version}', status=result.get('status_code'), details=None, request_id=request_id)
+ audit(
+ request,
+ actor=actor_user,
+ action='subscription.subscribe',
+ target=f'{target_user}:{api_data.api_name}/{api_data.api_version}',
+ status=result.get('status_code'),
+ details=None,
+ request_id=request_id,
+ )
return respond_rest(result)
except HTTPException as e:
- return respond_rest(ResponseModel(
- status_code=e.status_code,
- response_headers={
- 'request_id': request_id
- },
- error_code='GEN001',
- error_message=e.detail
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={'request_id': request_id},
+ error_code='GEN001',
+ error_message=e.detail,
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Unsubscribe from API
@@ -107,72 +120,84 @@ Response:
{}
"""
-@subscription_router.post('/unsubscribe',
+
+@subscription_router.post(
+ '/unsubscribe',
description='Unsubscribe from API',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Subscription deleted successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Subscription deleted successfully'}}
+ },
}
- }
+ },
)
-
async def unsubscribe_api(api_data: SubscribeModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
- if not await group_required(request, api_data.api_name + '/' + api_data.api_version, api_data.username):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- 'request_id': request_id
- },
- error_code='SUB008',
- error_message='You do not have the correct group access'
- ))
+ if not await group_required(
+ request, api_data.api_name + '/' + api_data.api_version, api_data.username
+ ):
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='SUB008',
+ error_message='You do not have the correct group access',
+ )
+ )
target_user = api_data.username or username
- logger.info(f'{request_id} | Actor: {username} | Action: unsubscribe | Target: {target_user} | API: {api_data.api_name}/{api_data.api_version}')
+ logger.info(
+ f'{request_id} | Actor: {username} | Action: unsubscribe | Target: {target_user} | API: {api_data.api_name}/{api_data.api_version}'
+ )
result = await SubscriptionService.unsubscribe(api_data, request_id)
actor_user = username
target_user = api_data.username or username
- audit(request, actor=actor_user, action='subscription.unsubscribe', target=f'{target_user}:{api_data.api_name}/{api_data.api_version}', status=result.get('status_code'), details=None, request_id=request_id)
+ audit(
+ request,
+ actor=actor_user,
+ action='subscription.unsubscribe',
+ target=f'{target_user}:{api_data.api_name}/{api_data.api_version}',
+ status=result.get('status_code'),
+ details=None,
+ request_id=request_id,
+ )
return respond_rest(result)
except HTTPException as e:
- return respond_rest(ResponseModel(
- status_code=e.status_code,
- response_headers={
- 'request_id': request_id
- },
- error_code='GEN002',
- error_message=e.detail
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={'request_id': request_id},
+ error_code='GEN002',
+ error_message=e.detail,
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Get current user's subscriptions
@@ -182,49 +207,44 @@ Response:
{}
"""
-@subscription_router.get('/subscriptions',
+
+@subscription_router.get(
+ '/subscriptions',
description="Get current user's subscriptions",
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'apis': [
- 'customer/v1',
- 'orders/v1'
- ]
- }
- }
- }
+ 'content': {'application/json': {'example': {'apis': ['customer/v1', 'orders/v1']}}},
}
- }
+ },
)
-
async def subscriptions_for_current_user(request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
return respond_rest(await SubscriptionService.get_user_subscriptions(username, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Get user's subscriptions
@@ -234,49 +254,44 @@ Response:
{}
"""
-@subscription_router.get('/subscriptions/{user_id}',
+
+@subscription_router.get(
+ '/subscriptions/{user_id}',
description="Get user's subscriptions",
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'apis': [
- 'customer/v1',
- 'orders/v1'
- ]
- }
- }
- }
+ 'content': {'application/json': {'example': {'apis': ['customer/v1', 'orders/v1']}}},
}
- }
+ },
)
-
async def subscriptions_for_user_by_id(user_id: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
return respond_rest(await SubscriptionService.get_user_subscriptions(user_id, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- 'request_id': request_id
- },
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -286,17 +301,21 @@ Response:
{}
"""
-@subscription_router.get('/available-apis/{username}',
- description='List available APIs for subscription based on permissions and groups',
- response_model=ResponseModel)
+@subscription_router.get(
+ '/available-apis/{username}',
+ description='List available APIs for subscription based on permissions and groups',
+ response_model=ResponseModel,
+)
async def available_apis(username: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
actor = payload.get('sub')
- logger.info(f'{request_id} | Username: {actor} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {actor} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
accesses = (payload or {}).get('accesses') or {}
@@ -307,19 +326,30 @@ async def available_apis(username: str, request: Request):
cursor = api_collection.find().sort('api_name', 1)
apis = cursor.to_list(length=None)
for a in apis:
- if a.get('_id'): del a['_id']
+ if a.get('_id'):
+ del a['_id']
if can_manage:
- data = [{ 'api_name': a.get('api_name'), 'api_version': a.get('api_version'), 'api_description': a.get('api_description') } for a in apis]
+ data = [
+ {
+ 'api_name': a.get('api_name'),
+ 'api_version': a.get('api_version'),
+ 'api_description': a.get('api_description'),
+ }
+ for a in apis
+ ]
return respond_rest(ResponseModel(status_code=200, response={'apis': data}))
if username != actor:
- return respond_rest(ResponseModel(
- status_code=403,
- error_code='SUB009',
- error_message='You do not have permission to view available APIs for this user'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ error_code='SUB009',
+ error_message='You do not have permission to view available APIs for this user',
+ )
+ )
try:
from services.user_service import UserService
+
user = await UserService.get_user_by_username_helper(actor)
user_groups = set(user.get('groups') or [])
except Exception:
@@ -328,16 +358,24 @@ async def available_apis(username: str, request: Request):
for a in apis:
api_groups = set(a.get('api_allowed_groups') or [])
if user_groups.intersection(api_groups):
- allowed.append({ 'api_name': a.get('api_name'), 'api_version': a.get('api_version'), 'api_description': a.get('api_description') })
+ allowed.append(
+ {
+ 'api_name': a.get('api_name'),
+ 'api_version': a.get('api_version'),
+ 'api_description': a.get('api_description'),
+ }
+ )
return respond_rest(ResponseModel(status_code=200, response={'apis': allowed}))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='GTW999',
- error_message='An unexpected error occurred'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='GTW999',
+ error_message='An unexpected error occurred',
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/tier_routes.py b/backend-services/routes/tier_routes.py
index b9cbafe..dcfb843 100644
--- a/backend-services/routes/tier_routes.py
+++ b/backend-services/routes/tier_routes.py
@@ -5,20 +5,20 @@ FastAPI routes for managing tiers, plans, and user assignments.
"""
import logging
-from typing import List, Optional
-from datetime import datetime
-from fastapi import APIRouter, HTTPException, Depends, status, Query, Request
-from pydantic import BaseModel, Field
-import uuid
import time
+import uuid
+from datetime import datetime
-from models.rate_limit_models import Tier, TierLimits, TierName, UserTierAssignment
+from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
+from pydantic import BaseModel, Field
+
+from models.rate_limit_models import Tier, TierLimits, TierName
from models.response_model import ResponseModel
from services.tier_service import TierService, get_tier_service
-from utils.database_async import async_database
from utils.auth_util import auth_required
-from utils.role_util import platform_role_required_bool
+from utils.database_async import async_database
from utils.response_util import respond_rest
+from utils.role_util import platform_role_required_bool
logger = logging.getLogger(__name__)
@@ -29,69 +29,76 @@ tier_router = APIRouter()
# REQUEST/RESPONSE MODELS
# ============================================================================
+
class TierLimitsRequest(BaseModel):
"""Request model for tier limits"""
- requests_per_second: Optional[int] = None
- requests_per_minute: Optional[int] = None
- requests_per_hour: Optional[int] = None
- requests_per_day: Optional[int] = None
- requests_per_month: Optional[int] = None
+
+ requests_per_second: int | None = None
+ requests_per_minute: int | None = None
+ requests_per_hour: int | None = None
+ requests_per_day: int | None = None
+ requests_per_month: int | None = None
burst_per_second: int = 0
burst_per_minute: int = 0
burst_per_hour: int = 0
- monthly_request_quota: Optional[int] = None
- daily_request_quota: Optional[int] = None
- monthly_bandwidth_quota: Optional[int] = None
+ monthly_request_quota: int | None = None
+ daily_request_quota: int | None = None
+ monthly_bandwidth_quota: int | None = None
enable_throttling: bool = False
max_queue_time_ms: int = 5000
class TierCreateRequest(BaseModel):
"""Request model for creating a tier"""
- tier_id: str = Field(..., description="Unique tier identifier")
- name: str = Field(..., description="Tier name (free, pro, enterprise, custom)")
- display_name: str = Field(..., description="Display name for tier")
- description: Optional[str] = None
+
+ tier_id: str = Field(..., description='Unique tier identifier')
+ name: str = Field(..., description='Tier name (free, pro, enterprise, custom)')
+ display_name: str = Field(..., description='Display name for tier')
+ description: str | None = None
limits: TierLimitsRequest
- price_monthly: Optional[float] = None
- price_yearly: Optional[float] = None
- features: List[str] = []
+ price_monthly: float | None = None
+ price_yearly: float | None = None
+ features: list[str] = []
is_default: bool = False
enabled: bool = True
class TierUpdateRequest(BaseModel):
"""Request model for updating a tier"""
- display_name: Optional[str] = None
- description: Optional[str] = None
- limits: Optional[TierLimitsRequest] = None
- price_monthly: Optional[float] = None
- price_yearly: Optional[float] = None
- features: Optional[List[str]] = None
- is_default: Optional[bool] = None
- enabled: Optional[bool] = None
+
+ display_name: str | None = None
+ description: str | None = None
+ limits: TierLimitsRequest | None = None
+ price_monthly: float | None = None
+ price_yearly: float | None = None
+ features: list[str] | None = None
+ is_default: bool | None = None
+ enabled: bool | None = None
class UserAssignmentRequest(BaseModel):
"""Request model for assigning user to tier"""
+
user_id: str
tier_id: str
- effective_from: Optional[datetime] = None
- effective_until: Optional[datetime] = None
- override_limits: Optional[TierLimitsRequest] = None
- notes: Optional[str] = None
+ effective_from: datetime | None = None
+ effective_until: datetime | None = None
+ override_limits: TierLimitsRequest | None = None
+ notes: str | None = None
class TierUpgradeRequest(BaseModel):
"""Request model for tier upgrade"""
+
user_id: str
new_tier_id: str
immediate: bool = True
- scheduled_date: Optional[datetime] = None
+ scheduled_date: datetime | None = None
class TierDowngradeRequest(BaseModel):
"""Request model for tier downgrade"""
+
user_id: str
new_tier_id: str
grace_period_days: int = 0
@@ -99,6 +106,7 @@ class TierDowngradeRequest(BaseModel):
class TemporaryUpgradeRequest(BaseModel):
"""Request model for temporary tier upgrade"""
+
user_id: str
temp_tier_id: str
duration_days: int
@@ -106,24 +114,26 @@ class TemporaryUpgradeRequest(BaseModel):
class TierResponse(BaseModel):
"""Response model for tier"""
+
tier_id: str
name: str
display_name: str
- description: Optional[str]
+ description: str | None
limits: dict
- price_monthly: Optional[float]
- price_yearly: Optional[float]
- features: List[str]
+ price_monthly: float | None
+ price_yearly: float | None
+ features: list[str]
is_default: bool
enabled: bool
- created_at: Optional[str]
- updated_at: Optional[str]
+ created_at: str | None
+ updated_at: str | None
# ============================================================================
# DEPENDENCY INJECTION
# ============================================================================
+
async def get_tier_service_dep() -> TierService:
"""Dependency to get tier service"""
return get_tier_service(async_database.db)
@@ -133,14 +143,14 @@ async def get_tier_service_dep() -> TierService:
# TIER CRUD ENDPOINTS
# ============================================================================
-@tier_router.post("/", response_model=TierResponse, status_code=status.HTTP_201_CREATED)
+
+@tier_router.post('/', response_model=TierResponse, status_code=status.HTTP_201_CREATED)
async def create_tier(
- request: TierCreateRequest,
- tier_service: TierService = Depends(get_tier_service_dep)
+ request: TierCreateRequest, tier_service: TierService = Depends(get_tier_service_dep)
):
"""
Create a new tier
-
+
Requires admin permissions.
"""
try:
@@ -155,128 +165,134 @@ async def create_tier(
price_yearly=request.price_yearly,
features=request.features,
is_default=request.is_default,
- enabled=request.enabled
+ enabled=request.enabled,
)
-
+
created_tier = await tier_service.create_tier(tier)
-
+
return TierResponse(
**created_tier.to_dict(),
created_at=created_tier.created_at.isoformat() if created_tier.created_at else None,
- updated_at=created_tier.updated_at.isoformat() if created_tier.updated_at else None
+ updated_at=created_tier.updated_at.isoformat() if created_tier.updated_at else None,
)
-
+
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
- logger.error(f"Error creating tier: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create tier")
+ logger.error(f'Error creating tier: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to create tier'
+ )
-@tier_router.get("/")
+@tier_router.get('/')
async def list_tiers(
request: Request,
- enabled_only: bool = Query(False, description="Only return enabled tiers"),
- search: Optional[str] = Query(None, description="Search tiers by name or description"),
+ enabled_only: bool = Query(False, description='Only return enabled tiers'),
+ search: str | None = Query(None, description='Search tiers by name or description'),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
- tier_service: TierService = Depends(get_tier_service_dep)
+ tier_service: TierService = Depends(get_tier_service_dep),
):
"""
List all tiers
-
+
Can filter by enabled status and paginate results.
"""
request_id = str(uuid.uuid4())
start_time = time.time()
-
+
try:
payload = await auth_required(request)
username = payload.get('sub')
-
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
- logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
-
- if not await platform_role_required_bool(username, 'manage_tiers'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='TIER001',
- error_message='You do not have permission to manage tiers'
- ))
-
- tiers = await tier_service.list_tiers(
- enabled_only=enabled_only,
- search_term=search,
- skip=skip,
- limit=limit
+
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
)
-
+ logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
+
+ if not await platform_role_required_bool(username, 'manage_tiers'):
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='TIER001',
+ error_message='You do not have permission to manage tiers',
+ )
+ )
+
+ tiers = await tier_service.list_tiers(
+ enabled_only=enabled_only, search_term=search, skip=skip, limit=limit
+ )
+
tier_list = [
TierResponse(
**tier.to_dict(),
created_at=tier.created_at.isoformat() if tier.created_at else None,
- updated_at=tier.updated_at.isoformat() if tier.updated_at else None
+ updated_at=tier.updated_at.isoformat() if tier.updated_at else None,
).dict()
for tier in tiers
]
-
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response=tier_list
- ))
-
+
+ return respond_rest(
+ ResponseModel(
+ status_code=200, response_headers={'request_id': request_id}, response=tier_list
+ )
+ )
+
except Exception as e:
logger.error(f'{request_id} | Error listing tiers: {e}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='TIER999',
- error_message='Failed to list tiers'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='TIER999',
+ error_message='Failed to list tiers',
+ )
+ )
+
finally:
end_time = time.time()
logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms')
-@tier_router.get("/{tier_id}", response_model=TierResponse)
-async def get_tier(
- tier_id: str,
- tier_service: TierService = Depends(get_tier_service_dep)
-):
+@tier_router.get('/{tier_id}', response_model=TierResponse)
+async def get_tier(tier_id: str, tier_service: TierService = Depends(get_tier_service_dep)):
"""
Get a specific tier by ID
"""
try:
tier = await tier_service.get_tier(tier_id)
-
+
if not tier:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Tier {tier_id} not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail=f'Tier {tier_id} not found'
+ )
+
return TierResponse(
**tier.to_dict(),
created_at=tier.created_at.isoformat() if tier.created_at else None,
- updated_at=tier.updated_at.isoformat() if tier.updated_at else None
+ updated_at=tier.updated_at.isoformat() if tier.updated_at else None,
)
-
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error getting tier: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get tier")
+ logger.error(f'Error getting tier: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get tier'
+ )
-@tier_router.put("/{tier_id}", response_model=TierResponse)
+@tier_router.put('/{tier_id}', response_model=TierResponse)
async def update_tier(
tier_id: str,
request: TierUpdateRequest,
- tier_service: TierService = Depends(get_tier_service_dep)
+ tier_service: TierService = Depends(get_tier_service_dep),
):
"""
Update a tier
-
+
Requires admin permissions.
"""
try:
@@ -298,63 +314,68 @@ async def update_tier(
updates['is_default'] = request.is_default
if request.enabled is not None:
updates['enabled'] = request.enabled
-
+
updated_tier = await tier_service.update_tier(tier_id, updates)
-
+
if not updated_tier:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Tier {tier_id} not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail=f'Tier {tier_id} not found'
+ )
+
return TierResponse(
**updated_tier.to_dict(),
created_at=updated_tier.created_at.isoformat() if updated_tier.created_at else None,
- updated_at=updated_tier.updated_at.isoformat() if updated_tier.updated_at else None
+ updated_at=updated_tier.updated_at.isoformat() if updated_tier.updated_at else None,
)
-
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error updating tier: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update tier")
+ logger.error(f'Error updating tier: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to update tier'
+ )
-@tier_router.delete("/{tier_id}", status_code=status.HTTP_204_NO_CONTENT)
-async def delete_tier(
- tier_id: str,
- tier_service: TierService = Depends(get_tier_service_dep)
-):
+@tier_router.delete('/{tier_id}', status_code=status.HTTP_204_NO_CONTENT)
+async def delete_tier(tier_id: str, tier_service: TierService = Depends(get_tier_service_dep)):
"""
Delete a tier
-
+
Requires admin permissions.
Cannot delete tier if users are assigned to it.
"""
try:
deleted = await tier_service.delete_tier(tier_id)
-
+
if not deleted:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Tier {tier_id} not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail=f'Tier {tier_id} not found'
+ )
+
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error deleting tier: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to delete tier")
+ logger.error(f'Error deleting tier: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to delete tier'
+ )
# ============================================================================
# USER ASSIGNMENT ENDPOINTS
# ============================================================================
-@tier_router.post("/assignments", status_code=status.HTTP_201_CREATED)
+
+@tier_router.post('/assignments', status_code=status.HTTP_201_CREATED)
async def assign_user_to_tier(
- request: UserAssignmentRequest,
- tier_service: TierService = Depends(get_tier_service_dep)
+ request: UserAssignmentRequest, tier_service: TierService = Depends(get_tier_service_dep)
):
"""
Assign a user to a tier
-
+
Requires admin permissions.
"""
try:
@@ -362,134 +383,147 @@ async def assign_user_to_tier(
override_limits = None
if request.override_limits:
override_limits = TierLimits(**request.override_limits.dict())
-
+
assignment = await tier_service.assign_user_to_tier(
user_id=request.user_id,
tier_id=request.tier_id,
effective_from=request.effective_from,
effective_until=request.effective_until,
override_limits=override_limits,
- notes=request.notes
+ notes=request.notes,
)
-
+
return assignment.to_dict()
-
+
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
- logger.error(f"Error assigning user to tier: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to assign user")
+ logger.error(f'Error assigning user to tier: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to assign user'
+ )
-@tier_router.get("/assignments/{user_id}")
+@tier_router.get('/assignments/{user_id}')
async def get_user_assignment(
- user_id: str,
- tier_service: TierService = Depends(get_tier_service_dep)
+ user_id: str, tier_service: TierService = Depends(get_tier_service_dep)
):
"""
Get a user's tier assignment
"""
try:
assignment = await tier_service.get_user_assignment(user_id)
-
+
if not assignment:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"No assignment found for user {user_id}")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f'No assignment found for user {user_id}',
+ )
+
return assignment.to_dict()
-
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error getting user assignment: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get assignment")
+ logger.error(f'Error getting user assignment: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get assignment'
+ )
-@tier_router.get("/assignments/{user_id}/tier", response_model=TierResponse)
-async def get_user_tier(
- user_id: str,
- tier_service: TierService = Depends(get_tier_service_dep)
-):
+@tier_router.get('/assignments/{user_id}/tier', response_model=TierResponse)
+async def get_user_tier(user_id: str, tier_service: TierService = Depends(get_tier_service_dep)):
"""
Get the effective tier for a user
-
+
Considers effective dates and returns current tier.
"""
try:
tier = await tier_service.get_user_tier(user_id)
-
+
if not tier:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"No tier found for user {user_id}")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail=f'No tier found for user {user_id}'
+ )
+
return TierResponse(
**tier.to_dict(),
created_at=tier.created_at.isoformat() if tier.created_at else None,
- updated_at=tier.updated_at.isoformat() if tier.updated_at else None
+ updated_at=tier.updated_at.isoformat() if tier.updated_at else None,
)
-
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error getting user tier: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get user tier")
+ logger.error(f'Error getting user tier: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to get user tier'
+ )
-@tier_router.delete("/assignments/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
+@tier_router.delete('/assignments/{user_id}', status_code=status.HTTP_204_NO_CONTENT)
async def remove_user_assignment(
- user_id: str,
- tier_service: TierService = Depends(get_tier_service_dep)
+ user_id: str, tier_service: TierService = Depends(get_tier_service_dep)
):
"""
Remove a user's tier assignment
-
+
Requires admin permissions.
"""
try:
removed = await tier_service.remove_user_assignment(user_id)
-
+
if not removed:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"No assignment found for user {user_id}")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f'No assignment found for user {user_id}',
+ )
+
except HTTPException:
raise
except Exception as e:
- logger.error(f"Error removing user assignment: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to remove assignment")
+ logger.error(f'Error removing user assignment: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to remove assignment'
+ )
-@tier_router.get("/{tier_id}/users")
+@tier_router.get('/{tier_id}/users')
async def list_users_in_tier(
tier_id: str,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
- tier_service: TierService = Depends(get_tier_service_dep)
+ tier_service: TierService = Depends(get_tier_service_dep),
):
"""
List all users assigned to a tier
-
+
Requires admin permissions.
"""
try:
assignments = await tier_service.list_users_in_tier(tier_id, skip=skip, limit=limit)
-
+
return [assignment.to_dict() for assignment in assignments]
-
+
except Exception as e:
- logger.error(f"Error listing users in tier: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to list users")
+ logger.error(f'Error listing users in tier: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to list users'
+ )
# ============================================================================
# TIER UPGRADE/DOWNGRADE ENDPOINTS
# ============================================================================
-@tier_router.post("/upgrade")
+
+@tier_router.post('/upgrade')
async def upgrade_user_tier(
- request: TierUpgradeRequest,
- tier_service: TierService = Depends(get_tier_service_dep)
+ request: TierUpgradeRequest, tier_service: TierService = Depends(get_tier_service_dep)
):
"""
Upgrade a user to a higher tier
-
+
Requires admin permissions.
"""
try:
@@ -497,78 +531,83 @@ async def upgrade_user_tier(
user_id=request.user_id,
new_tier_id=request.new_tier_id,
immediate=request.immediate,
- scheduled_date=request.scheduled_date
+ scheduled_date=request.scheduled_date,
)
-
+
return assignment.to_dict()
-
+
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
- logger.error(f"Error upgrading user tier: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to upgrade tier")
+ logger.error(f'Error upgrading user tier: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to upgrade tier'
+ )
-@tier_router.post("/downgrade")
+@tier_router.post('/downgrade')
async def downgrade_user_tier(
- request: TierDowngradeRequest,
- tier_service: TierService = Depends(get_tier_service_dep)
+ request: TierDowngradeRequest, tier_service: TierService = Depends(get_tier_service_dep)
):
"""
Downgrade a user to a lower tier
-
+
Requires admin permissions.
"""
try:
assignment = await tier_service.downgrade_user_tier(
user_id=request.user_id,
new_tier_id=request.new_tier_id,
- grace_period_days=request.grace_period_days
+ grace_period_days=request.grace_period_days,
)
-
+
return assignment.to_dict()
-
+
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
- logger.error(f"Error downgrading user tier: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to downgrade tier")
+ logger.error(f'Error downgrading user tier: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to downgrade tier'
+ )
-@tier_router.post("/temporary-upgrade")
+@tier_router.post('/temporary-upgrade')
async def temporary_tier_upgrade(
- request: TemporaryUpgradeRequest,
- tier_service: TierService = Depends(get_tier_service_dep)
+ request: TemporaryUpgradeRequest, tier_service: TierService = Depends(get_tier_service_dep)
):
"""
Temporarily upgrade a user to a higher tier
-
+
Requires admin permissions.
"""
try:
assignment = await tier_service.temporary_tier_upgrade(
user_id=request.user_id,
temp_tier_id=request.temp_tier_id,
- duration_days=request.duration_days
+ duration_days=request.duration_days,
)
-
+
return assignment.to_dict()
-
+
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
- logger.error(f"Error creating temporary upgrade: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create temporary upgrade")
+ logger.error(f'Error creating temporary upgrade: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail='Failed to create temporary upgrade',
+ )
# ============================================================================
# TIER COMPARISON & ANALYTICS ENDPOINTS
# ============================================================================
-@tier_router.post("/compare")
+
+@tier_router.post('/compare')
async def compare_tiers(
- tier_ids: List[str],
- tier_service: TierService = Depends(get_tier_service_dep)
+ tier_ids: list[str], tier_service: TierService = Depends(get_tier_service_dep)
):
"""
Compare multiple tiers side-by-side
@@ -576,108 +615,119 @@ async def compare_tiers(
try:
comparison = await tier_service.compare_tiers(tier_ids)
return comparison
-
+
except Exception as e:
- logger.error(f"Error comparing tiers: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to compare tiers")
+ logger.error(f'Error comparing tiers: {e}')
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to compare tiers'
+ )
-@tier_router.get("/statistics/all")
+@tier_router.get('/statistics/all')
async def get_all_tier_statistics(
- request: Request,
- tier_service: TierService = Depends(get_tier_service_dep)
+ request: Request, tier_service: TierService = Depends(get_tier_service_dep)
):
"""
Get statistics for all tiers
-
+
Requires admin permissions.
"""
request_id = str(uuid.uuid4())
start_time = time.time()
-
+
try:
payload = await auth_required(request)
username = payload.get('sub')
-
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
-
+
if not await platform_role_required_bool(username, 'manage_tiers'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='TIER001',
- error_message='You do not have permission to manage tiers'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='TIER001',
+ error_message='You do not have permission to manage tiers',
+ )
+ )
+
stats = await tier_service.get_all_tier_statistics()
-
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response=stats
- ))
-
+
+ return respond_rest(
+ ResponseModel(
+ status_code=200, response_headers={'request_id': request_id}, response=stats
+ )
+ )
+
except Exception as e:
logger.error(f'{request_id} | Error getting all tier statistics: {e}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='TIER999',
- error_message='Failed to get statistics'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='TIER999',
+ error_message='Failed to get statistics',
+ )
+ )
+
finally:
end_time = time.time()
logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms')
-@tier_router.get("/{tier_id}/statistics")
+@tier_router.get('/{tier_id}/statistics')
async def get_tier_statistics(
- request: Request,
- tier_id: str,
- tier_service: TierService = Depends(get_tier_service_dep)
+ request: Request, tier_id: str, tier_service: TierService = Depends(get_tier_service_dep)
):
"""
Get statistics for a tier
-
+
Requires admin permissions.
"""
request_id = str(uuid.uuid4())
start_time = time.time()
-
+
try:
payload = await auth_required(request)
username = payload.get('sub')
-
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
-
+
if not await platform_role_required_bool(username, 'manage_tiers'):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='TIER001',
- error_message='You do not have permission to manage tiers'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='TIER001',
+ error_message='You do not have permission to manage tiers',
+ )
+ )
+
stats = await tier_service.get_tier_statistics(tier_id)
-
- return respond_rest(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response=stats
- ))
-
+
+ return respond_rest(
+ ResponseModel(
+ status_code=200, response_headers={'request_id': request_id}, response=stats
+ )
+ )
+
except Exception as e:
logger.error(f'{request_id} | Error getting tier statistics: {e}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='TIER999',
- error_message='Failed to get statistics'
- ))
-
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='TIER999',
+ error_message='Failed to get statistics',
+ )
+ )
+
finally:
end_time = time.time()
logger.info(f'{request_id} | Total time: {(end_time - start_time) * 1000:.2f}ms')
diff --git a/backend-services/routes/tools_routes.py b/backend-services/routes/tools_routes.py
index 5979308..67e5652 100644
--- a/backend-services/routes/tools_routes.py
+++ b/backend-services/routes/tools_routes.py
@@ -2,30 +2,38 @@
Tools and diagnostics routes (e.g., CORS checker).
"""
+import logging
+import os
+import time
+import uuid
+from typing import Any
+
from fastapi import APIRouter, Request
from pydantic import BaseModel, Field
-from typing import List, Optional, Dict, Any
-import os
-import uuid
-import time
-import logging
from models.response_model import ResponseModel
-from utils.response_util import process_response
-from utils.auth_util import auth_required
-from utils.role_util import platform_role_required_bool
from utils import chaos_util
+from utils.auth_util import auth_required
+from utils.response_util import process_response
+from utils.role_util import platform_role_required_bool
tools_router = APIRouter()
logger = logging.getLogger('doorman.gateway')
+
class CorsCheckRequest(BaseModel):
origin: str = Field(..., description='Origin to evaluate, e.g., https://localhost:3000')
method: str = Field(..., description='Intended request method, e.g., GET/POST/PUT')
- request_headers: Optional[List[str]] = Field(default=None, description='Requested headers from Access-Control-Request-Headers')
- with_credentials: Optional[bool] = Field(default=None, description='Whether credentials will be sent; defaults to ALLOW_CREDENTIALS env if omitted')
+ request_headers: list[str] | None = Field(
+ default=None, description='Requested headers from Access-Control-Request-Headers'
+ )
+ with_credentials: bool | None = Field(
+ default=None,
+ description='Whether credentials will be sent; defaults to ALLOW_CREDENTIALS env if omitted',
+ )
-def _compute_cors_config() -> Dict[str, Any]:
+
+def _compute_cors_config() -> dict[str, Any]:
origins_env = os.getenv('ALLOWED_ORIGINS', 'http://localhost:3000')
if not (origins_env or '').strip():
@@ -67,6 +75,7 @@ def _compute_cors_config() -> Dict[str, Any]:
'cors_strict': cors_strict,
}
+
"""
Endpoint
@@ -76,67 +85,93 @@ Response:
{}
"""
-@tools_router.post('/cors/check', description='Simulate CORS preflight/actual decisions against current gateway config', response_model=ResponseModel)
+@tools_router.post(
+ '/cors/check',
+ description='Simulate CORS preflight/actual decisions against current gateway config',
+ response_model=ResponseModel,
+)
async def cors_check(request: Request, body: CorsCheckRequest):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, 'manage_security'):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='TLS001',
- error_message='You do not have permission to use tools'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='TLS001',
+ error_message='You do not have permission to use tools',
+ ).dict(),
+ 'rest',
+ )
cfg = _compute_cors_config()
origin = (body.origin or '').strip()
method = (body.method or '').strip().upper()
requested_headers = [h.strip() for h in (body.request_headers or []) if h.strip()]
- with_credentials = cfg['credentials'] if body.with_credentials is None else bool(body.with_credentials)
+ with_credentials = (
+ cfg['credentials'] if body.with_credentials is None else bool(body.with_credentials)
+ )
- origin_allowed = origin in cfg['safe_origins'] or (not cfg['cors_strict'] and '*' in cfg['origins'])
+ origin_allowed = origin in cfg['safe_origins'] or (
+ not cfg['cors_strict'] and '*' in cfg['origins']
+ )
method_allowed = method in cfg['methods']
allowed_headers_lower = {h.lower() for h in cfg['headers']}
requested_lower = [h.lower() for h in requested_headers]
- headers_not_allowed = [h for h in requested_headers if h.lower() not in allowed_headers_lower]
+ headers_not_allowed = [
+ h for h in requested_headers if h.lower() not in allowed_headers_lower
+ ]
headers_allowed = len(headers_not_allowed) == 0
preflight_allowed = origin_allowed and method_allowed and headers_allowed
preflight_headers = {
-
'Access-Control-Allow-Origin': origin if origin_allowed else None,
'Access-Control-Allow-Methods': ', '.join(cfg['methods']),
- 'Access-Control-Allow-Headers': ', '.join(cfg['headers']) if requested_headers else ', '.join(cfg['headers']),
- 'Access-Control-Allow-Credentials': 'true' if with_credentials and cfg['credentials'] else 'false',
+ 'Access-Control-Allow-Headers': ', '.join(cfg['headers'])
+ if requested_headers
+ else ', '.join(cfg['headers']),
+ 'Access-Control-Allow-Credentials': 'true'
+ if with_credentials and cfg['credentials']
+ else 'false',
'Vary': 'Origin',
}
actual_allowed = origin_allowed
actual_headers = {
'Access-Control-Allow-Origin': origin if origin_allowed else None,
- 'Access-Control-Allow-Credentials': 'true' if with_credentials and cfg['credentials'] else 'false',
+ 'Access-Control-Allow-Credentials': 'true'
+ if with_credentials and cfg['credentials']
+ else 'false',
'Vary': 'Origin',
}
- notes: List[str] = []
+ notes: list[str] = []
if cfg['credentials'] and ('*' in cfg['origins']) and not cfg['cors_strict']:
- notes.append('Wildcard origins with credentials can be rejected by browsers; prefer explicit origins or set CORS_STRICT=true.')
+ notes.append(
+ 'Wildcard origins with credentials can be rejected by browsers; prefer explicit origins or set CORS_STRICT=true.'
+ )
if any(h == '*' for h in os.getenv('ALLOW_HEADERS', '*').split(',')):
- notes.append("ALLOW_HEADERS='*' replaced with a conservative default set to satisfy credentialed requests.")
+ notes.append(
+ "ALLOW_HEADERS='*' replaced with a conservative default set to satisfy credentialed requests."
+ )
if not origin_allowed:
notes.append('Origin is not allowed based on current configuration.')
if not method_allowed:
notes.append('Requested method is not in ALLOW_METHODS.')
if not headers_allowed and headers_not_allowed:
- notes.append(f"Some requested headers are not allowed: {', '.join(headers_not_allowed)}")
+ notes.append(
+ f'Some requested headers are not allowed: {", ".join(headers_not_allowed)}'
+ )
response_payload = {
'config': {
@@ -162,37 +197,48 @@ async def cors_check(request: Request, body: CorsCheckRequest):
'not_allowed_headers': headers_not_allowed,
'response_headers': preflight_headers,
},
- 'actual': {
- 'allowed': actual_allowed,
- 'response_headers': actual_headers,
- },
+ 'actual': {'allowed': actual_allowed, 'response_headers': actual_headers},
'notes': notes,
}
- return process_response(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response=response_payload
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ response=response_payload,
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='TLS999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='TLS999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
class ChaosToggleRequest(BaseModel):
backend: str = Field(..., description='Backend to toggle (redis|mongo)')
enabled: bool = Field(..., description='Enable or disable outage simulation')
- duration_ms: Optional[int] = Field(default=None, description='Optional duration for outage before auto-disable')
+ duration_ms: int | None = Field(
+ default=None, description='Optional duration for outage before auto-disable'
+ )
-@tools_router.post('/chaos/toggle', description='Toggle simulated backend outages (redis|mongo)', response_model=ResponseModel)
+
+@tools_router.post(
+ '/chaos/toggle',
+ description='Toggle simulated backend outages (redis|mongo)',
+ response_model=ResponseModel,
+)
async def chaos_toggle(request: Request, body: ChaosToggleRequest):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
@@ -200,42 +246,57 @@ async def chaos_toggle(request: Request, body: ChaosToggleRequest):
payload = await auth_required(request)
username = payload.get('sub')
if not await platform_role_required_bool(username, 'manage_gateway'):
- return process_response(ResponseModel(
- status_code=403,
- response_headers={'request_id': request_id},
- error_code='TLS001',
- error_message='You do not have permission to use tools'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='TLS001',
+ error_message='You do not have permission to use tools',
+ ).dict(),
+ 'rest',
+ )
backend = (body.backend or '').strip().lower()
if backend not in ('redis', 'mongo'):
- return process_response(ResponseModel(
- status_code=400,
- response_headers={'request_id': request_id},
- error_code='TLS002',
- error_message='backend must be redis or mongo'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=400,
+ response_headers={'request_id': request_id},
+ error_code='TLS002',
+ error_message='backend must be redis or mongo',
+ ).dict(),
+ 'rest',
+ )
if body.duration_ms and int(body.duration_ms) > 0:
chaos_util.enable_for(backend, int(body.duration_ms))
else:
chaos_util.enable(backend, bool(body.enabled))
- return process_response(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response={'backend': backend, 'enabled': chaos_util.should_fail(backend)}
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=200,
+ response_headers={'request_id': request_id},
+ response={'backend': backend, 'enabled': chaos_util.should_fail(backend)},
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='TLS999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='TLS999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
-@tools_router.get('/chaos/stats', description='Get chaos simulation stats', response_model=ResponseModel)
+
+@tools_router.get(
+ '/chaos/stats', description='Get chaos simulation stats', response_model=ResponseModel
+)
async def chaos_stats(request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
@@ -243,25 +304,34 @@ async def chaos_stats(request: Request):
payload = await auth_required(request)
username = payload.get('sub')
if not await platform_role_required_bool(username, 'manage_gateway'):
- return process_response(ResponseModel(
- status_code=403,
+ return process_response(
+ ResponseModel(
+ status_code=403,
+ response_headers={'request_id': request_id},
+ error_code='TLS001',
+ error_message='You do not have permission to use tools',
+ ).dict(),
+ 'rest',
+ )
+ return process_response(
+ ResponseModel(
+ status_code=200,
response_headers={'request_id': request_id},
- error_code='TLS001',
- error_message='You do not have permission to use tools'
- ).dict(), 'rest')
- return process_response(ResponseModel(
- status_code=200,
- response_headers={'request_id': request_id},
- response=chaos_util.stats()
- ).dict(), 'rest')
+ response=chaos_util.stats(),
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={'request_id': request_id},
- error_code='TLS999',
- error_message='An unexpected error occurred'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={'request_id': request_id},
+ error_code='TLS999',
+ error_message='An unexpected error occurred',
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
diff --git a/backend-services/routes/user_routes.py b/backend-services/routes/user_routes.py
index 8dd41cb..3d1e100 100644
--- a/backend-services/routes/user_routes.py
+++ b/backend-services/routes/user_routes.py
@@ -4,41 +4,43 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import List
-import os
-from fastapi import APIRouter, Request, HTTPException
-import uuid
-import time
import logging
+import os
+import time
+import uuid
+from fastapi import APIRouter, HTTPException, Request
+
+from models.create_user_model import CreateUserModel
from models.response_model import ResponseModel
+from models.update_password_model import UpdatePasswordModel
+from models.update_user_model import UpdateUserModel
from models.user_model_response import UserModelResponse
from services.user_service import UserService
from utils.auth_util import auth_required
-from utils.response_util import respond_rest, process_response
-from utils.role_util import platform_role_required_bool, is_admin_user, is_admin_role
-from utils.constants import ErrorCodes, Messages, Defaults, Roles, Headers
-from utils.database import role_collection
-from models.create_user_model import CreateUserModel
-from models.update_user_model import UpdateUserModel
-from models.update_password_model import UpdatePasswordModel
+from utils.constants import Defaults, ErrorCodes, Headers, Messages, Roles
+from utils.response_util import process_response, respond_rest
+from utils.role_util import is_admin_role, is_admin_user, platform_role_required_bool
user_router = APIRouter()
logger = logging.getLogger('doorman.gateway')
+
async def _safe_is_admin_user(username: str) -> bool:
try:
return await is_admin_user(username)
except Exception:
return False
+
async def _safe_is_admin_role(role: str) -> bool:
try:
return await is_admin_role(role)
except Exception:
return False
+
"""
Add user
@@ -48,61 +50,62 @@ Response:
{}
"""
-@user_router.post('',
+
+@user_router.post(
+ '',
description='Add user',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'User created successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'User created successfully'}}},
}
- }
+ },
)
-
async def create_user(user_data: CreateUserModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
if not await platform_role_required_bool(username, Roles.MANAGE_USERS):
return respond_rest(
ResponseModel(
status_code=403,
error_code='USR006',
- error_message='Can only update your own information'
- ))
+ error_message='Can only update your own information',
+ )
+ )
if user_data.role and await _safe_is_admin_role(user_data.role):
if not await _safe_is_admin_user(username):
return respond_rest(
ResponseModel(
status_code=403,
error_code='USR015',
- error_message='Only admin may create users with the admin role'
- ))
+ error_message='Only admin may create users with the admin role',
+ )
+ )
return respond_rest(await UserService.create_user(user_data, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Update user
@@ -112,41 +115,47 @@ Response:
{}
"""
-@user_router.put('/{username}',
+
+@user_router.put(
+ '/{username}',
description='Update user',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'User updated successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'User updated successfully'}}},
}
- }
+ },
)
-
async def update_user(username: str, api_data: UpdateUserModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
auth_username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
# Block modifications to bootstrap admin user, except for limited operational fields
if username == 'admin':
allowed_keys = {
- 'bandwidth_limit_bytes', 'bandwidth_limit_window',
- 'rate_limit_duration', 'rate_limit_duration_type', 'rate_limit_enabled',
- 'throttle_duration', 'throttle_duration_type', 'throttle_wait_duration',
- 'throttle_wait_duration_type', 'throttle_queue_limit', 'throttle_enabled',
+ 'bandwidth_limit_bytes',
+ 'bandwidth_limit_window',
+ 'rate_limit_duration',
+ 'rate_limit_duration_type',
+ 'rate_limit_enabled',
+ 'throttle_duration',
+ 'throttle_duration_type',
+ 'throttle_wait_duration',
+ 'throttle_wait_duration_type',
+ 'throttle_queue_limit',
+ 'throttle_enabled',
}
try:
- incoming = {k for k, v in (api_data.dict(exclude_unset=True) or {}).items() if v is not None}
+ incoming = {
+ k for k, v in (api_data.dict(exclude_unset=True) or {}).items() if v is not None
+ }
except Exception:
incoming = set()
if not incoming.issubset(allowed_keys):
@@ -154,22 +163,27 @@ async def update_user(username: str, api_data: UpdateUserModel, request: Request
ResponseModel(
status_code=403,
error_code='USR020',
- error_message='Super admin user cannot be modified'
- ))
- if not auth_username == username and not await platform_role_required_bool(auth_username, Roles.MANAGE_USERS):
+ error_message='Super admin user cannot be modified',
+ )
+ )
+ if not auth_username == username and not await platform_role_required_bool(
+ auth_username, Roles.MANAGE_USERS
+ ):
return respond_rest(
ResponseModel(
status_code=403,
error_code='USR006',
- error_message='Can only update your own information'
- ))
+ error_message='Can only update your own information',
+ )
+ )
if await _safe_is_admin_user(username) and not await _safe_is_admin_user(auth_username):
return respond_rest(
ResponseModel(
status_code=403,
error_code='USR012',
- error_message='Only admin may modify admin users'
- ))
+ error_message='Only admin may modify admin users',
+ )
+ )
new_role = api_data.role
if new_role is not None:
target_is_admin = await _safe_is_admin_user(username)
@@ -179,23 +193,26 @@ async def update_user(username: str, api_data: UpdateUserModel, request: Request
ResponseModel(
status_code=403,
error_code='USR013',
- error_message='Only admin may change admin role assignments'
- ))
+ error_message='Only admin may change admin role assignments',
+ )
+ )
return respond_rest(await UserService.update_user(username, api_data, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Delete user
@@ -205,30 +222,27 @@ Response:
{}
"""
-@user_router.delete('/{username}',
+
+@user_router.delete(
+ '/{username}',
description='Delete user',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
- 'content': {
- 'application/json': {
- 'example': {
- 'message': 'User deleted successfully'
- }
- }
- }
+ 'content': {'application/json': {'example': {'message': 'User deleted successfully'}}},
}
- }
+ },
)
-
async def delete_user(username: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
auth_username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
# Block any deletion of bootstrap admin user
if username == 'admin':
@@ -236,37 +250,44 @@ async def delete_user(username: str, request: Request):
ResponseModel(
status_code=403,
error_code='USR021',
- error_message='Super admin user cannot be deleted'
- ))
- if not auth_username == username and not await platform_role_required_bool(auth_username, Roles.MANAGE_USERS):
+ error_message='Super admin user cannot be deleted',
+ )
+ )
+ if not auth_username == username and not await platform_role_required_bool(
+ auth_username, Roles.MANAGE_USERS
+ ):
return respond_rest(
ResponseModel(
status_code=403,
error_code='USR007',
- error_message='Can only delete your own account'
- ))
+ error_message='Can only delete your own account',
+ )
+ )
if await _safe_is_admin_user(username) and not await _safe_is_admin_user(auth_username):
return respond_rest(
ResponseModel(
status_code=403,
error_code='USR014',
- error_message='Only admin may delete admin users'
- ))
+ error_message='Only admin may delete admin users',
+ )
+ )
return respond_rest(await UserService.delete_user(username, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Update user password
@@ -276,65 +297,68 @@ Response:
{}
"""
-@user_router.put('/{username}/update-password',
+
+@user_router.put(
+ '/{username}/update-password',
description='Update user password',
response_model=ResponseModel,
responses={
200: {
'description': 'Successful Response',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Password updated successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Password updated successfully'}}
+ },
}
- }
+ },
)
-
async def update_user_password(username: str, api_data: UpdatePasswordModel, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
auth_username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
# Block any password changes to bootstrap admin user
if username == 'admin':
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='USR022',
- error_message='Super admin password cannot be changed via UI'
- ))
- if not auth_username == username and not await platform_role_required_bool(auth_username, Roles.MANAGE_USERS):
- return respond_rest(ResponseModel(
- status_code=403,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code='USR006',
- error_message='Can only update your own password'
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='USR022',
+ error_message='Super admin password cannot be changed via UI',
+ )
+ )
+ if not auth_username == username and not await platform_role_required_bool(
+ auth_username, Roles.MANAGE_USERS
+ ):
+ return respond_rest(
+ ResponseModel(
+ status_code=403,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code='USR006',
+ error_message='Can only update your own password',
+ )
+ )
return respond_rest(await UserService.update_password(username, api_data, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -344,43 +368,43 @@ Response:
{}
"""
-@user_router.get('/me',
- description='Get user by username',
- response_model=UserModelResponse
- )
+@user_router.get('/me', description='Get user by username', response_model=UserModelResponse)
async def get_user_by_username(request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
auth_username = payload.get('sub')
- logger.info(f'{request_id} | Username: {auth_username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {auth_username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
return respond_rest(await UserService.get_user_by_username(auth_username, request_id))
except HTTPException as e:
- return respond_rest(ResponseModel(
- status_code=e.status_code,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.HTTP_EXCEPTION,
- error_message=e.detail
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.HTTP_EXCEPTION,
+ error_message=e.detail,
+ )
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return respond_rest(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return respond_rest(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -390,18 +414,19 @@ Response:
{}
"""
-@user_router.get('/all',
- description='Get all users',
- response_model=List[UserModelResponse]
-)
-async def get_all_users(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE):
+@user_router.get('/all', description='Get all users', response_model=list[UserModelResponse])
+async def get_all_users(
+ request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE
+):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
data = await UserService.get_all_users(page, page_size, request_id)
if data.get('status_code') == 200 and isinstance(data.get('response'), dict):
@@ -412,35 +437,40 @@ async def get_all_users(request: Request, page: int = Defaults.PAGE, page_size:
if u.get('username') == 'admin':
continue
# Hide other admin role users from non-admin users
- if not await _safe_is_admin_user(username) and await _safe_is_admin_role(u.get('role')):
+ if not await _safe_is_admin_user(username) and await _safe_is_admin_role(
+ u.get('role')
+ ):
continue
filtered.append(u)
data = dict(data)
data['response'] = {'users': filtered}
return process_response(data, 'rest')
except HTTPException as e:
- return process_response(ResponseModel(
- status_code=e.status_code,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.HTTP_EXCEPTION,
- error_message=e.detail
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=e.status_code,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.HTTP_EXCEPTION,
+ error_message=e.detail,
+ ).dict(),
+ 'rest',
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -450,55 +480,72 @@ Response:
{}
"""
-@user_router.get('/{username}',
- description='Get user by username',
- response_model=UserModelResponse
-)
+@user_router.get(
+ '/{username}', description='Get user by username', response_model=UserModelResponse
+)
async def get_user_by_username(username: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
auth_username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
# Block access to bootstrap admin user for ALL users (including admin),
# except when STRICT_RESPONSE_ENVELOPE=true (envelope-shape tests)
- if username == 'admin' and not (os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true'):
- return process_response(ResponseModel(
- status_code=404,
- response_headers={Headers.REQUEST_ID: request_id},
- error_message='User not found'
- ).dict(), 'rest')
- if not auth_username == username and not await platform_role_required_bool(auth_username, 'manage_users'):
+ if username == 'admin' and not (
+ os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true'
+ ):
+ return process_response(
+ ResponseModel(
+ status_code=404,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_message='User not found',
+ ).dict(),
+ 'rest',
+ )
+ if not auth_username == username and not await platform_role_required_bool(
+ auth_username, 'manage_users'
+ ):
return process_response(
ResponseModel(
status_code=403,
error_code='USR008',
error_message='Unable to retrieve information for user',
- ).dict(), 'rest')
+ ).dict(),
+ 'rest',
+ )
if not await _safe_is_admin_user(auth_username) and await _safe_is_admin_user(username):
- return process_response(ResponseModel(
- status_code=404,
- response_headers={Headers.REQUEST_ID: request_id},
- error_message='User not found'
- ).dict(), 'rest')
- return process_response(await UserService.get_user_by_username(username, request_id), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=404,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_message='User not found',
+ ).dict(),
+ 'rest',
+ )
+ return process_response(
+ await UserService.get_user_by_username(username, request_id), 'rest'
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
+
"""
Endpoint
@@ -508,55 +555,66 @@ Response:
{}
"""
-@user_router.get('/email/{email}',
- description='Get user by email',
- response_model=List[UserModelResponse]
-)
+@user_router.get(
+ '/email/{email}', description='Get user by email', response_model=list[UserModelResponse]
+)
async def get_user_by_email(email: str, request: Request):
request_id = str(uuid.uuid4())
start_time = time.time() * 1000
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
data = await UserService.get_user_by_email(username, email, request_id)
if data.get('status_code') == 200 and isinstance(data.get('response'), dict):
u = data.get('response')
# Block access to bootstrap admin user for ALL users
if u.get('username') == 'admin':
- return process_response(ResponseModel(
- status_code=404,
- response_headers={Headers.REQUEST_ID: request_id},
- error_message='User not found'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=404,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_message='User not found',
+ ).dict(),
+ 'rest',
+ )
# Block access to other admin users for non-admin users
if not await _safe_is_admin_user(username) and await _safe_is_admin_role(u.get('role')):
- return process_response(ResponseModel(
- status_code=404,
- response_headers={Headers.REQUEST_ID: request_id},
- error_message='User not found'
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=404,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_message='User not found',
+ ).dict(),
+ 'rest',
+ )
return process_response(data, 'rest')
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ).dict(), 'rest')
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ ).dict(),
+ 'rest',
+ )
finally:
end_time = time.time() * 1000
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
-@user_router.get('',
- description='Get all users (base path)',
- response_model=List[UserModelResponse]
+
+
+@user_router.get(
+ '', description='Get all users (base path)', response_model=list[UserModelResponse]
)
-async def get_all_users_base(request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE):
+async def get_all_users_base(
+ request: Request, page: int = Defaults.PAGE, page_size: int = Defaults.PAGE_SIZE
+):
"""Convenience alias for GET /platform/user/all to support clients
and tests that expect listing at the base collection path.
"""
diff --git a/backend-services/routes/vault_routes.py b/backend-services/routes/vault_routes.py
index 9a0afac..b1fcd86 100644
--- a/backend-services/routes/vault_routes.py
+++ b/backend-services/routes/vault_routes.py
@@ -4,20 +4,19 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import List
-import uuid
-import time
import logging
-from fastapi import APIRouter, Request, HTTPException
+import time
+import uuid
+
+from fastapi import APIRouter, Request
-from models.response_model import ResponseModel
from models.create_vault_entry_model import CreateVaultEntryModel
+from models.response_model import ResponseModel
from models.update_vault_entry_model import UpdateVaultEntryModel
-from models.vault_entry_model_response import VaultEntryModelResponse
from services.vault_service import VaultService
from utils.auth_util import auth_required
-from utils.response_util import respond_rest, process_response
-from utils.constants import ErrorCodes, Messages, Headers
+from utils.constants import ErrorCodes, Headers, Messages
+from utils.response_util import process_response, respond_rest
vault_router = APIRouter()
@@ -35,12 +34,12 @@ logger = logging.getLogger('doorman.gateway')
'application/json': {
'example': {
'message': 'Vault entry created successfully',
- 'data': {'key_name': 'api_key_production'}
+ 'data': {'key_name': 'api_key_production'},
}
}
- }
+ },
}
- }
+ },
)
async def create_vault_entry(entry_data: CreateVaultEntryModel, request: Request):
"""
@@ -52,20 +51,22 @@ async def create_vault_entry(entry_data: CreateVaultEntryModel, request: Request
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
-
+
return respond_rest(await VaultService.create_vault_entry(username, entry_data, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
elapsed = time.time() * 1000 - start_time
logger.info(f'{request_id} | Total time: {elapsed:.2f}ms')
@@ -88,16 +89,16 @@ async def create_vault_entry(entry_data: CreateVaultEntryModel, request: Request
'username': 'john_doe',
'description': 'Production API key',
'created_at': '2024-11-22T10:15:30Z',
- 'updated_at': '2024-11-22T10:15:30Z'
+ 'updated_at': '2024-11-22T10:15:30Z',
}
],
- 'count': 1
+ 'count': 1,
}
}
}
- }
+ },
}
- }
+ },
)
async def list_vault_entries(request: Request):
"""
@@ -109,20 +110,22 @@ async def list_vault_entries(request: Request):
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
-
+
return respond_rest(await VaultService.list_vault_entries(username, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
elapsed = time.time() * 1000 - start_time
logger.info(f'{request_id} | Total time: {elapsed:.2f}ms')
@@ -143,13 +146,13 @@ async def list_vault_entries(request: Request):
'username': 'john_doe',
'description': 'Production API key',
'created_at': '2024-11-22T10:15:30Z',
- 'updated_at': '2024-11-22T10:15:30Z'
+ 'updated_at': '2024-11-22T10:15:30Z',
}
}
}
- }
+ },
}
- }
+ },
)
async def get_vault_entry(key_name: str, request: Request):
"""
@@ -161,20 +164,22 @@ async def get_vault_entry(key_name: str, request: Request):
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
-
+
return respond_rest(await VaultService.get_vault_entry(username, key_name, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
elapsed = time.time() * 1000 - start_time
logger.info(f'{request_id} | Total time: {elapsed:.2f}ms')
@@ -188,14 +193,10 @@ async def get_vault_entry(key_name: str, request: Request):
200: {
'description': 'Vault entry updated successfully',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Vault entry updated successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Vault entry updated successfully'}}
+ },
}
- }
+ },
)
async def update_vault_entry(key_name: str, update_data: UpdateVaultEntryModel, request: Request):
"""
@@ -207,20 +208,24 @@ async def update_vault_entry(key_name: str, update_data: UpdateVaultEntryModel,
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
-
- return respond_rest(await VaultService.update_vault_entry(username, key_name, update_data, request_id))
+
+ return respond_rest(
+ await VaultService.update_vault_entry(username, key_name, update_data, request_id)
+ )
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
elapsed = time.time() * 1000 - start_time
logger.info(f'{request_id} | Total time: {elapsed:.2f}ms')
@@ -234,14 +239,10 @@ async def update_vault_entry(key_name: str, update_data: UpdateVaultEntryModel,
200: {
'description': 'Vault entry deleted successfully',
'content': {
- 'application/json': {
- 'example': {
- 'message': 'Vault entry deleted successfully'
- }
- }
- }
+ 'application/json': {'example': {'message': 'Vault entry deleted successfully'}}
+ },
}
- }
+ },
)
async def delete_vault_entry(key_name: str, request: Request):
"""
@@ -253,20 +254,22 @@ async def delete_vault_entry(key_name: str, request: Request):
try:
payload = await auth_required(request)
username = payload.get('sub')
- logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
+ logger.info(
+ f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}'
+ )
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
-
+
return respond_rest(await VaultService.delete_vault_entry(username, key_name, request_id))
except Exception as e:
logger.critical(f'{request_id} | Unexpected error: {str(e)}', exc_info=True)
- return process_response(ResponseModel(
- status_code=500,
- response_headers={
- Headers.REQUEST_ID: request_id
- },
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
- ))
+ return process_response(
+ ResponseModel(
+ status_code=500,
+ response_headers={Headers.REQUEST_ID: request_id},
+ error_code=ErrorCodes.UNEXPECTED,
+ error_message=Messages.UNEXPECTED,
+ )
+ )
finally:
elapsed = time.time() * 1000 - start_time
logger.info(f'{request_id} | Total time: {elapsed:.2f}ms')
diff --git a/backend-services/services/api_service.py b/backend-services/services/api_service.py
index 1bbf4c2..d7e47c4 100644
--- a/backend-services/services/api_service.py
+++ b/backend-services/services/api_service.py
@@ -4,23 +4,22 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-import uuid
import logging
+import uuid
+from models.create_api_model import CreateApiModel
from models.response_model import ResponseModel
from models.update_api_model import UpdateApiModel
+from utils.async_db import db_delete_one, db_find_one, db_insert_one, db_update_one
+from utils.constants import ErrorCodes, Messages
from utils.database_async import api_collection
-from utils.async_db import db_find_one, db_insert_one, db_update_one, db_delete_one
-from utils.cache_manager_util import cache_manager
from utils.doorman_cache_util import doorman_cache
-from models.create_api_model import CreateApiModel
from utils.paging_util import validate_page_params
-from utils.constants import ErrorCodes, Messages, Defaults
logger = logging.getLogger('doorman.gateway')
-class ApiService:
+class ApiService:
@staticmethod
async def create_api(data: CreateApiModel, request_id):
"""
@@ -33,16 +32,17 @@ class ApiService:
return ResponseModel(
status_code=400,
error_code='API013',
- error_message='Public API cannot have credits enabled'
+ error_message='Public API cannot have credits enabled',
).dict()
except Exception:
pass
cache_key = f'{data.api_name}/{data.api_version}'
existing = doorman_cache.get_cache('api_cache', cache_key)
if not existing:
- existing = await db_find_one(api_collection, {'api_name': data.api_name, 'api_version': data.api_version})
+ existing = await db_find_one(
+ api_collection, {'api_name': data.api_name, 'api_version': data.api_version}
+ )
if existing:
-
try:
if existing.get('_id'):
existing = {k: v for k, v in existing.items() if k != '_id'}
@@ -50,7 +50,9 @@ class ApiService:
if not existing.get('api_id'):
existing['api_id'] = str(uuid.uuid4())
if not existing.get('api_path'):
- existing['api_path'] = f"/{existing.get('api_name')}/{existing.get('api_version')}"
+ existing['api_path'] = (
+ f'/{existing.get("api_name")}/{existing.get("api_version")}'
+ )
doorman_cache.set_cache('api_cache', cache_key, existing)
doorman_cache.set_cache('api_id_cache', existing['api_path'], existing['api_id'])
except Exception:
@@ -58,11 +60,9 @@ class ApiService:
logger.info(request_id + ' | API already exists; returning success')
return ResponseModel(
status_code=200,
- response_headers={
- 'request_id': request_id
- },
- message='API already exists'
- ).dict()
+ response_headers={'request_id': request_id},
+ message='API already exists',
+ ).dict()
data.api_path = f'/{data.api_name}/{data.api_version}'
data.api_id = str(uuid.uuid4())
api_dict = data.dict()
@@ -70,10 +70,8 @@ class ApiService:
if not insert_result.acknowledged:
logger.error(request_id + ' | API creation failed with code API002')
return ResponseModel(
- status_code=400,
- error_code='API002',
- error_message='Unable to insert endpoint'
- ).dict()
+ status_code=400, error_code='API002', error_message='Unable to insert endpoint'
+ ).dict()
api_dict['_id'] = str(insert_result.inserted_id)
# Cache by both api_id and canonical path for consistent lookups
doorman_cache.set_cache('api_cache', data.api_id, api_dict)
@@ -89,12 +87,10 @@ class ApiService:
api_copy = {'api_name': data.api_name, 'api_version': data.api_version}
return ResponseModel(
status_code=201,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
response={'api': api_copy},
- message='API created successfully'
- ).dict()
+ message='API created successfully',
+ ).dict()
@staticmethod
async def update_api(api_name, api_version, data: UpdateApiModel, request_id):
@@ -102,36 +98,49 @@ class ApiService:
Update an API on the platform.
"""
logger.info(request_id + ' | Updating API: ' + api_name + ' ' + api_version)
- if data.api_name and data.api_name != api_name or data.api_version and data.api_version != api_version or data.api_path and data.api_path != f'/{api_name}/{api_version}':
+ if (
+ data.api_name
+ and data.api_name != api_name
+ or data.api_version
+ and data.api_version != api_version
+ or data.api_path
+ and data.api_path != f'/{api_name}/{api_version}'
+ ):
logger.error(request_id + ' | API update failed with code API005')
return ResponseModel(
status_code=400,
error_code='API005',
- error_message='API name and version cannot be updated'
- ).dict()
+ error_message='API name and version cannot be updated',
+ ).dict()
api = doorman_cache.get_cache('api_cache', f'{api_name}/{api_version}')
if not api:
- api = await db_find_one(api_collection, {'api_name': api_name, 'api_version': api_version})
+ api = await db_find_one(
+ api_collection, {'api_name': api_name, 'api_version': api_version}
+ )
if not api:
logger.error(request_id + ' | API update failed with code API003')
return ResponseModel(
status_code=400,
error_code='API003',
- error_message='API does not exist for the requested name and version'
- ).dict()
+ error_message='API does not exist for the requested name and version',
+ ).dict()
else:
- doorman_cache.delete_cache('api_cache', doorman_cache.get_cache('api_id_cache', f'/{api_name}/{api_version}'))
+ doorman_cache.delete_cache(
+ 'api_cache', doorman_cache.get_cache('api_id_cache', f'/{api_name}/{api_version}')
+ )
doorman_cache.delete_cache('api_id_cache', f'/{api_name}/{api_version}')
not_null_data = {k: v for k, v in data.dict().items() if v is not None}
try:
desired_public = bool(not_null_data.get('api_public', api.get('api_public')))
- desired_credits = bool(not_null_data.get('api_credits_enabled', api.get('api_credits_enabled')))
+ desired_credits = bool(
+ not_null_data.get('api_credits_enabled', api.get('api_credits_enabled'))
+ )
if desired_public and desired_credits:
return ResponseModel(
status_code=400,
error_code='API013',
- error_message='Public API cannot have credits enabled'
+ error_message='Public API cannot have credits enabled',
).dict()
except Exception:
pass
@@ -140,7 +149,7 @@ class ApiService:
update_result = await db_update_one(
api_collection,
{'api_name': api_name, 'api_version': api_version},
- {'$set': not_null_data}
+ {'$set': not_null_data},
)
if update_result.modified_count > 0:
cache_key = f'{api_name}/{api_version}'
@@ -149,28 +158,23 @@ class ApiService:
if not update_result.acknowledged or update_result.modified_count == 0:
logger.error(request_id + ' | API update failed with code API002')
return ResponseModel(
- status_code=400,
- error_code='API002',
- error_message='Unable to update api'
- ).dict()
+ status_code=400, error_code='API002', error_message='Unable to update api'
+ ).dict()
except Exception as e:
cache_key = f'{api_name}/{api_version}'
doorman_cache.delete_cache('api_cache', cache_key)
doorman_cache.delete_cache('api_id_cache', f'/{api_name}/{api_version}')
- logger.error(request_id + ' | API update failed with exception: ' + str(e), exc_info=True)
+ logger.error(
+ request_id + ' | API update failed with exception: ' + str(e), exc_info=True
+ )
raise
logger.info(request_id + ' | API updated successful')
- return ResponseModel(
- status_code=200,
- message='API updated successfully'
- ).dict()
+ return ResponseModel(status_code=200, message='API updated successfully').dict()
else:
logger.error(request_id + ' | API update failed with code API006')
return ResponseModel(
- status_code=400,
- error_code='API006',
- error_message='No data to update'
- ).dict()
+ status_code=400, error_code='API006', error_message='No data to update'
+ ).dict()
@staticmethod
async def delete_api(api_name, api_version, request_id):
@@ -180,32 +184,34 @@ class ApiService:
logger.info(request_id + ' | Deleting API: ' + api_name + ' ' + api_version)
api = doorman_cache.get_cache('api_cache', f'{api_name}/{api_version}')
if not api:
- api = await db_find_one(api_collection, {'api_name': api_name, 'api_version': api_version})
+ api = await db_find_one(
+ api_collection, {'api_name': api_name, 'api_version': api_version}
+ )
if not api:
logger.error(request_id + ' | API deletion failed with code API003')
return ResponseModel(
status_code=400,
error_code='API003',
- error_message='API does not exist for the requested name and version'
- ).dict()
- delete_result = await db_delete_one(api_collection, {'api_name': api_name, 'api_version': api_version})
+ error_message='API does not exist for the requested name and version',
+ ).dict()
+ delete_result = await db_delete_one(
+ api_collection, {'api_name': api_name, 'api_version': api_version}
+ )
if not delete_result.acknowledged:
logger.error(request_id + ' | API deletion failed with code API002')
return ResponseModel(
- status_code=400,
- error_code='API002',
- error_message='Unable to delete endpoint'
- ).dict()
- doorman_cache.delete_cache('api_cache', doorman_cache.get_cache('api_id_cache', f'/{api_name}/{api_version}'))
+ status_code=400, error_code='API002', error_message='Unable to delete endpoint'
+ ).dict()
+ doorman_cache.delete_cache(
+ 'api_cache', doorman_cache.get_cache('api_id_cache', f'/{api_name}/{api_version}')
+ )
doorman_cache.delete_cache('api_id_cache', f'/{api_name}/{api_version}')
logger.info(request_id + ' | API deletion successful')
return ResponseModel(
status_code=200,
- response_headers={
- 'request_id': request_id
- },
- message='API deleted successfully'
- ).dict()
+ response_headers={'request_id': request_id},
+ message='API deleted successfully',
+ ).dict()
@staticmethod
async def get_api_by_name_version(api_name, api_version, request_id):
@@ -215,52 +221,52 @@ class ApiService:
logger.info(request_id + ' | Getting API: ' + api_name + ' ' + api_version)
api = doorman_cache.get_cache('api_cache', f'{api_name}/{api_version}')
if not api:
- api = api_collection.find_one({'api_name': api_name, 'api_version': api_version})
+ api = await db_find_one(
+ api_collection, {'api_name': api_name, 'api_version': api_version}
+ )
if not api:
logger.error(request_id + ' | API retrieval failed with code API003')
return ResponseModel(
status_code=400,
error_code='API003',
- error_message='API does not exist for the requested name and version'
- ).dict()
- if api.get('_id'): del api['_id']
+ error_message='API does not exist for the requested name and version',
+ ).dict()
+ if api.get('_id'):
+ del api['_id']
doorman_cache.set_cache('api_cache', f'{api_name}/{api_version}', api)
if '_id' in api:
del api['_id']
logger.info(request_id + ' | API retrieval successful')
- return ResponseModel(
- status_code=200,
- response=api
- ).dict()
+ return ResponseModel(status_code=200, response=api).dict()
@staticmethod
async def get_apis(page, page_size, request_id):
"""
Get all APIs that a user has access to with pagination.
"""
- logger.info(request_id + ' | Getting APIs: Page=' + str(page) + ' Page Size=' + str(page_size))
+ logger.info(
+ request_id + ' | Getting APIs: Page=' + str(page) + ' Page Size=' + str(page_size)
+ )
try:
page, page_size = validate_page_params(page, page_size)
except Exception as e:
return ResponseModel(
status_code=400,
error_code=ErrorCodes.PAGE_SIZE,
- error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING)
+ error_message=(
+ Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING
+ ),
).dict()
skip = (page - 1) * page_size
cursor = api_collection.find().sort('api_name', 1).skip(skip).limit(page_size)
apis = cursor.to_list(length=None)
for api in apis:
- if api.get('_id'): del api['_id']
+ if api.get('_id'):
+ del api['_id']
logger.info(request_id + ' | APIs retrieval successful')
- meta = {
- 'total': len(apis),
- 'page': page,
- 'page_size': page_size,
- }
+ meta = {'total': len(apis), 'page': page, 'page_size': page_size}
# Add a message to keep payload sizes above compression overhead for tests
message = 'API list retrieved successfully. This message also helps ensure the response has a reasonable size for compression benchmarks.'
return ResponseModel(
- status_code=200,
- response={'apis': apis, 'meta': meta, 'message': message}
- ).dict()
+ status_code=200, response={'apis': apis, 'meta': meta, 'message': message}
+ ).dict()
diff --git a/backend-services/services/credit_service.py b/backend-services/services/credit_service.py
index 6e8f84a..463bda0 100644
--- a/backend-services/services/credit_service.py
+++ b/backend-services/services/credit_service.py
@@ -4,39 +4,37 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from pymongo.errors import PyMongoError
import logging
-from typing import Optional
import secrets
-from models.response_model import ResponseModel
+from pymongo.errors import PyMongoError
+
from models.credit_model import CreditModel
+from models.response_model import ResponseModel
from models.user_credits_model import UserCreditModel
-from utils.database_async import credit_def_collection, user_credit_collection
-from utils.async_db import db_find_one, db_insert_one, db_update_one, db_delete_one, db_find_list
-from utils.encryption_util import encrypt_value, decrypt_value
-from utils.doorman_cache_util import doorman_cache
-from utils.paging_util import validate_page_params
+from utils.async_db import db_delete_one, db_find_list, db_find_one, db_insert_one, db_update_one
from utils.constants import ErrorCodes, Messages
+from utils.database_async import credit_def_collection, user_credit_collection
+from utils.doorman_cache_util import doorman_cache
+from utils.encryption_util import decrypt_value, encrypt_value
+from utils.paging_util import validate_page_params
logger = logging.getLogger('doorman.gateway')
-class CreditService:
+class CreditService:
@staticmethod
- def _validate_credit_data(data: CreditModel) -> Optional[ResponseModel]:
+ def _validate_credit_data(data: CreditModel) -> ResponseModel | None:
"""Validate credit definition data before creation or update."""
if not data.api_credit_group:
return ResponseModel(
- status_code=400,
- error_code='CRD009',
- error_message='Credit group name is required'
+ status_code=400, error_code='CRD009', error_message='Credit group name is required'
)
if not data.api_key or not data.api_key_header:
return ResponseModel(
status_code=400,
error_code='CRD010',
- error_message='API key and header are required'
+ error_message='API key and header are required',
)
return None
@@ -46,15 +44,21 @@ class CreditService:
logger.info(request_id + ' | Creating credit definition')
validation_error = CreditService._validate_credit_data(data)
if validation_error:
- logger.error(request_id + f' | Credit creation failed with code {validation_error.error_code}')
+ logger.error(
+ request_id + f' | Credit creation failed with code {validation_error.error_code}'
+ )
return validation_error.dict()
try:
- if doorman_cache.get_cache('credit_def_cache', data.api_credit_group) or await db_find_one(credit_def_collection, {'api_credit_group': data.api_credit_group}):
+ if doorman_cache.get_cache(
+ 'credit_def_cache', data.api_credit_group
+ ) or await db_find_one(
+ credit_def_collection, {'api_credit_group': data.api_credit_group}
+ ):
logger.error(request_id + ' | Credit creation failed with code CRD001')
return ResponseModel(
status_code=400,
error_code='CRD001',
- error_message='Credit group already exists'
+ error_message='Credit group already exists',
).dict()
credit_data = data.dict()
if credit_data.get('api_key') is not None:
@@ -67,7 +71,7 @@ class CreditService:
return ResponseModel(
status_code=400,
error_code='CRD002',
- error_message='Unable to insert credit definition'
+ error_message='Unable to insert credit definition',
).dict()
credit_data['_id'] = str(insert_result.inserted_id)
doorman_cache.set_cache('credit_def_cache', data.api_credit_group, credit_data)
@@ -75,14 +79,14 @@ class CreditService:
return ResponseModel(
status_code=201,
response_headers={'request_id': request_id},
- message='Credit definition created successfully'
+ message='Credit definition created successfully',
).dict()
except PyMongoError as e:
logger.error(request_id + f' | Credit creation failed with database error: {str(e)}')
return ResponseModel(
status_code=500,
error_code='CRD011',
- error_message='Database error occurred while creating credit definition'
+ error_message='Database error occurred while creating credit definition',
).dict()
@staticmethod
@@ -91,7 +95,9 @@ class CreditService:
logger.info(request_id + ' | Updating credit definition')
validation_error = CreditService._validate_credit_data(data)
if validation_error:
- logger.error(request_id + f' | Credit update failed with code {validation_error.error_code}')
+ logger.error(
+ request_id + f' | Credit update failed with code {validation_error.error_code}'
+ )
return validation_error.dict()
try:
if data.api_credit_group and data.api_credit_group != api_credit_group:
@@ -99,17 +105,19 @@ class CreditService:
return ResponseModel(
status_code=400,
error_code='CRD003',
- error_message='Credit group name cannot be updated'
+ error_message='Credit group name cannot be updated',
).dict()
doc = doorman_cache.get_cache('credit_def_cache', api_credit_group)
if not doc:
- doc = await db_find_one(credit_def_collection, {'api_credit_group': api_credit_group})
+ doc = await db_find_one(
+ credit_def_collection, {'api_credit_group': api_credit_group}
+ )
if not doc:
logger.error(request_id + ' | Credit update failed with code CRD004')
return ResponseModel(
status_code=400,
error_code='CRD004',
- error_message='Credit definition does not exist for the requested group'
+ error_message='Credit definition does not exist for the requested group',
).dict()
else:
doorman_cache.delete_cache('credit_def_cache', api_credit_group)
@@ -119,22 +127,34 @@ class CreditService:
if 'api_key_new' in not_null:
not_null['api_key_new'] = encrypt_value(not_null['api_key_new'])
if not_null:
- update_result = await db_update_one(credit_def_collection, {'api_credit_group': api_credit_group}, {'$set': not_null})
+ update_result = await db_update_one(
+ credit_def_collection,
+ {'api_credit_group': api_credit_group},
+ {'$set': not_null},
+ )
if not update_result.acknowledged or update_result.modified_count == 0:
logger.error(request_id + ' | Credit update failed with code CRD005')
return ResponseModel(
status_code=400,
error_code='CRD005',
- error_message='Unable to update credit definition'
+ error_message='Unable to update credit definition',
).dict()
logger.info(request_id + ' | Credit update successful')
- return ResponseModel(status_code=200, message='Credit definition updated successfully').dict()
+ return ResponseModel(
+ status_code=200, message='Credit definition updated successfully'
+ ).dict()
else:
logger.error(request_id + ' | Credit update failed with code CRD006')
- return ResponseModel(status_code=400, error_code='CRD006', error_message='No data to update').dict()
+ return ResponseModel(
+ status_code=400, error_code='CRD006', error_message='No data to update'
+ ).dict()
except PyMongoError as e:
logger.error(request_id + f' | Credit update failed with database error: {str(e)}')
- return ResponseModel(status_code=500, error_code='CRD012', error_message='Database error occurred while updating credit definition').dict()
+ return ResponseModel(
+ status_code=500,
+ error_code='CRD012',
+ error_message='Database error occurred while updating credit definition',
+ ).dict()
@staticmethod
async def delete_credit(api_credit_group: str, request_id):
@@ -143,21 +163,39 @@ class CreditService:
try:
doc = doorman_cache.get_cache('credit_def_cache', api_credit_group)
if not doc:
- doc = await db_find_one(credit_def_collection, {'api_credit_group': api_credit_group})
+ doc = await db_find_one(
+ credit_def_collection, {'api_credit_group': api_credit_group}
+ )
if not doc:
logger.error(request_id + ' | Credit deletion failed with code CRD007')
- return ResponseModel(status_code=400, error_code='CRD007', error_message='Credit definition does not exist for the requested group').dict()
+ return ResponseModel(
+ status_code=400,
+ error_code='CRD007',
+ error_message='Credit definition does not exist for the requested group',
+ ).dict()
else:
doorman_cache.delete_cache('credit_def_cache', api_credit_group)
- delete_result = await db_delete_one(credit_def_collection, {'api_credit_group': api_credit_group})
+ delete_result = await db_delete_one(
+ credit_def_collection, {'api_credit_group': api_credit_group}
+ )
if not delete_result.acknowledged or delete_result.deleted_count == 0:
logger.error(request_id + ' | Credit deletion failed with code CRD008')
- return ResponseModel(status_code=400, error_code='CRD008', error_message='Unable to delete credit definition').dict()
+ return ResponseModel(
+ status_code=400,
+ error_code='CRD008',
+ error_message='Unable to delete credit definition',
+ ).dict()
logger.info(request_id + ' | Credit deletion successful')
- return ResponseModel(status_code=200, message='Credit definition deleted successfully').dict()
+ return ResponseModel(
+ status_code=200, message='Credit definition deleted successfully'
+ ).dict()
except PyMongoError as e:
logger.error(request_id + f' | Credit deletion failed with database error: {str(e)}')
- return ResponseModel(status_code=500, error_code='CRD013', error_message='Database error occurred while deleting credit definition').dict()
+ return ResponseModel(
+ status_code=500,
+ error_code='CRD013',
+ error_message='Database error occurred while deleting credit definition',
+ ).dict()
@staticmethod
async def list_credit_defs(page: int, page_size: int, request_id):
@@ -170,7 +208,11 @@ class CreditService:
return ResponseModel(
status_code=400,
error_code=ErrorCodes.PAGE_SIZE,
- error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING)
+ error_message=(
+ Messages.PAGE_TOO_LARGE
+ if 'page_size' in str(e)
+ else Messages.INVALID_PAGING
+ ),
).dict()
all_defs = await db_find_list(credit_def_collection, {})
all_defs.sort(key=lambda d: d.get('api_credit_group'))
@@ -180,25 +222,35 @@ class CreditService:
for doc in all_defs[start:end]:
if doc.get('_id'):
del doc['_id']
- items.append({
- 'api_credit_group': doc.get('api_credit_group'),
- 'api_key_header': doc.get('api_key_header'),
- 'api_key_present': bool(doc.get('api_key')),
- 'credit_tiers': doc.get('credit_tiers', []),
- })
+ items.append(
+ {
+ 'api_credit_group': doc.get('api_credit_group'),
+ 'api_key_header': doc.get('api_key_header'),
+ 'api_key_present': bool(doc.get('api_key')),
+ 'credit_tiers': doc.get('credit_tiers', []),
+ }
+ )
return ResponseModel(status_code=200, response={'items': items}).dict()
except PyMongoError as e:
logger.error(request_id + f' | Credit list failed with database error: {str(e)}')
- return ResponseModel(status_code=500, error_code='CRD020', error_message='Database error occurred while listing credit definitions').dict()
+ return ResponseModel(
+ status_code=500,
+ error_code='CRD020',
+ error_message='Database error occurred while listing credit definitions',
+ ).dict()
@staticmethod
async def get_credit_def(api_credit_group: str, request_id):
"""Get a single credit definition (masked)."""
logger.info(request_id + ' | Getting credit definition')
try:
- doc = credit_def_collection.find_one({'api_credit_group': api_credit_group})
+ doc = await db_find_one(credit_def_collection, {'api_credit_group': api_credit_group})
if not doc:
- return ResponseModel(status_code=404, error_code='CRD021', error_message='Credit definition not found').dict()
+ return ResponseModel(
+ status_code=404,
+ error_code='CRD021',
+ error_message='Credit definition not found',
+ ).dict()
if doc.get('_id'):
del doc['_id']
masked = {
@@ -210,7 +262,11 @@ class CreditService:
return ResponseModel(status_code=200, response=masked).dict()
except PyMongoError as e:
logger.error(request_id + f' | Credit fetch failed with database error: {str(e)}')
- return ResponseModel(status_code=500, error_code='CRD022', error_message='Database error occurred while retrieving credit definition').dict()
+ return ResponseModel(
+ status_code=500,
+ error_code='CRD022',
+ error_message='Database error occurred while retrieving credit definition',
+ ).dict()
@staticmethod
async def add_credits(username: str, data: UserCreditModel, request_id):
@@ -218,7 +274,11 @@ class CreditService:
logger.info(request_id + f' | Adding credits for user: {username}')
try:
if data.username and data.username != username:
- return ResponseModel(status_code=400, error_code='CRD014', error_message='Username in body does not match path').dict()
+ return ResponseModel(
+ status_code=400,
+ error_code='CRD014',
+ error_message='Username in body does not match path',
+ ).dict()
doc = await db_find_one(user_credit_collection, {'username': username})
users_credits = data.users_credits or {}
secured = {}
@@ -229,13 +289,21 @@ class CreditService:
secured[group] = info
payload = {'username': username, 'users_credits': secured}
if doc:
- await db_update_one(user_credit_collection, {'username': username}, {'$set': {'users_credits': secured}})
+ await db_update_one(
+ user_credit_collection,
+ {'username': username},
+ {'$set': {'users_credits': secured}},
+ )
else:
await db_insert_one(user_credit_collection, payload)
return ResponseModel(status_code=200, message='Credits saved successfully').dict()
except PyMongoError as e:
logger.error(request_id + f' | Add credits failed with database error: {str(e)}')
- return ResponseModel(status_code=500, error_code='CRD015', error_message='Database error occurred while saving user credits').dict()
+ return ResponseModel(
+ status_code=500,
+ error_code='CRD015',
+ error_message='Database error occurred while saving user credits',
+ ).dict()
@staticmethod
async def get_all_credits(page: int, page_size: int, request_id, search: str = ''):
@@ -247,7 +315,11 @@ class CreditService:
return ResponseModel(
status_code=400,
error_code=ErrorCodes.PAGE_SIZE,
- error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING)
+ error_message=(
+ Messages.PAGE_TOO_LARGE
+ if 'page_size' in str(e)
+ else Messages.INVALID_PAGING
+ ),
).dict()
cursor = user_credit_collection.find().sort('username', 1)
@@ -271,7 +343,7 @@ class CreditService:
if it.get('_id'):
del it['_id']
uc = it.get('users_credits') or {}
- for g, info in uc.items():
+ for _g, info in uc.items():
if isinstance(info, dict) and 'user_api_key' in info:
dec = decrypt_value(info.get('user_api_key'))
if dec is not None:
@@ -279,7 +351,11 @@ class CreditService:
return ResponseModel(status_code=200, response={'user_credits': items}).dict()
except PyMongoError as e:
logger.error(request_id + f' | Get all credits failed with database error: {str(e)}')
- return ResponseModel(status_code=500, error_code='CRD016', error_message='Database error occurred while retrieving credits').dict()
+ return ResponseModel(
+ status_code=500,
+ error_code='CRD016',
+ error_message='Database error occurred while retrieving credits',
+ ).dict()
@staticmethod
async def get_user_credits(username: str, request_id):
@@ -287,11 +363,13 @@ class CreditService:
try:
doc = await db_find_one(user_credit_collection, {'username': username})
if not doc:
- return ResponseModel(status_code=404, error_code='CRD017', error_message='User credits not found').dict()
+ return ResponseModel(
+ status_code=404, error_code='CRD017', error_message='User credits not found'
+ ).dict()
if doc.get('_id'):
del doc['_id']
uc = doc.get('users_credits') or {}
- for g, info in uc.items():
+ for _g, info in uc.items():
if isinstance(info, dict) and 'user_api_key' in info:
dec = decrypt_value(info.get('user_api_key'))
if dec is not None:
@@ -299,7 +377,11 @@ class CreditService:
return ResponseModel(status_code=200, response=doc).dict()
except PyMongoError as e:
logger.error(request_id + f' | Get user credits failed with database error: {str(e)}')
- return ResponseModel(status_code=500, error_code='CRD018', error_message='Database error occurred while retrieving user credits').dict()
+ return ResponseModel(
+ status_code=500,
+ error_code='CRD018',
+ error_message='Database error occurred while retrieving user credits',
+ ).dict()
@staticmethod
async def rotate_api_key(username: str, group: str, request_id):
@@ -312,31 +394,45 @@ class CreditService:
# But maybe we can create the credit entry if it's missing.
# Let's error if not found for now, or create empty.
doc = {'username': username, 'users_credits': {}}
-
+
users_credits = doc.get('users_credits') or {}
group_credits = users_credits.get(group) or {}
-
+
# Generate new key
new_key = secrets.token_urlsafe(32)
encrypted_key = encrypt_value(new_key)
-
+
# Update group credits
# Preserve other fields in group_credits (like available_credits)
if isinstance(group_credits, dict):
group_credits['user_api_key'] = encrypted_key
else:
# Should be a dict, but if it was somehow not, reset it
- group_credits = {'user_api_key': encrypted_key, 'available_credits': 0, 'tier_name': 'default'}
+ group_credits = {
+ 'user_api_key': encrypted_key,
+ 'available_credits': 0,
+ 'tier_name': 'default',
+ }
users_credits[group] = group_credits
-
+
if doc.get('_id'):
- await db_update_one(user_credit_collection, {'username': username}, {'$set': {'users_credits': users_credits}})
+ await db_update_one(
+ user_credit_collection,
+ {'username': username},
+ {'$set': {'users_credits': users_credits}},
+ )
else:
- await db_insert_one(user_credit_collection, {'username': username, 'users_credits': users_credits})
-
+ await db_insert_one(
+ user_credit_collection, {'username': username, 'users_credits': users_credits}
+ )
+
return ResponseModel(status_code=200, response={'api_key': new_key}).dict()
-
+
except PyMongoError as e:
logger.error(request_id + f' | Rotate key failed with database error: {str(e)}')
- return ResponseModel(status_code=500, error_code='CRD019', error_message='Database error occurred while rotating API key').dict()
+ return ResponseModel(
+ status_code=500,
+ error_code='CRD019',
+ error_message='Database error occurred while rotating API key',
+ ).dict()
diff --git a/backend-services/services/endpoint_service.py b/backend-services/services/endpoint_service.py
index 2b2672e..0f0b746 100644
--- a/backend-services/services/endpoint_service.py
+++ b/backend-services/services/endpoint_service.py
@@ -4,61 +4,75 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-import uuid
import logging
-import os
import string as _string
+import uuid
from pathlib import Path
+from models.create_endpoint_model import CreateEndpointModel
from models.create_endpoint_validation_model import CreateEndpointValidationModel
from models.response_model import ResponseModel
from models.update_endpoint_model import UpdateEndpointModel
from models.update_endpoint_validation_model import UpdateEndpointValidationModel
-from utils.database import endpoint_collection, api_collection, endpoint_validation_collection
-from utils.cache_manager_util import cache_manager
+from utils.database import api_collection, endpoint_collection, endpoint_validation_collection
from utils.doorman_cache_util import doorman_cache
-from models.create_endpoint_model import CreateEndpointModel
logger = logging.getLogger('doorman.gateway')
-class EndpointService:
+class EndpointService:
@staticmethod
async def create_endpoint(data: CreateEndpointModel, request_id):
"""
Create an endpoint for an API.
"""
- logger.info(request_id + ' | Creating endpoint: ' + data.api_name + ' ' + data.api_version + ' ' + data.endpoint_uri)
- cache_key = f'/{data.endpoint_method}/{data.api_name}/{data.api_version}/{data.endpoint_uri}'.replace('//', '/')
- if doorman_cache.get_cache('endpoint_cache', cache_key) or endpoint_collection.find_one({
- 'endpoint_method': data.endpoint_method,
- 'api_name': data.api_name,
- 'api_version': data.api_version,
- 'endpoint_uri': data.endpoint_uri
- }):
+ logger.info(
+ request_id
+ + ' | Creating endpoint: '
+ + data.api_name
+ + ' '
+ + data.api_version
+ + ' '
+ + data.endpoint_uri
+ )
+ cache_key = f'/{data.endpoint_method}/{data.api_name}/{data.api_version}/{data.endpoint_uri}'.replace(
+ '//', '/'
+ )
+ if doorman_cache.get_cache('endpoint_cache', cache_key) or endpoint_collection.find_one(
+ {
+ 'endpoint_method': data.endpoint_method,
+ 'api_name': data.api_name,
+ 'api_version': data.api_version,
+ 'endpoint_uri': data.endpoint_uri,
+ }
+ ):
logger.error(request_id + ' | Endpoint creation failed with code END001')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='END001',
- error_message='Endpoint already exists for the requested API name, version and URI'
+ error_message='Endpoint already exists for the requested API name, version and URI',
).dict()
# Resolve API ID from cache using canonical key "/{name}/{version}"
- data.api_id = doorman_cache.get_cache('api_id_cache', f'/{data.api_name}/{data.api_version}')
+ data.api_id = doorman_cache.get_cache(
+ 'api_id_cache', f'/{data.api_name}/{data.api_version}'
+ )
if not data.api_id:
- api = api_collection.find_one({'api_name': data.api_name, 'api_version': data.api_version})
+ api = api_collection.find_one(
+ {'api_name': data.api_name, 'api_version': data.api_version}
+ )
if not api:
logger.error(request_id + ' | Endpoint creation failed with code END002')
return ResponseModel(
status_code=400,
error_code='END002',
- error_message='API does not exist for the requested name and version'
+ error_message='API does not exist for the requested name and version',
).dict()
data.api_id = api.get('api_id')
# Ensure cache uses the same canonical key with leading slash
- doorman_cache.set_cache('api_id_cache', f'/{data.api_name}/{data.api_version}', data.api_id)
+ doorman_cache.set_cache(
+ 'api_id_cache', f'/{data.api_name}/{data.api_version}', data.api_id
+ )
data.endpoint_id = str(uuid.uuid4())
endpoint_dict = data.dict()
insert_result = endpoint_collection.insert_one(endpoint_dict)
@@ -66,21 +80,25 @@ class EndpointService:
logger.error(request_id + ' | Endpoint creation failed with code END003')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='END003',
- error_message='Unable to insert endpoint'
+ error_message='Unable to insert endpoint',
).dict()
endpoint_dict['_id'] = str(insert_result.inserted_id)
doorman_cache.set_cache('endpoint_cache', cache_key, endpoint_dict)
api_endpoints = doorman_cache.get_cache('api_endpoint_cache', data.api_id) or list()
- api_endpoints.append(endpoint_dict.get('endpoint_method') + endpoint_dict.get('endpoint_uri'))
+ api_endpoints.append(
+ endpoint_dict.get('endpoint_method') + endpoint_dict.get('endpoint_uri')
+ )
doorman_cache.set_cache('api_endpoint_cache', data.api_id, api_endpoints)
logger.info(request_id + ' | Endpoint creation successful')
try:
- if data.endpoint_method.upper() == 'POST' and str(data.endpoint_uri).strip().lower() == '/grpc':
+ if (
+ data.endpoint_method.upper() == 'POST'
+ and str(data.endpoint_uri).strip().lower() == '/grpc'
+ ):
from grpc_tools import protoc as _protoc
+
api_name = data.api_name
api_version = data.api_version
module_base = f'{api_name}_{api_version}'.replace('-', '_')
@@ -118,11 +136,19 @@ class EndpointService:
'message DeleteReply { bool ok = 1; }\n'
)
proto_path.write_text(proto_content, encoding='utf-8')
- code = _protoc.main([
- 'protoc', f'--proto_path={str(proto_dir)}', f'--python_out={str(generated_dir)}', f'--grpc_python_out={str(generated_dir)}', str(proto_path)
- ])
+ code = _protoc.main(
+ [
+ 'protoc',
+ f'--proto_path={str(proto_dir)}',
+ f'--python_out={str(generated_dir)}',
+ f'--grpc_python_out={str(generated_dir)}',
+ str(proto_path),
+ ]
+ )
if code != 0:
- logger.warning(f'{request_id} | Pre-gen gRPC stubs returned {code} for {module_base}')
+ logger.warning(
+ f'{request_id} | Pre-gen gRPC stubs returned {code} for {module_base}'
+ )
try:
init_path = generated_dir / '__init__.py'
if not init_path.exists():
@@ -133,76 +159,81 @@ class EndpointService:
logger.debug(f'{request_id} | Skipping pre-gen gRPC stubs: {_e}')
return ResponseModel(
status_code=201,
- response_headers={
- 'request_id': request_id
- },
- message='Endpoint created successfully'
+ response_headers={'request_id': request_id},
+ message='Endpoint created successfully',
).dict()
@staticmethod
- async def update_endpoint(endpoint_method, api_name, api_version, endpoint_uri, data: UpdateEndpointModel, request_id):
- logger.info(request_id + ' | Updating endpoint: ' + api_name + ' ' + api_version + ' ' + endpoint_uri)
+ async def update_endpoint(
+ endpoint_method, api_name, api_version, endpoint_uri, data: UpdateEndpointModel, request_id
+ ):
+ logger.info(
+ request_id
+ + ' | Updating endpoint: '
+ + api_name
+ + ' '
+ + api_version
+ + ' '
+ + endpoint_uri
+ )
cache_key = f'/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}'.replace('//', '/')
endpoint = doorman_cache.get_cache('endpoint_cache', cache_key)
if not endpoint:
- endpoint = endpoint_collection.find_one({
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_uri': endpoint_uri,
- 'endpoint_method': endpoint_method
- })
+ endpoint = endpoint_collection.find_one(
+ {
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_uri': endpoint_uri,
+ 'endpoint_method': endpoint_method,
+ }
+ )
logger.error(request_id + ' | Endpoint update failed with code END008')
if not endpoint:
return ResponseModel(
status_code=400,
error_code='END008',
- error_message='Endpoint does not exist for the requested API name, version and URI'
+ error_message='Endpoint does not exist for the requested API name, version and URI',
).dict()
else:
doorman_cache.delete_cache('endpoint_cache', cache_key)
- if (data.endpoint_method and data.endpoint_method != endpoint.get('endpoint_method')) or (data.api_name and data.api_name != endpoint.get('api_name')) or (data.api_version and data.api_version != endpoint.get('api_version')) or (data.endpoint_uri and data.endpoint_uri != endpoint.get('endpoint_uri')):
+ if (
+ (data.endpoint_method and data.endpoint_method != endpoint.get('endpoint_method'))
+ or (data.api_name and data.api_name != endpoint.get('api_name'))
+ or (data.api_version and data.api_version != endpoint.get('api_version'))
+ or (data.endpoint_uri and data.endpoint_uri != endpoint.get('endpoint_uri'))
+ ):
logger.error(request_id + ' | Endpoint update failed with code END006')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='END006',
- error_message='API method, name, version and URI cannot be updated'
+ error_message='API method, name, version and URI cannot be updated',
).dict()
not_null_data = {k: v for k, v in data.dict().items() if v is not None}
if not_null_data:
- update_result = endpoint_collection.update_one({
+ update_result = endpoint_collection.update_one(
+ {
'api_name': api_name,
'api_version': api_version,
'endpoint_uri': endpoint_uri,
- 'endpoint_method': endpoint_method
+ 'endpoint_method': endpoint_method,
},
- {
- '$set': not_null_data
- }
+ {'$set': not_null_data},
)
if not update_result.acknowledged or update_result.modified_count == 0:
logger.error(request_id + ' | Endpoint update failed with code END003')
return ResponseModel(
- status_code=400,
- error_code='END003',
- error_message='Unable to update endpoint'
+ status_code=400, error_code='END003', error_message='Unable to update endpoint'
).dict()
logger.info(request_id + ' | Endpoint update successful')
- return ResponseModel(
- status_code=200,
- message='Endpoint updated successfully'
- ).dict()
+ return ResponseModel(status_code=200, message='Endpoint updated successfully').dict()
else:
logger.error(request_id + ' | Endpoint update failed with code END007')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='END007',
- error_message='No data to update'
+ error_message='No data to update',
).dict()
@staticmethod
@@ -210,38 +241,42 @@ class EndpointService:
"""
Delete an endpoint for an API.
"""
- logger.info(request_id + ' | Deleting: ' + api_name + ' ' + api_version + ' ' + endpoint_uri)
+ logger.info(
+ request_id + ' | Deleting: ' + api_name + ' ' + api_version + ' ' + endpoint_uri
+ )
cache_key = f'/{endpoint_method}/{api_name}/{api_version}/{endpoint_uri}'.replace('//', '/')
endpoint = doorman_cache.get_cache('endpoint_cache', cache_key)
if not endpoint:
- endpoint = endpoint_collection.find_one({
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_uri': endpoint_uri,
- 'endpoint_method': endpoint_method
- })
+ endpoint = endpoint_collection.find_one(
+ {
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_uri': endpoint_uri,
+ 'endpoint_method': endpoint_method,
+ }
+ )
if not endpoint:
logger.error(request_id + ' | Endpoint deletion failed with code END004')
return ResponseModel(
status_code=400,
error_code='END004',
- error_message='Endpoint does not exist for the requested API name, version and URI'
+ error_message='Endpoint does not exist for the requested API name, version and URI',
).dict()
- delete_result = endpoint_collection.delete_one({
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_uri': endpoint_uri,
- 'endpoint_method': endpoint_method
- })
+ delete_result = endpoint_collection.delete_one(
+ {
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_uri': endpoint_uri,
+ 'endpoint_method': endpoint_method,
+ }
+ )
if not delete_result.acknowledged:
logger.error(request_id + ' | Endpoint deletion failed with code END009')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='END009',
- error_message='Unable to delete endpoint'
+ error_message='Unable to delete endpoint',
).dict()
doorman_cache.delete_cache('endpoint_cache', cache_key)
try:
@@ -253,10 +288,8 @@ class EndpointService:
logger.info(request_id + ' | Endpoint deletion successful')
return ResponseModel(
status_code=200,
- response_headers={
- 'request_id': request_id
- },
- message='Endpoint deleted successfully'
+ response_headers={'request_id': request_id},
+ message='Endpoint deleted successfully',
).dict()
@staticmethod
@@ -265,30 +298,34 @@ class EndpointService:
Get an endpoint by API name, version and URI.
"""
logger.info(request_id + ' | Getting: ' + api_name + ' ' + api_version + ' ' + endpoint_uri)
- endpoint = doorman_cache.get_cache('endpoint_cache', f'{api_name}/{api_version}/{endpoint_uri}')
+ endpoint = doorman_cache.get_cache(
+ 'endpoint_cache', f'{api_name}/{api_version}/{endpoint_uri}'
+ )
if not endpoint:
- endpoint = endpoint_collection.find_one({
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_uri': endpoint_uri,
- 'endpoint_method': endpoint_method
- })
+ endpoint = endpoint_collection.find_one(
+ {
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_uri': endpoint_uri,
+ 'endpoint_method': endpoint_method,
+ }
+ )
if not endpoint:
logger.error(request_id + ' | Endpoint retrieval failed with code END004')
return ResponseModel(
status_code=400,
error_code='END004',
- error_message='Endpoint does not exist for the requested API name, version and URI'
+ error_message='Endpoint does not exist for the requested API name, version and URI',
).dict()
- if endpoint.get('_id'): del endpoint['_id']
- doorman_cache.set_cache('endpoint_cache', f'{api_name}/{api_version}/{endpoint_uri}', endpoint)
+ if endpoint.get('_id'):
+ del endpoint['_id']
+ doorman_cache.set_cache(
+ 'endpoint_cache', f'{api_name}/{api_version}/{endpoint_uri}', endpoint
+ )
if '_id' in endpoint:
del endpoint['_id']
logger.info(request_id + ' | Endpoint retrieval successful')
- return ResponseModel(
- status_code=200,
- response=endpoint
- ).dict()
+ return ResponseModel(status_code=200, response=endpoint).dict()
@staticmethod
async def get_endpoints_by_name_version(api_name, api_version, request_id):
@@ -296,32 +333,25 @@ class EndpointService:
Get all endpoints by API name and version.
"""
logger.info(request_id + ' | Getting: ' + api_name + ' ' + api_version)
- cursor = endpoint_collection.find({
- 'api_name': api_name,
- 'api_version': api_version
- })
+ cursor = endpoint_collection.find({'api_name': api_name, 'api_version': api_version})
try:
endpoints = list(cursor)
except Exception:
endpoints = await cursor.to_list(length=None)
for endpoint in endpoints:
- if '_id' in endpoint: del endpoint['_id']
+ if '_id' in endpoint:
+ del endpoint['_id']
if not endpoints:
logger.error(request_id + ' | Endpoint retrieval failed with code END005')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='END005',
- error_message='No endpoints found for the requested API name and version'
+ error_message='No endpoints found for the requested API name and version',
).dict()
logger.info(request_id + ' | Endpoint retrieval successful')
- return ResponseModel(
- status_code=200,
- response={'endpoints': endpoints}
- ).dict()
+ return ResponseModel(status_code=200, response={'endpoints': endpoints}).dict()
@staticmethod
async def create_endpoint_validation(data: CreateEndpointValidationModel, request_id):
@@ -332,32 +362,24 @@ class EndpointService:
if not data.endpoint_id:
logger.error(request_id + ' | Endpoint ID is required')
return ResponseModel(
- status_code=400,
- error_code='END013',
- error_message='Endpoint ID is required'
+ status_code=400, error_code='END013', error_message='Endpoint ID is required'
).dict()
if not data.validation_schema:
logger.error(request_id + ' | Validation schema is required')
return ResponseModel(
- status_code=400,
- error_code='END014',
- error_message='Validation schema is required'
+ status_code=400, error_code='END014', error_message='Validation schema is required'
).dict()
if doorman_cache.get_cache('endpoint_validation_cache', data.endpoint_id):
logger.error(request_id + ' | Endpoint validation already exists')
return ResponseModel(
status_code=400,
error_code='END017',
- error_message='Endpoint validation already exists'
+ error_message='Endpoint validation already exists',
).dict()
- if not endpoint_collection.find_one({
- 'endpoint_id': data.endpoint_id
- }):
+ if not endpoint_collection.find_one({'endpoint_id': data.endpoint_id}):
logger.error(request_id + ' | Endpoint does not exist')
return ResponseModel(
- status_code=400,
- error_code='END015',
- error_message='Endpoint does not exist'
+ status_code=400, error_code='END015', error_message='Endpoint does not exist'
).dict()
validation_dict = data.dict()
insert_result = endpoint_validation_collection.insert_one(validation_dict)
@@ -366,13 +388,12 @@ class EndpointService:
return ResponseModel(
status_code=400,
error_code='END016',
- error_message='Unable to create endpoint validation'
+ error_message='Unable to create endpoint validation',
).dict()
logger.info(request_id + ' | Endpoint validation created successfully')
doorman_cache.set_cache('endpoint_validation_cache', f'{data.endpoint_id}', validation_dict)
return ResponseModel(
- status_code=201,
- message='Endpoint validation created successfully'
+ status_code=201, message='Endpoint validation created successfully'
).dict()
@staticmethod
@@ -383,21 +404,18 @@ class EndpointService:
logger.info(request_id + ' | Getting endpoint validation: ' + endpoint_id)
validation = doorman_cache.get_cache('endpoint_validation_cache', endpoint_id)
if not validation:
- validation = endpoint_validation_collection.find_one({
- 'endpoint_id': endpoint_id
- })
+ validation = endpoint_validation_collection.find_one({'endpoint_id': endpoint_id})
if not validation:
- logger.error(request_id + ' | Endpoint validation retrieval failed with code END018')
+ logger.error(
+ request_id + ' | Endpoint validation retrieval failed with code END018'
+ )
return ResponseModel(
status_code=400,
error_code='END018',
- error_message='Endpoint validation does not exist'
+ error_message='Endpoint validation does not exist',
).dict()
logger.info(request_id + ' | Endpoint validation retrieval successful')
- return ResponseModel(
- status_code=200,
- response=validation
- ).dict()
+ return ResponseModel(status_code=200, response=validation).dict()
@staticmethod
async def delete_endpoint_validation(endpoint_id, request_id):
@@ -405,24 +423,23 @@ class EndpointService:
Delete an endpoint validation by endpoint ID.
"""
logger.info(request_id + ' | Deleting endpoint validation: ' + endpoint_id)
- delete_result = endpoint_validation_collection.delete_one({
- 'endpoint_id': endpoint_id
- })
+ delete_result = endpoint_validation_collection.delete_one({'endpoint_id': endpoint_id})
if not delete_result.acknowledged:
logger.error(request_id + ' | Endpoint validation deletion failed with code END019')
return ResponseModel(
status_code=400,
error_code='END019',
- error_message='Unable to delete endpoint validation'
+ error_message='Unable to delete endpoint validation',
).dict()
logger.info(request_id + ' | Endpoint validation deletion successful')
return ResponseModel(
- status_code=200,
- message='Endpoint validation deleted successfully'
+ status_code=200, message='Endpoint validation deleted successfully'
).dict()
@staticmethod
- async def update_endpoint_validation(endpoint_id, data: UpdateEndpointValidationModel, request_id):
+ async def update_endpoint_validation(
+ endpoint_id, data: UpdateEndpointValidationModel, request_id
+ ):
"""
Update an endpoint validation by endpoint ID.
"""
@@ -430,39 +447,32 @@ class EndpointService:
if not data.validation_enabled:
logger.error(request_id + ' | Validation enabled is required')
return ResponseModel(
- status_code=400,
- error_code='END020',
- error_message='Validation enabled is required'
+ status_code=400, error_code='END020', error_message='Validation enabled is required'
).dict()
if not data.validation_schema:
logger.error(request_id + ' | Validation schema is required')
return ResponseModel(
- status_code=400,
- error_code='END021',
- error_message='Validation schema is required'
+ status_code=400, error_code='END021', error_message='Validation schema is required'
).dict()
- if not endpoint_collection.find_one({
- 'endpoint_id': endpoint_id
- }):
+ if not endpoint_collection.find_one({'endpoint_id': endpoint_id}):
logger.error(request_id + ' | Endpoint does not exist')
return ResponseModel(
- status_code=400,
- error_code='END022',
- error_message='Endpoint does not exist'
+ status_code=400, error_code='END022', error_message='Endpoint does not exist'
).dict()
- update_result = endpoint_validation_collection.update_one({
- 'endpoint_id': endpoint_id
- }, {
- '$set': {
- 'validation_enabled': data.validation_enabled,
- 'validation_schema': data.validation_schema
- }
- })
+ update_result = endpoint_validation_collection.update_one(
+ {'endpoint_id': endpoint_id},
+ {
+ '$set': {
+ 'validation_enabled': data.validation_enabled,
+ 'validation_schema': data.validation_schema,
+ }
+ },
+ )
if not update_result.acknowledged:
logger.error(request_id + ' | Endpoint validation update failed with code END023')
return ResponseModel(
status_code=400,
error_code='END023',
- error_message='Unable to update endpoint validation'
+ error_message='Unable to update endpoint validation',
).dict()
logger.info(request_id + ' | Endpoint validation updated successfully')
diff --git a/backend-services/services/gateway_service.py b/backend-services/services/gateway_service.py
index 3fdad6f..b4687d6 100644
--- a/backend-services/services/gateway_service.py
+++ b/backend-services/services/gateway_service.py
@@ -4,56 +4,59 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
+import asyncio
+import importlib
+import json
+import logging
import os
import random
-import json
-import sys
-import xml.etree.ElementTree as ET
-import logging
import re
-import time
-import httpx
-from typing import Dict
-import grpc
-import asyncio
-from google.protobuf.json_format import MessageToDict
-import importlib
import string
+import sys
+import time
+import xml.etree.ElementTree as ET
from pathlib import Path
+import grpc
+import httpx
+from google.protobuf.json_format import MessageToDict
+
try:
from gql import Client as _GqlClient
+
def gql(q):
return q
except Exception:
+
class _GqlClient:
def __init__(self, *args, **kwargs):
pass
+
def gql(q):
return q
+
Client = _GqlClient
from models.response_model import ResponseModel
-from utils import api_util, routing_util
-from utils import credit_util
-from utils.gateway_utils import get_headers
+from utils import api_util, credit_util, routing_util
from utils.doorman_cache_util import doorman_cache
+from utils.gateway_utils import get_headers
+from utils.http_client import CircuitOpenError, request_with_resilience
from utils.validation_util import validation_util
-from utils.http_client import request_with_resilience, CircuitOpenError
logging.getLogger('gql').setLevel(logging.WARNING)
logger = logging.getLogger('doorman.gateway')
-class GatewayService:
+class GatewayService:
timeout = httpx.Timeout(
- connect=float(os.getenv('HTTP_CONNECT_TIMEOUT', 5.0)),
- read=float(os.getenv('HTTP_READ_TIMEOUT', 30.0)),
- write=float(os.getenv('HTTP_WRITE_TIMEOUT', 30.0)),
- pool=float(os.getenv('HTTP_TIMEOUT', 30.0))
- )
+ connect=float(os.getenv('HTTP_CONNECT_TIMEOUT', 5.0)),
+ read=float(os.getenv('HTTP_READ_TIMEOUT', 30.0)),
+ write=float(os.getenv('HTTP_WRITE_TIMEOUT', 30.0)),
+ pool=float(os.getenv('HTTP_TIMEOUT', 30.0)),
+ )
_http_client: httpx.AsyncClient | None = None
@staticmethod
@@ -77,7 +80,9 @@ class GatewayService:
expiry = float(os.getenv('HTTP_KEEPALIVE_EXPIRY', 30.0))
except Exception:
expiry = 30.0
- return httpx.Limits(max_connections=max_conns, max_keepalive_connections=max_keep, keepalive_expiry=expiry)
+ return httpx.Limits(
+ max_connections=max_conns, max_keepalive_connections=max_keep, keepalive_expiry=expiry
+ )
@classmethod
def get_http_client(cls) -> httpx.AsyncClient:
@@ -91,13 +96,13 @@ class GatewayService:
cls._http_client = httpx.AsyncClient(
timeout=cls.timeout,
limits=cls._build_limits(),
- http2=(os.getenv('HTTP_ENABLE_HTTP2', 'false').lower() == 'true')
+ http2=(os.getenv('HTTP_ENABLE_HTTP2', 'false').lower() == 'true'),
)
return cls._http_client
return httpx.AsyncClient(
timeout=cls.timeout,
limits=cls._build_limits(),
- http2=(os.getenv('HTTP_ENABLE_HTTP2', 'false').lower() == 'true')
+ http2=(os.getenv('HTTP_ENABLE_HTTP2', 'false').lower() == 'true'),
)
@classmethod
@@ -111,16 +116,18 @@ class GatewayService:
cls._http_client = None
def error_response(request_id, code, message, status=404):
- logger.error(f'{request_id} | REST gateway failed with code {code}')
- return ResponseModel(
- status_code=status,
- response_headers={'request_id': request_id},
- error_code=code,
- error_message=message
- ).dict()
+ logger.error(f'{request_id} | REST gateway failed with code {code}')
+ return ResponseModel(
+ status_code=status,
+ response_headers={'request_id': request_id},
+ error_code=code,
+ error_message=message,
+ ).dict()
@staticmethod
- def _compute_api_cors_headers(api: dict, origin: str | None, req_method: str | None, req_headers: str | None):
+ def _compute_api_cors_headers(
+ api: dict, origin: str | None, req_method: str | None, req_headers: str | None
+ ):
try:
origin = (origin or '').strip()
req_method = (req_method or '').strip().upper()
@@ -128,13 +135,21 @@ class GatewayService:
_ao = api.get('api_cors_allow_origins', None)
# None => default '*', empty list => disallow all
- allow_origins = (['*'] if _ao is None else list(_ao))
+ allow_origins = ['*'] if _ao is None else list(_ao)
_am = api.get('api_cors_allow_methods', None)
- allow_methods = [m.strip().upper() for m in (_am if _am is not None else ['GET','POST','PUT','DELETE','PATCH','HEAD','OPTIONS']) if m]
+ allow_methods = [
+ m.strip().upper()
+ for m in (
+ _am
+ if _am is not None
+ else ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS']
+ )
+ if m
+ ]
if 'OPTIONS' not in allow_methods:
allow_methods.append('OPTIONS')
_ah = api.get('api_cors_allow_headers', None)
- allow_headers = (_ah if _ah is not None else ['*'])
+ allow_headers = _ah if _ah is not None else ['*']
allow_credentials = bool(api.get('api_cors_allow_credentials'))
expose_headers = api.get('api_cors_expose_headers') or []
@@ -158,7 +173,10 @@ class GatewayService:
o_scheme, o_host = 'https', origin
if o_scheme != scheme:
continue
- if o_host.endswith(host_suffix) and o_host.count('.') >= host_suffix.count('.') + 1:
+ if (
+ o_host.endswith(host_suffix)
+ and o_host.count('.') >= host_suffix.count('.') + 1
+ ):
origin_allowed = True
break
except Exception:
@@ -205,7 +223,10 @@ class GatewayService:
ctype = ctype_raw.split(';', 1)[0].strip().lower()
body = getattr(response, 'content', b'')
- if ctype in ('application/json', 'application/graphql+json') or 'application/graphql' in ctype:
+ if (
+ ctype in ('application/json', 'application/graphql+json')
+ or 'application/graphql' in ctype
+ ):
return json.loads(body)
if ctype in ('application/xml', 'text/xml'):
@@ -278,10 +299,7 @@ class GatewayService:
if not await credit_util.deduct_credit(api.get('api_credit_group'), username):
logger.warning(f'{request_id} | Credit deduction failed for user {username}')
return GatewayService.error_response(
- request_id,
- 'GTW008',
- 'User does not have any credits',
- status=401
+ request_id, 'GTW008', 'User does not have any credits', status=401
)
return None
@@ -315,8 +333,7 @@ class GatewayService:
if username and not bool(api.get('api_public')):
user_specific_api_key = await credit_util.get_user_api_key(
- api.get('api_credit_group'),
- username
+ api.get('api_credit_group'), username
)
if user_specific_api_key:
headers[header_name] = user_specific_api_key
@@ -349,8 +366,7 @@ class GatewayService:
continue
sanitized_key = ''.join(
- c if c.isalnum() or c in ('-', '_', '.') else '-'
- for c in key
+ c if c.isalnum() or c in ('-', '_', '.') else '-' for c in key
)
if not sanitized_key:
@@ -369,7 +385,7 @@ class GatewayService:
return metadata_list
- _IDENT_ALLOWED = set(string.ascii_letters + string.digits + "_")
+ _IDENT_ALLOWED = set(string.ascii_letters + string.digits + '_')
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
@staticmethod
@@ -389,7 +405,7 @@ class GatewayService:
name = name.strip()
if not name or len(name) > max_len:
return False
- if name[0] not in string.ascii_letters + "_":
+ if name[0] not in string.ascii_letters + '_':
return False
for ch in name:
if ch not in GatewayService._IDENT_ALLOWED:
@@ -410,9 +426,9 @@ class GatewayService:
if not pkg:
return None
pkg = str(pkg).strip()
- if "/" in pkg or "\\" in pkg or ".." in pkg:
+ if '/' in pkg or '\\' in pkg or '..' in pkg:
return None
- parts = pkg.split(".") if "." in pkg else [pkg]
+ parts = pkg.split('.') if '.' in pkg else [pkg]
if any(not GatewayService._is_valid_identifier(p) for p in parts if p is not None):
return None
return pkg
@@ -424,19 +440,24 @@ class GatewayService:
return None
try:
method_fq = str(method_fq).strip()
- if "." not in method_fq:
+ if '.' not in method_fq:
return None
- service, method = method_fq.split(".", 1)
+ service, method = method_fq.split('.', 1)
service = service.strip()
method = method.strip()
- if not (GatewayService._is_valid_identifier(service) and GatewayService._is_valid_identifier(method)):
+ if not (
+ GatewayService._is_valid_identifier(service)
+ and GatewayService._is_valid_identifier(method)
+ ):
return None
return service, method
except Exception:
return None
@staticmethod
- async def rest_gateway(username, request, request_id, start_time, path, url=None, method=None, retry=0):
+ async def rest_gateway(
+ username, request, request_id, start_time, path, url=None, method=None, retry=0
+ ):
"""
External gateway.
"""
@@ -444,7 +465,6 @@ class GatewayService:
current_time = backend_end_time = None
try:
if not url and not method:
-
parts = [p for p in (path or '').split('/') if p]
api_name_version = ''
endpoint_uri = ''
@@ -454,22 +474,38 @@ class GatewayService:
api_key = doorman_cache.get_cache('api_id_cache', api_name_version)
api = await api_util.get_api(api_key, api_name_version)
if not api:
- return GatewayService.error_response(request_id, 'GTW001', 'API does not exist for the requested name and version')
+ return GatewayService.error_response(
+ request_id,
+ 'GTW001',
+ 'API does not exist for the requested name and version',
+ )
if api.get('active') is False:
- return GatewayService.error_response(request_id, 'GTW012', 'API is disabled', status=403)
+ return GatewayService.error_response(
+ request_id, 'GTW012', 'API is disabled', status=403
+ )
endpoints = await api_util.get_api_endpoints(api.get('api_id'))
if not endpoints:
- return GatewayService.error_response(request_id, 'GTW002', 'No endpoints found for the requested API')
+ return GatewayService.error_response(
+ request_id, 'GTW002', 'No endpoints found for the requested API'
+ )
regex_pattern = re.compile(r'\{[^/]+\}')
match_method = 'GET' if str(request.method).upper() == 'HEAD' else request.method
composite = match_method + '/' + endpoint_uri
- if not any(re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in endpoints):
+ if not any(
+ re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in endpoints
+ ):
logger.error(f'{endpoints} | REST gateway failed with code GTW003')
- return GatewayService.error_response(request_id, 'GTW003', 'Endpoint does not exist for the requested API')
+ return GatewayService.error_response(
+ request_id, 'GTW003', 'Endpoint does not exist for the requested API'
+ )
client_key = request.headers.get('client-key')
- server = await routing_util.pick_upstream_server(api, request.method, endpoint_uri, client_key)
+ server = await routing_util.pick_upstream_server(
+ api, request.method, endpoint_uri, client_key
+ )
if not server:
- return GatewayService.error_response(request_id, 'GTW001', 'No upstream servers configured')
+ return GatewayService.error_response(
+ request_id, 'GTW001', 'No upstream servers configured'
+ )
logger.info(f'{request_id} | REST gateway to: {server}')
url = server.rstrip('/') + '/' + endpoint_uri.lstrip('/')
method = request.method.upper()
@@ -477,7 +513,9 @@ class GatewayService:
if api.get('api_credits_enabled') and username and not bool(api.get('api_public')):
if not await credit_util.deduct_credit(api.get('api_credit_group'), username):
- return GatewayService.error_response(request_id, 'GTW008', 'User does not have any credits', status=401)
+ return GatewayService.error_response(
+ request_id, 'GTW008', 'User does not have any credits', status=401
+ )
else:
try:
parts = [p for p in (path or '').split('/') if p]
@@ -501,12 +539,16 @@ class GatewayService:
headers['X-User-Email'] = str(username)
headers['X-Doorman-User'] = str(username)
if api and api.get('api_credits_enabled'):
- ai_token_headers = await credit_util.get_credit_api_header(api.get('api_credit_group'))
+ ai_token_headers = await credit_util.get_credit_api_header(
+ api.get('api_credit_group')
+ )
if ai_token_headers:
headers[ai_token_headers[0]] = ai_token_headers[1]
if username and not bool(api.get('api_public')):
- user_specific_api_key = await credit_util.get_user_api_key(api.get('api_credit_group'), username)
+ user_specific_api_key = await credit_util.get_user_api_key(
+ api.get('api_credit_group'), username
+ )
if user_specific_api_key:
headers[ai_token_headers[0]] = user_specific_api_key
content_type = request.headers.get('Content-Type', '').upper()
@@ -516,11 +558,17 @@ class GatewayService:
swap_from = api.get('api_authorization_field_swap')
source_val = None
if swap_from:
- for key_variant in (swap_from, str(swap_from).lower(), str(swap_from).title()):
+ for key_variant in (
+ swap_from,
+ str(swap_from).lower(),
+ str(swap_from).title(),
+ ):
if key_variant in headers:
source_val = headers.get(key_variant)
break
- orig_auth = request.headers.get('Authorization') or request.headers.get('authorization')
+ orig_auth = request.headers.get('Authorization') or request.headers.get(
+ 'authorization'
+ )
if source_val is not None and str(source_val).strip() != '':
headers['Authorization'] = source_val
elif orig_auth is not None and str(orig_auth).strip() != '':
@@ -530,7 +578,11 @@ class GatewayService:
try:
lookup_method = 'GET' if str(method).upper() == 'HEAD' else method
- endpoint_doc = await api_util.get_endpoint(api, lookup_method, '/' + endpoint_uri.lstrip('/')) if api else None
+ endpoint_doc = (
+ await api_util.get_endpoint(api, lookup_method, '/' + endpoint_uri.lstrip('/'))
+ if api
+ else None
+ )
endpoint_id = endpoint_doc.get('endpoint_id') if endpoint_doc else None
if endpoint_id:
if 'JSON' in content_type:
@@ -546,24 +598,36 @@ class GatewayService:
try:
if method == 'GET':
http_response = await request_with_resilience(
- client, 'GET', url,
+ client,
+ 'GET',
+ url,
api_key=api.get('api_path') if api else (api_name_version or '/api/rest'),
- headers=headers, params=query_params,
+ headers=headers,
+ params=query_params,
retries=retry,
api_config=api,
)
elif method == 'HEAD':
http_response = await request_with_resilience(
- client, 'HEAD', url,
+ client,
+ 'HEAD',
+ url,
api_key=api.get('api_path') if api else (api_name_version or '/api/rest'),
- headers=headers, params=query_params,
+ headers=headers,
+ params=query_params,
retries=retry,
api_config=api,
)
elif method in ('POST', 'PUT', 'DELETE', 'PATCH'):
- cl_header = request.headers.get('content-length') or request.headers.get('Content-Length')
+ cl_header = request.headers.get('content-length') or request.headers.get(
+ 'Content-Length'
+ )
try:
- content_length = int(cl_header) if cl_header is not None and str(cl_header).strip() != '' else 0
+ content_length = (
+ int(cl_header)
+ if cl_header is not None and str(cl_header).strip() != ''
+ else 0
+ )
except Exception:
content_length = 0
@@ -571,31 +635,50 @@ class GatewayService:
if 'JSON' in content_type:
body = await request.json()
http_response = await request_with_resilience(
- client, method, url,
- api_key=api.get('api_path') if api else (api_name_version or '/api/rest'),
- headers=headers, params=query_params, json=body,
+ client,
+ method,
+ url,
+ api_key=api.get('api_path')
+ if api
+ else (api_name_version or '/api/rest'),
+ headers=headers,
+ params=query_params,
+ json=body,
retries=retry,
api_config=api,
)
else:
body = await request.body()
http_response = await request_with_resilience(
- client, method, url,
- api_key=api.get('api_path') if api else (api_name_version or '/api/rest'),
- headers=headers, params=query_params, content=body,
+ client,
+ method,
+ url,
+ api_key=api.get('api_path')
+ if api
+ else (api_name_version or '/api/rest'),
+ headers=headers,
+ params=query_params,
+ content=body,
retries=retry,
api_config=api,
)
else:
http_response = await request_with_resilience(
- client, method, url,
- api_key=api.get('api_path') if api else (api_name_version or '/api/rest'),
- headers=headers, params=query_params,
+ client,
+ method,
+ url,
+ api_key=api.get('api_path')
+ if api
+ else (api_name_version or '/api/rest'),
+ headers=headers,
+ params=query_params,
retries=retry,
api_config=api,
)
else:
- return GatewayService.error_response(request_id, 'GTW004', 'Method not supported', status=405)
+ return GatewayService.error_response(
+ request_id, 'GTW004', 'Method not supported', status=405
+ )
finally:
if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'true').lower() == 'false':
try:
@@ -615,13 +698,15 @@ class GatewayService:
status_code=500,
response_headers={'request_id': request_id},
error_code='GTW006',
- error_message='Malformed JSON from upstream'
+ error_message='Malformed JSON from upstream',
).dict()
else:
response_content = http_response.text
backend_end_time = time.time() * 1000
if http_response.status_code == 404:
- return GatewayService.error_response(request_id, 'GTW005', 'Endpoint does not exist in backend service')
+ return GatewayService.error_response(
+ request_id, 'GTW005', 'Endpoint does not exist in backend service'
+ )
logger.info(f'{request_id} | REST gateway status code: {http_response.status_code}')
response_headers = {'request_id': request_id}
allowed_lower = {h.lower() for h in (allowed_headers or [])}
@@ -645,29 +730,33 @@ class GatewayService:
return ResponseModel(
status_code=http_response.status_code,
response_headers=response_headers,
- response=response_content
+ response=response_content,
).dict()
except CircuitOpenError:
return ResponseModel(
status_code=503,
response_headers={'request_id': request_id},
error_code='GTW999',
- error_message='Upstream circuit open'
+ error_message='Upstream circuit open',
).dict()
except httpx.TimeoutException:
try:
- metrics_store.record_upstream_timeout('rest:' + (api.get('api_path') if api else (api_name_version or '/api/rest')))
+ metrics_store.record_upstream_timeout(
+ 'rest:' + (api.get('api_path') if api else (api_name_version or '/api/rest'))
+ )
except Exception:
pass
return ResponseModel(
status_code=504,
response_headers={'request_id': request_id},
error_code='GTW010',
- error_message='Gateway timeout'
+ error_message='Gateway timeout',
).dict()
except Exception:
logger.error(f'{request_id} | REST gateway failed with code GTW006', exc_info=True)
- return GatewayService.error_response(request_id, 'GTW006', 'Internal server error', status=500)
+ return GatewayService.error_response(
+ request_id, 'GTW006', 'Internal server error', status=500
+ )
finally:
if current_time:
logger.info(f'{request_id} | Gateway time {current_time - start_time}ms')
@@ -683,7 +772,6 @@ class GatewayService:
current_time = backend_end_time = None
try:
if not url:
-
parts = [p for p in (path or '').split('/') if p]
api_name_version = ''
endpoint_uri = ''
@@ -693,27 +781,45 @@ class GatewayService:
api_key = doorman_cache.get_cache('api_id_cache', api_name_version)
api = await api_util.get_api(api_key, api_name_version)
if not api:
- return GatewayService.error_response(request_id, 'GTW001', 'API does not exist for the requested name and version')
+ return GatewayService.error_response(
+ request_id,
+ 'GTW001',
+ 'API does not exist for the requested name and version',
+ )
if api.get('active') is False:
- return GatewayService.error_response(request_id, 'GTW012', 'API is disabled', status=403)
+ return GatewayService.error_response(
+ request_id, 'GTW012', 'API is disabled', status=403
+ )
endpoints = await api_util.get_api_endpoints(api.get('api_id'))
logger.info(f'{request_id} | SOAP gateway endpoints: {endpoints}')
if not endpoints:
- return GatewayService.error_response(request_id, 'GTW002', 'No endpoints found for the requested API')
+ return GatewayService.error_response(
+ request_id, 'GTW002', 'No endpoints found for the requested API'
+ )
regex_pattern = re.compile(r'\{[^/]+\}')
composite = 'POST/' + endpoint_uri
- if not any(re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in endpoints):
- return GatewayService.error_response(request_id, 'GTW003', 'Endpoint does not exist for the requested API')
+ if not any(
+ re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in endpoints
+ ):
+ return GatewayService.error_response(
+ request_id, 'GTW003', 'Endpoint does not exist for the requested API'
+ )
client_key = request.headers.get('client-key')
- server = await routing_util.pick_upstream_server(api, 'POST', endpoint_uri, client_key)
+ server = await routing_util.pick_upstream_server(
+ api, 'POST', endpoint_uri, client_key
+ )
if not server:
- return GatewayService.error_response(request_id, 'GTW001', 'No upstream servers configured')
+ return GatewayService.error_response(
+ request_id, 'GTW001', 'No upstream servers configured'
+ )
url = server.rstrip('/') + '/' + endpoint_uri.lstrip('/')
logger.info(f'{request_id} | SOAP gateway to: {url}')
retry = api.get('api_allowed_retry_count') or 0
if api.get('api_credits_enabled') and username and not bool(api.get('api_public')):
if not await credit_util.deduct_credit(api.get('api_credit_group'), username):
- return GatewayService.error_response(request_id, 'GTW008', 'User does not have any credits', status=401)
+ return GatewayService.error_response(
+ request_id, 'GTW008', 'User does not have any credits', status=401
+ )
else:
try:
parts = [p for p in (path or '').split('/') if p]
@@ -748,11 +854,17 @@ class GatewayService:
swap_from = api.get('api_authorization_field_swap')
source_val = None
if swap_from:
- for key_variant in (swap_from, str(swap_from).lower(), str(swap_from).title()):
+ for key_variant in (
+ swap_from,
+ str(swap_from).lower(),
+ str(swap_from).title(),
+ ):
if key_variant in headers:
source_val = headers.get(key_variant)
break
- orig_auth = request.headers.get('Authorization') or request.headers.get('authorization')
+ orig_auth = request.headers.get('Authorization') or request.headers.get(
+ 'authorization'
+ )
if source_val is not None and str(source_val).strip() != '':
headers['Authorization'] = source_val
elif orig_auth is not None and str(orig_auth).strip() != '':
@@ -761,7 +873,11 @@ class GatewayService:
pass
try:
- endpoint_doc = await api_util.get_endpoint(api, 'POST', '/' + endpoint_uri.lstrip('/')) if api else None
+ endpoint_doc = (
+ await api_util.get_endpoint(api, 'POST', '/' + endpoint_uri.lstrip('/'))
+ if api
+ else None
+ )
endpoint_id = endpoint_doc.get('endpoint_id') if endpoint_doc else None
if endpoint_id:
await validation_util.validate_soap_request(endpoint_id, envelope)
@@ -771,9 +887,13 @@ class GatewayService:
client = GatewayService.get_http_client()
try:
http_response = await request_with_resilience(
- client, 'POST', url,
+ client,
+ 'POST',
+ url,
api_key=api.get('api_path') if api else (api_name_version or '/api/soap'),
- headers=headers, params=query_params, content=envelope,
+ headers=headers,
+ params=query_params,
+ content=envelope,
retries=retry,
api_config=api,
)
@@ -787,7 +907,9 @@ class GatewayService:
logger.info(f'{request_id} | SOAP gateway response: {response_content}')
backend_end_time = time.time() * 1000
if http_response.status_code == 404:
- return GatewayService.error_response(request_id, 'GTW005', 'Endpoint does not exist in backend service')
+ return GatewayService.error_response(
+ request_id, 'GTW005', 'Endpoint does not exist in backend service'
+ )
logger.info(f'{request_id} | SOAP gateway status code: {http_response.status_code}')
response_headers = {'request_id': request_id}
allowed_lower = {h.lower() for h in (allowed_headers or [])}
@@ -811,29 +933,33 @@ class GatewayService:
return ResponseModel(
status_code=http_response.status_code,
response_headers=response_headers,
- response=response_content
+ response=response_content,
).dict()
except CircuitOpenError:
return ResponseModel(
status_code=503,
response_headers={'request_id': request_id},
error_code='GTW999',
- error_message='Upstream circuit open'
+ error_message='Upstream circuit open',
).dict()
except httpx.TimeoutException:
try:
- metrics_store.record_upstream_timeout('soap:' + (api.get('api_path') if api else '/api/soap'))
+ metrics_store.record_upstream_timeout(
+ 'soap:' + (api.get('api_path') if api else '/api/soap')
+ )
except Exception:
pass
return ResponseModel(
status_code=504,
response_headers={'request_id': request_id},
error_code='GTW010',
- error_message='Gateway timeout'
+ error_message='Gateway timeout',
).dict()
except Exception:
logger.error(f'{request_id} | SOAP gateway failed with code GTW006')
- return GatewayService.error_response(request_id, 'GTW006', 'Internal server error', status=500)
+ return GatewayService.error_response(
+ request_id, 'GTW006', 'Internal server error', status=500
+ )
finally:
if current_time:
logger.info(f'{request_id} | Gateway time {current_time - start_time}ms')
@@ -853,15 +979,21 @@ class GatewayService:
if not api:
api = await api_util.get_api(None, api_path)
if not api:
- logger.error(f'{request_id} | API not found: {api_path}')
- return GatewayService.error_response(request_id, 'GTW001', f'API does not exist: {api_path}')
+ logger.error(f'{request_id} | API not found: {api_path}')
+ return GatewayService.error_response(
+ request_id, 'GTW001', f'API does not exist: {api_path}'
+ )
if api.get('active') is False:
- return GatewayService.error_response(request_id, 'GTW012', 'API is disabled', status=403)
+ return GatewayService.error_response(
+ request_id, 'GTW012', 'API is disabled', status=403
+ )
doorman_cache.set_cache('api_cache', api_path, api)
retry = api.get('api_allowed_retry_count') or 0
if api.get('api_credits_enabled') and username and not bool(api.get('api_public')):
if not await credit_util.deduct_credit(api.get('api_credit_group'), username):
- return GatewayService.error_response(request_id, 'GTW008', 'User does not have any credits', status=401)
+ return GatewayService.error_response(
+ request_id, 'GTW008', 'User does not have any credits', status=401
+ )
current_time = time.time() * 1000
allowed_headers = api.get('api_allowed_headers') or []
headers = await get_headers(request, allowed_headers)
@@ -869,11 +1001,15 @@ class GatewayService:
headers['Content-Type'] = 'application/json'
headers['Accept'] = 'application/json'
if api.get('api_credits_enabled'):
- ai_token_headers = await credit_util.get_credit_api_header(api.get('api_credit_group'))
+ ai_token_headers = await credit_util.get_credit_api_header(
+ api.get('api_credit_group')
+ )
if ai_token_headers:
headers[ai_token_headers[0]] = ai_token_headers[1]
if username and not bool(api.get('api_public')):
- user_specific_api_key = await credit_util.get_user_api_key(api.get('api_credit_group'), username)
+ user_specific_api_key = await credit_util.get_user_api_key(
+ api.get('api_credit_group'), username
+ )
if user_specific_api_key:
headers[ai_token_headers[0]] = user_specific_api_key
if api.get('api_authorization_field_swap'):
@@ -881,11 +1017,17 @@ class GatewayService:
swap_from = api.get('api_authorization_field_swap')
source_val = None
if swap_from:
- for key_variant in (swap_from, str(swap_from).lower(), str(swap_from).title()):
+ for key_variant in (
+ swap_from,
+ str(swap_from).lower(),
+ str(swap_from).title(),
+ ):
if key_variant in headers:
source_val = headers.get(key_variant)
break
- orig_auth = request.headers.get('Authorization') or request.headers.get('authorization')
+ orig_auth = request.headers.get('Authorization') or request.headers.get(
+ 'authorization'
+ )
if source_val is not None and str(source_val).strip() != '':
headers['Authorization'] = source_val
elif orig_auth is not None and str(orig_auth).strip() != '':
@@ -910,19 +1052,27 @@ class GatewayService:
async with Client(transport=None, fetch_schema_from_transport=False) as session:
result = await session.execute(gql(query), variable_values=variables)
except Exception as _e:
- logger.debug(f'{request_id} | GraphQL Client execution failed; falling back to HTTP: {_e}')
+ logger.debug(
+ f'{request_id} | GraphQL Client execution failed; falling back to HTTP: {_e}'
+ )
if result is None:
client_key = request.headers.get('client-key')
- server = await routing_util.pick_upstream_server(api, 'POST', '/graphql', client_key)
+ server = await routing_util.pick_upstream_server(
+ api, 'POST', '/graphql', client_key
+ )
if not server:
logger.error(f'{request_id} | No upstream servers configured for {api_path}')
- return GatewayService.error_response(request_id, 'GTW001', 'No upstream servers configured')
+ return GatewayService.error_response(
+ request_id, 'GTW001', 'No upstream servers configured'
+ )
url = server.rstrip('/') + '/graphql'
client = GatewayService.get_http_client()
try:
http_resp = await request_with_resilience(
- client, 'POST', url,
+ client,
+ 'POST',
+ url,
api_key=api_path,
headers=headers,
json={'query': query, 'variables': variables},
@@ -931,24 +1081,29 @@ class GatewayService:
)
except AttributeError:
http_resp = await client.post(
- url,
- json={'query': query, 'variables': variables},
- headers=headers,
+ url, json={'query': query, 'variables': variables}, headers=headers
)
try:
data = http_resp.json()
except Exception as je:
data = {
- 'errors': [{
- 'message': f'Invalid JSON from upstream: {str(je)}',
- 'extensions': {'code': 'BAD_RESPONSE'}
- }]}
+ 'errors': [
+ {
+ 'message': f'Invalid JSON from upstream: {str(je)}',
+ 'extensions': {'code': 'BAD_RESPONSE'},
+ }
+ ]
+ }
status = getattr(http_resp, 'status_code', 200)
if status != 200 and 'errors' not in data:
- data = {'errors': [{
- 'message': data.get('message') or f'HTTP {status}',
- 'extensions': {'code': f'HTTP_{status}'}
- }]}
+ data = {
+ 'errors': [
+ {
+ 'message': data.get('message') or f'HTTP {status}',
+ 'extensions': {'code': f'HTTP_{status}'},
+ }
+ ]
+ }
result = data
backend_end_time = time.time() * 1000
@@ -972,24 +1127,28 @@ class GatewayService:
response_headers['X-Backend-Time'] = str(int(backend_end_time - current_time))
except Exception:
pass
- return ResponseModel(status_code=200, response_headers=response_headers, response=result).dict()
+ return ResponseModel(
+ status_code=200, response_headers=response_headers, response=result
+ ).dict()
except CircuitOpenError:
return ResponseModel(
status_code=503,
response_headers={'request_id': request_id},
error_code='GTW999',
- error_message='Upstream circuit open'
+ error_message='Upstream circuit open',
).dict()
except httpx.TimeoutException:
try:
- metrics_store.record_upstream_timeout('graphql:' + (api.get('api_path') if api else '/api/graphql'))
+ metrics_store.record_upstream_timeout(
+ 'graphql:' + (api.get('api_path') if api else '/api/graphql')
+ )
except Exception:
pass
return ResponseModel(
status_code=504,
response_headers={'request_id': request_id},
error_code='GTW010',
- error_message='Gateway timeout'
+ error_message='Gateway timeout',
).dict()
except Exception as e:
logger.error(f'{request_id} | GraphQL gateway failed with code GTW006: {str(e)}')
@@ -1002,7 +1161,9 @@ class GatewayService:
logger.info(f'{request_id} | Backend time {backend_end_time - current_time}ms')
@staticmethod
- async def grpc_gateway(username, request, request_id, start_time, path, api_name=None, url=None, retry=0):
+ async def grpc_gateway(
+ username, request, request_id, start_time, path, api_name=None, url=None, retry=0
+ ):
logger.info(f'{request_id} | gRPC gateway processing request')
current_time = backend_end_time = None
try:
@@ -1011,7 +1172,9 @@ class GatewayService:
path_parts = path.strip('/').split('/')
if len(path_parts) < 1:
logger.error(f'{request_id} | Invalid API path format: {path}')
- return GatewayService.error_response(request_id, 'GTW001', 'Invalid API path format', status=404)
+ return GatewayService.error_response(
+ request_id, 'GTW001', 'Invalid API path format', status=404
+ )
api_name = path_parts[-1]
api_version = request.headers.get('X-API-Version', 'v1')
api_path = f'{api_name}/{api_version}'
@@ -1021,19 +1184,33 @@ class GatewayService:
body = await request.json()
if not isinstance(body, dict):
logger.error(f'{request_id} | Invalid request body format')
- return GatewayService.error_response(request_id, 'GTW011', 'Invalid request body format', status=400)
+ return GatewayService.error_response(
+ request_id, 'GTW011', 'Invalid request body format', status=400
+ )
except json.JSONDecodeError:
logger.error(f'{request_id} | Invalid JSON in request body')
- return GatewayService.error_response(request_id, 'GTW011', 'Invalid JSON in request body', status=400)
+ return GatewayService.error_response(
+ request_id, 'GTW011', 'Invalid JSON in request body', status=400
+ )
parsed = GatewayService._parse_and_validate_method(body.get('method'))
if not parsed:
- return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.', status=400)
+ return GatewayService.error_response(
+ request_id,
+ 'GTW011',
+ 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.',
+ status=400,
+ )
_service_name_preview, _method_name_preview = parsed
pkg_override_raw = (body.get('package') or '').strip()
if pkg_override_raw:
if GatewayService._validate_package_name(pkg_override_raw) is None:
- return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC package. Use letters, digits, underscore only.', status=400)
+ return GatewayService.error_response(
+ request_id,
+ 'GTW011',
+ 'Invalid gRPC package. Use letters, digits, underscore only.',
+ status=400,
+ )
api = doorman_cache.get_cache('api_cache', api_path)
if not api:
@@ -1043,25 +1220,35 @@ class GatewayService:
endpoint_doc = await api_util.get_endpoint(api, 'POST', '/grpc')
endpoint_id = endpoint_doc.get('endpoint_id') if endpoint_doc else None
if endpoint_id:
- await validation_util.validate_grpc_request(endpoint_id, body.get('message'))
+ await validation_util.validate_grpc_request(
+ endpoint_id, body.get('message')
+ )
except Exception as e:
- return GatewayService.error_response(request_id, 'GTW011', str(e), status=400)
+ return GatewayService.error_response(
+ request_id, 'GTW011', str(e), status=400
+ )
api_pkg_raw = None
try:
api_pkg_raw = (api.get('api_grpc_package') or '').strip() if api else None
except Exception:
api_pkg_raw = None
pkg_override = (body.get('package') or '').strip() or None
- api_pkg = GatewayService._validate_package_name(api_pkg_raw) if api_pkg_raw else None
- pkg_override_valid = GatewayService._validate_package_name(pkg_override) if pkg_override else None
+ api_pkg = (
+ GatewayService._validate_package_name(api_pkg_raw) if api_pkg_raw else None
+ )
+ pkg_override_valid = (
+ GatewayService._validate_package_name(pkg_override) if pkg_override else None
+ )
default_base = f'{api_name}_{api_version}'.replace('-', '_')
if not GatewayService._is_valid_identifier(default_base):
- default_base = ''.join(ch if ch in GatewayService._IDENT_ALLOWED else '_' for ch in default_base)
- module_base = (api_pkg or pkg_override_valid or default_base)
+ default_base = ''.join(
+ ch if ch in GatewayService._IDENT_ALLOWED else '_' for ch in default_base
+ )
+ module_base = api_pkg or pkg_override_valid or default_base
try:
logger.info(
- f"{request_id} | gRPC module_base resolved: module_base={module_base} "
- f"api_pkg={api_pkg_raw!r} pkg_override={pkg_override_raw!r} default_base={default_base}"
+ f'{request_id} | gRPC module_base resolved: module_base={module_base} '
+ f'api_pkg={api_pkg_raw!r} pkg_override={pkg_override_raw!r} default_base={default_base}'
)
except Exception:
pass
@@ -1084,7 +1271,7 @@ class GatewayService:
request_id, 'GTW013', 'gRPC service not allowed', status=403
)
if allowed_methods and isinstance(allowed_methods, list):
- method_fq = f"{service_name}.{method_name}"
+ method_fq = f'{service_name}.{method_name}'
if method_fq not in allowed_methods:
return GatewayService.error_response(
request_id, 'GTW013', 'gRPC method not allowed', status=403
@@ -1100,9 +1287,16 @@ class GatewayService:
proto_path = (proto_dir / proto_rel.with_suffix('.proto')).resolve()
# Validate resolved path stays within project bounds
if not GatewayService._validate_under_base(project_root, proto_path):
- return GatewayService.error_response(request_id, 'GTW012', 'Invalid path for proto resolution', status=400)
+ return GatewayService.error_response(
+ request_id, 'GTW012', 'Invalid path for proto resolution', status=400
+ )
if not GatewayService._validate_under_base(proto_dir, proto_path):
- return GatewayService.error_response(request_id, 'GTW012', 'Proto path must be within proto directory', status=400)
+ return GatewayService.error_response(
+ request_id,
+ 'GTW012',
+ 'Proto path must be within proto directory',
+ status=400,
+ )
generated_dir = project_root / 'generated'
gen_dir_str = str(generated_dir)
@@ -1112,7 +1306,9 @@ class GatewayService:
if gen_dir_str not in sys.path:
sys.path.insert(0, gen_dir_str)
try:
- logger.info(f"{request_id} | sys.path updated for gRPC import. project_root={proj_root_str}, generated_dir={gen_dir_str}")
+ logger.info(
+ f'{request_id} | sys.path updated for gRPC import. project_root={proj_root_str}, generated_dir={gen_dir_str}'
+ )
except Exception:
pass
@@ -1129,11 +1325,17 @@ class GatewayService:
gen_pb2_grpc_name = f'generated.{module_base}_pb2_grpc'
pb2 = importlib.import_module(gen_pb2_name)
pb2_grpc = importlib.import_module(gen_pb2_grpc_name)
- logger.info(f"{request_id} | Successfully imported gRPC modules: {pb2.__name__} and {pb2_grpc.__name__}")
+ logger.info(
+ f'{request_id} | Successfully imported gRPC modules: {pb2.__name__} and {pb2_grpc.__name__}'
+ )
except ModuleNotFoundError as mnf_exc:
- logger.warning(f"{request_id} | gRPC modules not found, will attempt proto generation: {str(mnf_exc)}")
+ logger.warning(
+ f'{request_id} | gRPC modules not found, will attempt proto generation: {str(mnf_exc)}'
+ )
except ImportError as imp_exc:
- logger.error(f"{request_id} | ImportError loading gRPC modules (likely broken import in generated file): {str(imp_exc)}")
+ logger.error(
+ f'{request_id} | ImportError loading gRPC modules (likely broken import in generated file): {str(imp_exc)}'
+ )
mod_pb2 = f'{module_base}_pb2'
mod_pb2_grpc = f'{module_base}_pb2_grpc'
if mod_pb2 in sys.modules:
@@ -1144,15 +1346,17 @@ class GatewayService:
request_id,
'GTW012',
f'Failed to import gRPC modules. Proto files may need regeneration. Error: {str(imp_exc)[:100]}',
- status=404
+ status=404,
)
except Exception as import_exc:
- logger.error(f"{request_id} | Unexpected error importing gRPC modules: {type(import_exc).__name__}: {str(import_exc)}")
+ logger.error(
+ f'{request_id} | Unexpected error importing gRPC modules: {type(import_exc).__name__}: {str(import_exc)}'
+ )
return GatewayService.error_response(
request_id,
'GTW012',
f'Unexpected error importing gRPC modules: {type(import_exc).__name__}',
- status=500
+ status=500,
)
if pb2 is None or pb2_grpc is None:
@@ -1161,7 +1365,9 @@ class GatewayService:
try:
pb2_check_path = (generated_dir / (module_base + '_pb2.py')).resolve()
if GatewayService._validate_under_base(generated_dir, pb2_check_path):
- logger.info(f"{request_id} | gRPC generated check: proto_path={proto_path} exists={proto_path.exists()} generated_dir={generated_dir} pb2={module_base}_pb2.py={pb2_check_path.exists()}")
+ logger.info(
+ f'{request_id} | gRPC generated check: proto_path={proto_path} exists={proto_path.exists()} generated_dir={generated_dir} pb2={module_base}_pb2.py={pb2_check_path.exists()}'
+ )
except Exception:
pass
method_fq = body.get('method', '')
@@ -1196,44 +1402,71 @@ class GatewayService:
generated_dir.mkdir(exist_ok=True)
try:
from grpc_tools import protoc as _protoc
- code = _protoc.main([
- 'protoc', f'--proto_path={str(proto_dir)}', f'--python_out={str(generated_dir)}', f'--grpc_python_out={str(generated_dir)}', str(proto_path)
- ])
+
+ code = _protoc.main(
+ [
+ 'protoc',
+ f'--proto_path={str(proto_dir)}',
+ f'--python_out={str(generated_dir)}',
+ f'--grpc_python_out={str(generated_dir)}',
+ str(proto_path),
+ ]
+ )
if code != 0:
raise RuntimeError(f'protoc returned {code}')
init_path = generated_dir / '__init__.py'
if not init_path.exists():
- init_path.write_text('"""Generated gRPC code."""\n', encoding='utf-8')
+ init_path.write_text(
+ '"""Generated gRPC code."""\n', encoding='utf-8'
+ )
except Exception as ge:
logger.error(f'{request_id} | On-demand proto generation failed: {ge}')
if os.getenv('DOORMAN_TEST_MODE', '').lower() == 'true':
pb2 = type('PB2', (), {})
pb2_grpc = type('SVC', (), {})
else:
- return GatewayService.error_response(request_id, 'GTW012', f'Proto file not found for API: {api_path}', status=404)
+ return GatewayService.error_response(
+ request_id,
+ 'GTW012',
+ f'Proto file not found for API: {api_path}',
+ status=404,
+ )
except Exception as ge:
- logger.error(f'{request_id} | Proto file not found and generation skipped: {ge}')
+ logger.error(
+ f'{request_id} | Proto file not found and generation skipped: {ge}'
+ )
if os.getenv('DOORMAN_TEST_MODE', '').lower() != 'true':
- return GatewayService.error_response(request_id, 'GTW012', f'Proto file not found for API: {api_path}', status=404)
+ return GatewayService.error_response(
+ request_id,
+ 'GTW012',
+ f'Proto file not found for API: {api_path}',
+ status=404,
+ )
api = doorman_cache.get_cache('api_cache', api_path)
if not api:
api = await api_util.get_api(None, api_path)
if not api:
logger.error(f'{request_id} | API not found: {api_path}')
- return GatewayService.error_response(request_id, 'GTW001', f'API does not exist: {api_path}', status=404)
+ return GatewayService.error_response(
+ request_id, 'GTW001', f'API does not exist: {api_path}', status=404
+ )
doorman_cache.set_cache('api_cache', api_path, api)
client_key = request.headers.get('client-key')
server = await routing_util.pick_upstream_server(api, 'POST', '/grpc', client_key)
if not server:
logger.error(f'{request_id} | No upstream servers configured for {api_path}')
- return GatewayService.error_response(request_id, 'GTW001', 'No upstream servers configured', status=404)
+ return GatewayService.error_response(
+ request_id, 'GTW001', 'No upstream servers configured', status=404
+ )
url = server.rstrip('/')
if url.startswith('grpc://'):
url = url[7:]
retry = api.get('api_allowed_retry_count') or 0
if api.get('api_credits_enabled') and username and not bool(api.get('api_public')):
if not await credit_util.deduct_credit(api.get('api_credit_group'), username):
- return GatewayService.error_response(request_id, 'GTW008', 'User does not have any credits', status=401)
+ return GatewayService.error_response(
+ request_id, 'GTW008', 'User does not have any credits', status=401
+ )
current_time = time.time() * 1000
try:
if not url:
@@ -1245,24 +1478,41 @@ class GatewayService:
api_name = path_parts[-1] if path_parts else None
if api_name:
api_path = f'{api_name}/{api_version}'
- api = doorman_cache.get_cache('api_cache', api_path) or await api_util.get_api(None, api_path)
+ api = doorman_cache.get_cache(
+ 'api_cache', api_path
+ ) or await api_util.get_api(None, api_path)
try:
- api_pkg_raw = (api.get('api_grpc_package') or '').strip() if api else None
+ api_pkg_raw = (
+ (api.get('api_grpc_package') or '').strip() if api else None
+ )
except Exception:
api_pkg_raw = None
pkg_override = (body.get('package') or '').strip() or None
- api_pkg = GatewayService._validate_package_name(api_pkg_raw) if api_pkg_raw else None
- pkg_override_valid = GatewayService._validate_package_name(pkg_override) if pkg_override else None
+ api_pkg = (
+ GatewayService._validate_package_name(api_pkg_raw)
+ if api_pkg_raw
+ else None
+ )
+ pkg_override_valid = (
+ GatewayService._validate_package_name(pkg_override)
+ if pkg_override
+ else None
+ )
default_base = f'{api_name}_{api_version}'.replace('-', '_')
if not GatewayService._is_valid_identifier(default_base):
- default_base = ''.join(ch if ch in GatewayService._IDENT_ALLOWED else '_' for ch in default_base)
- module_base = (api_pkg or pkg_override_valid or default_base)
+ default_base = ''.join(
+ ch if ch in GatewayService._IDENT_ALLOWED else '_'
+ for ch in default_base
+ )
+ module_base = api_pkg or pkg_override_valid or default_base
try:
allowed_pkgs = api.get('api_grpc_allowed_packages') if api else None
allowed_svcs = api.get('api_grpc_allowed_services') if api else None
allowed_methods = api.get('api_grpc_allowed_methods') if api else None
- prev_parsed = GatewayService._parse_and_validate_method(body.get('method'))
+ prev_parsed = GatewayService._parse_and_validate_method(
+ body.get('method')
+ )
if prev_parsed:
svc_name, mth_name = prev_parsed
else:
@@ -1278,8 +1528,13 @@ class GatewayService:
return GatewayService.error_response(
request_id, 'GTW013', 'gRPC service not allowed', status=403
)
- if svc_name and mth_name and allowed_methods and isinstance(allowed_methods, list):
- if f"{svc_name}.{mth_name}" not in allowed_methods:
+ if (
+ svc_name
+ and mth_name
+ and allowed_methods
+ and isinstance(allowed_methods, list)
+ ):
+ if f'{svc_name}.{mth_name}' not in allowed_methods:
return GatewayService.error_response(
request_id, 'GTW013', 'gRPC method not allowed', status=403
)
@@ -1296,35 +1551,69 @@ class GatewayService:
body = await request.json()
if not isinstance(body, dict):
logger.error(f'{request_id} | Invalid request body format')
- return GatewayService.error_response(request_id, 'GTW011', 'Invalid request body format', status=400)
+ return GatewayService.error_response(
+ request_id, 'GTW011', 'Invalid request body format', status=400
+ )
except json.JSONDecodeError:
logger.error(f'{request_id} | Invalid JSON in request body')
- return GatewayService.error_response(request_id, 'GTW011', 'Invalid JSON in request body', status=400)
+ return GatewayService.error_response(
+ request_id, 'GTW011', 'Invalid JSON in request body', status=400
+ )
if 'method' not in body:
logger.error(f'{request_id} | Missing method in request body')
- return GatewayService.error_response(request_id, 'GTW011', 'Missing method in request body', status=400)
+ return GatewayService.error_response(
+ request_id, 'GTW011', 'Missing method in request body', status=400
+ )
if 'message' not in body:
logger.error(f'{request_id} | Missing message in request body')
- return GatewayService.error_response(request_id, 'GTW011', 'Missing message in request body', status=400)
+ return GatewayService.error_response(
+ request_id, 'GTW011', 'Missing message in request body', status=400
+ )
parsed_method = GatewayService._parse_and_validate_method(body.get('method'))
if not parsed_method:
- return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.', status=400)
+ return GatewayService.error_response(
+ request_id,
+ 'GTW011',
+ 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.',
+ status=400,
+ )
pkg_override = (body.get('package') or '').strip() or None
if pkg_override and GatewayService._validate_package_name(pkg_override) is None:
- return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC package. Use letters, digits, underscore only.', status=400)
+ return GatewayService.error_response(
+ request_id,
+ 'GTW011',
+ 'Invalid gRPC package. Use letters, digits, underscore only.',
+ status=400,
+ )
try:
svc_name, mth_name = parsed_method
allowed_pkgs = api.get('api_grpc_allowed_packages') if api else None
allowed_svcs = api.get('api_grpc_allowed_services') if api else None
allowed_methods = api.get('api_grpc_allowed_methods') if api else None
- if allowed_pkgs and isinstance(allowed_pkgs, list) and module_base not in allowed_pkgs:
- return GatewayService.error_response(request_id, 'GTW013', 'gRPC package not allowed', status=403)
+ if (
+ allowed_pkgs
+ and isinstance(allowed_pkgs, list)
+ and module_base not in allowed_pkgs
+ ):
+ return GatewayService.error_response(
+ request_id, 'GTW013', 'gRPC package not allowed', status=403
+ )
if allowed_svcs and isinstance(allowed_svcs, list) and svc_name not in allowed_svcs:
- return GatewayService.error_response(request_id, 'GTW013', 'gRPC service not allowed', status=403)
- if allowed_methods and isinstance(allowed_methods, list) and f"{svc_name}.{mth_name}" not in allowed_methods:
- return GatewayService.error_response(request_id, 'GTW013', 'gRPC method not allowed', status=403)
+ return GatewayService.error_response(
+ request_id, 'GTW013', 'gRPC service not allowed', status=403
+ )
+ if (
+ allowed_methods
+ and isinstance(allowed_methods, list)
+ and f'{svc_name}.{mth_name}' not in allowed_methods
+ ):
+ return GatewayService.error_response(
+ request_id, 'GTW013', 'gRPC method not allowed', status=403
+ )
except Exception:
- return GatewayService.error_response(request_id, 'GTW013', 'gRPC target not allowed', status=403)
+ return GatewayService.error_response(
+ request_id, 'GTW013', 'gRPC target not allowed', status=403
+ )
proto_rel = Path(module_base.replace('.', '/'))
proto_filename = f'{proto_rel.name}.proto'
@@ -1335,11 +1624,11 @@ class GatewayService:
await validation_util.validate_grpc_request(endpoint_id, body.get('message'))
except Exception as e:
return GatewayService.error_response(request_id, 'GTW011', str(e), status=400)
- proto_path = (GatewayService._PROJECT_ROOT / 'proto' / proto_rel.with_suffix('.proto'))
+ proto_path = GatewayService._PROJECT_ROOT / 'proto' / proto_rel.with_suffix('.proto')
use_imported = False
try:
if 'pb2' in locals() and 'pb2_grpc' in locals():
- use_imported = (pb2 is not None and pb2_grpc is not None)
+ use_imported = pb2 is not None and pb2_grpc is not None
except Exception:
use_imported = False
module_name = module_base
@@ -1351,7 +1640,9 @@ class GatewayService:
if gen_dir_str not in sys.path:
sys.path.insert(0, gen_dir_str)
try:
- logger.info(f"{request_id} | sys.path prepared for import: project_root={proj_root_str}, generated_dir={gen_dir_str}")
+ logger.info(
+ f'{request_id} | sys.path prepared for import: project_root={proj_root_str}, generated_dir={gen_dir_str}'
+ )
except Exception:
pass
parts = module_name.split('.') if '.' in module_name else [module_name]
@@ -1361,7 +1652,7 @@ class GatewayService:
if use_imported:
pb2_module = pb2
service_module = pb2_grpc
- logger.info(f"{request_id} | Using imported gRPC modules for {module_name}")
+ logger.info(f'{request_id} | Using imported gRPC modules for {module_name}')
else:
if not proto_path.exists():
if os.getenv('DOORMAN_TEST_MODE', '').lower() == 'true':
@@ -1371,89 +1662,159 @@ class GatewayService:
use_imported = True
except Exception:
logger.error(f'{request_id} | Proto file not found: {str(proto_path)}')
- return GatewayService.error_response(request_id, 'GTW012', f'Proto file not found for API: {api_path}', status=404)
+ return GatewayService.error_response(
+ request_id,
+ 'GTW012',
+ f'Proto file not found for API: {api_path}',
+ status=404,
+ )
if not use_imported:
- pb2_path = (package_dir / f"{parts[-1]}_pb2.py").resolve()
- pb2_grpc_path = (package_dir / f"{parts[-1]}_pb2_grpc.py").resolve()
+ pb2_path = (package_dir / f'{parts[-1]}_pb2.py').resolve()
+ pb2_grpc_path = (package_dir / f'{parts[-1]}_pb2_grpc.py').resolve()
# Validate paths before checking existence
- if not GatewayService._validate_under_base(generated_dir, pb2_path) or not GatewayService._validate_under_base(generated_dir, pb2_grpc_path):
- logger.error(f"{request_id} | Invalid path for generated modules: pb2={pb2_path} pb2_grpc={pb2_grpc_path}")
- return GatewayService.error_response(request_id, 'GTW012', 'Invalid generated module path', status=400)
+ if not GatewayService._validate_under_base(
+ generated_dir, pb2_path
+ ) or not GatewayService._validate_under_base(generated_dir, pb2_grpc_path):
+ logger.error(
+ f'{request_id} | Invalid path for generated modules: pb2={pb2_path} pb2_grpc={pb2_grpc_path}'
+ )
+ return GatewayService.error_response(
+ request_id, 'GTW012', 'Invalid generated module path', status=400
+ )
if not (pb2_path.is_file() and pb2_grpc_path.is_file()):
- logger.error(f"{request_id} | Generated modules not found for '{module_name}' pb2={pb2_path} exists={pb2_path.is_file()} pb2_grpc={pb2_grpc_path} exists={pb2_grpc_path.is_file()}")
+ logger.error(
+ f"{request_id} | Generated modules not found for '{module_name}' pb2={pb2_path} exists={pb2_path.is_file()} pb2_grpc={pb2_grpc_path} exists={pb2_grpc_path.is_file()}"
+ )
if isinstance(url, str) and url.startswith(('http://', 'https://')):
try:
client = GatewayService.get_http_client()
http_url = url.rstrip('/') + '/grpc'
- http_response = await client.post(http_url, json=body, headers=headers)
+ http_response = await client.post(
+ http_url, json=body, headers=headers
+ )
finally:
- if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'false').lower() != 'true':
+ if (
+ os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'false').lower()
+ != 'true'
+ ):
try:
await client.aclose()
except Exception:
pass
if http_response.status_code == 404:
- return GatewayService.error_response(request_id, 'GTW005', 'Endpoint does not exist in backend service')
+ return GatewayService.error_response(
+ request_id,
+ 'GTW005',
+ 'Endpoint does not exist in backend service',
+ )
response_headers = {'request_id': request_id}
try:
if current_time and start_time:
- response_headers['X-Gateway-Time'] = str(int(current_time - start_time))
+ response_headers['X-Gateway-Time'] = str(
+ int(current_time - start_time)
+ )
except Exception:
pass
return ResponseModel(
status_code=http_response.status_code,
response_headers=response_headers,
- response=(http_response.json() if http_response.headers.get('Content-Type','').startswith('application/json') else http_response.text)
+ response=(
+ http_response.json()
+ if http_response.headers.get('Content-Type', '').startswith(
+ 'application/json'
+ )
+ else http_response.text
+ ),
).dict()
- return GatewayService.error_response(request_id, 'GTW012', f'Generated gRPC modules not found for package: {module_name}', status=404)
+ return GatewayService.error_response(
+ request_id,
+ 'GTW012',
+ f'Generated gRPC modules not found for package: {module_name}',
+ status=404,
+ )
if not use_imported:
try:
if GatewayService._validate_package_name(module_name) is None:
- return GatewayService.error_response(request_id, 'GTW012', 'Invalid gRPC module name', status=400)
+ return GatewayService.error_response(
+ request_id, 'GTW012', 'Invalid gRPC module name', status=400
+ )
import_name_pb2 = f'{module_name}_pb2'
import_name_grpc = f'{module_name}_pb2_grpc'
- logger.info(f"{request_id} | Importing generated modules: {import_name_pb2} and {import_name_grpc}")
+ logger.info(
+ f'{request_id} | Importing generated modules: {import_name_pb2} and {import_name_grpc}'
+ )
try:
pb2_module = importlib.import_module(import_name_pb2)
service_module = importlib.import_module(import_name_grpc)
except ModuleNotFoundError:
alt_pb2 = f'generated.{module_name}_pb2'
alt_grpc = f'generated.{module_name}_pb2_grpc'
- logger.info(f"{request_id} | Retrying import via generated package: {alt_pb2} and {alt_grpc}")
+ logger.info(
+ f'{request_id} | Retrying import via generated package: {alt_pb2} and {alt_grpc}'
+ )
pb2_module = importlib.import_module(alt_pb2)
service_module = importlib.import_module(alt_grpc)
except ImportError as e:
- logger.error(f'{request_id} | Failed to import gRPC module: {str(e)}', exc_info=True)
+ logger.error(
+ f'{request_id} | Failed to import gRPC module: {str(e)}', exc_info=True
+ )
if isinstance(url, str) and url.startswith(('http://', 'https://')):
try:
client = GatewayService.get_http_client()
http_url = url.rstrip('/') + '/grpc'
- http_response = await client.post(http_url, json=body, headers=headers)
+ http_response = await client.post(
+ http_url, json=body, headers=headers
+ )
finally:
- if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'false').lower() != 'true':
+ if (
+ os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'false').lower()
+ != 'true'
+ ):
try:
await client.aclose()
except Exception:
pass
if http_response.status_code == 404:
- return GatewayService.error_response(request_id, 'GTW005', 'Endpoint does not exist in backend service')
+ return GatewayService.error_response(
+ request_id,
+ 'GTW005',
+ 'Endpoint does not exist in backend service',
+ )
response_headers = {'request_id': request_id}
try:
if current_time and start_time:
- response_headers['X-Gateway-Time'] = str(int(current_time - start_time))
+ response_headers['X-Gateway-Time'] = str(
+ int(current_time - start_time)
+ )
except Exception:
pass
return ResponseModel(
status_code=http_response.status_code,
response_headers=response_headers,
- response=(http_response.json() if http_response.headers.get('Content-Type','').startswith('application/json') else http_response.text)
+ response=(
+ http_response.json()
+ if http_response.headers.get('Content-Type', '').startswith(
+ 'application/json'
+ )
+ else http_response.text
+ ),
).dict()
- return GatewayService.error_response(request_id, 'GTW012', f'Failed to import gRPC module: {str(e)}', status=404)
+ return GatewayService.error_response(
+ request_id,
+ 'GTW012',
+ f'Failed to import gRPC module: {str(e)}',
+ status=404,
+ )
parsed = GatewayService._parse_and_validate_method(body.get('method'))
if not parsed:
- return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.', status=400)
+ return GatewayService.error_response(
+ request_id,
+ 'GTW011',
+ 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.',
+ status=400,
+ )
service_name, method_name = parsed
- if isinstance(url, str) and url.startswith(("http://", "https://")):
+ if isinstance(url, str) and url.startswith(('http://', 'https://')):
try:
client = GatewayService.get_http_client()
http_url = url.rstrip('/') + '/grpc'
@@ -1465,7 +1826,9 @@ class GatewayService:
except Exception:
pass
if http_response.status_code == 404:
- return GatewayService.error_response(request_id, 'GTW005', 'Endpoint does not exist in backend service')
+ return GatewayService.error_response(
+ request_id, 'GTW005', 'Endpoint does not exist in backend service'
+ )
response_headers = {'request_id': request_id}
try:
if current_time and start_time:
@@ -1475,10 +1838,16 @@ class GatewayService:
return ResponseModel(
status_code=http_response.status_code,
response_headers=response_headers,
- response=(http_response.json() if http_response.headers.get('Content-Type','').startswith('application/json') else http_response.text)
+ response=(
+ http_response.json()
+ if http_response.headers.get('Content-Type', '').startswith(
+ 'application/json'
+ )
+ else http_response.text
+ ),
).dict()
- logger.info(f"{request_id} | Connecting to gRPC upstream: {url}")
+ logger.info(f'{request_id} | Connecting to gRPC upstream: {url}')
channel = grpc.aio.insecure_channel(url)
try:
await asyncio.wait_for(channel.channel_ready(), timeout=2.0)
@@ -1488,48 +1857,60 @@ class GatewayService:
reply_class_name = f'{method_name}Reply'
try:
- logger.info(f"{request_id} | Resolving message types: {request_class_name} and {reply_class_name} from pb2_module={getattr(pb2_module, '__name__', 'unknown')}")
+ logger.info(
+ f'{request_id} | Resolving message types: {request_class_name} and {reply_class_name} from pb2_module={getattr(pb2_module, "__name__", "unknown")}'
+ )
if pb2_module is None:
- logger.error(f'{request_id} | pb2_module is None - cannot resolve message types')
+ logger.error(
+ f'{request_id} | pb2_module is None - cannot resolve message types'
+ )
return GatewayService.error_response(
request_id,
'GTW012',
'Internal error: protobuf module not loaded',
- status=500
+ status=500,
)
try:
request_class = getattr(pb2_module, request_class_name)
reply_class = getattr(pb2_module, reply_class_name)
except AttributeError as attr_err:
- logger.error(f'{request_id} | Message types not found in pb2_module: {str(attr_err)}')
+ logger.error(
+ f'{request_id} | Message types not found in pb2_module: {str(attr_err)}'
+ )
return GatewayService.error_response(
request_id,
'GTW006',
f'Message types {request_class_name}/{reply_class_name} not found in protobuf module',
- status=500
+ status=500,
)
try:
request_message = request_class()
- logger.info(f'{request_id} | Successfully created request message of type {request_class_name}')
+ logger.info(
+ f'{request_id} | Successfully created request message of type {request_class_name}'
+ )
except Exception as create_err:
- logger.error(f'{request_id} | Failed to instantiate request message: {type(create_err).__name__}: {str(create_err)}')
+ logger.error(
+ f'{request_id} | Failed to instantiate request message: {type(create_err).__name__}: {str(create_err)}'
+ )
return GatewayService.error_response(
request_id,
'GTW006',
f'Failed to create request message: {type(create_err).__name__}',
- status=500
+ status=500,
)
except Exception as e:
- logger.error(f'{request_id} | Unexpected error in message type resolution: {type(e).__name__}: {str(e)}')
+ logger.error(
+ f'{request_id} | Unexpected error in message type resolution: {type(e).__name__}: {str(e)}'
+ )
return GatewayService.error_response(
request_id,
'GTW012',
f'Unexpected error resolving message types: {type(e).__name__}',
- status=500
+ status=500,
)
for key, value in body['message'].items():
try:
@@ -1553,12 +1934,16 @@ class GatewayService:
except Exception:
base_ms, max_ms = 100, 1000
- stream_mode = str((body.get('stream') or body.get('streaming') or '')).lower()
+ stream_mode = str(body.get('stream') or body.get('streaming') or '').lower()
idempotent_override = body.get('idempotent')
if idempotent_override is not None:
is_idempotent = bool(idempotent_override)
else:
- is_idempotent = not (stream_mode.startswith('client') or stream_mode.startswith('bidi') or stream_mode.startswith('bi'))
+ is_idempotent = not (
+ stream_mode.startswith('client')
+ or stream_mode.startswith('bidi')
+ or stream_mode.startswith('bi')
+ )
retryable = {
grpc.StatusCode.UNAVAILABLE,
@@ -1572,19 +1957,26 @@ class GatewayService:
final_code_name = 'OK'
got_response = False
try:
- logger.info(f"{request_id} | gRPC entering attempts={attempts} stream_mode={stream_mode or 'unary'} method={service_name}.{method_name}")
+ logger.info(
+ f'{request_id} | gRPC entering attempts={attempts} stream_mode={stream_mode or "unary"} method={service_name}.{method_name}'
+ )
except Exception:
pass
for attempt in range(attempts):
try:
full_method = f'/{module_base}.{service_name}/{method_name}'
try:
- logger.info(f"{request_id} | gRPC attempt={attempt+1}/{attempts} calling {full_method}")
+ logger.info(
+ f'{request_id} | gRPC attempt={attempt + 1}/{attempts} calling {full_method}'
+ )
except Exception:
pass
req_ser = getattr(request_message, 'SerializeToString', None)
if not callable(req_ser):
- req_ser = (lambda _m: b'')
+
+ def req_ser(_m):
+ return b''
+
metadata_list = GatewayService._sanitize_grpc_metadata(headers or {})
if stream_mode.startswith('server'):
call = channel.unary_stream(
@@ -1604,7 +1996,15 @@ class GatewayService:
items.append(d)
if len(items) >= max_items:
break
- response = type('R', (), {'DESCRIPTOR': type('D', (), {'fields': []})(), 'ok': True, '_items': items})()
+ response = type(
+ 'R',
+ (),
+ {
+ 'DESCRIPTOR': type('D', (), {'fields': []})(),
+ 'ok': True,
+ '_items': items,
+ },
+ )()
got_response = True
elif stream_mode.startswith('client'):
try:
@@ -1615,6 +2015,7 @@ class GatewayService:
)
except AttributeError:
stream = None
+
async def _gen_client():
msgs = body.get('messages') or []
if not msgs:
@@ -1634,6 +2035,7 @@ class GatewayService:
yield msg
except Exception:
yield request_message
+
if stream is not None:
try:
response = await stream(_gen_client(), metadata=metadata_list)
@@ -1660,6 +2062,7 @@ class GatewayService:
)
except AttributeError:
bidi = None
+
async def _gen_bidi():
msgs = body.get('messages') or []
if not msgs:
@@ -1679,6 +2082,7 @@ class GatewayService:
yield msg
except Exception:
yield request_message
+
items = []
max_items = int(body.get('max_items') or 50)
if bidi is not None:
@@ -1704,7 +2108,15 @@ class GatewayService:
items.append(d)
if len(items) >= max_items:
break
- response = type('R', (), {'DESCRIPTOR': type('D', (), {'fields': []})(), 'ok': True, '_items': items})()
+ response = type(
+ 'R',
+ (),
+ {
+ 'DESCRIPTOR': type('D', (), {'fields': []})(),
+ 'ok': True,
+ '_items': items,
+ },
+ )()
got_response = True
else:
unary = channel.unary_unary(
@@ -1719,7 +2131,9 @@ class GatewayService:
got_response = True
last_exc = None
try:
- logger.info(f"{request_id} | gRPC unary success; stream_mode={stream_mode or 'unary'}")
+ logger.info(
+ f'{request_id} | gRPC unary success; stream_mode={stream_mode or "unary"}'
+ )
except Exception:
pass
break
@@ -1728,13 +2142,13 @@ class GatewayService:
try:
code = getattr(e2, 'code', lambda: None)()
cname = str(getattr(code, 'name', '') or 'UNKNOWN')
- logger.info(f"{request_id} | gRPC primary call raised: {cname}")
+ logger.info(f'{request_id} | gRPC primary call raised: {cname}')
except Exception:
- logger.info(f"{request_id} | gRPC primary call raised non-grpc exception")
+ logger.info(f'{request_id} | gRPC primary call raised non-grpc exception')
final_code_name = str(code.name) if getattr(code, 'name', None) else 'ERROR'
if attempt < attempts - 1 and is_idempotent and code in retryable:
retries_made += 1
- delay = min(max_ms, base_ms * (2 ** attempt)) / 1000.0
+ delay = min(max_ms, base_ms * (2**attempt)) / 1000.0
jitter_factor = 1.0 + (random.random() * jitter - (jitter / 2.0))
await asyncio.sleep(max(0.01, delay * jitter_factor))
continue
@@ -1742,7 +2156,10 @@ class GatewayService:
alt_method = f'/{service_name}/{method_name}'
req_ser = getattr(request_message, 'SerializeToString', None)
if not callable(req_ser):
- req_ser = (lambda _m: b'')
+
+ def req_ser(_m):
+ return b''
+
if stream_mode.startswith('server'):
call2 = channel.unary_stream(
alt_method,
@@ -1761,7 +2178,15 @@ class GatewayService:
items.append(d)
if len(items) >= max_items:
break
- response = type('R', (), {'DESCRIPTOR': type('D', (), {'fields': []})(), 'ok': True, '_items': items})()
+ response = type(
+ 'R',
+ (),
+ {
+ 'DESCRIPTOR': type('D', (), {'fields': []})(),
+ 'ok': True,
+ '_items': items,
+ },
+ )()
got_response = True
elif stream_mode.startswith('client'):
try:
@@ -1772,6 +2197,7 @@ class GatewayService:
)
except AttributeError:
stream2 = None
+
async def _gen_client_alt():
msgs = body.get('messages') or []
if not msgs:
@@ -1791,9 +2217,12 @@ class GatewayService:
yield msg
except Exception:
yield request_message
+
if stream2 is not None:
try:
- response = await stream2(_gen_client_alt(), metadata=metadata_list)
+ response = await stream2(
+ _gen_client_alt(), metadata=metadata_list
+ )
except TypeError:
response = await stream2(_gen_client_alt())
got_response = True
@@ -1817,6 +2246,7 @@ class GatewayService:
)
except AttributeError:
bidi2 = None
+
async def _gen_bidi_alt():
msgs = body.get('messages') or []
if not msgs:
@@ -1836,6 +2266,7 @@ class GatewayService:
yield msg
except Exception:
yield request_message
+
items = []
max_items = int(body.get('max_items') or 50)
if bidi2 is not None:
@@ -1861,7 +2292,15 @@ class GatewayService:
items.append(d)
if len(items) >= max_items:
break
- response = type('R', (), {'DESCRIPTOR': type('D', (), {'fields': []})(), 'ok': True, '_items': items})()
+ response = type(
+ 'R',
+ (),
+ {
+ 'DESCRIPTOR': type('D', (), {'fields': []})(),
+ 'ok': True,
+ '_items': items,
+ },
+ )()
got_response = True
else:
unary2 = channel.unary_unary(
@@ -1881,13 +2320,15 @@ class GatewayService:
try:
code3 = getattr(e3, 'code', lambda: None)()
cname3 = str(getattr(code3, 'name', '') or 'UNKNOWN')
- logger.info(f"{request_id} | gRPC alt call raised: {cname3}")
+ logger.info(f'{request_id} | gRPC alt call raised: {cname3}')
except Exception:
- logger.info(f"{request_id} | gRPC alt call raised non-grpc exception")
- final_code_name = str(code3.name) if getattr(code3, 'name', None) else 'ERROR'
+ logger.info(f'{request_id} | gRPC alt call raised non-grpc exception')
+ final_code_name = (
+ str(code3.name) if getattr(code3, 'name', None) else 'ERROR'
+ )
if attempt < attempts - 1 and is_idempotent and code3 in retryable:
retries_made += 1
- delay = min(max_ms, base_ms * (2 ** attempt)) / 1000.0
+ delay = min(max_ms, base_ms * (2**attempt)) / 1000.0
jitter_factor = 1.0 + (random.random() * jitter - (jitter / 2.0))
await asyncio.sleep(max(0.01, delay * jitter_factor))
continue
@@ -1900,11 +2341,13 @@ class GatewayService:
code_obj = getattr(last_exc, 'code', lambda: None)()
if code_obj and hasattr(code_obj, 'name'):
code_name = str(code_obj.name).upper()
- logger.info(f"{request_id} | gRPC call failed with status: {code_name}")
+ logger.info(f'{request_id} | gRPC call failed with status: {code_name}')
else:
- logger.warning(f"{request_id} | gRPC exception has no valid status code")
+ logger.warning(f'{request_id} | gRPC exception has no valid status code')
except Exception as code_extract_err:
- logger.warning(f"{request_id} | Failed to extract gRPC status code: {str(code_extract_err)}")
+ logger.warning(
+ f'{request_id} | Failed to extract gRPC status code: {str(code_extract_err)}'
+ )
status_map = {
'OK': 200,
@@ -1941,33 +2384,41 @@ class GatewayService:
details = f'gRPC error: {code_name}'
logger.error(
- f"{request_id} | gRPC call failed after {retries_made} retries. "
- f"Status: {code_name}, HTTP: {http_status}, Details: {details[:100]}"
+ f'{request_id} | gRPC call failed after {retries_made} retries. '
+ f'Status: {code_name}, HTTP: {http_status}, Details: {details[:100]}'
)
response_headers = {
'request_id': request_id,
'X-Retry-Count': str(retries_made),
'X-GRPC-Status': code_name,
- 'X-GRPC-Code': str(code_obj.value[0]) if code_obj and hasattr(code_obj, 'value') else 'unknown'
+ 'X-GRPC-Code': str(code_obj.value[0])
+ if code_obj and hasattr(code_obj, 'value')
+ else 'unknown',
}
return ResponseModel(
status_code=http_status,
response_headers=response_headers,
error_code='GTW006',
- error_message=str(details)[:255]
+ error_message=str(details)[:255],
).dict()
if not got_response and last_exc is None:
try:
- logger.error(f"{request_id} | gRPC loop ended with no response and no exception; returning 500 UNKNOWN")
+ logger.error(
+ f'{request_id} | gRPC loop ended with no response and no exception; returning 500 UNKNOWN'
+ )
except Exception:
pass
return ResponseModel(
status_code=500,
- response_headers={'request_id': request_id, 'X-Retry-Count': str(retries_made), 'X-Retry-Final': 'UNKNOWN'},
+ response_headers={
+ 'request_id': request_id,
+ 'X-Retry-Count': str(retries_made),
+ 'X-Retry-Final': 'UNKNOWN',
+ },
error_code='GTW006',
- error_message='gRPC call failed'
+ error_message='gRPC call failed',
).dict()
response_dict = {}
@@ -1981,7 +2432,11 @@ class GatewayService:
else:
response_dict[field.name] = value
backend_end_time = time.time() * 1000
- response_headers = {'request_id': request_id, 'X-Retry-Count': str(retries_made), 'X-Retry-Final': final_code_name}
+ response_headers = {
+ 'request_id': request_id,
+ 'X-Retry-Count': str(retries_made),
+ 'X-Retry-Final': final_code_name,
+ }
try:
if current_time and start_time:
response_headers['X-Gateway-Time'] = str(int(current_time - start_time))
@@ -1990,20 +2445,20 @@ class GatewayService:
except Exception:
pass
try:
- logger.info(f"{request_id} | gRPC return 200 with items={bool(response_dict.get('items'))}")
+ logger.info(
+ f'{request_id} | gRPC return 200 with items={bool(response_dict.get("items"))}'
+ )
except Exception:
pass
return ResponseModel(
- status_code=200,
- response_headers=response_headers,
- response=response_dict
+ status_code=200, response_headers=response_headers, response=response_dict
).dict()
except httpx.TimeoutException:
return ResponseModel(
status_code=504,
response_headers={'request_id': request_id},
error_code='GTW010',
- error_message='Gateway timeout'
+ error_message='Gateway timeout',
).dict()
except Exception as e:
code_name = 'UNKNOWN'
@@ -2057,10 +2512,10 @@ class GatewayService:
response_headers={
'request_id': request_id,
'X-Error-Type': type(e).__name__,
- 'X-GRPC-Status': code_name
+ 'X-GRPC-Status': code_name,
},
error_code='GTW006',
- error_message=details[:255]
+ error_message=details[:255],
).dict()
finally:
if current_time:
@@ -2068,7 +2523,9 @@ class GatewayService:
if backend_end_time and current_time:
logger.info(f'{request_id} | Backend time {backend_end_time - current_time}ms')
- async def _make_graphql_request(self, url: str, query: str, headers: Dict[str, str] = None) -> Dict:
+ async def _make_graphql_request(
+ self, url: str, query: str, headers: dict[str, str] = None
+ ) -> dict:
try:
if headers is None:
headers = {}
@@ -2079,13 +2536,22 @@ class GatewayService:
if 'errors' in data:
return data
if r.status_code != 200:
- return {'errors': [{'message': f'HTTP {r.status_code}: {data.get("message", "Unknown error")}', 'extensions': {'code': 'HTTP_ERROR'}}]}
+ return {
+ 'errors': [
+ {
+ 'message': f'HTTP {r.status_code}: {data.get("message", "Unknown error")}',
+ 'extensions': {'code': 'HTTP_ERROR'},
+ }
+ ]
+ }
return data
except Exception as e:
logger.error(f'Error making GraphQL request: {str(e)}')
return {
- 'errors': [{
- 'message': f'Error making GraphQL request: {str(e)}',
- 'extensions': {'code': 'REQUEST_ERROR'}
- }]
+ 'errors': [
+ {
+ 'message': f'Error making GraphQL request: {str(e)}',
+ 'extensions': {'code': 'REQUEST_ERROR'},
+ }
+ ]
}
diff --git a/backend-services/services/group_service.py b/backend-services/services/group_service.py
index f33c2e8..ce108f9 100644
--- a/backend-services/services/group_service.py
+++ b/backend-services/services/group_service.py
@@ -4,22 +4,22 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from pymongo.errors import DuplicateKeyError
import logging
+from pymongo.errors import DuplicateKeyError
+
+from models.create_group_model import CreateGroupModel
from models.response_model import ResponseModel
from models.update_group_model import UpdateGroupModel
-from utils.database import group_collection
-from utils.cache_manager_util import cache_manager
-from utils.doorman_cache_util import doorman_cache
-from models.create_group_model import CreateGroupModel
-from utils.paging_util import validate_page_params
from utils.constants import ErrorCodes, Messages
+from utils.database import group_collection
+from utils.doorman_cache_util import doorman_cache
+from utils.paging_util import validate_page_params
logger = logging.getLogger('doorman.gateway')
-class GroupService:
+class GroupService:
@staticmethod
async def create_group(data: CreateGroupModel, request_id):
"""
@@ -29,11 +29,9 @@ class GroupService:
if doorman_cache.get_cache('group_cache', data.group_name):
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='GRP001',
- error_message='Group already exists'
+ error_message='Group already exists',
).dict()
group_dict = data.dict()
try:
@@ -41,26 +39,19 @@ class GroupService:
if not insert_result.acknowledged:
logger.error(request_id + ' | Group creation failed with code GRP002')
return ResponseModel(
- status_code=400,
- error_code='GRP002',
- error_message='Unable to insert group'
+ status_code=400, error_code='GRP002', error_message='Unable to insert group'
).dict()
group_dict['_id'] = str(insert_result.inserted_id)
doorman_cache.set_cache('group_cache', data.group_name, group_dict)
logger.info(request_id + ' | Group creation successful')
- return ResponseModel(
- status_code=201,
- message='Group created successfully'
- ).dict()
- except DuplicateKeyError as e:
+ return ResponseModel(status_code=201, message='Group created successfully').dict()
+ except DuplicateKeyError:
logger.error(request_id + ' | Group creation failed with code GRP001')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='GRP001',
- error_message='Group already exists'
+ error_message='Group already exists',
).dict()
@staticmethod
@@ -72,53 +63,39 @@ class GroupService:
if data.group_name and data.group_name != group_name:
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='GRP004',
- error_message='Group name cannot be updated'
+ error_message='Group name cannot be updated',
).dict()
group = doorman_cache.get_cache('group_cache', group_name)
if not group:
- group = group_collection.find_one({
- 'group_name': group_name
- })
+ group = group_collection.find_one({'group_name': group_name})
if not group:
logger.error(request_id + ' | Group update failed with code GRP003')
return ResponseModel(
- status_code=400,
- error_code='GRP003',
- error_message='Group does not exist'
+ status_code=400, error_code='GRP003', error_message='Group does not exist'
).dict()
else:
doorman_cache.delete_cache('group_cache', group_name)
not_null_data = {k: v for k, v in data.dict().items() if v is not None}
if not_null_data:
update_result = group_collection.update_one(
- {'group_name': group_name},
- {'$set': not_null_data}
+ {'group_name': group_name}, {'$set': not_null_data}
)
if not update_result.acknowledged or update_result.modified_count == 0:
logger.error(request_id + ' | Group update failed with code GRP002')
return ResponseModel(
- status_code=400,
- error_code='GRP005',
- error_message='Unable to update group'
+ status_code=400, error_code='GRP005', error_message='Unable to update group'
).dict()
logger.info(request_id + ' | Group updated successful')
- return ResponseModel(
- status_code=200,
- message='Group updated successfully'
- ).dict()
+ return ResponseModel(status_code=200, message='Group updated successfully').dict()
else:
logger.error(request_id + ' | Group update failed with code GRP006')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='GRP006',
- error_message='No data to update'
+ error_message='No data to update',
).dict()
@staticmethod
@@ -129,35 +106,27 @@ class GroupService:
logger.info(request_id + ' | Deleting group: ' + group_name)
group = doorman_cache.get_cache('group_cache', group_name)
if not group:
- group = group_collection.find_one({
- 'group_name': group_name
- })
+ group = group_collection.find_one({'group_name': group_name})
if not group:
logger.error(request_id + ' | Group deletion failed with code GRP003')
return ResponseModel(
- status_code=400,
- error_code='GRP003',
- error_message='Group does not exist'
+ status_code=400, error_code='GRP003', error_message='Group does not exist'
).dict()
delete_result = group_collection.delete_one({'group_name': group_name})
if not delete_result.acknowledged:
logger.error(request_id + ' | Group deletion failed with code GRP002')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='GRP007',
- error_message='Unable to delete group'
+ error_message='Unable to delete group',
).dict()
doorman_cache.delete_cache('group_cache', group_name)
logger.info(request_id + ' | Group deletion successful')
return ResponseModel(
status_code=200,
- response_headers={
- 'request_id': request_id
- },
- message='Group deleted successfully'
+ response_headers={'request_id': request_id},
+ message='Group deleted successfully',
).dict()
@staticmethod
@@ -165,7 +134,9 @@ class GroupService:
"""
Check if a group exists.
"""
- if doorman_cache.get_cache('group_cache', data.get('group_name')) or group_collection.find_one({'group_name': data.get('group_name')}):
+ if doorman_cache.get_cache(
+ 'group_cache', data.get('group_name')
+ ) or group_collection.find_one({'group_name': data.get('group_name')}):
return True
return False
@@ -174,25 +145,27 @@ class GroupService:
"""
Get all groups.
"""
- logger.info(request_id + ' | Getting groups: Page=' + str(page) + ' Page Size=' + str(page_size))
+ logger.info(
+ request_id + ' | Getting groups: Page=' + str(page) + ' Page Size=' + str(page_size)
+ )
try:
page, page_size = validate_page_params(page, page_size)
except Exception as e:
return ResponseModel(
status_code=400,
error_code=ErrorCodes.PAGE_SIZE,
- error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING)
+ error_message=(
+ Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING
+ ),
).dict()
skip = (page - 1) * page_size
cursor = group_collection.find().sort('group_name', 1).skip(skip).limit(page_size)
groups = cursor.to_list(length=None)
for group in groups:
- if group.get('_id'): del group['_id']
+ if group.get('_id'):
+ del group['_id']
logger.info(request_id + ' | Groups retrieval successful')
- return ResponseModel(
- status_code=200,
- response={'groups': groups}
- ).dict()
+ return ResponseModel(status_code=200, response={'groups': groups}).dict()
@staticmethod
async def get_group(group_name, request_id):
@@ -206,15 +179,12 @@ class GroupService:
if not group:
logger.error(request_id + ' | Group retrieval failed with code GRP003')
return ResponseModel(
- status_code=404,
- error_code='GRP003',
- error_message='Group does not exist'
+ status_code=404, error_code='GRP003', error_message='Group does not exist'
).dict()
- if group.get('_id'): del group['_id']
+ if group.get('_id'):
+ del group['_id']
doorman_cache.set_cache('group_cache', group_name, group)
- if group.get('_id'): del group['_id']
+ if group.get('_id'):
+ del group['_id']
logger.info(request_id + ' | Group retrieval successful')
- return ResponseModel(
- status_code=200,
- response=group
- ).dict()
+ return ResponseModel(status_code=200, response=group).dict()
diff --git a/backend-services/services/logging_service.py b/backend-services/services/logging_service.py
index f1cea77..6df754a 100644
--- a/backend-services/services/logging_service.py
+++ b/backend-services/services/logging_service.py
@@ -4,17 +4,19 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
+import glob
+import json
import logging
import os
-import json
-import glob
-from datetime import datetime, timedelta
-from typing import List, Dict, Any, Optional
-from fastapi import HTTPException
import re
+from datetime import datetime
+from typing import Any
+
+from fastapi import HTTPException
logger = logging.getLogger('doorman.logging')
+
class LoggingService:
def __init__(self):
env_dir = os.getenv('LOGS_DIR')
@@ -24,43 +26,40 @@ class LoggingService:
base_dir = os.path.dirname(os.path.abspath(__file__))
backend_root = os.path.normpath(os.path.join(base_dir, '..'))
candidate = os.path.join(backend_root, 'platform-logs')
- self.log_directory = candidate if os.path.isdir(candidate) else os.path.join(backend_root, 'logs')
+ self.log_directory = (
+ candidate if os.path.isdir(candidate) else os.path.join(backend_root, 'logs')
+ )
self.log_file_patterns = ['doorman.log*', 'doorman-trail.log*']
self.max_logs_per_request = 1000
async def get_logs(
self,
- start_date: Optional[str] = None,
- end_date: Optional[str] = None,
- start_time: Optional[str] = None,
- end_time: Optional[str] = None,
- user: Optional[str] = None,
- endpoint: Optional[str] = None,
- request_id: Optional[str] = None,
- method: Optional[str] = None,
- ip_address: Optional[str] = None,
- min_response_time: Optional[str] = None,
- max_response_time: Optional[str] = None,
- level: Optional[str] = None,
+ start_date: str | None = None,
+ end_date: str | None = None,
+ start_time: str | None = None,
+ end_time: str | None = None,
+ user: str | None = None,
+ endpoint: str | None = None,
+ request_id: str | None = None,
+ method: str | None = None,
+ ip_address: str | None = None,
+ min_response_time: str | None = None,
+ max_response_time: str | None = None,
+ level: str | None = None,
limit: int = 100,
offset: int = 0,
- request_id_param: str = None
- ) -> Dict[str, Any]:
+ request_id_param: str = None,
+ ) -> dict[str, Any]:
"""
Retrieve and filter logs based on various criteria
"""
try:
-
log_files: list[str] = []
for pat in self.log_file_patterns:
log_files.extend(glob.glob(os.path.join(self.log_directory, pat)))
if not log_files:
- return {
- 'logs': [],
- 'total': 0,
- 'has_more': False
- }
+ return {'logs': [], 'total': 0, 'has_more': False}
logs = []
total_count = 0
@@ -72,23 +71,26 @@ class LoggingService:
continue
try:
- with open(log_file, 'r', encoding='utf-8') as file:
+ with open(log_file, encoding='utf-8') as file:
for line in file:
log_entry = self._parse_log_line(line)
- if log_entry and self._matches_filters(log_entry, {
- 'start_date': start_date,
- 'end_date': end_date,
- 'start_time': start_time,
- 'end_time': end_time,
- 'user': user,
- 'endpoint': endpoint,
- 'request_id': request_id,
- 'method': method,
- 'ip_address': ip_address,
- 'min_response_time': min_response_time,
- 'max_response_time': max_response_time,
- 'level': level
- }):
+ if log_entry and self._matches_filters(
+ log_entry,
+ {
+ 'start_date': start_date,
+ 'end_date': end_date,
+ 'start_time': start_time,
+ 'end_time': end_time,
+ 'user': user,
+ 'endpoint': endpoint,
+ 'request_id': request_id,
+ 'method': method,
+ 'ip_address': ip_address,
+ 'min_response_time': min_response_time,
+ 'max_response_time': max_response_time,
+ 'level': level,
+ },
+ ):
total_count += 1
if len(logs) < limit and total_count > offset:
logs.append(log_entry)
@@ -103,17 +105,13 @@ class LoggingService:
logger.warning(f'Error reading log file {log_file}: {str(e)}')
continue
- return {
- 'logs': logs,
- 'total': total_count,
- 'has_more': total_count > offset + limit
- }
+ return {'logs': logs, 'total': total_count, 'has_more': total_count > offset + limit}
except Exception as e:
logger.error(f'Error retrieving logs: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail='Failed to retrieve logs')
- def get_available_log_files(self) -> List[str]:
+ def get_available_log_files(self) -> list[str]:
"""
Get list of available log files for debugging
"""
@@ -123,12 +121,11 @@ class LoggingService:
log_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
return log_files
- async def get_log_statistics(self, request_id: str = None) -> Dict[str, Any]:
+ async def get_log_statistics(self, request_id: str = None) -> dict[str, Any]:
"""
Get log statistics for dashboard
"""
try:
-
log_files: list[str] = []
for pat in self.log_file_patterns:
log_files.extend(glob.glob(os.path.join(self.log_directory, pat)))
@@ -143,7 +140,7 @@ class LoggingService:
'avg_response_time': 0,
'top_apis': [],
'top_users': [],
- 'top_endpoints': []
+ 'top_endpoints': [],
}
stats = {
@@ -155,7 +152,7 @@ class LoggingService:
'response_times': [],
'apis': {},
'users': {},
- 'endpoints': {}
+ 'endpoints': {},
}
for log_file in log_files:
@@ -163,7 +160,7 @@ class LoggingService:
continue
try:
- with open(log_file, 'r', encoding='utf-8') as file:
+ with open(log_file, encoding='utf-8') as file:
for line in file:
log_entry = self._parse_log_line(line)
if log_entry:
@@ -181,28 +178,42 @@ class LoggingService:
if log_entry.get('response_time'):
try:
- stats['response_times'].append(float(log_entry['response_time']))
+ stats['response_times'].append(
+ float(log_entry['response_time'])
+ )
except (ValueError, TypeError):
pass
if log_entry.get('api'):
- stats['apis'][log_entry['api']] = stats['apis'].get(log_entry['api'], 0) + 1
+ stats['apis'][log_entry['api']] = (
+ stats['apis'].get(log_entry['api'], 0) + 1
+ )
if log_entry.get('user'):
- stats['users'][log_entry['user']] = stats['users'].get(log_entry['user'], 0) + 1
+ stats['users'][log_entry['user']] = (
+ stats['users'].get(log_entry['user'], 0) + 1
+ )
if log_entry.get('endpoint'):
- stats['endpoints'][log_entry['endpoint']] = stats['endpoints'].get(log_entry['endpoint'], 0) + 1
+ stats['endpoints'][log_entry['endpoint']] = (
+ stats['endpoints'].get(log_entry['endpoint'], 0) + 1
+ )
except Exception as e:
logger.warning(f'Error reading log file {log_file} for statistics: {str(e)}')
continue
- avg_response_time = sum(stats['response_times']) / len(stats['response_times']) if stats['response_times'] else 0
+ avg_response_time = (
+ sum(stats['response_times']) / len(stats['response_times'])
+ if stats['response_times']
+ else 0
+ )
top_apis = sorted(stats['apis'].items(), key=lambda x: x[1], reverse=True)[:10]
top_users = sorted(stats['users'].items(), key=lambda x: x[1], reverse=True)[:10]
- top_endpoints = sorted(stats['endpoints'].items(), key=lambda x: x[1], reverse=True)[:10]
+ top_endpoints = sorted(stats['endpoints'].items(), key=lambda x: x[1], reverse=True)[
+ :10
+ ]
return {
'total_logs': stats['total_logs'],
@@ -213,7 +224,9 @@ class LoggingService:
'avg_response_time': round(avg_response_time, 2),
'top_apis': [{'name': api, 'count': count} for api, count in top_apis],
'top_users': [{'name': user, 'count': count} for user, count in top_users],
- 'top_endpoints': [{'name': endpoint, 'count': count} for endpoint, count in top_endpoints]
+ 'top_endpoints': [
+ {'name': endpoint, 'count': count} for endpoint, count in top_endpoints
+ ],
}
except Exception as e:
@@ -223,21 +236,17 @@ class LoggingService:
async def export_logs(
self,
format: str = 'json',
- start_date: Optional[str] = None,
- end_date: Optional[str] = None,
- filters: Dict[str, Any] = None,
- request_id: str = None
- ) -> Dict[str, Any]:
+ start_date: str | None = None,
+ end_date: str | None = None,
+ filters: dict[str, Any] = None,
+ request_id: str = None,
+ ) -> dict[str, Any]:
"""
Export logs in various formats
"""
try:
-
logs_data = await self.get_logs(
- start_date=start_date,
- end_date=end_date,
- limit=10000,
- **filters if filters else {}
+ start_date=start_date, end_date=end_date, limit=10000, **filters if filters else {}
)
logs = logs_data['logs']
@@ -246,24 +255,24 @@ class LoggingService:
return {
'format': 'json',
'data': json.dumps(logs, indent=2, default=str),
- 'filename': f"logs_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
+ 'filename': f'logs_export_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json',
}
elif format.lower() == 'csv':
if not logs:
return {
'format': 'csv',
'data': 'timestamp,level,message,source,user,api,endpoint,method,status_code,response_time,ip_address,protocol,request_id,group,role\n',
- 'filename': f"logs_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
+ 'filename': f'logs_export_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv',
}
csv_data = 'timestamp,level,message,source,user,api,endpoint,method,status_code,response_time,ip_address,protocol,request_id,group,role\n'
for log in logs:
- csv_data += f"{log.get('timestamp', '')},{log.get('level', '')},{log.get('message', '').replace(',', ';')},{log.get('source', '')},{log.get('user', '')},{log.get('api', '')},{log.get('endpoint', '')},{log.get('method', '')},{log.get('status_code', '')},{log.get('response_time', '')},{log.get('ip_address', '')},{log.get('protocol', '')},{log.get('request_id', '')},{log.get('group', '')},{log.get('role', '')}\n"
+ csv_data += f'{log.get("timestamp", "")},{log.get("level", "")},{log.get("message", "").replace(",", ";")},{log.get("source", "")},{log.get("user", "")},{log.get("api", "")},{log.get("endpoint", "")},{log.get("method", "")},{log.get("status_code", "")},{log.get("response_time", "")},{log.get("ip_address", "")},{log.get("protocol", "")},{log.get("request_id", "")},{log.get("group", "")},{log.get("role", "")}\n'
return {
'format': 'csv',
'data': csv_data,
- 'filename': f"logs_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
+ 'filename': f'logs_export_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv',
}
else:
raise HTTPException(status_code=400, detail='Unsupported export format')
@@ -272,7 +281,7 @@ class LoggingService:
logger.error(f'Error exporting logs: {str(e)}')
raise HTTPException(status_code=500, detail='Failed to export logs')
- def _parse_log_line(self, line: str) -> Optional[Dict[str, Any]]:
+ def _parse_log_line(self, line: str) -> dict[str, Any] | None:
"""
Parse a log line and extract structured data
Format: timestamp - logger_name - level - request_id | message
@@ -292,7 +301,9 @@ class LoggingService:
level = rec.get('level', '')
structured = self._extract_structured_data(message)
return {
- 'timestamp': timestamp if isinstance(timestamp, str) else timestamp.isoformat(),
+ 'timestamp': timestamp
+ if isinstance(timestamp, str)
+ else timestamp.isoformat(),
'level': level,
'message': message,
'source': name,
@@ -330,7 +341,7 @@ class LoggingService:
logger.debug(f'Failed to parse log line: {str(e)}', exc_info=True)
return None
- def _extract_structured_data(self, message: str) -> Dict[str, Any]:
+ def _extract_structured_data(self, message: str) -> dict[str, Any]:
"""
Extract structured data from log message
"""
@@ -375,12 +386,11 @@ class LoggingService:
return data
- def _matches_filters(self, log_entry: Dict[str, Any], filters: Dict[str, Any]) -> bool:
+ def _matches_filters(self, log_entry: dict[str, Any], filters: dict[str, Any]) -> bool:
"""
Check if log entry matches all applied filters
"""
try:
-
timestamp_str = log_entry.get('timestamp', '')
if not timestamp_str:
return False
@@ -429,14 +439,20 @@ class LoggingService:
log_value = str(log_entry.get(field, '')).lower()
if not log_value:
- logger.debug(f"Filter '{field}' = '{filter_value}' but log has no value for '{field}'")
+ logger.debug(
+ f"Filter '{field}' = '{filter_value}' but log has no value for '{field}'"
+ )
return False
if filter_value not in log_value:
- logger.debug(f"Filter '{field}' = '{filter_value}' not found in log value '{log_value}'")
+ logger.debug(
+ f"Filter '{field}' = '{filter_value}' not found in log value '{log_value}'"
+ )
return False
else:
- logger.debug(f"Filter '{field}' = '{filter_value}' matches log value '{log_value}'")
+ logger.debug(
+ f"Filter '{field}' = '{filter_value}' matches log value '{log_value}'"
+ )
if filters.get('status_code') and filters['status_code'].strip():
try:
@@ -470,7 +486,7 @@ class LoggingService:
applied_filters = [f'{k}={v}' for k, v in filters.items() if v and str(v).strip()]
if applied_filters:
- logger.debug(f"Log entry passed all filters: {', '.join(applied_filters)}")
+ logger.debug(f'Log entry passed all filters: {", ".join(applied_filters)}')
return True
except Exception as e:
diff --git a/backend-services/services/rate_limit_rule_service.py b/backend-services/services/rate_limit_rule_service.py
index c8fe8a0..c15a3f2 100644
--- a/backend-services/services/rate_limit_rule_service.py
+++ b/backend-services/services/rate_limit_rule_service.py
@@ -6,8 +6,9 @@ Handles rule CRUD, validation, and application.
"""
import logging
-from typing import Optional, List, Dict, Any, TYPE_CHECKING
from datetime import datetime
+from typing import TYPE_CHECKING, Any
+
try:
# Use a type-only import to avoid a hard runtime dependency during tests
if TYPE_CHECKING:
@@ -17,7 +18,7 @@ try:
except Exception: # Defensive: never fail import due to typing
AsyncIOMotorDatabase = Any # type: ignore
-from models.rate_limit_models import RateLimitRule, RuleType, TimeWindow
+from models.rate_limit_models import RateLimitRule, RuleType
logger = logging.getLogger(__name__)
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
class RateLimitRuleService:
"""
Service for managing rate limit rules
-
+
Features:
- CRUD operations for rules
- Rule validation
@@ -33,248 +34,242 @@ class RateLimitRuleService:
- Bulk operations
- Rule testing
"""
-
+
def __init__(self, db: AsyncIOMotorDatabase):
"""
Initialize rate limit rule service
-
+
Args:
db: MongoDB database instance
"""
self.db = db
self.rules_collection = db.rate_limit_rules
-
+
# ========================================================================
# RULE CRUD OPERATIONS
# ========================================================================
-
+
async def create_rule(self, rule: RateLimitRule) -> RateLimitRule:
"""
Create a new rate limit rule
-
+
Args:
rule: RateLimitRule object to create
-
+
Returns:
Created rule
-
+
Raises:
ValueError: If rule with same ID already exists
"""
# Check if rule already exists
existing = await self.rules_collection.find_one({'rule_id': rule.rule_id})
if existing:
- raise ValueError(f"Rule with ID {rule.rule_id} already exists")
-
+ raise ValueError(f'Rule with ID {rule.rule_id} already exists')
+
# Set timestamps
rule.created_at = datetime.now()
rule.updated_at = datetime.now()
-
+
# Insert into database
await self.rules_collection.insert_one(rule.to_dict())
-
- logger.info(f"Created rate limit rule: {rule.rule_id}")
+
+ logger.info(f'Created rate limit rule: {rule.rule_id}')
return rule
-
- async def get_rule(self, rule_id: str) -> Optional[RateLimitRule]:
+
+ async def get_rule(self, rule_id: str) -> RateLimitRule | None:
"""
Get rule by ID
-
+
Args:
rule_id: Rule identifier
-
+
Returns:
RateLimitRule object or None if not found
"""
rule_data = await self.rules_collection.find_one({'rule_id': rule_id})
-
+
if rule_data:
return RateLimitRule.from_dict(rule_data)
-
+
return None
-
+
async def list_rules(
self,
- rule_type: Optional[RuleType] = None,
+ rule_type: RuleType | None = None,
enabled_only: bool = False,
skip: int = 0,
- limit: int = 100
- ) -> List[RateLimitRule]:
+ limit: int = 100,
+ ) -> list[RateLimitRule]:
"""
List rate limit rules
-
+
Args:
rule_type: Filter by rule type
enabled_only: Only return enabled rules
skip: Number of records to skip
limit: Maximum number of records to return
-
+
Returns:
List of rules
"""
query = {}
-
+
if rule_type:
query['rule_type'] = rule_type.value
-
+
if enabled_only:
query['enabled'] = True
-
+
# Sort by priority (highest first)
cursor = self.rules_collection.find(query).sort('priority', -1).skip(skip).limit(limit)
rules = []
-
+
async for rule_data in cursor:
rules.append(RateLimitRule.from_dict(rule_data))
-
+
return rules
-
- async def update_rule(self, rule_id: str, updates: Dict[str, Any]) -> Optional[RateLimitRule]:
+
+ async def update_rule(self, rule_id: str, updates: dict[str, Any]) -> RateLimitRule | None:
"""
Update rule
-
+
Args:
rule_id: Rule identifier
updates: Dictionary of fields to update
-
+
Returns:
Updated rule or None if not found
"""
# Add updated timestamp
updates['updated_at'] = datetime.now().isoformat()
-
+
result = await self.rules_collection.find_one_and_update(
- {'rule_id': rule_id},
- {'$set': updates},
- return_document=True
+ {'rule_id': rule_id}, {'$set': updates}, return_document=True
)
-
+
if result:
- logger.info(f"Updated rate limit rule: {rule_id}")
+ logger.info(f'Updated rate limit rule: {rule_id}')
return RateLimitRule.from_dict(result)
-
+
return None
-
+
async def delete_rule(self, rule_id: str) -> bool:
"""
Delete rule
-
+
Args:
rule_id: Rule identifier
-
+
Returns:
True if deleted, False if not found
"""
result = await self.rules_collection.delete_one({'rule_id': rule_id})
-
+
if result.deleted_count > 0:
- logger.info(f"Deleted rate limit rule: {rule_id}")
+ logger.info(f'Deleted rate limit rule: {rule_id}')
return True
-
+
return False
-
- async def enable_rule(self, rule_id: str) -> Optional[RateLimitRule]:
+
+ async def enable_rule(self, rule_id: str) -> RateLimitRule | None:
"""
Enable a rule
-
+
Args:
rule_id: Rule identifier
-
+
Returns:
Updated rule or None if not found
"""
return await self.update_rule(rule_id, {'enabled': True})
-
- async def disable_rule(self, rule_id: str) -> Optional[RateLimitRule]:
+
+ async def disable_rule(self, rule_id: str) -> RateLimitRule | None:
"""
Disable a rule
-
+
Args:
rule_id: Rule identifier
-
+
Returns:
Updated rule or None if not found
"""
return await self.update_rule(rule_id, {'enabled': False})
-
+
# ========================================================================
# RULE QUERIES
# ========================================================================
-
+
async def get_applicable_rules(
self,
- user_id: Optional[str] = None,
- api_name: Optional[str] = None,
- endpoint_uri: Optional[str] = None,
- ip_address: Optional[str] = None
- ) -> List[RateLimitRule]:
+ user_id: str | None = None,
+ api_name: str | None = None,
+ endpoint_uri: str | None = None,
+ ip_address: str | None = None,
+ ) -> list[RateLimitRule]:
"""
Get all applicable rules for a request
-
+
Args:
user_id: User identifier
api_name: API name
endpoint_uri: Endpoint URI
ip_address: IP address
-
+
Returns:
List of applicable rules sorted by priority
"""
query = {'enabled': True}
-
+
# Build OR query for applicable rules
or_conditions = []
-
+
# Global rules always apply
or_conditions.append({'rule_type': RuleType.GLOBAL.value})
-
+
# Per-user rules
if user_id:
- or_conditions.append({
- 'rule_type': RuleType.PER_USER.value,
- 'target_identifier': user_id
- })
-
+ or_conditions.append(
+ {'rule_type': RuleType.PER_USER.value, 'target_identifier': user_id}
+ )
+
# Per-API rules
if api_name:
- or_conditions.append({
- 'rule_type': RuleType.PER_API.value,
- 'target_identifier': api_name
- })
-
+ or_conditions.append(
+ {'rule_type': RuleType.PER_API.value, 'target_identifier': api_name}
+ )
+
# Per-endpoint rules
if endpoint_uri:
- or_conditions.append({
- 'rule_type': RuleType.PER_ENDPOINT.value,
- 'target_identifier': endpoint_uri
- })
-
+ or_conditions.append(
+ {'rule_type': RuleType.PER_ENDPOINT.value, 'target_identifier': endpoint_uri}
+ )
+
# Per-IP rules
if ip_address:
- or_conditions.append({
- 'rule_type': RuleType.PER_IP.value,
- 'target_identifier': ip_address
- })
-
+ or_conditions.append(
+ {'rule_type': RuleType.PER_IP.value, 'target_identifier': ip_address}
+ )
+
if or_conditions:
query['$or'] = or_conditions
-
+
# Get rules sorted by priority
cursor = self.rules_collection.find(query).sort('priority', -1)
rules = []
-
+
async for rule_data in cursor:
rules.append(RateLimitRule.from_dict(rule_data))
-
+
return rules
-
- async def search_rules(self, search_term: str) -> List[RateLimitRule]:
+
+ async def search_rules(self, search_term: str) -> list[RateLimitRule]:
"""
Search rules by ID, description, or target identifier
-
+
Args:
search_term: Search term
-
+
Returns:
List of matching rules
"""
@@ -282,142 +277,140 @@ class RateLimitRuleService:
'$or': [
{'rule_id': {'$regex': search_term, '$options': 'i'}},
{'description': {'$regex': search_term, '$options': 'i'}},
- {'target_identifier': {'$regex': search_term, '$options': 'i'}}
+ {'target_identifier': {'$regex': search_term, '$options': 'i'}},
]
}
-
+
cursor = self.rules_collection.find(query).sort('priority', -1)
rules = []
-
+
async for rule_data in cursor:
rules.append(RateLimitRule.from_dict(rule_data))
-
+
return rules
-
+
# ========================================================================
# BULK OPERATIONS
# ========================================================================
-
- async def bulk_create_rules(self, rules: List[RateLimitRule]) -> int:
+
+ async def bulk_create_rules(self, rules: list[RateLimitRule]) -> int:
"""
Create multiple rules at once
-
+
Args:
rules: List of rules to create
-
+
Returns:
Number of rules created
"""
if not rules:
return 0
-
+
# Set timestamps
now = datetime.now()
for rule in rules:
rule.created_at = now
rule.updated_at = now
-
+
# Insert all rules
rule_dicts = [rule.to_dict() for rule in rules]
result = await self.rules_collection.insert_many(rule_dicts)
-
+
count = len(result.inserted_ids)
- logger.info(f"Bulk created {count} rate limit rules")
+ logger.info(f'Bulk created {count} rate limit rules')
return count
-
- async def bulk_delete_rules(self, rule_ids: List[str]) -> int:
+
+ async def bulk_delete_rules(self, rule_ids: list[str]) -> int:
"""
Delete multiple rules at once
-
+
Args:
rule_ids: List of rule IDs to delete
-
+
Returns:
Number of rules deleted
"""
if not rule_ids:
return 0
-
- result = await self.rules_collection.delete_many({
- 'rule_id': {'$in': rule_ids}
- })
-
+
+ result = await self.rules_collection.delete_many({'rule_id': {'$in': rule_ids}})
+
count = result.deleted_count
- logger.info(f"Bulk deleted {count} rate limit rules")
+ logger.info(f'Bulk deleted {count} rate limit rules')
return count
-
- async def bulk_enable_rules(self, rule_ids: List[str]) -> int:
+
+ async def bulk_enable_rules(self, rule_ids: list[str]) -> int:
"""
Enable multiple rules at once
-
+
Args:
rule_ids: List of rule IDs to enable
-
+
Returns:
Number of rules enabled
"""
if not rule_ids:
return 0
-
+
result = await self.rules_collection.update_many(
{'rule_id': {'$in': rule_ids}},
- {'$set': {'enabled': True, 'updated_at': datetime.now().isoformat()}}
+ {'$set': {'enabled': True, 'updated_at': datetime.now().isoformat()}},
)
-
+
count = result.modified_count
- logger.info(f"Bulk enabled {count} rate limit rules")
+ logger.info(f'Bulk enabled {count} rate limit rules')
return count
-
- async def bulk_disable_rules(self, rule_ids: List[str]) -> int:
+
+ async def bulk_disable_rules(self, rule_ids: list[str]) -> int:
"""
Disable multiple rules at once
-
+
Args:
rule_ids: List of rule IDs to disable
-
+
Returns:
Number of rules disabled
"""
if not rule_ids:
return 0
-
+
result = await self.rules_collection.update_many(
{'rule_id': {'$in': rule_ids}},
- {'$set': {'enabled': False, 'updated_at': datetime.now().isoformat()}}
+ {'$set': {'enabled': False, 'updated_at': datetime.now().isoformat()}},
)
-
+
count = result.modified_count
- logger.info(f"Bulk disabled {count} rate limit rules")
+ logger.info(f'Bulk disabled {count} rate limit rules')
return count
-
+
# ========================================================================
# RULE DUPLICATION
# ========================================================================
-
+
async def duplicate_rule(self, rule_id: str, new_rule_id: str) -> RateLimitRule:
"""
Duplicate an existing rule
-
+
Args:
rule_id: Source rule ID
new_rule_id: New rule ID
-
+
Returns:
Duplicated rule
-
+
Raises:
ValueError: If source rule not found or new ID already exists
"""
# Get source rule
source_rule = await self.get_rule(rule_id)
if not source_rule:
- raise ValueError(f"Source rule {rule_id} not found")
-
+ raise ValueError(f'Source rule {rule_id} not found')
+
# Check if new ID already exists
existing = await self.rules_collection.find_one({'rule_id': new_rule_id})
if existing:
- raise ValueError(f"Rule with ID {new_rule_id} already exists")
-
+ raise ValueError(f'Rule with ID {new_rule_id} already exists')
+
# Create new rule with same properties
new_rule = RateLimitRule(
rule_id=new_rule_id,
@@ -428,90 +421,93 @@ class RateLimitRuleService:
burst_allowance=source_rule.burst_allowance,
priority=source_rule.priority,
enabled=source_rule.enabled,
- description=f"Copy of {source_rule.rule_id}"
+ description=f'Copy of {source_rule.rule_id}',
)
-
+
return await self.create_rule(new_rule)
-
+
# ========================================================================
# RULE STATISTICS
# ========================================================================
-
- async def get_rule_statistics(self) -> Dict[str, Any]:
+
+ async def get_rule_statistics(self) -> dict[str, Any]:
"""
Get statistics about rate limit rules
-
+
Returns:
Dictionary with rule statistics
"""
total_rules = await self.rules_collection.count_documents({})
enabled_rules = await self.rules_collection.count_documents({'enabled': True})
disabled_rules = total_rules - enabled_rules
-
+
# Count by type
type_counts = {}
for rule_type in RuleType:
- count = await self.rules_collection.count_documents({
- 'rule_type': rule_type.value
- })
+ count = await self.rules_collection.count_documents({'rule_type': rule_type.value})
type_counts[rule_type.value] = count
-
+
return {
'total_rules': total_rules,
'enabled_rules': enabled_rules,
'disabled_rules': disabled_rules,
- 'rules_by_type': type_counts
+ 'rules_by_type': type_counts,
}
-
+
# ========================================================================
# RULE VALIDATION
# ========================================================================
-
- def validate_rule(self, rule: RateLimitRule) -> List[str]:
+
+ def validate_rule(self, rule: RateLimitRule) -> list[str]:
"""
Validate a rule
-
+
Args:
rule: Rule to validate
-
+
Returns:
List of validation errors (empty if valid)
"""
errors = []
-
+
# Check limit is positive
if rule.limit <= 0:
- errors.append("Limit must be greater than 0")
-
+ errors.append('Limit must be greater than 0')
+
# Check burst allowance is non-negative
if rule.burst_allowance < 0:
- errors.append("Burst allowance cannot be negative")
-
+ errors.append('Burst allowance cannot be negative')
+
# Check target identifier for specific rule types
- if rule.rule_type in [RuleType.PER_USER, RuleType.PER_API, RuleType.PER_ENDPOINT, RuleType.PER_IP]:
+ if rule.rule_type in [
+ RuleType.PER_USER,
+ RuleType.PER_API,
+ RuleType.PER_ENDPOINT,
+ RuleType.PER_IP,
+ ]:
if not rule.target_identifier:
- errors.append(f"Target identifier required for {rule.rule_type.value} rules")
-
+ errors.append(f'Target identifier required for {rule.rule_type.value} rules')
+
return errors
# Global rule service instance
-_rule_service: Optional[RateLimitRuleService] = None
+_rule_service: RateLimitRuleService | None = None
def get_rate_limit_rule_service(db: AsyncIOMotorDatabase) -> RateLimitRuleService:
"""
Get or create global rule service instance
-
+
Args:
db: MongoDB database instance
-
+
Returns:
RateLimitRuleService instance
"""
global _rule_service
-
+
if _rule_service is None:
_rule_service = RateLimitRuleService(db)
-
+
return _rule_service
diff --git a/backend-services/services/role_service.py b/backend-services/services/role_service.py
index 124d895..2ab4908 100644
--- a/backend-services/services/role_service.py
+++ b/backend-services/services/role_service.py
@@ -4,23 +4,23 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from pymongo.errors import DuplicateKeyError
import logging
+from pymongo.errors import DuplicateKeyError
+
+from models.create_role_model import CreateRoleModel
from models.response_model import ResponseModel
from models.update_role_model import UpdateRoleModel
-from utils.database_async import role_collection
-from utils.async_db import db_find_one, db_insert_one, db_update_one, db_delete_one, db_find_list
-from utils.cache_manager_util import cache_manager
-from utils.doorman_cache_util import doorman_cache
-from models.create_role_model import CreateRoleModel
-from utils.paging_util import validate_page_params
+from utils.async_db import db_delete_one, db_find_list, db_find_one, db_insert_one, db_update_one
from utils.constants import ErrorCodes, Messages
+from utils.database_async import role_collection
+from utils.doorman_cache_util import doorman_cache
+from utils.paging_util import validate_page_params
logger = logging.getLogger('doorman.gateway')
-class RoleService:
+class RoleService:
@staticmethod
async def create_role(data: CreateRoleModel, request_id):
"""
@@ -31,11 +31,9 @@ class RoleService:
logger.error(request_id + ' | Role creation failed with code ROLE001')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='ROLE001',
- error_message='Role already exists'
+ error_message='Role already exists',
).dict()
role_dict = data.dict()
try:
@@ -43,26 +41,19 @@ class RoleService:
if not insert_result.acknowledged:
logger.error(request_id + ' | Role creation failed with code ROLE002')
return ResponseModel(
- status_code=400,
- error_code='ROLE002',
- error_message='Unable to insert role'
+ status_code=400, error_code='ROLE002', error_message='Unable to insert role'
).dict()
role_dict['_id'] = str(insert_result.inserted_id)
doorman_cache.set_cache('role_cache', data.role_name, role_dict)
logger.info(request_id + ' | Role creation successful')
- return ResponseModel(
- status_code=201,
- message='Role created successfully'
- ).dict()
- except DuplicateKeyError as e:
+ return ResponseModel(status_code=201, message='Role created successfully').dict()
+ except DuplicateKeyError:
logger.error(request_id + ' | Role creation failed with code ROLE001')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='ROLE001',
- error_message='Role already exists'
+ error_message='Role already exists',
).dict()
@staticmethod
@@ -75,34 +66,29 @@ class RoleService:
logger.error(request_id + ' | Role update failed with code ROLE005')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='ROLE005',
- error_message='Role name cannot be changed'
+ error_message='Role name cannot be changed',
).dict()
role = doorman_cache.get_cache('role_cache', role_name)
if not role:
- role = await db_find_one(role_collection, {
- 'role_name': role_name
- })
+ role = await db_find_one(role_collection, {'role_name': role_name})
if not role:
logger.error(request_id + ' | Role update failed with code ROLE004')
return ResponseModel(
- status_code=400,
- error_code='ROLE004',
- error_message='Role does not exist'
+ status_code=400, error_code='ROLE004', error_message='Role does not exist'
).dict()
else:
doorman_cache.delete_cache('role_cache', role_name)
not_null_data = {k: v for k, v in data.dict().items() if v is not None}
if not_null_data:
try:
- update_result = await db_update_one(role_collection, {'role_name': role_name}, {'$set': not_null_data})
+ update_result = await db_update_one(
+ role_collection, {'role_name': role_name}, {'$set': not_null_data}
+ )
if update_result.modified_count > 0:
doorman_cache.delete_cache('role_cache', role_name)
if not update_result.acknowledged or update_result.modified_count == 0:
-
current = await db_find_one(role_collection, {'role_name': role_name}) or {}
is_applied = all(current.get(k) == v for k, v in not_null_data.items())
if not is_applied:
@@ -110,31 +96,30 @@ class RoleService:
return ResponseModel(
status_code=400,
error_code='ROLE006',
- error_message='Unable to update role'
+ error_message='Unable to update role',
).dict()
except Exception as e:
doorman_cache.delete_cache('role_cache', role_name)
- logger.error(request_id + ' | Role update failed with exception: ' + str(e), exc_info=True)
+ logger.error(
+ request_id + ' | Role update failed with exception: ' + str(e), exc_info=True
+ )
raise
updated_role = await db_find_one(role_collection, {'role_name': role_name}) or {}
- if updated_role.get('_id'): del updated_role['_id']
+ if updated_role.get('_id'):
+ del updated_role['_id']
doorman_cache.set_cache('role_cache', role_name, updated_role)
logger.info(request_id + ' | Role update successful')
return ResponseModel(
- status_code=200,
- response=updated_role,
- message='Role updated successfully'
+ status_code=200, response=updated_role, message='Role updated successfully'
).dict()
else:
logger.error(request_id + ' | Role update failed with code ROLE007')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='ROLE007',
- error_message='No data to update'
+ error_message='No data to update',
).dict()
@staticmethod
@@ -149,9 +134,7 @@ class RoleService:
if not role:
logger.error(request_id + ' | Role deletion failed with code ROLE004')
return ResponseModel(
- status_code=400,
- error_code='ROLE004',
- error_message='Role does not exist'
+ status_code=400, error_code='ROLE004', error_message='Role does not exist'
).dict()
else:
doorman_cache.delete_cache('role_cache', role_name)
@@ -160,19 +143,15 @@ class RoleService:
logger.error(request_id + ' | Role deletion failed with code ROLE008')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='ROLE008',
- error_message='Unable to delete role'
+ error_message='Unable to delete role',
).dict()
logger.info(request_id + ' | Role Deletion Successful')
return ResponseModel(
status_code=200,
- response_headers={
- 'request_id': request_id
- },
- message='Role deleted successfully'
+ response_headers={'request_id': request_id},
+ message='Role deleted successfully',
).dict()
@staticmethod
@@ -180,7 +159,9 @@ class RoleService:
"""
Check if a role exists.
"""
- if doorman_cache.get_cache('role_cache', data.get('role_name')) or await db_find_one(role_collection, {'role_name': data.get('role_name')}):
+ if doorman_cache.get_cache('role_cache', data.get('role_name')) or await db_find_one(
+ role_collection, {'role_name': data.get('role_name')}
+ ):
return True
return False
@@ -189,26 +170,28 @@ class RoleService:
"""
Get all roles.
"""
- logger.info(request_id + ' | Getting roles: Page=' + str(page) + ' Page Size=' + str(page_size))
+ logger.info(
+ request_id + ' | Getting roles: Page=' + str(page) + ' Page Size=' + str(page_size)
+ )
try:
page, page_size = validate_page_params(page, page_size)
except Exception as e:
return ResponseModel(
status_code=400,
error_code=ErrorCodes.PAGE_SIZE,
- error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING)
+ error_message=(
+ Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING
+ ),
).dict()
skip = (page - 1) * page_size
roles_all = await db_find_list(role_collection, {})
roles_all.sort(key=lambda r: r.get('role_name'))
- roles = roles_all[skip: skip + page_size]
+ roles = roles_all[skip : skip + page_size]
for role in roles:
- if role.get('_id'): del role['_id']
+ if role.get('_id'):
+ del role['_id']
logger.info(request_id + ' | Roles retrieval successful')
- return ResponseModel(
- status_code=200,
- response={'roles': roles}
- ).dict()
+ return ResponseModel(status_code=200, response={'roles': roles}).dict()
@staticmethod
async def get_role(role_name, request_id):
@@ -222,15 +205,12 @@ class RoleService:
if not role:
logger.error(request_id + ' | Role retrieval failed with code ROLE004')
return ResponseModel(
- status_code=404,
- error_code='ROLE004',
- error_message='Role does not exist'
+ status_code=404, error_code='ROLE004', error_message='Role does not exist'
).dict()
- if role.get('_id'): del role['_id']
+ if role.get('_id'):
+ del role['_id']
doorman_cache.set_cache('role_cache', role_name, role)
- if role.get('_id'): del role['_id']
+ if role.get('_id'):
+ del role['_id']
logger.info(request_id + ' | Role retrieval successful')
- return ResponseModel(
- status_code=200,
- response=role
- ).dict()
+ return ResponseModel(status_code=200, response=role).dict()
diff --git a/backend-services/services/routing_service.py b/backend-services/services/routing_service.py
index 1619154..ffa571f 100644
--- a/backend-services/services/routing_service.py
+++ b/backend-services/services/routing_service.py
@@ -4,22 +4,23 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/doorman for more information
"""
-from pymongo.errors import DuplicateKeyError
-import uuid
import logging
+import uuid
+
+from pymongo.errors import DuplicateKeyError
-from models.response_model import ResponseModel
from models.create_routing_model import CreateRoutingModel
+from models.response_model import ResponseModel
from models.update_routing_model import UpdateRoutingModel
+from utils.constants import ErrorCodes, Messages
from utils.database import routing_collection
from utils.doorman_cache_util import doorman_cache
from utils.paging_util import validate_page_params
-from utils.constants import ErrorCodes, Messages
logger = logging.getLogger('doorman.gateway')
-class RoutingService:
+class RoutingService:
@staticmethod
async def create_routing(data: CreateRoutingModel, request_id):
"""
@@ -31,11 +32,9 @@ class RoutingService:
logger.error(request_id + ' | Routing creation failed with code RTG001')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='RTG001',
- error_message='Routing already exists'
+ error_message='Routing already exists',
).dict()
routing_dict = data.dict()
try:
@@ -43,26 +42,21 @@ class RoutingService:
if not insert_result.acknowledged:
logger.error(request_id + ' | Routing creation failed with code RTG002')
return ResponseModel(
- status_code=400,
- error_code='RTG002',
- error_message='Unable to insert routing'
+ status_code=400, error_code='RTG002', error_message='Unable to insert routing'
).dict()
routing_dict['_id'] = str(insert_result.inserted_id)
doorman_cache.set_cache('client_routing_cache', data.client_key, routing_dict)
logger.info(request_id + ' | Routing creation successful')
return ResponseModel(
- status_code=201,
- message='Routing created successfully with key: ' + data.client_key,
+ status_code=201, message='Routing created successfully with key: ' + data.client_key
).dict()
- except DuplicateKeyError as e:
+ except DuplicateKeyError:
logger.error(request_id + ' | Routing creation failed with code RTG001')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='RTG001',
- error_message='Routing already exists'
+ error_message='Routing already exists',
).dict()
@staticmethod
@@ -75,50 +69,39 @@ class RoutingService:
logger.error(request_id + ' | Role update failed with code ROLE005')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='RTG005',
- error_message='Routing key cannot be changed'
+ error_message='Routing key cannot be changed',
).dict()
routing = doorman_cache.get_cache('client_routing_cache', client_key)
if not routing:
- routing = routing_collection.find_one({
- 'client_key': client_key
- })
+ routing = routing_collection.find_one({'client_key': client_key})
if not routing:
logger.error(request_id + ' | Routing update failed with code RTG004')
return ResponseModel(
- status_code=400,
- error_code='RTG004',
- error_message='Routing does not exist'
+ status_code=400, error_code='RTG004', error_message='Routing does not exist'
).dict()
else:
doorman_cache.delete_cache('client_routing_cache', client_key)
not_null_data = {k: v for k, v in data.dict().items() if v is not None}
if not_null_data:
- update_result = routing_collection.update_one({'client_key': client_key}, {'$set': not_null_data})
+ update_result = routing_collection.update_one(
+ {'client_key': client_key}, {'$set': not_null_data}
+ )
if not update_result.acknowledged or update_result.modified_count == 0:
logger.error(request_id + ' | Routing update failed with code RTG006')
return ResponseModel(
- status_code=400,
- error_code='RTG006',
- error_message='Unable to update routing'
+ status_code=400, error_code='RTG006', error_message='Unable to update routing'
).dict()
logger.info(request_id + ' | Routing update successful')
- return ResponseModel(
- status_code=200,
- message='Routing updated successfully'
- ).dict()
+ return ResponseModel(status_code=200, message='Routing updated successfully').dict()
else:
logger.error(request_id + ' | Routing update failed with code RTG007')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='RTG007',
- error_message='No data to update'
+ error_message='No data to update',
).dict()
@staticmethod
@@ -129,15 +112,11 @@ class RoutingService:
logger.info(request_id + ' | Deleting: ' + client_key)
routing = doorman_cache.get_cache('client_routing_cache', client_key)
if not routing:
- routing = routing_collection.find_one({
- 'client_key': client_key
- })
+ routing = routing_collection.find_one({'client_key': client_key})
if not routing:
logger.error(request_id + ' | Routing deletion failed with code RTG004')
return ResponseModel(
- status_code=400,
- error_code='RTG004',
- error_message='Routing does not exist'
+ status_code=400, error_code='RTG004', error_message='Routing does not exist'
).dict()
else:
doorman_cache.delete_cache('client_routing_cache', client_key)
@@ -146,19 +125,15 @@ class RoutingService:
logger.error(request_id + ' | Routing deletion failed with code RTG008')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='RTG008',
- error_message='Unable to delete routing'
+ error_message='Unable to delete routing',
).dict()
logger.info(request_id + ' | Routing deletion successful')
return ResponseModel(
status_code=200,
- response_headers={
- 'request_id': request_id
- },
- message='Routing deleted successfully'
+ response_headers={'request_id': request_id},
+ message='Routing deleted successfully',
).dict()
@staticmethod
@@ -169,44 +144,40 @@ class RoutingService:
logger.info(request_id + ' | Getting: ' + client_key)
routing = doorman_cache.get_cache('client_routing_cache', client_key)
if not routing:
- routing = routing_collection.find_one({
- 'client_key': client_key
- })
+ routing = routing_collection.find_one({'client_key': client_key})
if not routing:
logger.error(request_id + ' | Routing retrieval failed with code RTG004')
return ResponseModel(
- status_code=400,
- error_code='RTG004',
- error_message='Routing does not exist'
+ status_code=400, error_code='RTG004', error_message='Routing does not exist'
).dict()
logger.info(request_id + ' | Routing retrieval successful')
- if routing.get('_id'): del routing['_id']
- return ResponseModel(
- status_code=200,
- response=routing
- ).dict()
+ if routing.get('_id'):
+ del routing['_id']
+ return ResponseModel(status_code=200, response=routing).dict()
@staticmethod
async def get_routings(page=1, page_size=10, request_id=None):
"""
Get all routings.
"""
- logger.info(request_id + ' | Getting routings: Page=' + str(page) + ' Page Size=' + str(page_size))
+ logger.info(
+ request_id + ' | Getting routings: Page=' + str(page) + ' Page Size=' + str(page_size)
+ )
try:
page, page_size = validate_page_params(page, page_size)
except Exception as e:
return ResponseModel(
status_code=400,
error_code=ErrorCodes.PAGE_SIZE,
- error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING)
+ error_message=(
+ Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING
+ ),
).dict()
skip = (page - 1) * page_size
cursor = routing_collection.find().sort('client_key', 1).skip(skip).limit(page_size)
routings = cursor.to_list(length=None)
for route in routings:
- if route.get('_id'): del route['_id']
+ if route.get('_id'):
+ del route['_id']
logger.info(request_id + ' | Routing retrieval successful')
- return ResponseModel(
- status_code=200,
- response={'routings': routings}
- ).dict()
+ return ResponseModel(status_code=200, response={'routings': routings}).dict()
diff --git a/backend-services/services/subscription_service.py b/backend-services/services/subscription_service.py
index 59cf735..cbdcbd5 100644
--- a/backend-services/services/subscription_service.py
+++ b/backend-services/services/subscription_service.py
@@ -7,15 +7,14 @@ See https://github.com/pypeople-dev/doorman for more information
import logging
from models.response_model import ResponseModel
-from utils.database import subscriptions_collection, api_collection
-from utils.cache_manager_util import cache_manager
-from utils.doorman_cache_util import doorman_cache
from models.subscribe_model import SubscribeModel
+from utils.database import api_collection, subscriptions_collection
+from utils.doorman_cache_util import doorman_cache
logger = logging.getLogger('doorman.gateway')
-class SubscriptionService:
+class SubscriptionService:
@staticmethod
async def api_exists(api_name, api_version):
"""
@@ -40,12 +39,15 @@ class SubscriptionService:
if not api.get('api_id'):
# Ensure api_id exists for consistent caching
import uuid as _uuid
+
api['api_id'] = api.get('api_id') or str(_uuid.uuid4())
except Exception:
pass
doorman_cache.set_cache('api_cache', path_key, api)
try:
- doorman_cache.set_cache('api_id_cache', f'/{api_name}/{api_version}', api.get('api_id'))
+ doorman_cache.set_cache(
+ 'api_id_cache', f'/{api_name}/{api_version}', api.get('api_id')
+ )
except Exception:
pass
if api and '_id' in api:
@@ -60,16 +62,16 @@ class SubscriptionService:
logger.info(f'{request_id} | Getting subscriptions for: {username}')
subscriptions = doorman_cache.get_cache('user_subscription_cache', username)
- if not subscriptions or not isinstance(subscriptions, dict) or not subscriptions.get('apis'):
-
+ if (
+ not subscriptions
+ or not isinstance(subscriptions, dict)
+ or not subscriptions.get('apis')
+ ):
subscriptions = subscriptions_collection.find_one({'username': username})
if not subscriptions:
logger.info(f'{request_id} | No subscriptions found; returning empty list')
- return ResponseModel(
- status_code=200,
- response={'apis': []}
- ).dict()
+ return ResponseModel(status_code=200, response={'apis': []}).dict()
if subscriptions.get('_id'):
del subscriptions['_id']
@@ -78,8 +80,7 @@ class SubscriptionService:
apis = subscriptions.get('apis', []) if isinstance(subscriptions, dict) else []
logger.info(f'{request_id} | Subscriptions retrieved successfully')
return ResponseModel(
- status_code=200,
- response={'apis': apis, 'subscriptions': {'apis': apis}}
+ status_code=200, response={'apis': apis, 'subscriptions': {'apis': apis}}
).dict()
@staticmethod
@@ -87,17 +88,17 @@ class SubscriptionService:
"""
Subscribe to an API.
"""
- logger.info(f'{request_id} | Subscribing {data.username} to API: {data.api_name}/{data.api_version}')
+ logger.info(
+ f'{request_id} | Subscribing {data.username} to API: {data.api_name}/{data.api_version}'
+ )
api = await SubscriptionService.api_exists(data.api_name, data.api_version)
if not api:
logger.error(f'{request_id} | Subscription failed with code SUB003')
return ResponseModel(
status_code=404,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='SUB003',
- error_message='API does not exist for the requested name and version'
+ error_message='API does not exist for the requested name and version',
).dict()
doorman_cache.delete_cache('user_subscription_cache', data.username)
@@ -107,23 +108,24 @@ class SubscriptionService:
if user_subscriptions is None:
user_subscriptions = {
'username': data.username,
- 'apis': [f'{data.api_name}/{data.api_version}']
+ 'apis': [f'{data.api_name}/{data.api_version}'],
}
subscriptions_collection.insert_one(user_subscriptions)
- elif 'apis' in user_subscriptions and f'{data.api_name}/{data.api_version}' in user_subscriptions['apis']:
+ elif (
+ 'apis' in user_subscriptions
+ and f'{data.api_name}/{data.api_version}' in user_subscriptions['apis']
+ ):
logger.error(f'{request_id} | Subscription failed with code SUB004')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='SUB004',
- error_message='User is already subscribed to the API'
+ error_message='User is already subscribed to the API',
).dict()
else:
subscriptions_collection.update_one(
{'username': data.username},
- {'$push': {'apis': f'{data.api_name}/{data.api_version}'}}
+ {'$push': {'apis': f'{data.api_name}/{data.api_version}'}},
)
user_subscriptions = subscriptions_collection.find_one({'username': data.username})
@@ -133,8 +135,7 @@ class SubscriptionService:
doorman_cache.set_cache('user_subscription_cache', data.username, user_subscriptions)
logger.info(f'{request_id} | Subscription successful')
return ResponseModel(
- status_code=200,
- response={'message': 'Successfully subscribed to the API'}
+ status_code=200, response={'message': 'Successfully subscribed to the API'}
).dict()
@staticmethod
@@ -142,36 +143,36 @@ class SubscriptionService:
"""
Unsubscribe from an API.
"""
- logger.info(f'{request_id} | Unsubscribing {data.username} from API: {data.api_name}/{data.api_version}')
+ logger.info(
+ f'{request_id} | Unsubscribing {data.username} from API: {data.api_name}/{data.api_version}'
+ )
api = await SubscriptionService.api_exists(data.api_name, data.api_version)
if not api:
return ResponseModel(
status_code=404,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='SUB005',
- error_message='API does not exist for the requested name and version'
+ error_message='API does not exist for the requested name and version',
).dict()
doorman_cache.delete_cache('user_subscription_cache', data.username)
user_subscriptions = subscriptions_collection.find_one({'username': data.username})
if user_subscriptions and '_id' in user_subscriptions:
del user_subscriptions['_id']
- if not user_subscriptions or f'{data.api_name}/{data.api_version}' not in user_subscriptions.get('apis', []):
+ if (
+ not user_subscriptions
+ or f'{data.api_name}/{data.api_version}' not in user_subscriptions.get('apis', [])
+ ):
logger.error(f'{request_id} | Unsubscription failed with code SUB006')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='SUB006',
- error_message='User is not subscribed to the API'
+ error_message='User is not subscribed to the API',
).dict()
user_subscriptions['apis'].remove(f'{data.api_name}/{data.api_version}')
subscriptions_collection.update_one(
- {'username': data.username},
- {'$set': {'apis': user_subscriptions.get('apis', [])}}
+ {'username': data.username}, {'$set': {'apis': user_subscriptions.get('apis', [])}}
)
user_subscriptions = subscriptions_collection.find_one({'username': data.username})
@@ -181,6 +182,5 @@ class SubscriptionService:
doorman_cache.set_cache('user_subscription_cache', data.username, user_subscriptions)
logger.info(f'{request_id} | Unsubscription successful')
return ResponseModel(
- status_code=200,
- response={'message': 'Successfully unsubscribed from the API'}
+ status_code=200, response={'message': 'Successfully unsubscribed from the API'}
).dict()
diff --git a/backend-services/services/tier_service.py b/backend-services/services/tier_service.py
index 608a9cc..387e2db 100644
--- a/backend-services/services/tier_service.py
+++ b/backend-services/services/tier_service.py
@@ -6,15 +6,10 @@ Handles tier CRUD, user assignments, upgrades, downgrades, and transitions.
"""
import logging
-from typing import Optional, List, Dict, Any
from datetime import datetime, timedelta
+from typing import Any
-from models.rate_limit_models import (
- Tier,
- TierLimits,
- TierName,
- UserTierAssignment
-)
+from models.rate_limit_models import Tier, TierLimits, TierName, UserTierAssignment
logger = logging.getLogger(__name__)
@@ -22,7 +17,7 @@ logger = logging.getLogger(__name__)
class TierService:
"""
Service for managing tiers and user assignments
-
+
Features:
- CRUD operations for tiers
- User-to-tier assignments
@@ -31,204 +26,203 @@ class TierService:
- Temporary tier assignments
- Tier comparison and selection
"""
-
+
def __init__(self, db):
"""
Initialize tier service
-
+
Args:
db: MongoDB database instance (sync) or InMemoryDB
"""
self.db = db
self.tiers_collection = db.tiers
self.assignments_collection = db.user_tier_assignments
-
+
# ========================================================================
# TIER CRUD OPERATIONS
# ========================================================================
-
+
async def create_tier(self, tier: Tier) -> Tier:
"""
Create a new tier
-
+
Args:
tier: Tier object to create
-
+
Returns:
Created tier
-
+
Raises:
ValueError: If tier with same ID already exists
"""
# Check if tier already exists
existing = await self.tiers_collection.find_one({'tier_id': tier.tier_id})
if existing:
- raise ValueError(f"Tier with ID {tier.tier_id} already exists")
-
+ raise ValueError(f'Tier with ID {tier.tier_id} already exists')
+
# Set timestamps
tier.created_at = datetime.now()
tier.updated_at = datetime.now()
-
+
# Insert into database
await self.tiers_collection.insert_one(tier.to_dict())
-
- logger.info(f"Created tier: {tier.tier_id}")
+
+ logger.info(f'Created tier: {tier.tier_id}')
return tier
-
- async def get_tier(self, tier_id: str) -> Optional[Tier]:
+
+ async def get_tier(self, tier_id: str) -> Tier | None:
"""
Get tier by ID
-
+
Args:
tier_id: Tier identifier
-
+
Returns:
Tier object or None if not found
"""
tier_data = await self.tiers_collection.find_one({'tier_id': tier_id})
-
+
if tier_data:
return Tier.from_dict(tier_data)
-
+
return None
-
- async def get_tier_by_name(self, name: TierName) -> Optional[Tier]:
+
+ async def get_tier_by_name(self, name: TierName) -> Tier | None:
"""
Get tier by name
-
+
Args:
name: Tier name enum
-
+
Returns:
Tier object or None if not found
"""
tier_data = await self.tiers_collection.find_one({'name': name.value})
-
+
if tier_data:
return Tier.from_dict(tier_data)
-
+
return None
-
+
async def list_tiers(
self,
enabled_only: bool = False,
- search_term: Optional[str] = None,
+ search_term: str | None = None,
skip: int = 0,
- limit: int = 100
- ) -> List[Tier]:
+ limit: int = 100,
+ ) -> list[Tier]:
"""
List all tiers
-
+
Args:
enabled_only: Only return enabled tiers
skip: Number of records to skip
limit: Maximum number of records to return
-
+
Returns:
List of tiers
"""
query = {}
if enabled_only:
query['enabled'] = True
-
+
cursor = self.tiers_collection.find(query).skip(skip).limit(limit)
- tiers: List[Tier] = []
-
+ tiers: list[Tier] = []
+
async for tier_data in cursor:
tiers.append(Tier.from_dict(tier_data))
-
+
if search_term:
term = search_term.lower()
tiers = [
- tier for tier in tiers
+ tier
+ for tier in tiers
if term in (tier.name or '').lower()
or term in (tier.display_name or '').lower()
or term in (tier.description or '').lower()
]
-
+
return tiers
-
- async def update_tier(self, tier_id: str, updates: Dict[str, Any]) -> Optional[Tier]:
+
+ async def update_tier(self, tier_id: str, updates: dict[str, Any]) -> Tier | None:
"""
Update tier
-
+
Args:
tier_id: Tier identifier
updates: Dictionary of fields to update
-
+
Returns:
Updated tier or None if not found
"""
# Add updated timestamp
updates['updated_at'] = datetime.now().isoformat()
-
+
result = await self.tiers_collection.find_one_and_update(
- {'tier_id': tier_id},
- {'$set': updates},
- return_document=True
+ {'tier_id': tier_id}, {'$set': updates}, return_document=True
)
-
+
if result:
- logger.info(f"Updated tier: {tier_id}")
+ logger.info(f'Updated tier: {tier_id}')
return Tier.from_dict(result)
-
+
return None
-
+
async def delete_tier(self, tier_id: str) -> bool:
"""
Delete tier
-
+
Args:
tier_id: Tier identifier
-
+
Returns:
True if deleted, False if not found
"""
# Check if any users are assigned to this tier
user_count = await self.assignments_collection.count_documents({'tier_id': tier_id})
-
+
if user_count > 0:
- raise ValueError(f"Cannot delete tier {tier_id}: {user_count} users are assigned to it")
-
+ raise ValueError(f'Cannot delete tier {tier_id}: {user_count} users are assigned to it')
+
result = await self.tiers_collection.delete_one({'tier_id': tier_id})
-
+
if result.deleted_count > 0:
- logger.info(f"Deleted tier: {tier_id}")
+ logger.info(f'Deleted tier: {tier_id}')
return True
-
+
return False
-
- async def get_default_tier(self) -> Optional[Tier]:
+
+ async def get_default_tier(self) -> Tier | None:
"""
Get the default tier
-
+
Returns:
Default tier or None if not set
"""
tier_data = await self.tiers_collection.find_one({'is_default': True})
-
+
if tier_data:
return Tier.from_dict(tier_data)
-
+
return None
-
+
# ========================================================================
# USER TIER ASSIGNMENTS
# ========================================================================
-
+
async def assign_user_to_tier(
self,
user_id: str,
tier_id: str,
- assigned_by: Optional[str] = None,
- effective_from: Optional[datetime] = None,
- effective_until: Optional[datetime] = None,
- override_limits: Optional[TierLimits] = None,
- notes: Optional[str] = None
+ assigned_by: str | None = None,
+ effective_from: datetime | None = None,
+ effective_until: datetime | None = None,
+ override_limits: TierLimits | None = None,
+ notes: str | None = None,
) -> UserTierAssignment:
"""
Assign user to a tier
-
+
Args:
user_id: User identifier
tier_id: Tier identifier
@@ -237,21 +231,21 @@ class TierService:
effective_until: When assignment expires
override_limits: Custom limits for this user
notes: Assignment notes
-
+
Returns:
UserTierAssignment object
-
+
Raises:
ValueError: If tier doesn't exist
"""
# Verify tier exists
tier = await self.get_tier(tier_id)
if not tier:
- raise ValueError(f"Tier {tier_id} not found")
-
+ raise ValueError(f'Tier {tier_id} not found')
+
# Check if user already has an assignment
existing = await self.assignments_collection.find_one({'user_id': user_id})
-
+
assignment = UserTierAssignment(
user_id=user_id,
tier_id=tier_id,
@@ -260,257 +254,249 @@ class TierService:
effective_until=effective_until,
assigned_at=datetime.now(),
assigned_by=assigned_by,
- notes=notes
+ notes=notes,
)
-
+
if existing:
# Update existing assignment
await self.assignments_collection.replace_one(
- {'user_id': user_id},
- assignment.to_dict()
+ {'user_id': user_id}, assignment.to_dict()
)
- logger.info(f"Updated tier assignment for user {user_id} to {tier_id}")
+ logger.info(f'Updated tier assignment for user {user_id} to {tier_id}')
else:
# Create new assignment
await self.assignments_collection.insert_one(assignment.to_dict())
- logger.info(f"Assigned user {user_id} to tier {tier_id}")
-
+ logger.info(f'Assigned user {user_id} to tier {tier_id}')
+
return assignment
-
- async def get_user_assignment(self, user_id: str) -> Optional[UserTierAssignment]:
+
+ async def get_user_assignment(self, user_id: str) -> UserTierAssignment | None:
"""
Get user's tier assignment
-
+
Args:
user_id: User identifier
-
+
Returns:
UserTierAssignment or None if not assigned
"""
assignment_data = await self.assignments_collection.find_one({'user_id': user_id})
-
+
if assignment_data:
return UserTierAssignment(**assignment_data)
-
+
return None
-
- async def get_user_tier(self, user_id: str) -> Optional[Tier]:
+
+ async def get_user_tier(self, user_id: str) -> Tier | None:
"""
Get the tier for a user
-
+
Considers effective dates and returns appropriate tier.
-
+
Args:
user_id: User identifier
-
+
Returns:
Tier object or default tier if no assignment
"""
assignment = await self.get_user_assignment(user_id)
-
+
if assignment:
# Check if assignment is currently effective
now = datetime.now()
-
+
if assignment.effective_from and now < assignment.effective_from:
# Assignment not yet effective
return await self.get_default_tier()
-
+
if assignment.effective_until and now > assignment.effective_until:
# Assignment expired
return await self.get_default_tier()
-
+
# Get the assigned tier
return await self.get_tier(assignment.tier_id)
-
+
# No assignment, return default tier
return await self.get_default_tier()
-
- async def get_user_limits(self, user_id: str) -> Optional[TierLimits]:
+
+ async def get_user_limits(self, user_id: str) -> TierLimits | None:
"""
Get effective limits for a user
-
+
Considers tier limits and user-specific overrides.
-
+
Args:
user_id: User identifier
-
+
Returns:
TierLimits object or None
"""
assignment = await self.get_user_assignment(user_id)
-
+
if assignment and assignment.override_limits:
# User has custom limits
return assignment.override_limits
-
+
# Get tier limits
tier = await self.get_user_tier(user_id)
if tier:
return tier.limits
-
+
return None
-
+
async def remove_user_assignment(self, user_id: str) -> bool:
"""
Remove user's tier assignment
-
+
Args:
user_id: User identifier
-
+
Returns:
True if removed, False if no assignment
"""
result = await self.assignments_collection.delete_one({'user_id': user_id})
-
+
if result.deleted_count > 0:
- logger.info(f"Removed tier assignment for user {user_id}")
+ logger.info(f'Removed tier assignment for user {user_id}')
return True
-
+
return False
-
+
async def list_users_in_tier(
- self,
- tier_id: str,
- skip: int = 0,
- limit: int = 100
- ) -> List[UserTierAssignment]:
+ self, tier_id: str, skip: int = 0, limit: int = 100
+ ) -> list[UserTierAssignment]:
"""
List all users assigned to a tier
-
+
Args:
tier_id: Tier identifier
skip: Number of records to skip
limit: Maximum number of records to return
-
+
Returns:
List of user assignments
"""
- cursor = self.assignments_collection.find(
- {'tier_id': tier_id}
- ).skip(skip).limit(limit)
-
+ cursor = self.assignments_collection.find({'tier_id': tier_id}).skip(skip).limit(limit)
+
assignments = []
async for assignment_data in cursor:
assignments.append(UserTierAssignment(**assignment_data))
-
+
return assignments
-
+
# ========================================================================
# TIER UPGRADES & DOWNGRADES
# ========================================================================
-
+
async def upgrade_user_tier(
self,
user_id: str,
new_tier_id: str,
immediate: bool = True,
- scheduled_date: Optional[datetime] = None,
- assigned_by: Optional[str] = None
+ scheduled_date: datetime | None = None,
+ assigned_by: str | None = None,
) -> UserTierAssignment:
"""
Upgrade user to a higher tier
-
+
Args:
user_id: User identifier
new_tier_id: New tier identifier
immediate: Apply immediately or schedule
scheduled_date: When to apply (if not immediate)
assigned_by: Who initiated the upgrade
-
+
Returns:
Updated UserTierAssignment
"""
# Get current and new tiers
current_tier = await self.get_user_tier(user_id)
new_tier = await self.get_tier(new_tier_id)
-
+
if not new_tier:
- raise ValueError(f"Tier {new_tier_id} not found")
-
+ raise ValueError(f'Tier {new_tier_id} not found')
+
# Determine effective date
effective_from = datetime.now() if immediate else scheduled_date
-
+
# Create assignment
assignment = await self.assign_user_to_tier(
user_id=user_id,
tier_id=new_tier_id,
assigned_by=assigned_by,
effective_from=effective_from,
- notes=f"Upgraded from {current_tier.tier_id if current_tier else 'default'}"
+ notes=f'Upgraded from {current_tier.tier_id if current_tier else "default"}',
)
-
- logger.info(f"Upgraded user {user_id} to tier {new_tier_id}")
+
+ logger.info(f'Upgraded user {user_id} to tier {new_tier_id}')
return assignment
-
+
async def downgrade_user_tier(
self,
user_id: str,
new_tier_id: str,
grace_period_days: int = 0,
- assigned_by: Optional[str] = None
+ assigned_by: str | None = None,
) -> UserTierAssignment:
"""
Downgrade user to a lower tier
-
+
Args:
user_id: User identifier
new_tier_id: New tier identifier
grace_period_days: Days before downgrade takes effect
assigned_by: Who initiated the downgrade
-
+
Returns:
Updated UserTierAssignment
"""
# Get current and new tiers
current_tier = await self.get_user_tier(user_id)
new_tier = await self.get_tier(new_tier_id)
-
+
if not new_tier:
- raise ValueError(f"Tier {new_tier_id} not found")
-
+ raise ValueError(f'Tier {new_tier_id} not found')
+
# Calculate effective date with grace period
effective_from = datetime.now() + timedelta(days=grace_period_days)
-
+
# Create assignment
assignment = await self.assign_user_to_tier(
user_id=user_id,
tier_id=new_tier_id,
assigned_by=assigned_by,
effective_from=effective_from,
- notes=f"Downgraded from {current_tier.tier_id if current_tier else 'default'} with {grace_period_days} day grace period"
+ notes=f'Downgraded from {current_tier.tier_id if current_tier else "default"} with {grace_period_days} day grace period',
+ )
+
+ logger.info(
+ f'Scheduled downgrade for user {user_id} to tier {new_tier_id} on {effective_from}'
)
-
- logger.info(f"Scheduled downgrade for user {user_id} to tier {new_tier_id} on {effective_from}")
return assignment
-
+
async def temporary_tier_upgrade(
- self,
- user_id: str,
- temp_tier_id: str,
- duration_days: int,
- assigned_by: Optional[str] = None
+ self, user_id: str, temp_tier_id: str, duration_days: int, assigned_by: str | None = None
) -> UserTierAssignment:
"""
Temporarily upgrade user to a higher tier
-
+
Args:
user_id: User identifier
temp_tier_id: Temporary tier identifier
duration_days: How many days the upgrade lasts
assigned_by: Who initiated the upgrade
-
+
Returns:
UserTierAssignment with expiration
"""
temp_tier = await self.get_tier(temp_tier_id)
if not temp_tier:
- raise ValueError(f"Tier {temp_tier_id} not found")
-
+ raise ValueError(f'Tier {temp_tier_id} not found')
+
# Set expiration
effective_from = datetime.now()
effective_until = effective_from + timedelta(days=duration_days)
-
+
# Create temporary assignment
assignment = await self.assign_user_to_tier(
user_id=user_id,
@@ -518,112 +504,112 @@ class TierService:
assigned_by=assigned_by,
effective_from=effective_from,
effective_until=effective_until,
- notes=f"Temporary upgrade for {duration_days} days"
+ notes=f'Temporary upgrade for {duration_days} days',
+ )
+
+ logger.info(
+ f'Temporary upgrade for user {user_id} to tier {temp_tier_id} until {effective_until}'
)
-
- logger.info(f"Temporary upgrade for user {user_id} to tier {temp_tier_id} until {effective_until}")
return assignment
-
+
# ========================================================================
# TIER COMPARISON & ANALYTICS
# ========================================================================
-
- async def compare_tiers(self, tier_ids: List[str]) -> List[Dict[str, Any]]:
+
+ async def compare_tiers(self, tier_ids: list[str]) -> list[dict[str, Any]]:
"""
Compare multiple tiers side-by-side
-
+
Args:
tier_ids: List of tier identifiers to compare
-
+
Returns:
List of tier comparison data
"""
comparison = []
-
+
for tier_id in tier_ids:
tier = await self.get_tier(tier_id)
if tier:
- comparison.append({
- 'tier_id': tier.tier_id,
- 'name': tier.name.value,
- 'display_name': tier.display_name,
- 'limits': tier.limits.to_dict(),
- 'price_monthly': tier.price_monthly,
- 'price_yearly': tier.price_yearly,
- 'features': tier.features
- })
-
+ comparison.append(
+ {
+ 'tier_id': tier.tier_id,
+ 'name': tier.name.value,
+ 'display_name': tier.display_name,
+ 'limits': tier.limits.to_dict(),
+ 'price_monthly': tier.price_monthly,
+ 'price_yearly': tier.price_yearly,
+ 'features': tier.features,
+ }
+ )
+
return comparison
-
- async def get_tier_statistics(self, tier_id: str) -> Dict[str, Any]:
+
+ async def get_tier_statistics(self, tier_id: str) -> dict[str, Any]:
"""
Get statistics for a tier
-
+
Args:
tier_id: Tier identifier
-
+
Returns:
Dictionary with tier statistics
"""
# Count users in tier
user_count = await self.assignments_collection.count_documents({'tier_id': tier_id})
-
+
# Count active assignments (within effective dates)
now = datetime.now()
- active_count = await self.assignments_collection.count_documents({
- 'tier_id': tier_id,
- '$or': [
- {'effective_from': None},
- {'effective_from': {'$lte': now}}
- ],
- '$or': [
- {'effective_until': None},
- {'effective_until': {'$gte': now}}
- ]
- })
-
+ active_count = await self.assignments_collection.count_documents(
+ {
+ 'tier_id': tier_id,
+ '$or': [{'effective_from': None}, {'effective_from': {'$lte': now}}],
+ '$or': [{'effective_until': None}, {'effective_until': {'$gte': now}}],
+ }
+ )
+
return {
'tier_id': tier_id,
'total_users': user_count,
'active_users': active_count,
- 'inactive_users': user_count - active_count
+ 'inactive_users': user_count - active_count,
}
-
- async def get_all_tier_statistics(self) -> List[Dict[str, Any]]:
+
+ async def get_all_tier_statistics(self) -> list[dict[str, Any]]:
"""
Get statistics for all tiers
-
+
Returns:
List of tier statistics
"""
tiers = await self.list_tiers()
statistics = []
-
+
for tier in tiers:
stats = await self.get_tier_statistics(tier.tier_id)
stats['tier_name'] = tier.display_name
statistics.append(stats)
-
+
return statistics
# Global tier service instance
-_tier_service: Optional[TierService] = None
+_tier_service: TierService | None = None
def get_tier_service(db) -> TierService:
"""
Get or create global tier service instance
-
+
Args:
db: MongoDB database instance (sync) or InMemoryDB
-
+
Returns:
TierService instance
"""
global _tier_service
-
+
if _tier_service is None:
_tier_service = TierService(db)
-
+
return _tier_service
diff --git a/backend-services/services/user_service.py b/backend-services/services/user_service.py
index 1a5dabe..b3af7dd 100644
--- a/backend-services/services/user_service.py
+++ b/backend-services/services/user_service.py
@@ -4,40 +4,40 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import List
-from fastapi import HTTPException
import logging
-import asyncio
+import time
+from fastapi import HTTPException
+
+from models.create_user_model import CreateUserModel
from models.response_model import ResponseModel
from utils import password_util
-from utils.database_async import user_collection, subscriptions_collection, api_collection
-from utils.async_db import db_find_one, db_insert_one, db_update_one, db_delete_one, db_find_list
-from utils.doorman_cache_util import doorman_cache
-from models.create_user_model import CreateUserModel
-from utils.paging_util import validate_page_params
-from utils.constants import ErrorCodes, Messages
-from utils.role_util import platform_role_required_bool
+from utils.async_db import db_delete_one, db_find_list, db_find_one, db_insert_one, db_update_one
from utils.bandwidth_util import get_current_usage
-import time
+from utils.constants import ErrorCodes, Messages
+from utils.database_async import api_collection, subscriptions_collection, user_collection
+from utils.doorman_cache_util import doorman_cache
+from utils.paging_util import validate_page_params
+from utils.role_util import platform_role_required_bool
logger = logging.getLogger('doorman.gateway')
-class UserService:
+class UserService:
@staticmethod
- async def get_user_by_email_with_password_helper(email):
+ async def get_user_by_email_with_password_helper(email: str) -> dict:
"""
Retrieve a user by email.
"""
user = await db_find_one(user_collection, {'email': email})
- if user.get('_id'): del user['_id']
+ if user.get('_id'):
+ del user['_id']
if not user:
raise HTTPException(status_code=404, detail='User not found')
return user
@staticmethod
- async def get_user_by_username_helper(username):
+ async def get_user_by_username_helper(username: str) -> dict:
"""
Retrieve a user by username.
"""
@@ -47,17 +47,19 @@ class UserService:
user = await db_find_one(user_collection, {'username': username})
if not user:
raise HTTPException(status_code=404, detail='User not found')
- if user.get('_id'): del user['_id']
- if user.get('password'): del user['password']
+ if user.get('_id'):
+ del user['_id']
+ if user.get('password'):
+ del user['password']
doorman_cache.set_cache('user_cache', username, user)
if not user:
raise HTTPException(status_code=404, detail='User not found')
return user
- except Exception as e:
+ except Exception:
raise HTTPException(status_code=404, detail='User not found')
@staticmethod
- async def get_user_by_username(username, request_id):
+ async def get_user_by_username(username: str, request_id: str) -> dict:
"""
Retrieve a user by username.
"""
@@ -68,22 +70,20 @@ class UserService:
if not user:
logger.error(f'{request_id} | User retrieval failed with code USR002')
return ResponseModel(
- status_code=404,
- error_code='USR002',
- error_message='User not found'
+ status_code=404, error_code='USR002', error_message='User not found'
).dict()
- if user.get('_id'): del user['_id']
- if user.get('password'): del user['password']
+ if user.get('_id'):
+ del user['_id']
+ if user.get('password'):
+ del user['password']
doorman_cache.set_cache('user_cache', username, user)
if not user:
logger.error(f'{request_id} | User retrieval failed with code USR002')
return ResponseModel(
status_code=404,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='USR002',
- error_message='User not found'
+ error_message='User not found',
).dict()
try:
limit = user.get('bandwidth_limit_bytes')
@@ -108,13 +108,10 @@ class UserService:
except Exception:
pass
logger.info(f'{request_id} | User retrieval successful')
- return ResponseModel(
- status_code=200,
- response=user
- ).dict()
+ return ResponseModel(status_code=200, response=user).dict()
@staticmethod
- async def get_user_by_email(active_username, email, request_id):
+ async def get_user_by_email(active_username: str, email: str, request_id: str) -> dict:
"""
Retrieve a user by email.
"""
@@ -128,81 +125,72 @@ class UserService:
logger.error(f'{request_id} | User retrieval failed with code USR002')
return ResponseModel(
status_code=404,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='USR002',
- error_message='User not found'
+ error_message='User not found',
).dict()
logger.info(f'{request_id} | User retrieval successful')
- if not active_username == user.get('username') and not await platform_role_required_bool(active_username, 'manage_users'):
+ if not active_username == user.get('username') and not await platform_role_required_bool(
+ active_username, 'manage_users'
+ ):
logger.error(f'{request_id} | User retrieval failed with code USR008')
return ResponseModel(
- status_code=403,
- error_code='USR008',
- error_message='Unable to retrieve information for user',
- ).dict()
- return ResponseModel(
- status_code=200,
- response=user
- ).dict()
+ status_code=403,
+ error_code='USR008',
+ error_message='Unable to retrieve information for user',
+ ).dict()
+ return ResponseModel(status_code=200, response=user).dict()
@staticmethod
- async def create_user(data: CreateUserModel, request_id):
+ async def create_user(data: CreateUserModel, request_id: str) -> dict:
"""
Create a new user.
"""
logger.info(f'{request_id} | Creating user: {data.username}')
try:
if data.custom_attributes is not None and len(data.custom_attributes.keys()) > 10:
- logger.error(f"{request_id} | User creation failed with code USR016: Too many custom attributes")
+ logger.error(
+ f'{request_id} | User creation failed with code USR016: Too many custom attributes'
+ )
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='USR016',
- error_message='Maximum 10 custom attributes allowed. Please replace an existing one.'
+ error_message='Maximum 10 custom attributes allowed. Please replace an existing one.',
).dict()
except Exception:
- logger.error(f"{request_id} | User creation failed with code USR016: Invalid custom attributes payload")
+ logger.error(
+ f'{request_id} | User creation failed with code USR016: Invalid custom attributes payload'
+ )
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='USR016',
- error_message='Maximum 10 custom attributes allowed. Please replace an existing one.'
+ error_message='Maximum 10 custom attributes allowed. Please replace an existing one.',
).dict()
if await db_find_one(user_collection, {'username': data.username}):
logger.error(f'{request_id} | User creation failed with code USR001')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='USR001',
- error_message='Username already exists'
+ error_message='Username already exists',
).dict()
if await db_find_one(user_collection, {'email': data.email}):
logger.error(f'{request_id} | User creation failed with code USR001')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='USR001',
- error_message='Email already exists'
+ error_message='Email already exists',
).dict()
if not password_util.is_secure_password(data.password):
logger.error(f'{request_id} | User creation failed with code USR005')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='USR005',
- error_message='Password must include at least 16 characters, one uppercase letter, one lowercase letter, one digit, and one special character'
+ error_message='Password must include at least 16 characters, one uppercase letter, one lowercase letter, one digit, and one special character',
).dict()
data.password = password_util.hash_password(data.password)
data_dict = data.dict()
@@ -215,14 +203,12 @@ class UserService:
logger.info(f'{request_id} | User creation successful')
return ResponseModel(
status_code=201,
- response_headers={
- 'request_id': request_id
- },
- message='User created successfully'
+ response_headers={'request_id': request_id},
+ message='User created successfully',
).dict()
@staticmethod
- async def check_password_return_user(email, password):
+ async def check_password_return_user(email: str, password: str) -> dict:
"""
Verify password and return user if valid.
"""
@@ -230,7 +216,6 @@ class UserService:
try:
user = await UserService.get_user_by_email_with_password_helper(email)
except Exception:
-
maybe_user = await db_find_one(user_collection, {'username': email})
if maybe_user:
user = maybe_user
@@ -239,11 +224,11 @@ class UserService:
if not password_util.verify_password(password, user.get('password')):
raise HTTPException(status_code=400, detail='Invalid email or password')
return user
- except Exception as e:
+ except Exception:
raise HTTPException(status_code=400, detail='Invalid email or password')
@staticmethod
- async def update_user(username, update_data, request_id):
+ async def update_user(username: str, update_data: dict, request_id: str) -> dict:
"""
Update user information.
"""
@@ -254,58 +239,63 @@ class UserService:
if not user:
logger.error(f'{request_id} | User update failed with code USR002')
return ResponseModel(
- status_code=404,
- error_code='USR002',
- error_message='User not found'
+ status_code=404, error_code='USR002', error_message='User not found'
).dict()
else:
doorman_cache.delete_cache('user_cache', username)
non_null_update_data = {k: v for k, v in update_data.dict().items() if v is not None}
if 'custom_attributes' in non_null_update_data:
try:
- if non_null_update_data['custom_attributes'] is not None and len(non_null_update_data['custom_attributes'].keys()) > 10:
- logger.error(f"{request_id} | User update failed with code USR016: Too many custom attributes")
+ if (
+ non_null_update_data['custom_attributes'] is not None
+ and len(non_null_update_data['custom_attributes'].keys()) > 10
+ ):
+ logger.error(
+ f'{request_id} | User update failed with code USR016: Too many custom attributes'
+ )
return ResponseModel(
status_code=400,
error_code='USR016',
- error_message='Maximum 10 custom attributes allowed. Please replace an existing one.'
+ error_message='Maximum 10 custom attributes allowed. Please replace an existing one.',
).dict()
except Exception:
- logger.error(f"{request_id} | User update failed with code USR016: Invalid custom attributes payload")
+ logger.error(
+ f'{request_id} | User update failed with code USR016: Invalid custom attributes payload'
+ )
return ResponseModel(
status_code=400,
error_code='USR016',
- error_message='Maximum 10 custom attributes allowed. Please replace an existing one.'
+ error_message='Maximum 10 custom attributes allowed. Please replace an existing one.',
).dict()
if non_null_update_data:
try:
- update_result = await db_update_one(user_collection, {'username': username}, {'$set': non_null_update_data})
+ update_result = await db_update_one(
+ user_collection, {'username': username}, {'$set': non_null_update_data}
+ )
if update_result.modified_count > 0:
doorman_cache.delete_cache('user_cache', username)
if not update_result.acknowledged or update_result.modified_count == 0:
logger.error(f'{request_id} | User update failed with code USR003')
return ResponseModel(
- status_code=400,
- error_code='USR004',
- error_message='Unable to update user'
+ status_code=400, error_code='USR004', error_message='Unable to update user'
).dict()
except Exception as e:
doorman_cache.delete_cache('user_cache', username)
- logger.error(f'{request_id} | User update failed with exception: {str(e)}', exc_info=True)
+ logger.error(
+ f'{request_id} | User update failed with exception: {str(e)}', exc_info=True
+ )
raise
if non_null_update_data.get('role'):
await UserService.purge_apis_after_role_change(username, request_id)
logger.info(f'{request_id} | User update successful')
return ResponseModel(
status_code=200,
- response_headers={
- 'request_id': request_id
- },
- message='User updated successfully'
+ response_headers={'request_id': request_id},
+ message='User updated successfully',
).dict()
@staticmethod
- async def delete_user(username, request_id):
+ async def delete_user(username: str, request_id: str) -> dict:
"""
Delete a user.
"""
@@ -316,34 +306,28 @@ class UserService:
if not user:
logger.error(f'{request_id} | User deletion failed with code USR002')
return ResponseModel(
- status_code=404,
- error_code='USR002',
- error_message='User not found'
+ status_code=404, error_code='USR002', error_message='User not found'
).dict()
delete_result = await db_delete_one(user_collection, {'username': username})
if not delete_result.acknowledged or delete_result.deleted_count == 0:
logger.error(f'{request_id} | User deletion failed with code USR003')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='USR003',
- error_message='Unable to delete user'
+ error_message='Unable to delete user',
).dict()
doorman_cache.delete_cache('user_cache', username)
doorman_cache.delete_cache('user_subscription_cache', username)
logger.info(f'{request_id} | User deletion successful')
return ResponseModel(
status_code=200,
- response_headers={
- 'request_id': request_id
- },
- message='User deleted successfully'
+ response_headers={'request_id': request_id},
+ message='User deleted successfully',
).dict()
@staticmethod
- async def update_password(username, update_data, request_id):
+ async def update_password(username: str, update_data: dict, request_id: str) -> dict:
"""
Update user information.
"""
@@ -352,31 +336,32 @@ class UserService:
logger.error(f'{request_id} | User password update failed with code USR005')
return ResponseModel(
status_code=400,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='USR005',
- error_message='Password must include at least 16 characters, one uppercase letter, one lowercase letter, one digit, and one special character'
+ error_message='Password must include at least 16 characters, one uppercase letter, one lowercase letter, one digit, and one special character',
).dict()
hashed_password = password_util.hash_password(update_data.new_password)
try:
- update_result = await db_update_one(user_collection, {'username': username}, {'$set': {'password': hashed_password}})
+ update_result = await db_update_one(
+ user_collection, {'username': username}, {'$set': {'password': hashed_password}}
+ )
if update_result.modified_count > 0:
doorman_cache.delete_cache('user_cache', username)
except Exception as e:
doorman_cache.delete_cache('user_cache', username)
- logger.error(f'{request_id} | User password update failed with exception: {str(e)}', exc_info=True)
+ logger.error(
+ f'{request_id} | User password update failed with exception: {str(e)}',
+ exc_info=True,
+ )
raise
user = await db_find_one(user_collection, {'username': username})
if not user:
logger.error(f'{request_id} | User password update failed with code USR002')
return ResponseModel(
status_code=404,
- response_headers={
- 'request_id': request_id
- },
+ response_headers={'request_id': request_id},
error_code='USR002',
- error_message='User not found'
+ error_message='User not found',
).dict()
if '_id' in user:
del user['_id']
@@ -386,42 +371,52 @@ class UserService:
logger.info(f'{request_id} | User password update successful')
return ResponseModel(
status_code=200,
- response_headers={
- 'request_id': request_id
- },
- message='User updated successfully'
+ response_headers={'request_id': request_id},
+ message='User updated successfully',
).dict()
@staticmethod
- async def purge_apis_after_role_change(username, request_id):
+ async def purge_apis_after_role_change(username: str, request_id: str) -> None:
"""
Remove subscriptions after role change.
"""
logger.info(f'{request_id} | Purging APIs for user: {username}')
- user_subscriptions = doorman_cache.get_cache('user_subscription_cache', username) or await db_find_one(subscriptions_collection, {'username': username})
+ user_subscriptions = doorman_cache.get_cache(
+ 'user_subscription_cache', username
+ ) or await db_find_one(subscriptions_collection, {'username': username})
if user_subscriptions:
for subscription in user_subscriptions.get('apis'):
api_name, api_version = subscription.split('/')
- user = doorman_cache.get_cache('user_cache', username) or await db_find_one(user_collection, {'username': username})
- api = doorman_cache.get_cache('api_cache', f'{api_name}/{api_version}') or await db_find_one(api_collection, {'api_name': api_name, 'api_version': api_version})
+ user = doorman_cache.get_cache('user_cache', username) or await db_find_one(
+ user_collection, {'username': username}
+ )
+ api = doorman_cache.get_cache(
+ 'api_cache', f'{api_name}/{api_version}'
+ ) or await db_find_one(
+ api_collection, {'api_name': api_name, 'api_version': api_version}
+ )
if api and api.get('role') and user.get('role') not in api.get('role'):
user_subscriptions['apis'].remove(subscription)
try:
- update_result = await db_update_one(subscriptions_collection,
+ update_result = await db_update_one(
+ subscriptions_collection,
{'username': username},
- {'$set': {'apis': user_subscriptions.get('apis', [])}}
+ {'$set': {'apis': user_subscriptions.get('apis', [])}},
)
if update_result.modified_count > 0:
doorman_cache.delete_cache('user_subscription_cache', username)
doorman_cache.set_cache('user_subscription_cache', username, user_subscriptions)
except Exception as e:
doorman_cache.delete_cache('user_subscription_cache', username)
- logger.error(f'{request_id} | Subscription update failed with exception: {str(e)}', exc_info=True)
+ logger.error(
+ f'{request_id} | Subscription update failed with exception: {str(e)}',
+ exc_info=True,
+ )
raise
logger.info(f'{request_id} | Purge successful')
@staticmethod
- async def get_all_users(page, page_size, request_id):
+ async def get_all_users(page: int, page_size: int, request_id: str) -> dict:
"""
Get all users.
"""
@@ -432,20 +427,21 @@ class UserService:
return ResponseModel(
status_code=400,
error_code=ErrorCodes.PAGE_SIZE,
- error_message=(Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING)
+ error_message=(
+ Messages.PAGE_TOO_LARGE if 'page_size' in str(e) else Messages.INVALID_PAGING
+ ),
).dict()
skip = (page - 1) * page_size
users_all = await db_find_list(user_collection, {})
users_all.sort(key=lambda u: u.get('username'))
- users = users_all[skip: skip + page_size]
+ users = users_all[skip : skip + page_size]
for user in users:
- if user.get('_id'): del user['_id']
- if user.get('password'): del user['password']
+ if user.get('_id'):
+ del user['_id']
+ if user.get('password'):
+ del user['password']
for key, value in user.items():
if isinstance(value, bytes):
user[key] = value.decode('utf-8')
logger.info(f'{request_id} | User retrieval successful')
- return ResponseModel(
- status_code=200,
- response={'users': users}
- ).dict()
+ return ResponseModel(status_code=200, response={'users': users}).dict()
diff --git a/backend-services/services/vault_service.py b/backend-services/services/vault_service.py
index 9abdc11..cc81430 100644
--- a/backend-services/services/vault_service.py
+++ b/backend-services/services/vault_service.py
@@ -4,25 +4,21 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import List, Optional
-from fastapi import HTTPException
import logging
from datetime import datetime
try:
from datetime import UTC
except Exception:
- from datetime import timezone as _timezone
- UTC = _timezone.utc
+ UTC = UTC
-from models.response_model import ResponseModel
from models.create_vault_entry_model import CreateVaultEntryModel
+from models.response_model import ResponseModel
from models.update_vault_entry_model import UpdateVaultEntryModel
-from models.vault_entry_model_response import VaultEntryModelResponse
-from utils.database_async import vault_entries_collection, user_collection
-from utils.async_db import db_find_one, db_insert_one, db_update_one, db_delete_one, db_find_list
-from utils.vault_encryption_util import encrypt_vault_value, is_vault_configured
+from utils.async_db import db_delete_one, db_find_list, db_find_one, db_insert_one, db_update_one
from utils.constants import ErrorCodes, Messages
+from utils.database_async import user_collection, vault_entries_collection
+from utils.vault_encryption_util import encrypt_vault_value, is_vault_configured
logger = logging.getLogger('doorman.gateway')
@@ -31,61 +27,62 @@ class VaultService:
"""Service for managing encrypted vault entries."""
@staticmethod
- async def create_vault_entry(username: str, entry_data: CreateVaultEntryModel, request_id: str) -> dict:
+ async def create_vault_entry(
+ username: str, entry_data: CreateVaultEntryModel, request_id: str
+ ) -> dict:
"""
Create a new vault entry for a user.
-
+
Args:
username: Username of the vault entry owner
entry_data: Vault entry creation data
request_id: Request ID for logging
-
+
Returns:
ResponseModel dict with success or error
"""
- logger.info(f'{request_id} | Creating vault entry: {entry_data.key_name} for user: {username}')
-
+ logger.info(
+ f'{request_id} | Creating vault entry: {entry_data.key_name} for user: {username}'
+ )
+
# Check if VAULT_KEY is configured
if not is_vault_configured():
logger.error(f'{request_id} | VAULT_KEY not configured')
return ResponseModel(
status_code=500,
error_code='VAULT001',
- error_message='Vault encryption is not configured. Set VAULT_KEY in environment variables.'
+ error_message='Vault encryption is not configured. Set VAULT_KEY in environment variables.',
).dict()
-
+
# Get user to retrieve email
user = await db_find_one(user_collection, {'username': username})
if not user:
logger.error(f'{request_id} | User not found: {username}')
return ResponseModel(
- status_code=404,
- error_code='VAULT002',
- error_message='User not found'
+ status_code=404, error_code='VAULT002', error_message='User not found'
).dict()
-
+
email = user.get('email')
if not email:
logger.error(f'{request_id} | User email not found: {username}')
return ResponseModel(
status_code=400,
error_code='VAULT003',
- error_message='User email is required for vault encryption'
+ error_message='User email is required for vault encryption',
).dict()
-
+
# Check if entry already exists
existing = await db_find_one(
- vault_entries_collection,
- {'username': username, 'key_name': entry_data.key_name}
+ vault_entries_collection, {'username': username, 'key_name': entry_data.key_name}
)
if existing:
logger.error(f'{request_id} | Vault entry already exists: {entry_data.key_name}')
return ResponseModel(
status_code=409,
error_code='VAULT004',
- error_message=f'Vault entry with key_name "{entry_data.key_name}" already exists'
+ error_message=f'Vault entry with key_name "{entry_data.key_name}" already exists',
).dict()
-
+
# Encrypt the value
try:
encrypted_value = encrypt_vault_value(entry_data.value, email, username)
@@ -94,9 +91,9 @@ class VaultService:
return ResponseModel(
status_code=500,
error_code='VAULT005',
- error_message='Failed to encrypt vault value'
+ error_message='Failed to encrypt vault value',
).dict()
-
+
# Create vault entry
now = datetime.now(UTC).isoformat()
vault_entry = {
@@ -105,89 +102,83 @@ class VaultService:
'encrypted_value': encrypted_value,
'description': entry_data.description,
'created_at': now,
- 'updated_at': now
+ 'updated_at': now,
}
-
+
try:
result = await db_insert_one(vault_entries_collection, vault_entry)
if result.acknowledged:
- logger.info(f'{request_id} | Vault entry created successfully: {entry_data.key_name}')
+ logger.info(
+ f'{request_id} | Vault entry created successfully: {entry_data.key_name}'
+ )
return ResponseModel(
status_code=201,
message='Vault entry created successfully',
- data={'key_name': entry_data.key_name}
+ data={'key_name': entry_data.key_name},
).dict()
else:
logger.error(f'{request_id} | Failed to create vault entry')
return ResponseModel(
status_code=500,
error_code='VAULT006',
- error_message='Failed to create vault entry'
+ error_message='Failed to create vault entry',
).dict()
except Exception as e:
logger.error(f'{request_id} | Error creating vault entry: {str(e)}')
return ResponseModel(
- status_code=500,
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
+ status_code=500, error_code=ErrorCodes.UNEXPECTED, error_message=Messages.UNEXPECTED
).dict()
@staticmethod
async def get_vault_entry(username: str, key_name: str, request_id: str) -> dict:
"""
Get a vault entry by key name. Value is NOT returned.
-
+
Args:
username: Username of the vault entry owner
key_name: Name of the vault key
request_id: Request ID for logging
-
+
Returns:
ResponseModel dict with vault entry (without value) or error
"""
logger.info(f'{request_id} | Getting vault entry: {key_name} for user: {username}')
-
+
entry = await db_find_one(
- vault_entries_collection,
- {'username': username, 'key_name': key_name}
+ vault_entries_collection, {'username': username, 'key_name': key_name}
)
-
+
if not entry:
logger.error(f'{request_id} | Vault entry not found: {key_name}')
return ResponseModel(
- status_code=404,
- error_code='VAULT007',
- error_message='Vault entry not found'
+ status_code=404, error_code='VAULT007', error_message='Vault entry not found'
).dict()
-
+
# Remove sensitive data
if entry.get('_id'):
del entry['_id']
if entry.get('encrypted_value'):
del entry['encrypted_value']
-
- return ResponseModel(
- status_code=200,
- data=entry
- ).dict()
+
+ return ResponseModel(status_code=200, data=entry).dict()
@staticmethod
async def list_vault_entries(username: str, request_id: str) -> dict:
"""
List all vault entries for a user. Values are NOT returned.
-
+
Args:
username: Username of the vault entry owner
request_id: Request ID for logging
-
+
Returns:
ResponseModel dict with list of vault entries (without values)
"""
logger.info(f'{request_id} | Listing vault entries for user: {username}')
-
+
try:
entries = await db_find_list(vault_entries_collection, {'username': username})
-
+
# Remove sensitive data from all entries
clean_entries = []
for entry in entries:
@@ -196,137 +187,120 @@ class VaultService:
if entry.get('encrypted_value'):
del entry['encrypted_value']
clean_entries.append(entry)
-
+
return ResponseModel(
- status_code=200,
- data={'entries': clean_entries, 'count': len(clean_entries)}
+ status_code=200, data={'entries': clean_entries, 'count': len(clean_entries)}
).dict()
except Exception as e:
logger.error(f'{request_id} | Error listing vault entries: {str(e)}')
return ResponseModel(
- status_code=500,
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
+ status_code=500, error_code=ErrorCodes.UNEXPECTED, error_message=Messages.UNEXPECTED
).dict()
@staticmethod
- async def update_vault_entry(username: str, key_name: str, update_data: UpdateVaultEntryModel, request_id: str) -> dict:
+ async def update_vault_entry(
+ username: str, key_name: str, update_data: UpdateVaultEntryModel, request_id: str
+ ) -> dict:
"""
Update a vault entry. Only description can be updated, not the value.
-
+
Args:
username: Username of the vault entry owner
key_name: Name of the vault key
update_data: Update data (description only)
request_id: Request ID for logging
-
+
Returns:
ResponseModel dict with success or error
"""
logger.info(f'{request_id} | Updating vault entry: {key_name} for user: {username}')
-
+
# Check if entry exists
entry = await db_find_one(
- vault_entries_collection,
- {'username': username, 'key_name': key_name}
+ vault_entries_collection, {'username': username, 'key_name': key_name}
)
-
+
if not entry:
logger.error(f'{request_id} | Vault entry not found: {key_name}')
return ResponseModel(
- status_code=404,
- error_code='VAULT007',
- error_message='Vault entry not found'
+ status_code=404, error_code='VAULT007', error_message='Vault entry not found'
).dict()
-
+
# Update only description
now = datetime.now(UTC).isoformat()
- update_fields = {
- 'updated_at': now
- }
-
+ update_fields = {'updated_at': now}
+
if update_data.description is not None:
update_fields['description'] = update_data.description
-
+
try:
result = await db_update_one(
vault_entries_collection,
{'username': username, 'key_name': key_name},
- {'$set': update_fields}
+ {'$set': update_fields},
)
-
+
if result.modified_count > 0:
logger.info(f'{request_id} | Vault entry updated successfully: {key_name}')
return ResponseModel(
- status_code=200,
- message='Vault entry updated successfully'
+ status_code=200, message='Vault entry updated successfully'
).dict()
else:
logger.warning(f'{request_id} | No changes made to vault entry: {key_name}')
return ResponseModel(
- status_code=200,
- message='No changes made to vault entry'
+ status_code=200, message='No changes made to vault entry'
).dict()
except Exception as e:
logger.error(f'{request_id} | Error updating vault entry: {str(e)}')
return ResponseModel(
- status_code=500,
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
+ status_code=500, error_code=ErrorCodes.UNEXPECTED, error_message=Messages.UNEXPECTED
).dict()
@staticmethod
async def delete_vault_entry(username: str, key_name: str, request_id: str) -> dict:
"""
Delete a vault entry.
-
+
Args:
username: Username of the vault entry owner
key_name: Name of the vault key
request_id: Request ID for logging
-
+
Returns:
ResponseModel dict with success or error
"""
logger.info(f'{request_id} | Deleting vault entry: {key_name} for user: {username}')
-
+
# Check if entry exists
entry = await db_find_one(
- vault_entries_collection,
- {'username': username, 'key_name': key_name}
+ vault_entries_collection, {'username': username, 'key_name': key_name}
)
-
+
if not entry:
logger.error(f'{request_id} | Vault entry not found: {key_name}')
return ResponseModel(
- status_code=404,
- error_code='VAULT007',
- error_message='Vault entry not found'
+ status_code=404, error_code='VAULT007', error_message='Vault entry not found'
).dict()
-
+
try:
result = await db_delete_one(
- vault_entries_collection,
- {'username': username, 'key_name': key_name}
+ vault_entries_collection, {'username': username, 'key_name': key_name}
)
-
+
if result.deleted_count > 0:
logger.info(f'{request_id} | Vault entry deleted successfully: {key_name}')
return ResponseModel(
- status_code=200,
- message='Vault entry deleted successfully'
+ status_code=200, message='Vault entry deleted successfully'
).dict()
else:
logger.error(f'{request_id} | Failed to delete vault entry: {key_name}')
return ResponseModel(
status_code=500,
error_code='VAULT008',
- error_message='Failed to delete vault entry'
+ error_message='Failed to delete vault entry',
).dict()
except Exception as e:
logger.error(f'{request_id} | Error deleting vault entry: {str(e)}')
return ResponseModel(
- status_code=500,
- error_code=ErrorCodes.UNEXPECTED,
- error_message=Messages.UNEXPECTED
+ status_code=500, error_code=ErrorCodes.UNEXPECTED, error_message=Messages.UNEXPECTED
).dict()
diff --git a/backend-services/tests/conftest.py b/backend-services/tests/conftest.py
index 2747bdf..77475ad 100644
--- a/backend-services/tests/conftest.py
+++ b/backend-services/tests/conftest.py
@@ -22,6 +22,7 @@ os.environ.setdefault('DOORMAN_TEST_MODE', 'true')
try:
import sys as _sys
+
if _sys.version_info >= (3, 13):
os.environ.setdefault('DISABLE_PLATFORM_CHUNKED_WRAP', 'true')
os.environ.setdefault('DISABLE_PLATFORM_CORS_ASGI', 'true')
@@ -34,19 +35,23 @@ _PROJECT_ROOT = os.path.abspath(os.path.join(_HERE, os.pardir))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
+import asyncio
+import datetime as _dt
+
+import pytest
import pytest_asyncio
from httpx import AsyncClient
-import pytest
-import asyncio
-from typing import Optional
-import datetime as _dt
try:
from utils.database import database as _db
- _INITIAL_DB_SNAPSHOT: Optional[dict] = _db.db.dump_data() if getattr(_db, 'memory_only', True) else None
+
+ _INITIAL_DB_SNAPSHOT: dict | None = (
+ _db.db.dump_data() if getattr(_db, 'memory_only', True) else None
+ )
except Exception:
_INITIAL_DB_SNAPSHOT = None
+
@pytest_asyncio.fixture(autouse=True)
async def ensure_memory_dump_defaults(monkeypatch, tmp_path):
"""Ensure sane defaults for memory dump/restore tests.
@@ -58,11 +63,15 @@ async def ensure_memory_dump_defaults(monkeypatch, tmp_path):
"""
try:
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
- monkeypatch.setenv('MEM_ENCRYPTION_KEY', os.environ.get('MEM_ENCRYPTION_KEY') or 'test-encryption-key-32-characters-min')
+ monkeypatch.setenv(
+ 'MEM_ENCRYPTION_KEY',
+ os.environ.get('MEM_ENCRYPTION_KEY') or 'test-encryption-key-32-characters-min',
+ )
dump_base = tmp_path / 'mem' / 'memory_dump.bin'
monkeypatch.setenv('MEM_DUMP_PATH', str(dump_base))
try:
import utils.memory_dump_util as md
+
md.DEFAULT_DUMP_PATH = str(dump_base)
except Exception:
pass
@@ -70,20 +79,22 @@ async def ensure_memory_dump_defaults(monkeypatch, tmp_path):
pass
yield
+
@pytest.fixture(autouse=True)
def _log_test_start_end(request):
try:
ts = _dt.datetime.now().strftime('%H:%M:%S.%f')[:-3]
- print(f"=== [{ts}] START {request.node.nodeid}", flush=True)
+ print(f'=== [{ts}] START {request.node.nodeid}', flush=True)
except Exception:
pass
yield
try:
ts = _dt.datetime.now().strftime('%H:%M:%S.%f')[:-3]
- print(f"=== [{ts}] END {request.node.nodeid}", flush=True)
+ print(f'=== [{ts}] END {request.node.nodeid}', flush=True)
except Exception:
pass
+
@pytest.fixture(autouse=True, scope='session')
def _log_env_toggles():
try:
@@ -92,29 +103,43 @@ def _log_env_toggles():
'DISABLE_PLATFORM_CORS_ASGI': os.getenv('DISABLE_PLATFORM_CORS_ASGI'),
'DISABLE_BODY_SIZE_LIMIT': os.getenv('DISABLE_BODY_SIZE_LIMIT'),
'DOORMAN_TEST_MODE': os.getenv('DOORMAN_TEST_MODE'),
- 'PYTHON_VERSION': f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
+ 'PYTHON_VERSION': f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}',
}
- print(f"=== ENV TOGGLES: {toggles}", flush=True)
+ print(f'=== ENV TOGGLES: {toggles}', flush=True)
except Exception:
pass
yield
+
@pytest_asyncio.fixture
async def authed_client():
+ # Ensure chaos is disabled so login/cache work reliably across tests
+ try:
+ from utils import chaos_util as _cu
+ _cu.enable('redis', False)
+ _cu.enable('mongo', False)
+ except Exception:
+ pass
from doorman import doorman
+
client = AsyncClient(app=doorman, base_url='http://testserver')
r = await client.post(
'/platform/authorization',
- json={'email': os.environ.get('DOORMAN_ADMIN_EMAIL'), 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD')},
+ json={
+ 'email': os.environ.get('DOORMAN_ADMIN_EMAIL'),
+ 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD'),
+ },
)
assert r.status_code == 200, r.text
try:
has_cookie = any(c.name == 'access_token_cookie' for c in client.cookies.jar)
if not has_cookie:
- body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
+ body = (
+ r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
+ )
token = body.get('access_token')
if token:
client.cookies.set(
@@ -126,43 +151,52 @@ async def authed_client():
except Exception:
pass
try:
- await client.put('/platform/user/admin', json={
- 'bandwidth_limit_bytes': 0,
- 'bandwidth_limit_window': 'day',
- 'rate_limit_duration': 1000000,
- 'rate_limit_duration_type': 'second',
- 'throttle_duration': 1000000,
- 'throttle_duration_type': 'second',
- 'throttle_queue_limit': 1000000,
- 'throttle_wait_duration': 0,
- 'throttle_wait_duration_type': 'second'
- })
+ await client.put(
+ '/platform/user/admin',
+ json={
+ 'bandwidth_limit_bytes': 0,
+ 'bandwidth_limit_window': 'day',
+ 'rate_limit_duration': 1000000,
+ 'rate_limit_duration_type': 'second',
+ 'throttle_duration': 1000000,
+ 'throttle_duration_type': 'second',
+ 'throttle_queue_limit': 1000000,
+ 'throttle_wait_duration': 0,
+ 'throttle_wait_duration_type': 'second',
+ },
+ )
except Exception:
pass
return client
+
@pytest.fixture
def client():
from doorman import doorman
+
return AsyncClient(app=doorman, base_url='http://testserver')
+
@pytest.fixture
def event_loop():
loop = asyncio.new_event_loop()
yield loop
loop.close()
+
@pytest_asyncio.fixture(autouse=True)
async def reset_http_client():
"""Reset the pooled httpx client between tests to prevent connection pool exhaustion."""
try:
from services.gateway_service import GatewayService
+
await GatewayService.aclose_http_client()
except Exception:
pass
try:
from utils.limit_throttle_util import reset_counters
+
reset_counters()
except Exception:
pass
@@ -170,10 +204,12 @@ async def reset_http_client():
yield
try:
from services.gateway_service import GatewayService
+
await GatewayService.aclose_http_client()
except Exception:
pass
+
@pytest_asyncio.fixture(autouse=True, scope='module')
async def reset_in_memory_db_state():
"""Restore in-memory DB and caches before each test to ensure isolation.
@@ -184,23 +220,29 @@ async def reset_in_memory_db_state():
try:
if _INITIAL_DB_SNAPSHOT is not None:
from utils.database import database as _db
+
_db.db.load_data(_INITIAL_DB_SNAPSHOT)
try:
- from utils.database import user_collection
from utils import password_util as _pw
+ from utils.database import user_collection
+
pwd = os.environ.get('DOORMAN_ADMIN_PASSWORD') or 'test-only-password-12chars'
- user_collection.update_one({'username': 'admin'}, {'$set': {'password': _pw.hash_password(pwd)}})
+ user_collection.update_one(
+ {'username': 'admin'}, {'$set': {'password': _pw.hash_password(pwd)}}
+ )
except Exception:
pass
except Exception:
pass
try:
from utils.doorman_cache_util import doorman_cache
+
doorman_cache.clear_all()
except Exception:
pass
yield
+
async def create_api(client: AsyncClient, api_name: str, api_version: str):
payload = {
'api_name': api_name,
@@ -216,7 +258,10 @@ async def create_api(client: AsyncClient, api_name: str, api_version: str):
assert r.status_code in (200, 201), r.text
return r
-async def create_endpoint(client: AsyncClient, api_name: str, api_version: str, method: str, uri: str):
+
+async def create_endpoint(
+ client: AsyncClient, api_name: str, api_version: str, method: str, uri: str
+):
payload = {
'api_name': api_name,
'api_version': api_version,
@@ -228,10 +273,10 @@ async def create_endpoint(client: AsyncClient, api_name: str, api_version: str,
assert r.status_code in (200, 201), r.text
return r
-async def subscribe_self(client: AsyncClient, api_name: str, api_version: str):
+async def subscribe_self(client: AsyncClient, api_name: str, api_version: str):
r_me = await client.get('/platform/user/me')
- username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin')
+ username = r_me.json().get('username') if r_me.status_code == 200 else 'admin'
r = await client.post(
'/platform/subscription/subscribe',
json={'username': username, 'api_name': api_name, 'api_version': api_version},
diff --git a/backend-services/tests/sitecustomize.py b/backend-services/tests/sitecustomize.py
index 42d4cf2..133e21c 100644
--- a/backend-services/tests/sitecustomize.py
+++ b/backend-services/tests/sitecustomize.py
@@ -2,6 +2,7 @@ import logging
import re
import sys
+
class _RedactFilter(logging.Filter):
PATTERNS = [
re.compile(r'(?i)(authorization\s*[:=]\s*)([^;\r\n]+)'),
@@ -17,13 +18,21 @@ class _RedactFilter(logging.Filter):
msg = str(record.getMessage())
red = msg
for pat in self.PATTERNS:
- red = pat.sub(lambda m: (m.group(1) + '[REDACTED]' + (m.group(3) if m.lastindex and m.lastindex >= 3 else '')), red)
+ red = pat.sub(
+ lambda m: (
+ m.group(1)
+ + '[REDACTED]'
+ + (m.group(3) if m.lastindex and m.lastindex >= 3 else '')
+ ),
+ red,
+ )
if red != msg:
record.msg = red
except Exception:
pass
return True
+
def _ensure_logger(name: str):
logger = logging.getLogger(name)
for h in logger.handlers:
@@ -34,9 +43,9 @@ def _ensure_logger(name: str):
h.addFilter(_RedactFilter())
logger.addHandler(h)
+
try:
_ensure_logger('doorman.gateway')
_ensure_logger('doorman.logging')
except Exception:
pass
-
diff --git a/backend-services/tests/test_access_control.py b/backend-services/tests/test_access_control.py
index 816c98f..47baa1e 100644
--- a/backend-services/tests/test_access_control.py
+++ b/backend-services/tests/test_access_control.py
@@ -1,22 +1,24 @@
-import os
import pytest
import pytest_asyncio
from httpx import AsyncClient
+
@pytest_asyncio.fixture
-async def login_client() :
+async def login_client():
async def _login(username: str, password: str, email: str = None) -> AsyncClient:
from doorman import doorman
+
client = AsyncClient(app=doorman, base_url='http://testserver')
cred = {'email': email or f'{username}@example.com', 'password': password}
r = await client.post('/platform/authorization', json=cred)
assert r.status_code == 200, r.text
return client
+
return _login
+
@pytest.mark.asyncio
async def test_roles_and_permissions_enforced(authed_client, login_client):
-
cr = await authed_client.post(
'/platform/role',
json={
@@ -78,9 +80,9 @@ async def test_roles_and_permissions_enforced(authed_client, login_client):
await viewer_client.aclose()
+
@pytest.mark.asyncio
async def test_group_and_subscription_enforcement(login_client, authed_client, monkeypatch):
-
c = await authed_client.post(
'/platform/api',
json={
@@ -140,16 +142,23 @@ async def test_group_and_subscription_enforcement(login_client, authed_client, m
assert s2.status_code in (401, 403)
import services.gateway_service as gs
+
class _FakeHTTPResponse:
def __init__(self, status_code=200, json_body=None):
self.status_code = status_code
self._json_body = json_body or {'ok': True}
self.headers = {'Content-Type': 'application/json'}
+
def json(self):
return self._json_body
+
class _FakeAsyncClient:
- async def __aenter__(self): return self
- async def __aexit__(self, exc_type, exc, tb): return False
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -167,13 +176,25 @@ async def test_group_and_subscription_enforcement(login_client, authed_client, m
return await self.put(url, **kwargs)
else:
return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'})
- async def get(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse(200)
- async def post(self, url, **kwargs): return _FakeHTTPResponse(200)
- async def put(self, url, **kwargs): return _FakeHTTPResponse(200)
- async def delete(self, url, **kwargs): return _FakeHTTPResponse(200)
+
+ async def get(self, url, params=None, headers=None, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def post(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def put(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def delete(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
import routes.gateway_routes as gr
- async def _no_limit(req): return None
+
+ async def _no_limit(req):
+ return None
+
monkeypatch.setattr(gr, 'limit_and_throttle', _no_limit)
viewer1_client = await login_client('viewer1', 'StrongViewerPwd!1234')
diff --git a/backend-services/tests/test_admin_bootstrap_parity.py b/backend-services/tests/test_admin_bootstrap_parity.py
index d49a04f..5355949 100644
--- a/backend-services/tests/test_admin_bootstrap_parity.py
+++ b/backend-services/tests/test_admin_bootstrap_parity.py
@@ -1,6 +1,8 @@
import os
+
import pytest
+
@pytest.mark.asyncio
async def test_admin_seed_fields_memory_mode(monkeypatch):
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
@@ -8,9 +10,16 @@ async def test_admin_seed_fields_memory_mode(monkeypatch):
monkeypatch.setenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
from utils import database as dbmod
+
dbmod.database.initialize_collections()
- from utils.database import user_collection, role_collection, group_collection, _build_admin_seed_doc
+ from utils.database import (
+ _build_admin_seed_doc,
+ group_collection,
+ role_collection,
+ user_collection,
+ )
+
admin = user_collection.find_one({'username': 'admin'})
assert admin is not None, 'Admin user should be seeded'
@@ -20,22 +29,37 @@ async def test_admin_seed_fields_memory_mode(monkeypatch):
assert '_id' in doc_keys
from utils import password_util
- assert password_util.verify_password(os.environ['DOORMAN_ADMIN_PASSWORD'], admin.get('password'))
+
+ assert password_util.verify_password(
+ os.environ['DOORMAN_ADMIN_PASSWORD'], admin.get('password')
+ )
assert set(admin.get('groups') or []) >= {'ALL', 'admin'}
role = role_collection.find_one({'role_name': 'admin'})
assert role is not None
for cap in (
- 'manage_users','manage_apis','manage_endpoints','manage_groups','manage_roles',
- 'manage_routings','manage_gateway','manage_subscriptions','manage_credits','manage_auth','manage_security','view_logs'
+ 'manage_users',
+ 'manage_apis',
+ 'manage_endpoints',
+ 'manage_groups',
+ 'manage_roles',
+ 'manage_routings',
+ 'manage_gateway',
+ 'manage_subscriptions',
+ 'manage_credits',
+ 'manage_auth',
+ 'manage_security',
+ 'view_logs',
):
assert role.get(cap) is True, f'Missing admin capability: {cap}'
grp_admin = group_collection.find_one({'group_name': 'admin'})
grp_all = group_collection.find_one({'group_name': 'ALL'})
assert grp_admin is not None and grp_all is not None
+
def test_admin_seed_helper_is_canonical():
from utils.database import _build_admin_seed_doc
+
doc = _build_admin_seed_doc('a@b.c', 'hash')
assert doc['username'] == 'admin'
assert doc['role'] == 'admin'
@@ -49,4 +73,3 @@ def test_admin_seed_helper_is_canonical():
assert doc['throttle_wait_duration_type'] == 'second'
assert doc['throttle_queue_limit'] == 1
assert set(doc['groups']) == {'ALL', 'admin'}
-
diff --git a/backend-services/tests/test_admin_positive_paths.py b/backend-services/tests/test_admin_positive_paths.py
index aa6ae1a..56f9016 100644
--- a/backend-services/tests/test_admin_positive_paths.py
+++ b/backend-services/tests/test_admin_positive_paths.py
@@ -1,9 +1,10 @@
import uuid
+
import pytest
+
@pytest.mark.asyncio
async def test_admin_can_view_admin_role_and_user(authed_client):
-
r_role = await authed_client.get('/platform/role/admin')
assert r_role.status_code == 200, r_role.text
role = r_role.json()
@@ -12,7 +13,7 @@ async def test_admin_can_view_admin_role_and_user(authed_client):
r_roles = await authed_client.get('/platform/role/all?page=1&page_size=50')
assert r_roles.status_code == 200
roles = r_roles.json().get('roles') or r_roles.json().get('response', {}).get('roles') or []
- names = { (r.get('role_name') or '').lower() for r in roles }
+ names = {(r.get('role_name') or '').lower() for r in roles}
assert 'admin' in names
# Super admin user (username='admin') should be hidden from ALL users (ghost user)
@@ -23,13 +24,17 @@ async def test_admin_can_view_admin_role_and_user(authed_client):
r_users = await authed_client.get('/platform/user/all?page=1&page_size=100')
assert r_users.status_code == 200
users = r_users.json()
- user_list = users if isinstance(users, list) else (users.get('users') or users.get('response', {}).get('users') or [])
+ user_list = (
+ users
+ if isinstance(users, list)
+ else (users.get('users') or users.get('response', {}).get('users') or [])
+ )
usernames = {u.get('username') for u in user_list}
assert 'admin' not in usernames, 'Super admin should not appear in user list'
+
@pytest.mark.asyncio
async def test_admin_can_update_admin_role_description(authed_client):
-
desc = f'Administrator role ({uuid.uuid4().hex[:6]})'
up = await authed_client.put('/platform/role/admin', json={'role_description': desc})
assert up.status_code in (200, 201), up.text
@@ -39,6 +44,7 @@ async def test_admin_can_update_admin_role_description(authed_client):
body = r.json().get('response') or r.json()
assert body.get('role_description') == desc
+
@pytest.mark.asyncio
async def test_admin_can_create_and_delete_admin_user(authed_client):
uname = f'adm_{uuid.uuid4().hex[:8]}'
@@ -64,4 +70,3 @@ async def test_admin_can_create_and_delete_admin_user(authed_client):
r2 = await authed_client.get(f'/platform/user/{uname}')
assert r2.status_code in (404, 500)
-
diff --git a/backend-services/tests/test_api_active_and_patch.py b/backend-services/tests/test_api_active_and_patch.py
index 7b19c5f..e068141 100644
--- a/backend-services/tests/test_api_active_and_patch.py
+++ b/backend-services/tests/test_api_active_and_patch.py
@@ -1,20 +1,42 @@
import pytest
+
@pytest.mark.asyncio
async def test_api_disabled_rest_blocks(monkeypatch, authed_client):
async def _create_api(c, n, v):
- payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0}
+ payload = {
+ 'api_name': n,
+ 'api_version': v,
+ 'api_description': f'{n} {v}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://upstream'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ }
rr = await c.post('/platform/api', json=payload)
assert rr.status_code in (200, 201)
+
async def _create_endpoint(c, n, v, m, u):
- payload = {'api_name': n, 'api_version': v, 'endpoint_method': m, 'endpoint_uri': u, 'endpoint_description': f'{m} {u}'}
+ payload = {
+ 'api_name': n,
+ 'api_version': v,
+ 'endpoint_method': m,
+ 'endpoint_uri': u,
+ 'endpoint_description': f'{m} {u}',
+ }
rr = await c.post('/platform/endpoint', json=payload)
assert rr.status_code in (200, 201)
+
async def _subscribe_self(c, n, v):
r_me = await c.get('/platform/user/me')
- username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin')
- rr = await c.post('/platform/subscription/subscribe', json={'username': username, 'api_name': n, 'api_version': v})
+ username = r_me.json().get('username') if r_me.status_code == 200 else 'admin'
+ rr = await c.post(
+ '/platform/subscription/subscribe',
+ json={'username': username, 'api_name': n, 'api_version': v},
+ )
assert rr.status_code in (200, 201)
+
await _create_api(authed_client, 'disabled', 'v1')
r_upd = await authed_client.put('/platform/api/disabled/v1', json={'active': False})
@@ -24,99 +46,192 @@ async def test_api_disabled_rest_blocks(monkeypatch, authed_client):
r = await authed_client.get('/api/rest/disabled/v1/status')
assert r.status_code in (403, 500)
+
@pytest.mark.asyncio
async def test_api_disabled_graphql_blocks(authed_client):
async def _create_api(c, n, v):
- payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0}
+ payload = {
+ 'api_name': n,
+ 'api_version': v,
+ 'api_description': f'{n} {v}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://upstream'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ }
rr = await c.post('/platform/api', json=payload)
assert rr.status_code in (200, 201)
+
async def _subscribe_self(c, n, v):
r_me = await c.get('/platform/user/me')
- username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin')
- rr = await c.post('/platform/subscription/subscribe', json={'username': username, 'api_name': n, 'api_version': v})
+ username = r_me.json().get('username') if r_me.status_code == 200 else 'admin'
+ rr = await c.post(
+ '/platform/subscription/subscribe',
+ json={'username': username, 'api_name': n, 'api_version': v},
+ )
assert rr.status_code in (200, 201)
+
await _create_api(authed_client, 'gqlx', 'v1')
ru = await authed_client.put('/platform/api/gqlx/v1', json={'active': False})
assert ru.status_code == 200
await _subscribe_self(authed_client, 'gqlx', 'v1')
- r = await authed_client.post('/api/graphql/gqlx', headers={'X-API-Version': 'v1', 'Content-Type': 'application/json'}, json={'query': '{__typename}'})
+ r = await authed_client.post(
+ '/api/graphql/gqlx',
+ headers={'X-API-Version': 'v1', 'Content-Type': 'application/json'},
+ json={'query': '{__typename}'},
+ )
assert r.status_code == 403
+
@pytest.mark.asyncio
async def test_api_disabled_grpc_blocks(authed_client):
async def _create_api(c, n, v):
- payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0}
+ payload = {
+ 'api_name': n,
+ 'api_version': v,
+ 'api_description': f'{n} {v}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://upstream'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ }
rr = await c.post('/platform/api', json=payload)
assert rr.status_code in (200, 201)
+
async def _subscribe_self(c, n, v):
r_me = await c.get('/platform/user/me')
- username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin')
- rr = await c.post('/platform/subscription/subscribe', json={'username': username, 'api_name': n, 'api_version': v})
+ username = r_me.json().get('username') if r_me.status_code == 200 else 'admin'
+ rr = await c.post(
+ '/platform/subscription/subscribe',
+ json={'username': username, 'api_name': n, 'api_version': v},
+ )
assert rr.status_code in (200, 201)
+
await _create_api(authed_client, 'grpcx', 'v1')
ru = await authed_client.put('/platform/api/grpcx/v1', json={'active': False})
assert ru.status_code == 200
await _subscribe_self(authed_client, 'grpcx', 'v1')
- r = await authed_client.post('/api/grpc/grpcx', headers={'X-API-Version': 'v1', 'Content-Type': 'application/json'}, json={'method': 'X', 'message': {}})
+ r = await authed_client.post(
+ '/api/grpc/grpcx',
+ headers={'X-API-Version': 'v1', 'Content-Type': 'application/json'},
+ json={'method': 'X', 'message': {}},
+ )
assert r.status_code in (400, 403, 404)
+
@pytest.mark.asyncio
async def test_api_disabled_soap_blocks(authed_client):
async def _create_api(c, n, v):
- payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0}
+ payload = {
+ 'api_name': n,
+ 'api_version': v,
+ 'api_description': f'{n} {v}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://upstream'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ }
rr = await c.post('/platform/api', json=payload)
assert rr.status_code in (200, 201)
+
async def _create_endpoint(c, n, v, m, u):
- payload = {'api_name': n, 'api_version': v, 'endpoint_method': m, 'endpoint_uri': u, 'endpoint_description': f'{m} {u}'}
+ payload = {
+ 'api_name': n,
+ 'api_version': v,
+ 'endpoint_method': m,
+ 'endpoint_uri': u,
+ 'endpoint_description': f'{m} {u}',
+ }
rr = await c.post('/platform/endpoint', json=payload)
assert rr.status_code in (200, 201)
+
async def _subscribe_self(c, n, v):
r_me = await c.get('/platform/user/me')
- username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin')
- rr = await c.post('/platform/subscription/subscribe', json={'username': username, 'api_name': n, 'api_version': v})
+ username = r_me.json().get('username') if r_me.status_code == 200 else 'admin'
+ rr = await c.post(
+ '/platform/subscription/subscribe',
+ json={'username': username, 'api_name': n, 'api_version': v},
+ )
assert rr.status_code in (200, 201)
+
await _create_api(authed_client, 'soapx', 'v1')
await _create_endpoint(authed_client, 'soapx', 'v1', 'POST', '/op')
ru = await authed_client.put('/platform/api/soapx/v1', json={'active': False})
assert ru.status_code == 200
await _subscribe_self(authed_client, 'soapx', 'v1')
- r = await authed_client.post('/api/soap/soapx/v1/op', headers={'Content-Type': 'text/xml'}, content='')
+ r = await authed_client.post(
+ '/api/soap/soapx/v1/op', headers={'Content-Type': 'text/xml'}, content=''
+ )
assert r.status_code in (403, 400, 404, 500)
+
@pytest.mark.asyncio
async def test_gateway_patch_support(monkeypatch, authed_client):
-
async def _create_api(c, n, v):
- payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0}
+ payload = {
+ 'api_name': n,
+ 'api_version': v,
+ 'api_description': f'{n} {v}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://upstream'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ }
rr = await c.post('/platform/api', json=payload)
assert rr.status_code in (200, 201)
+
async def _create_endpoint(c, n, v, m, u):
- payload = {'api_name': n, 'api_version': v, 'endpoint_method': m, 'endpoint_uri': u, 'endpoint_description': f'{m} {u}'}
+ payload = {
+ 'api_name': n,
+ 'api_version': v,
+ 'endpoint_method': m,
+ 'endpoint_uri': u,
+ 'endpoint_description': f'{m} {u}',
+ }
rr = await c.post('/platform/endpoint', json=payload)
assert rr.status_code in (200, 201)
+
async def _subscribe_self(c, n, v):
r_me = await c.get('/platform/user/me')
- username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin')
- rr = await c.post('/platform/subscription/subscribe', json={'username': username, 'api_name': n, 'api_version': v})
+ username = r_me.json().get('username') if r_me.status_code == 200 else 'admin'
+ rr = await c.post(
+ '/platform/subscription/subscribe',
+ json={'username': username, 'api_name': n, 'api_version': v},
+ )
assert rr.status_code in (200, 201)
+
await _create_api(authed_client, 'patchy', 'v1')
await _create_endpoint(authed_client, 'patchy', 'v1', 'PATCH', '/item')
await _subscribe_self(authed_client, 'patchy', 'v1')
import services.gateway_service as gs
+
class _Resp:
def __init__(self):
self.status_code = 200
self.headers = {'Content-Type': 'application/json'}
+
def json(self):
return {'ok': True}
+
@property
def text(self):
return '{"ok":true}'
+
class _Client:
- async def __aenter__(self): return self
- async def __aexit__(self, exc_type, exc, tb): return False
- async def patch(self, url, json=None, params=None, headers=None, **kw): return _Resp()
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
+ async def patch(self, url, json=None, params=None, headers=None, **kw):
+ return _Resp()
+
monkeypatch.setattr(gs, 'httpx', type('X', (), {'AsyncClient': lambda timeout=None: _Client()}))
r = await authed_client.patch('/api/rest/patchy/v1/item', json={'x': 1})
assert r.status_code in (200, 500)
diff --git a/backend-services/tests/test_api_and_endpoint_crud.py b/backend-services/tests/test_api_and_endpoint_crud.py
index f0d45f0..8fa6cf6 100644
--- a/backend-services/tests/test_api_and_endpoint_crud.py
+++ b/backend-services/tests/test_api_and_endpoint_crud.py
@@ -1,8 +1,8 @@
import pytest
+
@pytest.mark.asyncio
async def test_api_crud_flow(authed_client):
-
payload = {
'api_name': 'customer',
'api_version': 'v1',
@@ -28,8 +28,7 @@ async def test_api_crud_flow(authed_client):
assert any(a.get('api_name') == 'customer' and a.get('api_version') == 'v1' for a in apis)
upd = await authed_client.put(
- '/platform/api/customer/v1',
- json={'api_description': 'Customer API Updated'},
+ '/platform/api/customer/v1', json={'api_description': 'Customer API Updated'}
)
assert upd.status_code == 200
diff --git a/backend-services/tests/test_api_cors_control.py b/backend-services/tests/test_api_cors_control.py
index ffc143f..2dc5f42 100644
--- a/backend-services/tests/test_api_cors_control.py
+++ b/backend-services/tests/test_api_cors_control.py
@@ -1,5 +1,6 @@
import pytest
+
@pytest.mark.asyncio
async def test_rest_preflight_per_api_cors_allows_blocks(authed_client):
name, ver = 'corsrest', 'v1'
@@ -48,6 +49,7 @@ async def test_rest_preflight_per_api_cors_allows_blocks(authed_client):
assert blocked.headers.get('access-control-allow-origin') in (None, '')
+
@pytest.mark.asyncio
async def test_graphql_preflight_per_api_cors_allows(authed_client):
name, ver = 'corsgql', 'v1'
@@ -82,6 +84,7 @@ async def test_graphql_preflight_per_api_cors_allows(authed_client):
assert 'Content-Type' in (pre.headers.get('access-control-allow-headers') or '')
assert pre.headers.get('access-control-allow-credentials') == 'true'
+
@pytest.mark.asyncio
async def test_soap_preflight_per_api_cors_allows(authed_client):
name, ver = 'corssoap', 'v1'
@@ -115,4 +118,3 @@ async def test_soap_preflight_per_api_cors_allows(authed_client):
assert 'Content-Type' in (pre.headers.get('access-control-allow-headers') or '')
assert pre.headers.get('access-control-allow-credentials') in (None, 'false')
-
diff --git a/backend-services/tests/test_api_cors_headers_matrix.py b/backend-services/tests/test_api_cors_headers_matrix.py
index cc806b7..13b5d57 100644
--- a/backend-services/tests/test_api_cors_headers_matrix.py
+++ b/backend-services/tests/test_api_cors_headers_matrix.py
@@ -1,6 +1,9 @@
import pytest
-async def _setup_api_and_endpoint(client, name, ver, api_overrides=None, method='GET', uri='/status'):
+
+async def _setup_api_and_endpoint(
+ client, name, ver, api_overrides=None, method='GET', uri='/status'
+):
payload = {
'api_name': name,
'api_version': ver,
@@ -15,126 +18,186 @@ async def _setup_api_and_endpoint(client, name, ver, api_overrides=None, method=
payload.update(api_overrides)
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201), r.text
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': method,
- 'endpoint_uri': uri,
- 'endpoint_description': f'{method} {uri}',
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': method,
+ 'endpoint_uri': uri,
+ 'endpoint_description': f'{method} {uri}',
+ },
+ )
assert r2.status_code in (200, 201), r2.text
+
@pytest.mark.asyncio
async def test_api_cors_allow_origins_exact_match_allowed(authed_client):
name, ver = 'corsm1', 'v1'
- await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={
- 'api_cors_allow_origins': ['http://ok.example'],
- 'api_cors_allow_methods': ['GET'],
- 'api_cors_allow_headers': ['Content-Type'],
- })
+ await _setup_api_and_endpoint(
+ authed_client,
+ name,
+ ver,
+ api_overrides={
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['GET'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ },
+ )
r = await authed_client.options(
f'/api/rest/{name}/{ver}/status',
- headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'Content-Type'}
+ headers={
+ 'Origin': 'http://ok.example',
+ 'Access-Control-Request-Method': 'GET',
+ 'Access-Control-Request-Headers': 'Content-Type',
+ },
)
assert r.status_code == 204
assert r.headers.get('Access-Control-Allow-Origin') == 'http://ok.example'
assert r.headers.get('Vary') == 'Origin'
+
@pytest.mark.asyncio
async def test_api_cors_allow_origins_wildcard_allowed(authed_client):
name, ver = 'corsm2', 'v1'
- await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={
- 'api_cors_allow_origins': ['*'],
- 'api_cors_allow_methods': ['GET'],
- 'api_cors_allow_headers': ['Content-Type'],
- })
+ await _setup_api_and_endpoint(
+ authed_client,
+ name,
+ ver,
+ api_overrides={
+ 'api_cors_allow_origins': ['*'],
+ 'api_cors_allow_methods': ['GET'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ },
+ )
r = await authed_client.options(
f'/api/rest/{name}/{ver}/status',
- headers={'Origin': 'http://any.example', 'Access-Control-Request-Method': 'GET'}
+ headers={'Origin': 'http://any.example', 'Access-Control-Request-Method': 'GET'},
)
assert r.status_code == 204
assert r.headers.get('Access-Control-Allow-Origin') == 'http://any.example'
+
@pytest.mark.asyncio
async def test_api_cors_allow_methods_contains_options_appended(authed_client):
name, ver = 'corsm3', 'v1'
- await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={
- 'api_cors_allow_origins': ['http://ok.example'],
- 'api_cors_allow_methods': ['GET'],
- 'api_cors_allow_headers': ['Content-Type'],
- })
+ await _setup_api_and_endpoint(
+ authed_client,
+ name,
+ ver,
+ api_overrides={
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['GET'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ },
+ )
r = await authed_client.options(
f'/api/rest/{name}/{ver}/status',
- headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET'}
+ headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET'},
)
assert r.status_code == 204
- methods = [m.strip().upper() for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',') if m.strip()]
+ methods = [
+ m.strip().upper()
+ for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',')
+ if m.strip()
+ ]
assert 'OPTIONS' in methods
+
@pytest.mark.asyncio
async def test_api_cors_allow_headers_asterisk_allows_any(authed_client):
name, ver = 'corsm4', 'v1'
- await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={
- 'api_cors_allow_origins': ['http://ok.example'],
- 'api_cors_allow_methods': ['GET'],
- 'api_cors_allow_headers': ['*'],
- })
+ await _setup_api_and_endpoint(
+ authed_client,
+ name,
+ ver,
+ api_overrides={
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['GET'],
+ 'api_cors_allow_headers': ['*'],
+ },
+ )
r = await authed_client.options(
f'/api/rest/{name}/{ver}/status',
- headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-Random-Header'}
+ headers={
+ 'Origin': 'http://ok.example',
+ 'Access-Control-Request-Method': 'GET',
+ 'Access-Control-Request-Headers': 'X-Random-Header',
+ },
)
assert r.status_code == 204
ach = r.headers.get('Access-Control-Allow-Headers') or ''
assert '*' in ach
+
@pytest.mark.asyncio
async def test_api_cors_allow_headers_specific_disallows_others(authed_client):
name, ver = 'corsm5', 'v1'
- await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={
- 'api_cors_allow_origins': ['http://ok.example'],
- 'api_cors_allow_methods': ['GET'],
- 'api_cors_allow_headers': ['Content-Type'],
- })
+ await _setup_api_and_endpoint(
+ authed_client,
+ name,
+ ver,
+ api_overrides={
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['GET'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ },
+ )
r = await authed_client.options(
f'/api/rest/{name}/{ver}/status',
- headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-Other'}
+ headers={
+ 'Origin': 'http://ok.example',
+ 'Access-Control-Request-Method': 'GET',
+ 'Access-Control-Request-Headers': 'X-Other',
+ },
)
assert r.status_code == 204
ach = r.headers.get('Access-Control-Allow-Headers') or ''
assert 'Content-Type' in ach and 'X-Other' not in ach
+
@pytest.mark.asyncio
async def test_api_cors_allow_credentials_true_sets_header(authed_client):
name, ver = 'corsm6', 'v1'
- await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={
- 'api_cors_allow_origins': ['http://ok.example'],
- 'api_cors_allow_methods': ['GET'],
- 'api_cors_allow_headers': ['Content-Type'],
- 'api_cors_allow_credentials': True,
- })
+ await _setup_api_and_endpoint(
+ authed_client,
+ name,
+ ver,
+ api_overrides={
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['GET'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ 'api_cors_allow_credentials': True,
+ },
+ )
r = await authed_client.options(
f'/api/rest/{name}/{ver}/status',
- headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET'}
+ headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET'},
)
assert r.status_code == 204
assert r.headers.get('Access-Control-Allow-Credentials') == 'true'
+
@pytest.mark.asyncio
async def test_api_cors_expose_headers_propagated(authed_client):
name, ver = 'corsm7', 'v1'
expose = ['X-Resp-Id', 'X-Trace-Id']
- await _setup_api_and_endpoint(authed_client, name, ver, api_overrides={
- 'api_cors_allow_origins': ['http://ok.example'],
- 'api_cors_allow_methods': ['GET'],
- 'api_cors_allow_headers': ['Content-Type'],
- 'api_cors_expose_headers': expose,
- })
+ await _setup_api_and_endpoint(
+ authed_client,
+ name,
+ ver,
+ api_overrides={
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['GET'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ 'api_cors_expose_headers': expose,
+ },
+ )
r = await authed_client.options(
f'/api/rest/{name}/{ver}/status',
- headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET'}
+ headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET'},
)
assert r.status_code == 204
aceh = r.headers.get('Access-Control-Expose-Headers') or ''
for h in expose:
assert h in aceh
-
diff --git a/backend-services/tests/test_api_crud_failures.py b/backend-services/tests/test_api_crud_failures.py
index ea82d16..b57a044 100644
--- a/backend-services/tests/test_api_crud_failures.py
+++ b/backend-services/tests/test_api_crud_failures.py
@@ -1,5 +1,6 @@
import pytest
+
@pytest.mark.asyncio
async def test_update_delete_nonexistent_api(authed_client):
u = await authed_client.put('/platform/api/doesnot/v9', json={'api_description': 'x'})
@@ -7,4 +8,3 @@ async def test_update_delete_nonexistent_api(authed_client):
d = await authed_client.delete('/platform/api/doesnot/v9')
assert d.status_code in (400, 404)
-
diff --git a/backend-services/tests/test_api_get_by_name_version.py b/backend-services/tests/test_api_get_by_name_version.py
new file mode 100644
index 0000000..55f2512
--- /dev/null
+++ b/backend-services/tests/test_api_get_by_name_version.py
@@ -0,0 +1,27 @@
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_api_get_by_name_version_returns_200(authed_client):
+ name, ver = 'apiget', 'v1'
+ r = await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'demo',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://upstream.invalid'],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ assert r.status_code in (200, 201), r.text
+
+ g = await authed_client.get(f'/platform/api/{name}/{ver}')
+ assert g.status_code == 200, g.text
+ body = g.json().get('response', g.json())
+ assert body.get('api_name') == name
+ assert body.get('api_version') == ver
+ assert '_id' not in body
diff --git a/backend-services/tests/test_async_collection.py b/backend-services/tests/test_async_collection.py
index 1267573..aa73405 100644
--- a/backend-services/tests/test_async_collection.py
+++ b/backend-services/tests/test_async_collection.py
@@ -1,14 +1,17 @@
#!/usr/bin/env python3
"""Test script to verify async collection behavior"""
+
import asyncio
import os
+
os.environ['MEM_OR_EXTERNAL'] = 'MEM'
from utils.database_async import async_database
+
async def test_async_collection():
- print("Testing async collection...")
-
+ print('Testing async collection...')
+
# Insert a test tier
test_tier = {
'tier_id': 'test',
@@ -17,31 +20,32 @@ async def test_async_collection():
'limits': {},
'features': [],
'is_default': False,
- 'enabled': True
+ 'enabled': True,
}
-
+
await async_database.db.tiers.insert_one(test_tier)
- print("Inserted test tier")
-
+ print('Inserted test tier')
+
# Try to list tiers
cursor = async_database.db.tiers.find({})
- print(f"Cursor type: {type(cursor)}")
- print(f"Cursor has skip: {hasattr(cursor, 'skip')}")
- print(f"Cursor has limit: {hasattr(cursor, 'limit')}")
-
+ print(f'Cursor type: {type(cursor)}')
+ print(f'Cursor has skip: {hasattr(cursor, "skip")}')
+ print(f'Cursor has limit: {hasattr(cursor, "limit")}')
+
# Test skip/limit
cursor = cursor.skip(0).limit(10)
- print("Applied skip/limit")
-
+ print('Applied skip/limit')
+
# Iterate
tiers = []
async for tier_data in cursor:
- print(f"Found tier: {tier_data.get('tier_id')}")
+ print(f'Found tier: {tier_data.get("tier_id")}')
tiers.append(tier_data)
-
- print(f"Total tiers found: {len(tiers)}")
+
+ print(f'Total tiers found: {len(tiers)}')
return tiers
+
if __name__ == '__main__':
result = asyncio.run(test_async_collection())
- print(f"Result: {result}")
+ print(f'Result: {result}')
diff --git a/backend-services/tests/test_async_db_wrappers.py b/backend-services/tests/test_async_db_wrappers.py
new file mode 100644
index 0000000..c85da11
--- /dev/null
+++ b/backend-services/tests/test_async_db_wrappers.py
@@ -0,0 +1,94 @@
+import pytest
+
+from utils.async_db import db_delete_one, db_find_one, db_insert_one, db_update_one
+from utils.database_async import async_database
+
+
+@pytest.mark.asyncio
+async def test_async_wrappers_with_inmemory_async_collections():
+ coll = async_database.db.tiers # AsyncInMemoryCollection
+
+ # Insert
+ await db_insert_one(coll, {'tier_id': 't1', 'name': 'Tier 1'})
+ doc = await db_find_one(coll, {'tier_id': 't1'})
+ assert doc and doc.get('name') == 'Tier 1'
+
+ # Update
+ await db_update_one(coll, {'tier_id': 't1'}, {'$set': {'name': 'Tier One'}})
+ doc2 = await db_find_one(coll, {'tier_id': 't1'})
+ assert doc2 and doc2.get('name') == 'Tier One'
+
+ # Delete
+ await db_delete_one(coll, {'tier_id': 't1'})
+ assert await db_find_one(coll, {'tier_id': 't1'}) is None
+
+
+class _SyncColl:
+ def __init__(self):
+ self._docs = []
+
+ def find_one(self, q):
+ for d in self._docs:
+ match = all(d.get(k) == v for k, v in q.items())
+ if match:
+ return dict(d)
+ return None
+
+ def insert_one(self, doc):
+ self._docs.append(dict(doc))
+
+ class R:
+ acknowledged = True
+ inserted_id = 'x'
+
+ return R()
+
+ def update_one(self, q, upd):
+ for i, d in enumerate(self._docs):
+ if all(d.get(k) == v for k, v in q.items()):
+ setv = upd.get('$set', {})
+ nd = dict(d)
+ nd.update(setv)
+ self._docs[i] = nd
+
+ class R:
+ acknowledged = True
+ modified_count = 1
+
+ return R()
+
+ class R2:
+ acknowledged = True
+ modified_count = 0
+
+ return R2()
+
+ def delete_one(self, q):
+ for i, d in enumerate(self._docs):
+ if all(d.get(k) == v for k, v in q.items()):
+ del self._docs[i]
+
+ class R:
+ acknowledged = True
+ deleted_count = 1
+
+ return R()
+
+ class R2:
+ acknowledged = True
+ deleted_count = 0
+
+ return R2()
+
+
+@pytest.mark.asyncio
+async def test_async_wrappers_fallback_to_thread_for_sync_collections():
+ coll = _SyncColl()
+ await db_insert_one(coll, {'k': 1, 'v': 'a'})
+ d = await db_find_one(coll, {'k': 1})
+ assert d and d['v'] == 'a'
+ await db_update_one(coll, {'k': 1}, {'$set': {'v': 'b'}})
+ d2 = await db_find_one(coll, {'k': 1})
+ assert d2 and d2['v'] == 'b'
+ await db_delete_one(coll, {'k': 1})
+ assert await db_find_one(coll, {'k': 1}) is None
diff --git a/backend-services/tests/test_async_endpoints.py b/backend-services/tests/test_async_endpoints.py
index 2c6abef..551d649 100644
--- a/backend-services/tests/test_async_endpoints.py
+++ b/backend-services/tests/test_async_endpoints.py
@@ -5,55 +5,47 @@ The contents of this file are property of Doorman Dev, LLC
Review the Apache License 2.0 for valid authorization of use
"""
-from fastapi import APIRouter, HTTPException
-from typing import Dict, Any
import asyncio
import time
+from typing import Any
-from utils.database_async import (
- user_collection as async_user_collection,
- api_collection as async_api_collection,
- async_database
-)
+from fastapi import APIRouter, HTTPException
+
+from utils.database import api_collection as sync_api_collection
+from utils.database import user_collection as sync_user_collection
+from utils.database_async import api_collection as async_api_collection
+from utils.database_async import async_database
+from utils.database_async import user_collection as async_user_collection
from utils.doorman_cache_async import async_doorman_cache
-
-from utils.database import (
- user_collection as sync_user_collection,
- api_collection as sync_api_collection
-)
from utils.doorman_cache_util import doorman_cache
-router = APIRouter(prefix="/test/async", tags=["Async Testing"])
+router = APIRouter(prefix='/test/async', tags=['Async Testing'])
-@router.get("/health")
-async def async_health_check() -> Dict[str, Any]:
+
+@router.get('/health')
+async def async_health_check() -> dict[str, Any]:
"""Test async database and cache health."""
try:
if async_database.is_memory_only():
- db_status = "memory_only"
+ db_status = 'memory_only'
else:
await async_user_collection.find_one({'username': 'admin'})
- db_status = "connected"
+ db_status = 'connected'
cache_operational = await async_doorman_cache.is_operational()
cache_info = await async_doorman_cache.get_cache_info()
return {
- "status": "healthy",
- "database": {
- "status": db_status,
- "mode": async_database.get_mode_info()
- },
- "cache": {
- "operational": cache_operational,
- "info": cache_info
- }
+ 'status': 'healthy',
+ 'database': {'status': db_status, 'mode': async_database.get_mode_info()},
+ 'cache': {'operational': cache_operational, 'info': cache_info},
}
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}")
+ raise HTTPException(status_code=500, detail=f'Health check failed: {str(e)}')
-@router.get("/performance/sync")
-async def test_sync_performance() -> Dict[str, Any]:
+
+@router.get('/performance/sync')
+async def test_sync_performance() -> dict[str, Any]:
"""Test SYNC (blocking) database operations - SLOW under load."""
start_time = time.time()
@@ -68,17 +60,18 @@ async def test_sync_performance() -> Dict[str, Any]:
elapsed = time.time() - start_time
return {
- "method": "sync (blocking)",
- "elapsed_ms": round(elapsed * 1000, 2),
- "user_found": user is not None,
- "apis_count": len(apis),
- "warning": "This endpoint blocks the event loop and causes poor performance under load"
+ 'method': 'sync (blocking)',
+ 'elapsed_ms': round(elapsed * 1000, 2),
+ 'user_found': user is not None,
+ 'apis_count': len(apis),
+ 'warning': 'This endpoint blocks the event loop and causes poor performance under load',
}
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Sync test failed: {str(e)}")
+ raise HTTPException(status_code=500, detail=f'Sync test failed: {str(e)}')
-@router.get("/performance/async")
-async def test_async_performance() -> Dict[str, Any]:
+
+@router.get('/performance/async')
+async def test_async_performance() -> dict[str, Any]:
"""Test ASYNC (non-blocking) database operations - FAST under load."""
start_time = time.time()
@@ -98,17 +91,18 @@ async def test_async_performance() -> Dict[str, Any]:
elapsed = time.time() - start_time
return {
- "method": "async (non-blocking)",
- "elapsed_ms": round(elapsed * 1000, 2),
- "user_found": user is not None,
- "apis_count": len(apis),
- "note": "This endpoint does NOT block the event loop and performs well under load"
+ 'method': 'async (non-blocking)',
+ 'elapsed_ms': round(elapsed * 1000, 2),
+ 'user_found': user is not None,
+ 'apis_count': len(apis),
+ 'note': 'This endpoint does NOT block the event loop and performs well under load',
}
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Async test failed: {str(e)}")
+ raise HTTPException(status_code=500, detail=f'Async test failed: {str(e)}')
-@router.get("/performance/parallel")
-async def test_parallel_performance() -> Dict[str, Any]:
+
+@router.get('/performance/parallel')
+async def test_parallel_performance() -> dict[str, Any]:
"""Test PARALLEL async operations - Maximum performance."""
start_time = time.time()
@@ -116,19 +110,13 @@ async def test_parallel_performance() -> Dict[str, Any]:
user_task = async_user_collection.find_one({'username': 'admin'})
if async_database.is_memory_only():
- apis_task = asyncio.to_thread(
- lambda: list(async_api_collection.find({}).limit(10))
- )
+ apis_task = asyncio.to_thread(lambda: list(async_api_collection.find({}).limit(10)))
else:
apis_task = async_api_collection.find({}).limit(10).to_list(length=10)
cache_task = async_doorman_cache.get_cache('user_cache', 'admin')
- user, apis, cached_user = await asyncio.gather(
- user_task,
- apis_task,
- cache_task
- )
+ user, apis, cached_user = await asyncio.gather(user_task, apis_task, cache_task)
if not cached_user and user:
await async_doorman_cache.set_cache('user_cache', 'admin', user)
@@ -136,25 +124,22 @@ async def test_parallel_performance() -> Dict[str, Any]:
elapsed = time.time() - start_time
return {
- "method": "async parallel (non-blocking + concurrent)",
- "elapsed_ms": round(elapsed * 1000, 2),
- "user_found": user is not None,
- "apis_count": len(apis) if apis else 0,
- "note": "Operations executed in parallel for maximum performance"
+ 'method': 'async parallel (non-blocking + concurrent)',
+ 'elapsed_ms': round(elapsed * 1000, 2),
+ 'user_found': user is not None,
+ 'apis_count': len(apis) if apis else 0,
+ 'note': 'Operations executed in parallel for maximum performance',
}
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Parallel test failed: {str(e)}")
+ raise HTTPException(status_code=500, detail=f'Parallel test failed: {str(e)}')
-@router.get("/cache/test")
-async def test_cache_operations() -> Dict[str, Any]:
+
+@router.get('/cache/test')
+async def test_cache_operations() -> dict[str, Any]:
"""Test async cache operations."""
try:
- test_key = "test_user_123"
- test_value = {
- "username": "test_user_123",
- "email": "test@example.com",
- "role": "user"
- }
+ test_key = 'test_user_123'
+ test_value = {'username': 'test_user_123', 'email': 'test@example.com', 'role': 'user'}
await async_doorman_cache.set_cache('user_cache', test_key, test_value)
@@ -165,16 +150,17 @@ async def test_cache_operations() -> Dict[str, Any]:
after_delete = await async_doorman_cache.get_cache('user_cache', test_key)
return {
- "set": "success",
- "get": "success" if retrieved == test_value else "failed",
- "delete": "success" if after_delete is None else "failed",
- "cache_info": await async_doorman_cache.get_cache_info()
+ 'set': 'success',
+ 'get': 'success' if retrieved == test_value else 'failed',
+ 'delete': 'success' if after_delete is None else 'failed',
+ 'cache_info': await async_doorman_cache.get_cache_info(),
}
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Cache test failed: {str(e)}")
+ raise HTTPException(status_code=500, detail=f'Cache test failed: {str(e)}')
-@router.get("/load-test-compare")
-async def load_test_comparison() -> Dict[str, Any]:
+
+@router.get('/load-test-compare')
+async def load_test_comparison() -> dict[str, Any]:
"""
Compare sync vs async performance under simulated load.
@@ -183,33 +169,30 @@ async def load_test_comparison() -> Dict[str, Any]:
try:
sync_start = time.time()
sync_results = []
- for i in range(10):
+ for _i in range(10):
user = sync_user_collection.find_one({'username': 'admin'})
sync_results.append(user is not None)
sync_elapsed = time.time() - sync_start
async_start = time.time()
- async_tasks = [
- async_user_collection.find_one({'username': 'admin'})
- for i in range(10)
- ]
- async_results = await asyncio.gather(*async_tasks)
+ async_tasks = [async_user_collection.find_one({'username': 'admin'}) for i in range(10)]
+ await asyncio.gather(*async_tasks)
async_elapsed = time.time() - async_start
speedup = sync_elapsed / async_elapsed if async_elapsed > 0 else 0
return {
- "test": "10 concurrent user lookups",
- "sync": {
- "elapsed_ms": round(sync_elapsed * 1000, 2),
- "queries_per_second": round(10 / sync_elapsed, 2)
+ 'test': '10 concurrent user lookups',
+ 'sync': {
+ 'elapsed_ms': round(sync_elapsed * 1000, 2),
+ 'queries_per_second': round(10 / sync_elapsed, 2),
},
- "async": {
- "elapsed_ms": round(async_elapsed * 1000, 2),
- "queries_per_second": round(10 / async_elapsed, 2)
+ 'async': {
+ 'elapsed_ms': round(async_elapsed * 1000, 2),
+ 'queries_per_second': round(10 / async_elapsed, 2),
},
- "speedup": f"{round(speedup, 2)}x faster",
- "note": "Async shows significant improvement with concurrent operations"
+ 'speedup': f'{round(speedup, 2)}x faster',
+ 'note': 'Async shows significant improvement with concurrent operations',
}
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Load test failed: {str(e)}")
+ raise HTTPException(status_code=500, detail=f'Load test failed: {str(e)}')
diff --git a/backend-services/tests/test_audit_and_export_negatives.py b/backend-services/tests/test_audit_and_export_negatives.py
index b8f29c1..69ba26c 100644
--- a/backend-services/tests/test_audit_and_export_negatives.py
+++ b/backend-services/tests/test_audit_and_export_negatives.py
@@ -1,25 +1,35 @@
import pytest
+
class _AuditSpy:
def __init__(self):
self.calls = []
+
def info(self, msg):
self.calls.append(msg)
+
@pytest.mark.asyncio
async def test_audit_api_create_update_delete(monkeypatch, authed_client):
-
import utils.audit_util as au
+
orig = au._logger
spy = _AuditSpy()
au._logger = spy
try:
-
- r = await authed_client.post('/platform/api', json={
- 'api_name': 'aud1', 'api_version': 'v1', 'api_description': 'd',
- 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0,
- })
+ r = await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': 'aud1',
+ 'api_version': 'v1',
+ 'api_description': 'd',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://upstream'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
assert r.status_code in (200, 201)
r = await authed_client.put('/platform/api/aud1/v1', json={'api_description': 'd2'})
@@ -34,27 +44,44 @@ async def test_audit_api_create_update_delete(monkeypatch, authed_client):
finally:
au._logger = orig
+
@pytest.mark.asyncio
async def test_audit_user_credits_and_subscriptions(monkeypatch, authed_client):
import utils.audit_util as au
+
orig = au._logger
spy = _AuditSpy()
au._logger = spy
try:
-
- r = await authed_client.post('/platform/credit/admin', json={'username': 'admin', 'users_credits': {}})
+ r = await authed_client.post(
+ '/platform/credit/admin', json={'username': 'admin', 'users_credits': {}}
+ )
assert r.status_code in (200, 201)
- await authed_client.post('/platform/api', json={
- 'api_name': 'aud2', 'api_version': 'v1', 'api_description': 'd',
- 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0,
- })
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': 'aud2',
+ 'api_version': 'v1',
+ 'api_description': 'd',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://upstream'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
r_me = await authed_client.get('/platform/user/me')
- username = (r_me.json().get('username') if r_me.status_code == 200 else 'admin')
- r = await authed_client.post('/platform/subscription/subscribe', json={'username': username, 'api_name': 'aud2', 'api_version': 'v1'})
+ username = r_me.json().get('username') if r_me.status_code == 200 else 'admin'
+ r = await authed_client.post(
+ '/platform/subscription/subscribe',
+ json={'username': username, 'api_name': 'aud2', 'api_version': 'v1'},
+ )
assert r.status_code in (200, 201)
- r = await authed_client.post('/platform/subscription/unsubscribe', json={'username': username, 'api_name': 'aud2', 'api_version': 'v1'})
+ r = await authed_client.post(
+ '/platform/subscription/unsubscribe',
+ json={'username': username, 'api_name': 'aud2', 'api_version': 'v1'},
+ )
assert r.status_code in (200, 201, 400)
assert any('user_credits.save' in c for c in spy.calls)
@@ -63,14 +90,20 @@ async def test_audit_user_credits_and_subscriptions(monkeypatch, authed_client):
finally:
au._logger = orig
+
@pytest.mark.asyncio
async def test_export_not_found_cases(authed_client):
- r = await authed_client.get('/platform/config/export/apis', params={'api_name': 'nope', 'api_version': 'v9'})
+ r = await authed_client.get(
+ '/platform/config/export/apis', params={'api_name': 'nope', 'api_version': 'v9'}
+ )
assert r.status_code == 404
r = await authed_client.get('/platform/config/export/roles', params={'role_name': 'nope-role'})
assert r.status_code == 404
- r = await authed_client.get('/platform/config/export/groups', params={'group_name': 'nope-group'})
+ r = await authed_client.get(
+ '/platform/config/export/groups', params={'group_name': 'nope-group'}
+ )
assert r.status_code == 404
- r = await authed_client.get('/platform/config/export/routings', params={'client_key': 'nope-key'})
+ r = await authed_client.get(
+ '/platform/config/export/routings', params={'client_key': 'nope-key'}
+ )
assert r.status_code == 404
-
diff --git a/backend-services/tests/test_auth.py b/backend-services/tests/test_auth.py
index 35a1dcd..04df3ae 100644
--- a/backend-services/tests/test_auth.py
+++ b/backend-services/tests/test_auth.py
@@ -1,6 +1,8 @@
import os
+
import pytest
+
@pytest.mark.asyncio
async def test_authorization_login_and_status(client):
resp = await client.post(
@@ -19,9 +21,9 @@ async def test_authorization_login_and_status(client):
body = status.json()
assert body.get('message') == 'Token is valid'
+
@pytest.mark.asyncio
async def test_auth_refresh_and_invalidate(authed_client):
-
r = await authed_client.post('/platform/authorization/refresh')
assert r.status_code == 200
@@ -31,10 +33,10 @@ async def test_auth_refresh_and_invalidate(authed_client):
status = await authed_client.get('/platform/authorization/status')
assert status.status_code in (401, 500)
+
@pytest.mark.asyncio
async def test_authorization_invalid_login(client):
resp = await client.post(
- '/platform/authorization',
- json={'email': 'unknown@example.com', 'password': 'bad'},
+ '/platform/authorization', json={'email': 'unknown@example.com', 'password': 'bad'}
)
assert resp.status_code in (400, 401)
diff --git a/backend-services/tests/test_auth_admin.py b/backend-services/tests/test_auth_admin.py
index 16c4f81..ef443ac 100644
--- a/backend-services/tests/test_auth_admin.py
+++ b/backend-services/tests/test_auth_admin.py
@@ -1,20 +1,18 @@
import os
+
import pytest
+
@pytest.mark.asyncio
async def test_auth_admin_endpoints(authed_client):
-
- r = await authed_client.put(
- '/platform/role/admin',
- json={'manage_auth': True},
- )
+ r = await authed_client.put('/platform/role/admin', json={'manage_auth': True})
assert r.status_code == 200, r.text
relog = await authed_client.post(
'/platform/authorization',
json={
'email': os.environ.get('DOORMAN_ADMIN_EMAIL'),
- 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD')
+ 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD'),
},
)
assert relog.status_code == 200, relog.text
@@ -27,7 +25,7 @@ async def test_auth_admin_endpoints(authed_client):
'password': 'VerySecurePassword!123',
'role': 'admin',
'groups': ['ALL'],
- 'active': True
+ 'active': True,
},
)
assert create.status_code in (200, 201), create.text
diff --git a/backend-services/tests/test_auth_csrf_https.py b/backend-services/tests/test_auth_csrf_https.py
index f7c42e5..7663c7e 100644
--- a/backend-services/tests/test_auth_csrf_https.py
+++ b/backend-services/tests/test_auth_csrf_https.py
@@ -1,17 +1,20 @@
import pytest
+
@pytest.mark.asyncio
async def test_auth_rejects_missing_csrf_when_https_only(monkeypatch, authed_client):
monkeypatch.setenv('HTTPS_ONLY', 'true')
r = await authed_client.get('/platform/user/me')
assert r.status_code == 401
+
@pytest.mark.asyncio
async def test_auth_rejects_mismatched_csrf_when_https_only(monkeypatch, authed_client):
monkeypatch.setenv('HTTPS_ONLY', 'true')
r = await authed_client.get('/platform/user/me', headers={'X-CSRF-Token': 'not-the-cookie'})
assert r.status_code == 401
+
@pytest.mark.asyncio
async def test_auth_accepts_matching_csrf(monkeypatch, authed_client):
monkeypatch.setenv('HTTPS_ONLY', 'true')
@@ -24,9 +27,9 @@ async def test_auth_accepts_matching_csrf(monkeypatch, authed_client):
r = await authed_client.get('/platform/user/me', headers={'X-CSRF-Token': csrf})
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_auth_http_mode_skips_csrf_validation(monkeypatch, authed_client):
monkeypatch.setenv('HTTPS_ONLY', 'false')
r = await authed_client.get('/platform/user/me')
assert r.status_code == 200
-
diff --git a/backend-services/tests/test_auth_guard.py b/backend-services/tests/test_auth_guard.py
index 1651331..0ffa82c 100644
--- a/backend-services/tests/test_auth_guard.py
+++ b/backend-services/tests/test_auth_guard.py
@@ -1,7 +1,7 @@
import pytest
+
@pytest.mark.asyncio
async def test_unauthorized_access_rejected(client):
-
me = await client.get('/platform/user/me')
assert me.status_code in (401, 500)
diff --git a/backend-services/tests/test_auth_refresh_unauthenticated.py b/backend-services/tests/test_auth_refresh_unauthenticated.py
index 1024476..b5c5801 100644
--- a/backend-services/tests/test_auth_refresh_unauthenticated.py
+++ b/backend-services/tests/test_auth_refresh_unauthenticated.py
@@ -1,7 +1,7 @@
import pytest
+
@pytest.mark.asyncio
async def test_refresh_requires_auth(client):
r = await client.post('/platform/authorization/refresh')
assert r.status_code == 401
-
diff --git a/backend-services/tests/test_bandwidth_and_monitor.py b/backend-services/tests/test_bandwidth_and_monitor.py
index 0fd1475..d2654d7 100644
--- a/backend-services/tests/test_bandwidth_and_monitor.py
+++ b/backend-services/tests/test_bandwidth_and_monitor.py
@@ -1,16 +1,22 @@
import json
+
import pytest
+
@pytest.mark.asyncio
async def test_bandwidth_enforcement_and_usage_tracking(monkeypatch, authed_client):
try:
- upd = await authed_client.put('/platform/user/admin', json={'bandwidth_limit_bytes': 80, 'bandwidth_limit_window': 'second'})
+ upd = await authed_client.put(
+ '/platform/user/admin',
+ json={'bandwidth_limit_bytes': 80, 'bandwidth_limit_window': 'second'},
+ )
assert upd.status_code in (200, 204)
except AssertionError:
await authed_client.put('/platform/user/admin', json={'bandwidth_limit_bytes': None})
raise
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'bwapi', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/p')
@@ -24,16 +30,20 @@ async def test_bandwidth_enforcement_and_usage_tracking(monkeypatch, authed_clie
self.headers = {'Content-Type': 'application/json', 'Content-Length': str(len(body))}
self.text = body.decode('utf-8')
self.content = body
+
def json(self):
return json.loads(self.text)
class _FakeAsyncClient:
def __init__(self, timeout=None, limits=None, http2=False):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -51,12 +61,16 @@ async def test_bandwidth_enforcement_and_usage_tracking(monkeypatch, authed_clie
return await self.put(url, **kwargs)
else:
return _FakeHTTPResponse(405)
+
async def get(self, url, **kwargs):
return _FakeHTTPResponse(200)
+
async def post(self, url, data=None, json=None, headers=None, params=None, **kwargs):
return _FakeHTTPResponse(200)
+
async def put(self, url, **kwargs):
return _FakeHTTPResponse(200)
+
async def delete(self, url, **kwargs):
return _FakeHTTPResponse(200)
@@ -82,9 +96,11 @@ async def test_bandwidth_enforcement_and_usage_tracking(monkeypatch, authed_clie
await authed_client.put('/platform/user/admin', json={'bandwidth_limit_bytes': 0})
+
@pytest.mark.asyncio
async def test_monitor_tracks_bytes_in_out(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'bwmon', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/echo')
@@ -100,13 +116,20 @@ async def test_monitor_tracks_bytes_in_out(monkeypatch, authed_client):
self.headers = {'Content-Type': 'application/json', 'Content-Length': str(len(body))}
self.text = body.decode('utf-8')
self.content = body
+
def json(self):
return json.loads(self.text)
class _FakeAsyncClient:
- def __init__(self, timeout=None, limits=None, http2=False): pass
- async def __aenter__(self): return self
- async def __aexit__(self, exc_type, exc, tb): return False
+ def __init__(self, timeout=None, limits=None, http2=False):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -124,10 +147,18 @@ async def test_monitor_tracks_bytes_in_out(monkeypatch, authed_client):
return await self.put(url, **kwargs)
else:
return _FakeHTTPResponse(405)
- async def get(self, url, **kwargs): return _FakeHTTPResponse(200)
- async def post(self, url, data=None, json=None, headers=None, params=None, **kwargs): return _FakeHTTPResponse(200)
- async def put(self, url, **kwargs): return _FakeHTTPResponse(200)
- async def delete(self, url, **kwargs): return _FakeHTTPResponse(200)
+
+ async def get(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def post(self, url, data=None, json=None, headers=None, params=None, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def put(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def delete(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
diff --git a/backend-services/tests/test_bandwidth_limit_windows.py b/backend-services/tests/test_bandwidth_limit_windows.py
index 43a8dbf..7d07407 100644
--- a/backend-services/tests/test_bandwidth_limit_windows.py
+++ b/backend-services/tests/test_bandwidth_limit_windows.py
@@ -1,26 +1,36 @@
-import pytest
import time
+import pytest
from tests.test_gateway_routing_limits import _FakeAsyncClient
+
async def _setup_basic_rest(client, name='bw', ver='v1', method='GET', uri='/p'):
from conftest import create_api, create_endpoint, subscribe_self
+
await create_api(client, name, ver)
await create_endpoint(client, name, ver, method, uri)
await subscribe_self(client, name, ver)
return name, ver, uri
+
@pytest.mark.asyncio
async def test_bandwidth_limit_blocks_when_exceeded(monkeypatch, authed_client):
name, ver, uri = await _setup_basic_rest(authed_client, name='bw1', method='GET', uri='/g')
from utils.database import user_collection
- user_collection.update_one({'username': 'admin'}, {'$set': {
- 'bandwidth_limit_bytes': 1,
- 'bandwidth_limit_window': 'second',
- 'bandwidth_limit_enabled': True,
- }})
+
+ user_collection.update_one(
+ {'username': 'admin'},
+ {
+ '$set': {
+ 'bandwidth_limit_bytes': 1,
+ 'bandwidth_limit_window': 'second',
+ 'bandwidth_limit_enabled': True,
+ }
+ },
+ )
await authed_client.delete('/api/caches')
from utils.doorman_cache_util import doorman_cache
+
try:
for k in list(doorman_cache.cache.keys('bandwidth_usage:admin*')):
doorman_cache.cache.delete(k)
@@ -28,6 +38,7 @@ async def test_bandwidth_limit_blocks_when_exceeded(monkeypatch, authed_client):
pass
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
ok = await authed_client.get(f'/api/rest/{name}/{ver}{uri}')
@@ -35,17 +46,25 @@ async def test_bandwidth_limit_blocks_when_exceeded(monkeypatch, authed_client):
blocked = await authed_client.get(f'/api/rest/{name}/{ver}{uri}')
assert blocked.status_code == 429
+
@pytest.mark.asyncio
async def test_bandwidth_limit_resets_after_window(monkeypatch, authed_client):
name, ver, uri = await _setup_basic_rest(authed_client, name='bw2', method='GET', uri='/g')
from utils.database import user_collection
- user_collection.update_one({'username': 'admin'}, {'$set': {
- 'bandwidth_limit_bytes': 1,
- 'bandwidth_limit_window': 'second',
- 'bandwidth_limit_enabled': True,
- }})
+
+ user_collection.update_one(
+ {'username': 'admin'},
+ {
+ '$set': {
+ 'bandwidth_limit_bytes': 1,
+ 'bandwidth_limit_window': 'second',
+ 'bandwidth_limit_enabled': True,
+ }
+ },
+ )
await authed_client.delete('/api/caches')
from utils.doorman_cache_util import doorman_cache
+
try:
for k in list(doorman_cache.cache.keys('bandwidth_usage:admin*')):
doorman_cache.cache.delete(k)
@@ -53,6 +72,7 @@ async def test_bandwidth_limit_resets_after_window(monkeypatch, authed_client):
pass
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r1 = await authed_client.get(f'/api/rest/{name}/{ver}{uri}')
@@ -63,17 +83,25 @@ async def test_bandwidth_limit_resets_after_window(monkeypatch, authed_client):
r3 = await authed_client.get(f'/api/rest/{name}/{ver}{uri}')
assert r3.status_code == 200
+
@pytest.mark.asyncio
async def test_bandwidth_limit_counts_request_and_response_bytes(monkeypatch, authed_client):
name, ver, uri = await _setup_basic_rest(authed_client, name='bw3', method='POST', uri='/p')
from utils.database import user_collection
- user_collection.update_one({'username': 'admin'}, {'$set': {
- 'bandwidth_limit_bytes': 1_000_000,
- 'bandwidth_limit_window': 'second',
- 'bandwidth_limit_enabled': True,
- }})
+
+ user_collection.update_one(
+ {'username': 'admin'},
+ {
+ '$set': {
+ 'bandwidth_limit_bytes': 1_000_000,
+ 'bandwidth_limit_window': 'second',
+ 'bandwidth_limit_enabled': True,
+ }
+ },
+ )
await authed_client.delete('/api/caches')
from utils.doorman_cache_util import doorman_cache
+
try:
for k in list(doorman_cache.cache.keys('bandwidth_usage:admin*')):
doorman_cache.cache.delete(k)
@@ -81,9 +109,11 @@ async def test_bandwidth_limit_counts_request_and_response_bytes(monkeypatch, au
pass
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
from utils.bandwidth_util import get_current_usage
+
before = get_current_usage('admin', 'second')
payload = 'x' * 1234
r = await authed_client.post(f'/api/rest/{name}/{ver}{uri}', json={'data': payload})
diff --git a/backend-services/tests/test_cache_bytes_serialization.py b/backend-services/tests/test_cache_bytes_serialization.py
new file mode 100644
index 0000000..7da9e78
--- /dev/null
+++ b/backend-services/tests/test_cache_bytes_serialization.py
@@ -0,0 +1,18 @@
+import bcrypt
+
+from utils.doorman_cache_util import doorman_cache
+
+
+def test_cache_serializes_bytes_password_to_json_string():
+ doc = {
+ 'username': 'u1',
+ 'password': bcrypt.hashpw(b'super-secret', bcrypt.gensalt()), # bytes
+ 'role': 'user',
+ }
+
+ doorman_cache.set_cache('user_cache', 'u1', doc)
+ out = doorman_cache.get_cache('user_cache', 'u1')
+
+ assert isinstance(out, dict)
+ assert isinstance(out.get('password'), str)
+ assert out['password'].startswith('$2b$') or out['password'].startswith('$2a$')
diff --git a/backend-services/tests/test_chunked_encoding_body_limit.py b/backend-services/tests/test_chunked_encoding_body_limit.py
index 4defa3f..d5284ad 100644
--- a/backend-services/tests/test_chunked_encoding_body_limit.py
+++ b/backend-services/tests/test_chunked_encoding_body_limit.py
@@ -7,20 +7,23 @@ vulnerability where attackers could stream unlimited data without a
Content-Length header.
"""
-import pytest
-from fastapi.testclient import TestClient
import os
import sys
+import pytest
+from fastapi.testclient import TestClient
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from doorman import doorman
+
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(doorman)
+
class TestChunkedEncodingBodyLimit:
"""Test suite for chunked encoding body size limit enforcement."""
@@ -31,10 +34,7 @@ class TestChunkedEncodingBodyLimit:
response = client.post(
'/platform/authorization',
data=small_payload,
- headers={
- 'Transfer-Encoding': 'chunked',
- 'Content-Type': 'application/json'
- }
+ headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'},
)
assert response.status_code != 413
@@ -49,10 +49,7 @@ class TestChunkedEncodingBodyLimit:
response = client.post(
'/platform/authorization',
data=large_payload,
- headers={
- 'Transfer-Encoding': 'chunked',
- 'Content-Type': 'application/json'
- }
+ headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'},
)
assert response.status_code == 413
@@ -71,10 +68,7 @@ class TestChunkedEncodingBodyLimit:
response = client.post(
'/api/rest/test/v1/endpoint',
data=large_payload,
- headers={
- 'Transfer-Encoding': 'chunked',
- 'Content-Type': 'application/json'
- }
+ headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'},
)
assert response.status_code == 413
@@ -93,10 +87,7 @@ class TestChunkedEncodingBodyLimit:
response = client.post(
'/api/soap/test/v1/service',
data=medium_payload,
- headers={
- 'Transfer-Encoding': 'chunked',
- 'Content-Type': 'text/xml'
- }
+ headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'text/xml'},
)
assert response.status_code != 413
@@ -115,9 +106,7 @@ class TestChunkedEncodingBodyLimit:
response = client.post(
'/platform/authorization',
data=large_payload,
- headers={
- 'Content-Type': 'application/json'
- }
+ headers={'Content-Type': 'application/json'},
)
assert response.status_code == 413
@@ -139,8 +128,8 @@ class TestChunkedEncodingBodyLimit:
headers={
'Transfer-Encoding': 'chunked',
'Content-Length': '100',
- 'Content-Type': 'application/json'
- }
+ 'Content-Type': 'application/json',
+ },
)
assert response.status_code == 413
@@ -151,10 +140,7 @@ class TestChunkedEncodingBodyLimit:
def test_get_request_with_chunked_ignored(self, client):
"""Test that GET requests with Transfer-Encoding: chunked are not limited."""
response = client.get(
- '/platform/authorization/status',
- headers={
- 'Transfer-Encoding': 'chunked'
- }
+ '/platform/authorization/status', headers={'Transfer-Encoding': 'chunked'}
)
assert response.status_code != 413
@@ -169,10 +155,7 @@ class TestChunkedEncodingBodyLimit:
response = client.put(
'/platform/user/testuser',
data=large_payload,
- headers={
- 'Transfer-Encoding': 'chunked',
- 'Content-Type': 'application/json'
- }
+ headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'},
)
assert response.status_code == 413
@@ -190,10 +173,7 @@ class TestChunkedEncodingBodyLimit:
response = client.patch(
'/platform/user/testuser',
data=large_payload,
- headers={
- 'Transfer-Encoding': 'chunked',
- 'Content-Type': 'application/json'
- }
+ headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'},
)
assert response.status_code == 413
@@ -211,10 +191,7 @@ class TestChunkedEncodingBodyLimit:
response = client.post(
'/api/graphql/test',
data=large_query.encode(),
- headers={
- 'Transfer-Encoding': 'chunked',
- 'Content-Type': 'application/json'
- }
+ headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'},
)
assert response.status_code == 413
@@ -241,10 +218,7 @@ class TestChunkedEncodingBodyLimit:
response = client.post(
route,
data=large_payload,
- headers={
- 'Transfer-Encoding': 'chunked',
- 'Content-Type': 'application/json'
- }
+ headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'},
)
assert response.status_code == 413, f'Route {route} not protected'
@@ -262,10 +236,7 @@ class TestChunkedEncodingBodyLimit:
response = client.post(
'/platform/authorization',
data=large_payload,
- headers={
- 'Transfer-Encoding': 'chunked',
- 'Content-Type': 'application/json'
- }
+ headers={'Transfer-Encoding': 'chunked', 'Content-Type': 'application/json'},
)
assert response.status_code == 413
@@ -273,5 +244,6 @@ class TestChunkedEncodingBodyLimit:
finally:
os.environ['MAX_BODY_SIZE_BYTES'] = '1048576'
+
if __name__ == '__main__':
pytest.main([__file__, '-v'])
diff --git a/backend-services/tests/test_compression_content_types.py b/backend-services/tests/test_compression_content_types.py
index f4092aa..3317838 100644
--- a/backend-services/tests/test_compression_content_types.py
+++ b/backend-services/tests/test_compression_content_types.py
@@ -9,13 +9,14 @@ Verifies compression works correctly with:
- Various HTTP methods
"""
-import pytest
import gzip
-import json
import io
+import json
import os
import time
+import pytest
+
@pytest.mark.asyncio
async def test_compression_with_rest_gateway_json(client):
@@ -26,7 +27,7 @@ async def test_compression_with_rest_gateway_json(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
@@ -34,8 +35,7 @@ async def test_compression_with_rest_gateway_json(client):
# Test JSON response (typical REST response)
r = await client.get(
- '/platform/api',
- headers={'Accept-Encoding': 'gzip', 'Accept': 'application/json'}
+ '/platform/api', headers={'Accept-Encoding': 'gzip', 'Accept': 'application/json'}
)
assert r.status_code == 200
@@ -54,10 +54,10 @@ async def test_compression_with_rest_gateway_json(client):
compressed_size = len(compressed_buffer.getvalue())
ratio = (1 - (compressed_size / uncompressed_size)) * 100
- print(f"\nREST JSON Compression:")
- print(f" Original: {uncompressed_size} bytes")
- print(f" Compressed: {compressed_size} bytes")
- print(f" Ratio: {ratio:.1f}% reduction")
+ print('\nREST JSON Compression:')
+ print(f' Original: {uncompressed_size} bytes')
+ print(f' Compressed: {compressed_size} bytes')
+ print(f' Ratio: {ratio:.1f}% reduction')
# JSON typically compresses well
assert ratio > 20
@@ -69,7 +69,8 @@ async def test_compression_with_xml_content(client):
# XML is very verbose and should compress extremely well
# Create a mock XML response
- xml_content = """
+ xml_content = (
+ """
@@ -83,7 +84,9 @@ async def test_compression_with_xml_content(client):
- """ * 3 # Repeat to make it larger
+ """
+ * 3
+ ) # Repeat to make it larger
uncompressed_size = len(xml_content.encode('utf-8'))
@@ -95,13 +98,13 @@ async def test_compression_with_xml_content(client):
ratio = (1 - (compressed_size / uncompressed_size)) * 100
- print(f"\nXML/SOAP Compression:")
- print(f" Original: {uncompressed_size} bytes")
- print(f" Compressed: {compressed_size} bytes")
- print(f" Ratio: {ratio:.1f}% reduction")
+ print('\nXML/SOAP Compression:')
+ print(f' Original: {uncompressed_size} bytes')
+ print(f' Compressed: {compressed_size} bytes')
+ print(f' Ratio: {ratio:.1f}% reduction')
# XML should compress very well (lots of repetitive tags)
- assert ratio > 60, f"XML should compress >60%, got {ratio:.1f}%"
+ assert ratio > 60, f'XML should compress >60%, got {ratio:.1f}%'
@pytest.mark.asyncio
@@ -111,7 +114,7 @@ async def test_compression_with_graphql_style_response(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
@@ -132,10 +135,10 @@ async def test_compression_with_graphql_style_response(client):
compressed_size = len(compressed_buffer.getvalue())
ratio = (1 - (compressed_size / uncompressed_size)) * 100
- print(f"\nGraphQL-style Response Compression:")
- print(f" Original: {uncompressed_size} bytes")
- print(f" Compressed: {compressed_size} bytes")
- print(f" Ratio: {ratio:.1f}% reduction")
+ print('\nGraphQL-style Response Compression:')
+ print(f' Original: {uncompressed_size} bytes')
+ print(f' Compressed: {compressed_size} bytes')
+ print(f' Ratio: {ratio:.1f}% reduction')
@pytest.mark.asyncio
@@ -150,24 +153,20 @@ async def test_compression_with_post_requests(client):
'api_allowed_groups': ['ALL'],
'api_servers': ['http://example.com'],
'api_type': 'REST',
- 'active': True
+ 'active': True,
}
# Authenticate first
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
assert r_auth.status_code == 200
# POST with compression
- r = await client.post(
- '/platform/api',
- json=api_payload,
- headers={'Accept-Encoding': 'gzip'}
- )
+ r = await client.post('/platform/api', json=api_payload, headers={'Accept-Encoding': 'gzip'})
assert r.status_code in (200, 201)
# Should get valid response
@@ -176,7 +175,7 @@ async def test_compression_with_post_requests(client):
# Cleanup
try:
- await client.delete(f"/platform/api/{api_payload['api_name']}/v1")
+ await client.delete(f'/platform/api/{api_payload["api_name"]}/v1')
except Exception:
pass
@@ -187,7 +186,7 @@ async def test_compression_with_put_requests(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
@@ -203,21 +202,17 @@ async def test_compression_with_put_requests(client):
'api_allowed_groups': ['ALL'],
'api_servers': ['http://example.com'],
'api_type': 'REST',
- 'active': True
+ 'active': True,
}
r = await client.post('/platform/api', json=api_payload)
assert r.status_code in (200, 201)
# Update with PUT
- update_payload = {
- 'api_description': 'Updated description with compression test'
- }
+ update_payload = {'api_description': 'Updated description with compression test'}
r = await client.put(
- f'/platform/api/{api_name}/v1',
- json=update_payload,
- headers={'Accept-Encoding': 'gzip'}
+ f'/platform/api/{api_name}/v1', json=update_payload, headers={'Accept-Encoding': 'gzip'}
)
assert r.status_code == 200
@@ -238,7 +233,7 @@ async def test_compression_with_delete_requests(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
@@ -254,17 +249,14 @@ async def test_compression_with_delete_requests(client):
'api_allowed_groups': ['ALL'],
'api_servers': ['http://example.com'],
'api_type': 'REST',
- 'active': True
+ 'active': True,
}
r = await client.post('/platform/api', json=api_payload)
assert r.status_code in (200, 201)
# Delete with compression
- r = await client.delete(
- f'/platform/api/{api_name}/v1',
- headers={'Accept-Encoding': 'gzip'}
- )
+ r = await client.delete(f'/platform/api/{api_name}/v1', headers={'Accept-Encoding': 'gzip'})
assert r.status_code in (200, 204)
@@ -278,10 +270,7 @@ async def test_compression_with_error_responses(client):
pass
# Try to access protected endpoint without auth
- r = await client.get(
- '/platform/api',
- headers={'Accept-Encoding': 'gzip'}
- )
+ r = await client.get('/platform/api', headers={'Accept-Encoding': 'gzip'})
# Should be unauthorized
assert r.status_code in (401, 403)
@@ -296,17 +285,14 @@ async def test_compression_with_list_responses(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
assert r_auth.status_code == 200
# Get list of users
- r = await client.get(
- '/platform/user',
- headers={'Accept-Encoding': 'gzip'}
- )
+ r = await client.get('/platform/user', headers={'Accept-Encoding': 'gzip'})
assert r.status_code == 200
data = r.json()
@@ -322,7 +308,7 @@ async def test_compression_consistent_across_content_types(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
@@ -351,8 +337,8 @@ async def test_compression_consistent_across_content_types(client):
if ratios:
avg_ratio = sum(ratios) / len(ratios)
- print(f"\nAverage compression ratio across endpoints: {avg_ratio:.1f}%")
- print(f"Individual ratios: {[f'{r:.1f}%' for r in ratios]}")
+ print(f'\nAverage compression ratio across endpoints: {avg_ratio:.1f}%')
+ print(f'Individual ratios: {[f"{r:.1f}%" for r in ratios]}')
# Mark all tests
diff --git a/backend-services/tests/test_compression_cpu_impact.py b/backend-services/tests/test_compression_cpu_impact.py
index 8565325..4a3f241 100644
--- a/backend-services/tests/test_compression_cpu_impact.py
+++ b/backend-services/tests/test_compression_cpu_impact.py
@@ -8,12 +8,12 @@ Tests realistic scenarios:
4. Memory allocation patterns during compression
"""
-import pytest
import gzip
-import json
import io
+import json
import time
-import os
+
+import pytest
def create_realistic_response(size_category):
@@ -23,7 +23,7 @@ def create_realistic_response(size_category):
return {
'status': 'success',
'data': {'id': 123, 'name': 'John Doe'},
- 'timestamp': '2025-01-18T10:30:00Z'
+ 'timestamp': '2025-01-18T10:30:00Z',
}
elif size_category == 'medium':
# Typical REST API response (1-10 KB)
@@ -35,11 +35,11 @@ def create_realistic_response(size_category):
'name': f'User {i}',
'email': f'user{i}@example.com',
'role': 'developer',
- 'created_at': '2025-01-15T10:00:00Z'
+ 'created_at': '2025-01-15T10:00:00Z',
}
for i in range(50)
],
- 'pagination': {'page': 1, 'total': 500}
+ 'pagination': {'page': 1, 'total': 500},
}
elif size_category == 'large':
# Large API list (10-50 KB)
@@ -58,12 +58,12 @@ def create_realistic_response(size_category):
'updated_at': '2025-01-18T15:30:00Z',
'views': 1234,
'likes': 567,
- 'reviews_count': 42
- }
+ 'reviews_count': 42,
+ },
}
for i in range(200)
],
- 'pagination': {'page': 1, 'per_page': 200, 'total': 2000}
+ 'pagination': {'page': 1, 'per_page': 200, 'total': 2000},
}
else: # very_large
# Very large response (50-100 KB)
@@ -74,10 +74,10 @@ def create_realistic_response(size_category):
'id': i,
'name': f'Item {i}',
'description': 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. ' * 5,
- 'attributes': {f'attr_{j}': f'value_{j}' for j in range(20)}
+ 'attributes': {f'attr_{j}': f'value_{j}' for j in range(20)},
}
for i in range(500)
- ]
+ ],
}
@@ -111,7 +111,7 @@ def benchmark_compression(data, level, iterations=1000):
'compressed_size': avg_compressed_size,
'compression_ratio': compression_ratio,
'avg_time_ms': avg_time_ms,
- 'throughput_per_core': 1000 / avg_time_ms if avg_time_ms > 0 else 0
+ 'throughput_per_core': 1000 / avg_time_ms if avg_time_ms > 0 else 0,
}
@@ -119,53 +119,59 @@ def benchmark_compression(data, level, iterations=1000):
async def test_cpu_impact_by_response_size():
"""Measure CPU impact across different response sizes"""
- print(f"\n{'='*80}")
- print(f"CPU IMPACT ANALYSIS - BY RESPONSE SIZE")
- print(f"{'='*80}\n")
+ print(f'\n{"=" * 80}')
+ print('CPU IMPACT ANALYSIS - BY RESPONSE SIZE')
+ print(f'{"=" * 80}\n')
sizes = ['small', 'medium', 'large', 'very_large']
levels = [1, 4, 6, 9]
for size in sizes:
data = create_realistic_response(size)
- print(f"\n{size.upper()} Response:")
- print(f"{'-'*80}")
+ print(f'\n{size.upper()} Response:')
+ print(f'{"-" * 80}')
for level in levels:
result = benchmark_compression(data, level, iterations=500)
- print(f"\n Level {level}:")
- print(f" Size: {result['uncompressed_size']:,} → {result['compressed_size']:,} bytes")
- print(f" Compression ratio: {result['compression_ratio']:.1f}%")
- print(f" CPU time: {result['avg_time_ms']:.3f} ms/request")
- print(f" Max throughput: {result['throughput_per_core']:.0f} req/sec (single core)")
+ print(f'\n Level {level}:')
+ print(
+ f' Size: {result["uncompressed_size"]:,} → {result["compressed_size"]:,} bytes'
+ )
+ print(f' Compression ratio: {result["compression_ratio"]:.1f}%')
+ print(f' CPU time: {result["avg_time_ms"]:.3f} ms/request')
+ print(
+ f' Max throughput: {result["throughput_per_core"]:.0f} req/sec (single core)'
+ )
@pytest.mark.asyncio
async def test_cpu_overhead_on_total_request_time():
"""Calculate compression overhead as % of total request time"""
- print(f"\n{'='*80}")
- print(f"COMPRESSION OVERHEAD AS % OF TOTAL REQUEST TIME")
- print(f"{'='*80}\n")
+ print(f'\n{"=" * 80}')
+ print('COMPRESSION OVERHEAD AS % OF TOTAL REQUEST TIME')
+ print(f'{"=" * 80}\n')
# Realistic request times for different operations
base_times = {
- 'health_check': 2, # Very fast
- 'auth': 50, # JWT generation/verification
- 'simple_query': 30, # Database lookup
- 'list_query': 80, # Multiple DB queries
- 'complex_query': 150, # Joins, aggregations
+ 'health_check': 2, # Very fast
+ 'auth': 50, # JWT generation/verification
+ 'simple_query': 30, # Database lookup
+ 'list_query': 80, # Multiple DB queries
+ 'complex_query': 150, # Joins, aggregations
'upstream_proxy': 200, # Proxying to upstream API
}
data = create_realistic_response('medium')
- print(f"Medium response size: {len(json.dumps(data, separators=(',', ':')).encode('utf-8')):,} bytes\n")
+ print(
+ f'Medium response size: {len(json.dumps(data, separators=(",", ":")).encode("utf-8")):,} bytes\n'
+ )
for operation, base_time_ms in base_times.items():
- print(f"\n{operation.replace('_', ' ').title()} (base: {base_time_ms}ms):")
- print(f"{'-'*80}")
+ print(f'\n{operation.replace("_", " ").title()} (base: {base_time_ms}ms):')
+ print(f'{"-" * 80}')
for level in [1, 4, 6, 9]:
result = benchmark_compression(data, level, iterations=500)
@@ -174,37 +180,39 @@ async def test_cpu_overhead_on_total_request_time():
overhead_pct = (compression_time / total_time) * 100
throughput_reduction = (compression_time / base_time_ms) * 100
- print(f" Level {level}: {compression_time:.3f}ms → "
- f"total {total_time:.1f}ms "
- f"({overhead_pct:.1f}% overhead, "
- f"{throughput_reduction:.1f}% slower)")
+ print(
+ f' Level {level}: {compression_time:.3f}ms → '
+ f'total {total_time:.1f}ms '
+ f'({overhead_pct:.1f}% overhead, '
+ f'{throughput_reduction:.1f}% slower)'
+ )
@pytest.mark.asyncio
async def test_realistic_production_scenario():
"""Simulate production workload with compression"""
- print(f"\n{'='*80}")
- print(f"REALISTIC PRODUCTION SCENARIO")
- print(f"{'='*80}\n")
+ print(f'\n{"=" * 80}')
+ print('REALISTIC PRODUCTION SCENARIO')
+ print(f'{"=" * 80}\n')
# Realistic traffic mix
workload = [
- ('small', 0.30, 10), # 30% small responses (health checks, simple GETs)
- ('medium', 0.50, 40), # 50% medium responses (typical API calls)
- ('large', 0.15, 100), # 15% large responses (list endpoints)
- ('very_large', 0.05, 200) # 5% very large (export/reports)
+ ('small', 0.30, 10), # 30% small responses (health checks, simple GETs)
+ ('medium', 0.50, 40), # 50% medium responses (typical API calls)
+ ('large', 0.15, 100), # 15% large responses (list endpoints)
+ ('very_large', 0.05, 200), # 5% very large (export/reports)
]
- print("Traffic Mix:")
+ print('Traffic Mix:')
for size, percentage, base_time in workload:
- print(f" {size:12s}: {percentage*100:>5.1f}% (base processing: {base_time}ms)")
+ print(f' {size:12s}: {percentage * 100:>5.1f}% (base processing: {base_time}ms)')
- print(f"\n{'='*80}")
+ print(f'\n{"=" * 80}')
for level in [1, 4, 6, 9]:
- print(f"\nCompression Level {level}:")
- print(f"{'-'*80}")
+ print(f'\nCompression Level {level}:')
+ print(f'{"-" * 80}')
total_time_without_compression = 0
total_time_with_compression = 0
@@ -230,43 +238,49 @@ async def test_realistic_production_scenario():
total_time_with_compression += weighted_base_time + weighted_compression_time
total_bytes_saved += weighted_bytes_saved
- overhead_pct = ((total_time_with_compression - total_time_without_compression) /
- total_time_without_compression) * 100
+ overhead_pct = (
+ (total_time_with_compression - total_time_without_compression)
+ / total_time_without_compression
+ ) * 100
# Calculate max RPS reduction
rps_without = 1000 / total_time_without_compression
rps_with = 1000 / total_time_with_compression
rps_reduction_pct = ((rps_without - rps_with) / rps_without) * 100
- print(f" Avg request time: {total_time_without_compression:.1f}ms → {total_time_with_compression:.1f}ms")
- print(f" CPU overhead: {overhead_pct:.1f}%")
- print(f" Max RPS (1 core): {rps_without:.1f} → {rps_with:.1f} ({rps_reduction_pct:.1f}% reduction)")
- print(f" Avg bytes saved: {total_bytes_saved:.0f} bytes/request")
+ print(
+ f' Avg request time: {total_time_without_compression:.1f}ms → {total_time_with_compression:.1f}ms'
+ )
+ print(f' CPU overhead: {overhead_pct:.1f}%')
+ print(
+ f' Max RPS (1 core): {rps_without:.1f} → {rps_with:.1f} ({rps_reduction_pct:.1f}% reduction)'
+ )
+ print(f' Avg bytes saved: {total_bytes_saved:.0f} bytes/request')
@pytest.mark.asyncio
async def test_two_vcpu_capacity_analysis():
"""Calculate realistic capacity for 2 vCPU instance"""
- print(f"\n{'='*80}")
- print(f"2 vCPU AWS LIGHTSAIL CAPACITY ANALYSIS")
- print(f"{'='*80}\n")
+ print(f'\n{"=" * 80}')
+ print('2 vCPU AWS LIGHTSAIL CAPACITY ANALYSIS')
+ print(f'{"=" * 80}\n')
# Workload parameters
workload = [
- ('small', 0.30, 10, False), # Not compressed
- ('medium', 0.50, 40, True), # Compressed
- ('large', 0.15, 100, True), # Compressed
- ('very_large', 0.05, 200, True) # Compressed
+ ('small', 0.30, 10, False), # Not compressed
+ ('medium', 0.50, 40, True), # Compressed
+ ('large', 0.15, 100, True), # Compressed
+ ('very_large', 0.05, 200, True), # Compressed
]
- print("AWS Lightsail 1GB RAM, 2 vCPUs")
- print("Single worker mode (MEM_OR_EXTERNAL=MEM)")
- print(f"\n{'='*80}\n")
+ print('AWS Lightsail 1GB RAM, 2 vCPUs')
+ print('Single worker mode (MEM_OR_EXTERNAL=MEM)')
+ print(f'\n{"=" * 80}\n')
for level in [1, 4, 6, 9]:
- print(f"Compression Level {level}:")
- print(f"{'-'*80}")
+ print(f'Compression Level {level}:')
+ print(f'{"-" * 80}')
total_cpu_time = 0
total_bytes_uncompressed = 0
@@ -302,24 +316,28 @@ async def test_two_vcpu_capacity_analysis():
realistic_rps = max_rps_two_cores * 1.3 # Async bonus
# Transfer limits
- avg_transfer_per_request = (total_bytes_uncompressed + total_bytes_compressed) / 2 # In + Out
+ avg_transfer_per_request = (
+ total_bytes_uncompressed + total_bytes_compressed
+ ) / 2 # In + Out
monthly_requests_at_1tb = (1000 * 1024 * 1024 * 1024) / avg_transfer_per_request
monthly_rps_limit = monthly_requests_at_1tb / (30 * 24 * 60 * 60)
compression_ratio = (1 - total_bytes_compressed / total_bytes_uncompressed) * 100
- print(f" CPU time per request: {total_cpu_time:.1f} ms")
- print(f" Max RPS (CPU-limited): {realistic_rps:.1f} RPS")
- print(f" Avg response size: {total_bytes_uncompressed:.0f} → {total_bytes_compressed:.0f} bytes")
- print(f" Compression ratio: {compression_ratio:.1f}%")
- print(f" Transfer per request: {avg_transfer_per_request:.0f} bytes (req+resp)")
- print(f" Max RPS (transfer-limited): {monthly_rps_limit:.1f} RPS")
- print(f" Monthly capacity (1TB): {monthly_requests_at_1tb/1_000_000:.1f}M requests")
+ print(f' CPU time per request: {total_cpu_time:.1f} ms')
+ print(f' Max RPS (CPU-limited): {realistic_rps:.1f} RPS')
+ print(
+ f' Avg response size: {total_bytes_uncompressed:.0f} → {total_bytes_compressed:.0f} bytes'
+ )
+ print(f' Compression ratio: {compression_ratio:.1f}%')
+ print(f' Transfer per request: {avg_transfer_per_request:.0f} bytes (req+resp)')
+ print(f' Max RPS (transfer-limited): {monthly_rps_limit:.1f} RPS')
+ print(f' Monthly capacity (1TB): {monthly_requests_at_1tb / 1_000_000:.1f}M requests')
if monthly_rps_limit < realistic_rps:
- print(f" ⚠️ BOTTLENECK: Transfer (CPU can handle {realistic_rps:.1f} RPS)")
+ print(f' ⚠️ BOTTLENECK: Transfer (CPU can handle {realistic_rps:.1f} RPS)')
else:
- print(f" ⚠️ BOTTLENECK: CPU (transfer allows {monthly_rps_limit:.1f} RPS)")
+ print(f' ⚠️ BOTTLENECK: CPU (transfer allows {monthly_rps_limit:.1f} RPS)')
print()
@@ -327,17 +345,17 @@ async def test_two_vcpu_capacity_analysis():
async def test_recommended_production_level():
"""Determine optimal compression level for production"""
- print(f"\n{'='*80}")
- print(f"PRODUCTION COMPRESSION LEVEL RECOMMENDATION")
- print(f"{'='*80}\n")
+ print(f'\n{"=" * 80}')
+ print('PRODUCTION COMPRESSION LEVEL RECOMMENDATION')
+ print(f'{"=" * 80}\n')
data = create_realistic_response('medium')
- print("Criteria:")
- print(" 1. Maximize bandwidth savings")
- print(" 2. Minimize CPU overhead")
- print(" 3. Balance throughput vs. transfer")
- print(f"\n{'='*80}\n")
+ print('Criteria:')
+ print(' 1. Maximize bandwidth savings')
+ print(' 2. Minimize CPU overhead')
+ print(' 3. Balance throughput vs. transfer')
+ print(f'\n{"=" * 80}\n')
recommendations = []
@@ -348,26 +366,28 @@ async def test_recommended_production_level():
# Higher compression ratio is good, lower CPU time is good
efficiency = result['compression_ratio'] / result['avg_time_ms']
- recommendations.append({
- 'level': level,
- 'ratio': result['compression_ratio'],
- 'time': result['avg_time_ms'],
- 'efficiency': efficiency
- })
+ recommendations.append(
+ {
+ 'level': level,
+ 'ratio': result['compression_ratio'],
+ 'time': result['avg_time_ms'],
+ 'efficiency': efficiency,
+ }
+ )
- print(f"Level {level}:")
- print(f" Compression ratio: {result['compression_ratio']:.1f}%")
- print(f" CPU time: {result['avg_time_ms']:.3f} ms")
- print(f" Efficiency score: {efficiency:.1f}")
+ print(f'Level {level}:')
+ print(f' Compression ratio: {result["compression_ratio"]:.1f}%')
+ print(f' CPU time: {result["avg_time_ms"]:.3f} ms')
+ print(f' Efficiency score: {efficiency:.1f}')
print()
# Find best efficiency
best = max(recommendations, key=lambda x: x['efficiency'])
- print(f"{'='*80}")
- print(f"RECOMMENDATION: Level {best['level']}")
- print(f" Best efficiency score: {best['efficiency']:.1f}")
- print(f" Compression: {best['ratio']:.1f}%")
- print(f" CPU cost: {best['time']:.3f} ms")
+ print(f'{"=" * 80}')
+ print(f'RECOMMENDATION: Level {best["level"]}')
+ print(f' Best efficiency score: {best["efficiency"]:.1f}')
+ print(f' Compression: {best["ratio"]:.1f}%')
+ print(f' CPU cost: {best["time"]:.3f} ms')
pytestmark = [pytest.mark.benchmark, pytest.mark.cpu]
diff --git a/backend-services/tests/test_compression_size_reduction.py b/backend-services/tests/test_compression_size_reduction.py
index f70dd55..8f416ca 100644
--- a/backend-services/tests/test_compression_size_reduction.py
+++ b/backend-services/tests/test_compression_size_reduction.py
@@ -7,11 +7,12 @@ These tests verify that:
3. Compression settings affect the compression ratio
"""
-import pytest
import gzip
+import io
import json
import os
-import io
+
+import pytest
@pytest.mark.asyncio
@@ -20,7 +21,7 @@ async def test_json_compression_ratio(client):
# Authenticate to get access to endpoints
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
@@ -44,15 +45,15 @@ async def test_json_compression_ratio(client):
# Calculate compression ratio
compression_ratio = (1 - (compressed_size / uncompressed_size)) * 100
- print(f"\nJSON Compression Stats:")
- print(f" Uncompressed: {uncompressed_size} bytes")
- print(f" Compressed: {compressed_size} bytes")
- print(f" Ratio: {compression_ratio:.1f}% reduction")
+ print('\nJSON Compression Stats:')
+ print(f' Uncompressed: {uncompressed_size} bytes')
+ print(f' Compressed: {compressed_size} bytes')
+ print(f' Ratio: {compression_ratio:.1f}% reduction')
# JSON should compress well (typically 60-80%)
# But only if response is large enough
if uncompressed_size > 500:
- assert compression_ratio > 30, f"Expected >30% compression, got {compression_ratio:.1f}%"
+ assert compression_ratio > 30, f'Expected >30% compression, got {compression_ratio:.1f}%'
@pytest.mark.asyncio
@@ -61,7 +62,7 @@ async def test_large_list_compression(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
@@ -78,11 +79,11 @@ async def test_large_list_compression(client):
'api_allowed_groups': ['ALL'],
'api_servers': ['http://example.com'],
'api_type': 'REST',
- 'active': True
+ 'active': True,
}
r = await client.post('/platform/api', json=api_payload)
if r.status_code in (200, 201):
- test_apis.append(f"compression-test-{i}")
+ test_apis.append(f'compression-test-{i}')
# Get the full list
r = await client.get('/platform/api')
@@ -100,10 +101,10 @@ async def test_large_list_compression(client):
compression_ratio = (1 - (compressed_size / uncompressed_size)) * 100
- print(f"\nLarge List Compression Stats:")
- print(f" Uncompressed: {uncompressed_size} bytes")
- print(f" Compressed: {compressed_size} bytes")
- print(f" Ratio: {compression_ratio:.1f}% reduction")
+ print('\nLarge List Compression Stats:')
+ print(f' Uncompressed: {uncompressed_size} bytes')
+ print(f' Compressed: {compressed_size} bytes')
+ print(f' Ratio: {compression_ratio:.1f}% reduction')
# Cleanup
for api_name in test_apis:
@@ -114,7 +115,7 @@ async def test_large_list_compression(client):
# Should achieve good compression on repeated data
if uncompressed_size > 1000:
- assert compression_ratio > 40, f"Expected >40% compression for large list"
+ assert compression_ratio > 40, 'Expected >40% compression for large list'
@pytest.mark.asyncio
@@ -123,19 +124,14 @@ async def test_compression_bandwidth_savings(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
assert r_auth.status_code == 200
# Test multiple endpoints
- endpoints = [
- '/platform/api',
- '/platform/user',
- '/platform/role',
- '/platform/group',
- ]
+ endpoints = ['/platform/api', '/platform/user', '/platform/role', '/platform/group']
total_uncompressed = 0
total_compressed = 0
@@ -161,14 +157,14 @@ async def test_compression_bandwidth_savings(client):
if total_uncompressed > 0:
overall_ratio = (1 - (total_compressed / total_uncompressed)) * 100
- print(f"\nOverall Bandwidth Savings:")
- print(f" Total uncompressed: {total_uncompressed} bytes")
- print(f" Total compressed: {total_compressed} bytes")
- print(f" Overall ratio: {overall_ratio:.1f}% reduction")
- print(f" Bandwidth saved: {total_uncompressed - total_compressed} bytes")
+ print('\nOverall Bandwidth Savings:')
+ print(f' Total uncompressed: {total_uncompressed} bytes')
+ print(f' Total compressed: {total_compressed} bytes')
+ print(f' Overall ratio: {overall_ratio:.1f}% reduction')
+ print(f' Bandwidth saved: {total_uncompressed - total_compressed} bytes')
# Should see significant savings
- assert overall_ratio > 0, "Compression should reduce size"
+ assert overall_ratio > 0, 'Compression should reduce size'
@pytest.mark.asyncio
@@ -177,7 +173,7 @@ async def test_compression_level_affects_ratio(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
@@ -199,15 +195,12 @@ async def test_compression_level_affects_ratio(client):
gz.write(json_str.encode('utf-8'))
compressed_size = len(compressed_buffer.getvalue())
ratio = (1 - (compressed_size / uncompressed_size)) * 100
- compression_results[level] = {
- 'size': compressed_size,
- 'ratio': ratio
- }
+ compression_results[level] = {'size': compressed_size, 'ratio': ratio}
- print(f"\nCompression Level Comparison:")
- print(f" Uncompressed: {uncompressed_size} bytes")
+ print('\nCompression Level Comparison:')
+ print(f' Uncompressed: {uncompressed_size} bytes')
for level, result in compression_results.items():
- print(f" Level {level}: {result['size']} bytes ({result['ratio']:.1f}% reduction)")
+ print(f' Level {level}: {result["size"]} bytes ({result["ratio"]:.1f}% reduction)')
# Higher compression levels should achieve better (or equal) compression
# Level 9 should be <= Level 6 <= Level 1 in size
@@ -226,7 +219,7 @@ async def test_minimum_size_threshold(client):
response_content = r.content
response_size = len(response_content)
- print(f"\nSmall Response Size: {response_size} bytes")
+ print(f'\nSmall Response Size: {response_size} bytes')
# If response is smaller than 500 bytes (default minimum_size),
# compressing it may not be worth the CPU overhead
@@ -239,7 +232,7 @@ async def test_compression_transfer_savings_calculation(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
@@ -269,12 +262,12 @@ async def test_compression_transfer_savings_calculation(client):
monthly_savings_mb = monthly_savings_bytes / (1024 * 1024)
monthly_savings_gb = monthly_savings_mb / 1024
- print(f"\nTransfer Savings Estimate:")
- print(f" Compression ratio: {compression_ratio:.1f}%")
- print(f" Bytes saved per request: {bytes_saved_per_request}")
- print(f" Monthly requests: {requests_per_month:,}")
- print(f" Monthly bandwidth saved: {monthly_savings_gb:.2f} GB")
- print(f" Annual bandwidth saved: {monthly_savings_gb * 12:.2f} GB")
+ print('\nTransfer Savings Estimate:')
+ print(f' Compression ratio: {compression_ratio:.1f}%')
+ print(f' Bytes saved per request: {bytes_saved_per_request}')
+ print(f' Monthly requests: {requests_per_month:,}')
+ print(f' Monthly bandwidth saved: {monthly_savings_gb:.2f} GB')
+ print(f' Annual bandwidth saved: {monthly_savings_gb * 12:.2f} GB')
# Should save significant bandwidth
assert bytes_saved_per_request >= 0
diff --git a/backend-services/tests/test_config_import_export_extended.py b/backend-services/tests/test_config_import_export_extended.py
index 691d253..2de71e1 100644
--- a/backend-services/tests/test_config_import_export_extended.py
+++ b/backend-services/tests/test_config_import_export_extended.py
@@ -1,6 +1,6 @@
-import json
import pytest
+
@pytest.mark.asyncio
async def test_export_all_basic(authed_client):
r = await authed_client.get('/platform/config/export/all')
@@ -12,79 +12,132 @@ async def test_export_all_basic(authed_client):
assert isinstance(data.get('routings'), list)
assert isinstance(data.get('endpoints'), list)
+
@pytest.mark.asyncio
-@pytest.mark.parametrize('path', [
- '/platform/config/export/apis',
- '/platform/config/export/roles',
- '/platform/config/export/groups',
- '/platform/config/export/routings',
- '/platform/config/export/endpoints',
-])
+@pytest.mark.parametrize(
+ 'path',
+ [
+ '/platform/config/export/apis',
+ '/platform/config/export/roles',
+ '/platform/config/export/groups',
+ '/platform/config/export/routings',
+ '/platform/config/export/endpoints',
+ ],
+)
async def test_export_lists(authed_client, path):
r = await authed_client.get(path)
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_export_single_api_with_endpoints(authed_client):
async def _create_api(c, n, v):
- payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0}
+ payload = {
+ 'api_name': n,
+ 'api_version': v,
+ 'api_description': f'{n} {v}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://upstream'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ }
rr = await c.post('/platform/api', json=payload)
assert rr.status_code in (200, 201)
+
async def _create_endpoint(c, n, v, m, u):
- payload = {'api_name': n, 'api_version': v, 'endpoint_method': m, 'endpoint_uri': u, 'endpoint_description': f'{m} {u}'}
+ payload = {
+ 'api_name': n,
+ 'api_version': v,
+ 'endpoint_method': m,
+ 'endpoint_uri': u,
+ 'endpoint_description': f'{m} {u}',
+ }
rr = await c.post('/platform/endpoint', json=payload)
assert rr.status_code in (200, 201)
+
await _create_api(authed_client, 'exapi', 'v1')
await _create_endpoint(authed_client, 'exapi', 'v1', 'GET', '/status')
- r = await authed_client.get('/platform/config/export/apis', params={'api_name': 'exapi', 'api_version': 'v1'})
+ r = await authed_client.get(
+ '/platform/config/export/apis', params={'api_name': 'exapi', 'api_version': 'v1'}
+ )
assert r.status_code == 200
payload = r.json().get('response') or r.json()
assert payload.get('api') and payload.get('endpoints') is not None
+
@pytest.mark.asyncio
async def test_export_endpoints_filter(authed_client):
async def _create_api(c, n, v):
- payload = {'api_name': n, 'api_version': v, 'api_description': f'{n} {v}', 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://upstream'], 'api_type': 'REST', 'api_allowed_retry_count': 0}
+ payload = {
+ 'api_name': n,
+ 'api_version': v,
+ 'api_description': f'{n} {v}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://upstream'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ }
rr = await c.post('/platform/api', json=payload)
assert rr.status_code in (200, 201)
+
async def _create_endpoint(c, n, v, m, u):
- payload = {'api_name': n, 'api_version': v, 'endpoint_method': m, 'endpoint_uri': u, 'endpoint_description': f'{m} {u}'}
+ payload = {
+ 'api_name': n,
+ 'api_version': v,
+ 'endpoint_method': m,
+ 'endpoint_uri': u,
+ 'endpoint_description': f'{m} {u}',
+ }
rr = await c.post('/platform/endpoint', json=payload)
assert rr.status_code in (200, 201)
+
await _create_api(authed_client, 'filterapi', 'v1')
await _create_endpoint(authed_client, 'filterapi', 'v1', 'GET', '/x')
- r = await authed_client.get('/platform/config/export/endpoints', params={'api_name': 'filterapi', 'api_version': 'v1'})
+ r = await authed_client.get(
+ '/platform/config/export/endpoints', params={'api_name': 'filterapi', 'api_version': 'v1'}
+ )
assert r.status_code == 200
eps = (r.json().get('response') or r.json()).get('endpoints')
assert isinstance(eps, list) and len(eps) >= 1
+
@pytest.mark.asyncio
-@pytest.mark.parametrize('sections', [
- {'apis': []},
- {'roles': []},
- {'groups': []},
- {'routings': []},
- {'endpoints': []},
- {'apis': [], 'endpoints': []},
- {'roles': [], 'groups': []},
-])
+@pytest.mark.parametrize(
+ 'sections',
+ [
+ {'apis': []},
+ {'roles': []},
+ {'groups': []},
+ {'routings': []},
+ {'endpoints': []},
+ {'apis': [], 'endpoints': []},
+ {'roles': [], 'groups': []},
+ ],
+)
async def test_import_various_sections(authed_client, sections):
r = await authed_client.post('/platform/config/import', json=sections)
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_security_restart_pid_missing(authed_client):
r = await authed_client.post('/platform/security/restart')
assert r.status_code in (202, 409, 403)
+
@pytest.mark.asyncio
async def test_audit_called_on_export(monkeypatch, authed_client):
calls = []
import utils.audit_util as au
+
orig = au._logger
+
class _L:
def info(self, msg):
calls.append(msg)
+
au._logger = _L()
try:
r = await authed_client.get('/platform/config/export/all')
diff --git a/backend-services/tests/test_config_import_tolerates_malformed.py b/backend-services/tests/test_config_import_tolerates_malformed.py
new file mode 100644
index 0000000..1e1ff86
--- /dev/null
+++ b/backend-services/tests/test_config_import_tolerates_malformed.py
@@ -0,0 +1,31 @@
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_config_import_ignores_malformed_entries(authed_client):
+ body = {
+ 'apis': [
+ {'api_name': 'x-only'}, # missing version; ignored
+ {'api_version': 'v1'}, # missing name; ignored
+ ],
+ 'endpoints': [
+ {'api_name': 'x', 'endpoint_method': 'GET'} # missing api_version/uri
+ ],
+ 'roles': [
+ {'bad': 'doc'} # missing role_name
+ ],
+ 'groups': [
+ {'bad': 'doc'} # missing group_name
+ ],
+ 'routings': [
+ {'bad': 'doc'} # missing client_key
+ ],
+ }
+ r = await authed_client.post('/platform/config/import', json=body)
+ assert r.status_code == 200
+ payload = r.json().get('response', r.json())
+ payload.get('imported') or {}
+ # Import reports how many items were processed, not how many were actually upserted.
+ # Verify no valid API was created as a result of malformed entries.
+ bad_get = await authed_client.get('/platform/api/x-only/v1')
+ assert bad_get.status_code in (400, 404)
diff --git a/backend-services/tests/test_config_permission_specific.py b/backend-services/tests/test_config_permission_specific.py
new file mode 100644
index 0000000..f441e0d
--- /dev/null
+++ b/backend-services/tests/test_config_permission_specific.py
@@ -0,0 +1,52 @@
+import pytest
+from httpx import AsyncClient
+
+
+async def _login(email: str, password: str) -> AsyncClient:
+ from doorman import doorman
+
+ c = AsyncClient(app=doorman, base_url='http://testserver')
+ r = await c.post('/platform/authorization', json={'email': email, 'password': password})
+ assert r.status_code == 200, r.text
+ body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
+ token = body.get('access_token')
+ if token:
+ c.cookies.set('access_token_cookie', token, domain='testserver', path='/')
+ return c
+
+
+@pytest.mark.asyncio
+async def test_export_apis_allowed_but_roles_forbidden(authed_client):
+ uname = 'mgr_apis_only'
+ pwd = 'PerM1ssionsMore!!'
+ # Create role granting manage_apis only
+ role_name = 'apis_manager'
+ cr = await authed_client.post(
+ '/platform/role',
+ json={'role_name': role_name, 'role_description': 'Can export APIs', 'manage_apis': True},
+ )
+ assert cr.status_code in (200, 201), cr.text
+
+ # Create user assigned to that role
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': 'mgr_apis@example.com',
+ 'password': pwd,
+ 'role': role_name,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu.status_code in (200, 201)
+
+ client = await _login('mgr_apis@example.com', pwd)
+
+ # Allowed
+ ok = await client.get('/platform/config/export/apis')
+ assert ok.status_code == 200
+
+ # Forbidden
+ no = await client.get('/platform/config/export/roles')
+ assert no.status_code == 403
diff --git a/backend-services/tests/test_config_permission_specific_more.py b/backend-services/tests/test_config_permission_specific_more.py
new file mode 100644
index 0000000..b3bffa1
--- /dev/null
+++ b/backend-services/tests/test_config_permission_specific_more.py
@@ -0,0 +1,137 @@
+import time
+
+import pytest
+from httpx import AsyncClient
+
+
+async def _login(email: str, password: str) -> AsyncClient:
+ from doorman import doorman
+
+ c = AsyncClient(app=doorman, base_url='http://testserver')
+ r = await c.post('/platform/authorization', json={'email': email, 'password': password})
+ assert r.status_code == 200, r.text
+ body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
+ token = body.get('access_token')
+ if token:
+ c.cookies.set('access_token_cookie', token, domain='testserver', path='/')
+ return c
+
+
+@pytest.mark.asyncio
+async def test_export_roles_allowed_manage_roles_only(authed_client):
+ uname = f'roles_mgr_{int(time.time())}'
+ role_name = f'roles_manager_{int(time.time())}'
+ pwd = 'ManAgeRoleSStrong1!!'
+
+ cr = await authed_client.post(
+ '/platform/role', json={'role_name': role_name, 'manage_roles': True}
+ )
+ assert cr.status_code in (200, 201)
+
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': f'{uname}@example.com',
+ 'password': pwd,
+ 'role': role_name,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu.status_code in (200, 201), cu.text
+
+ client = await _login(f'{uname}@example.com', pwd)
+ # Allowed
+ ok = await client.get('/platform/config/export/roles')
+ assert ok.status_code == 200
+ # Forbidden
+ no = await client.get('/platform/config/export/apis')
+ assert no.status_code == 403
+
+
+@pytest.mark.asyncio
+async def test_export_groups_allowed_manage_groups_only(authed_client):
+ uname = f'groups_mgr_{int(time.time())}'
+ role_name = f'groups_manager_{int(time.time())}'
+ pwd = 'ManAgeGroupSStrong1!!'
+
+ cr = await authed_client.post(
+ '/platform/role', json={'role_name': role_name, 'manage_groups': True}
+ )
+ assert cr.status_code in (200, 201)
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': f'{uname}@example.com',
+ 'password': pwd,
+ 'role': role_name,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu.status_code in (200, 201)
+ client = await _login(f'{uname}@example.com', pwd)
+ ok = await client.get('/platform/config/export/groups')
+ assert ok.status_code == 200
+ no = await client.get('/platform/config/export/roles')
+ assert no.status_code == 403
+
+
+@pytest.mark.asyncio
+async def test_export_routings_allowed_manage_routings_only(authed_client):
+ uname = f'rout_mgr_{int(time.time())}'
+ role_name = f'rout_manager_{int(time.time())}'
+ pwd = 'ManAgeRoutIngS1!!'
+
+ cr = await authed_client.post(
+ '/platform/role', json={'role_name': role_name, 'manage_routings': True}
+ )
+ assert cr.status_code in (200, 201)
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': f'{uname}@example.com',
+ 'password': pwd,
+ 'role': role_name,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu.status_code in (200, 201)
+ client = await _login(f'{uname}@example.com', pwd)
+ ok = await client.get('/platform/config/export/routings')
+ assert ok.status_code == 200
+ no = await client.get('/platform/config/export/endpoints')
+ assert no.status_code == 403
+
+
+@pytest.mark.asyncio
+async def test_export_all_and_import_allowed_manage_gateway(authed_client):
+ uname = f'gate_mgr_{int(time.time())}'
+ role_name = f'gate_manager_{int(time.time())}'
+ pwd = 'ManAgeGateWayStrong1!!'
+
+ cr = await authed_client.post(
+ '/platform/role', json={'role_name': role_name, 'manage_gateway': True}
+ )
+ assert cr.status_code in (200, 201)
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': f'{uname}@example.com',
+ 'password': pwd,
+ 'role': role_name,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu.status_code in (200, 201)
+ client = await _login(f'{uname}@example.com', pwd)
+ ok = await client.get('/platform/config/export/all')
+ assert ok.status_code == 200
+ imp = await client.post('/platform/config/import', json={'apis': []})
+ assert imp.status_code == 200
diff --git a/backend-services/tests/test_config_permissions_matrix.py b/backend-services/tests/test_config_permissions_matrix.py
new file mode 100644
index 0000000..dd09764
--- /dev/null
+++ b/backend-services/tests/test_config_permissions_matrix.py
@@ -0,0 +1,61 @@
+import pytest
+from httpx import AsyncClient
+
+
+async def _login(email: str, password: str) -> AsyncClient:
+ from doorman import doorman
+
+ c = AsyncClient(app=doorman, base_url='http://testserver')
+ r = await c.post('/platform/authorization', json={'email': email, 'password': password})
+ assert r.status_code == 200, r.text
+ body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
+ token = body.get('access_token')
+ if token:
+ c.cookies.set('access_token_cookie', token, domain='testserver', path='/')
+ return c
+
+
+@pytest.mark.asyncio
+async def test_config_export_import_forbidden_for_non_admin(authed_client):
+ # Create a limited user
+ uname = 'limited_user'
+ pwd = 'limited-password-1A!'
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': 'limited@example.com',
+ 'password': pwd,
+ 'role': 'user',
+ 'groups': ['ALL'],
+ # No manage_* permissions
+ 'ui_access': True,
+ 'manage_users': False,
+ 'manage_apis': False,
+ 'manage_endpoints': False,
+ 'manage_groups': False,
+ 'manage_roles': False,
+ 'manage_routings': False,
+ 'manage_gateway': False,
+ },
+ )
+ assert cu.status_code in (200, 201), cu.text
+
+ user_client = await _login('limited@example.com', pwd)
+
+ # Export endpoints require manage_* permissions – assert 403s
+ endpoints = [
+ '/platform/config/export/all',
+ '/platform/config/export/apis',
+ '/platform/config/export/roles',
+ '/platform/config/export/groups',
+ '/platform/config/export/routings',
+ '/platform/config/export/endpoints',
+ ]
+ for ep in endpoints:
+ r = await user_client.get(ep)
+ assert r.status_code == 403
+
+ # Import also requires manage_gateway
+ imp = await user_client.post('/platform/config/import', json={'apis': []})
+ assert imp.status_code == 403
diff --git a/backend-services/tests/test_cookie_domain.py b/backend-services/tests/test_cookie_domain.py
index d60ebc5..be36a60 100644
--- a/backend-services/tests/test_cookie_domain.py
+++ b/backend-services/tests/test_cookie_domain.py
@@ -1,9 +1,10 @@
import os
+
import pytest
+
@pytest.mark.asyncio
async def test_cookie_domain_set_on_login(client, monkeypatch):
-
monkeypatch.setenv('COOKIE_DOMAIN', 'testserver')
resp = await client.post(
'/platform/authorization',
diff --git a/backend-services/tests/test_cookies_policy.py b/backend-services/tests/test_cookies_policy.py
index 6a5cc68..e5d8a3b 100644
--- a/backend-services/tests/test_cookies_policy.py
+++ b/backend-services/tests/test_cookies_policy.py
@@ -1,7 +1,9 @@
import os
import re
+
import pytest
+
def _collect_set_cookie_headers(resp):
try:
return resp.headers.get_list('set-cookie')
@@ -9,64 +11,97 @@ def _collect_set_cookie_headers(resp):
raw = resp.headers.get('set-cookie') or ''
return [h.strip() for h in raw.split(',') if 'Expires=' not in h] or ([raw] if raw else [])
+
def _find_cookie_lines(lines, name):
name_l = name.lower() + '='
return [l for l in lines if name_l in l.lower()]
+
@pytest.mark.asyncio
async def test_default_samesite_strict_and_secure_false(monkeypatch, client):
monkeypatch.delenv('COOKIE_SAMESITE', raising=False)
monkeypatch.setenv('HTTPS_ONLY', 'false')
- r = await client.post('/platform/authorization', json={'email': os.environ['DOORMAN_ADMIN_EMAIL'], 'password': os.environ['DOORMAN_ADMIN_PASSWORD']})
+ r = await client.post(
+ '/platform/authorization',
+ json={
+ 'email': os.environ['DOORMAN_ADMIN_EMAIL'],
+ 'password': os.environ['DOORMAN_ADMIN_PASSWORD'],
+ },
+ )
assert r.status_code == 200
cookies = _collect_set_cookie_headers(r)
atk = _find_cookie_lines(cookies, 'access_token_cookie')
csrf = _find_cookie_lines(cookies, 'csrf_token')
assert atk and csrf
+
def has_attr(lines, pattern):
return any(re.search(pattern, l, flags=re.I) for l in lines)
- assert has_attr(atk, r"samesite\s*=\s*strict")
- assert has_attr(csrf, r"samesite\s*=\s*strict")
- assert not has_attr(atk, r";\s*secure(\s*;|$)")
- assert not has_attr(csrf, r";\s*secure(\s*;|$)")
+
+ assert has_attr(atk, r'samesite\s*=\s*strict')
+ assert has_attr(csrf, r'samesite\s*=\s*strict')
+ assert not has_attr(atk, r';\s*secure(\s*;|$)')
+ assert not has_attr(csrf, r';\s*secure(\s*;|$)')
+
@pytest.mark.asyncio
async def test_cookies_samesite_lax_override(monkeypatch, client):
monkeypatch.setenv('COOKIE_SAMESITE', 'Lax')
monkeypatch.setenv('HTTPS_ONLY', 'false')
- r = await client.post('/platform/authorization', json={'email': os.environ['DOORMAN_ADMIN_EMAIL'], 'password': os.environ['DOORMAN_ADMIN_PASSWORD']})
+ r = await client.post(
+ '/platform/authorization',
+ json={
+ 'email': os.environ['DOORMAN_ADMIN_EMAIL'],
+ 'password': os.environ['DOORMAN_ADMIN_PASSWORD'],
+ },
+ )
assert r.status_code == 200
cookies = _collect_set_cookie_headers(r)
atk = _find_cookie_lines(cookies, 'access_token_cookie')
csrf = _find_cookie_lines(cookies, 'csrf_token')
assert atk and csrf
+
def has_attr(lines, pattern):
return any(re.search(pattern, l, flags=re.I) for l in lines)
- assert has_attr(atk, r"samesite\s*=\s*lax")
- assert has_attr(csrf, r"samesite\s*=\s*lax")
+
+ assert has_attr(atk, r'samesite\s*=\s*lax')
+ assert has_attr(csrf, r'samesite\s*=\s*lax')
+
@pytest.mark.asyncio
async def test_secure_flag_toggles_with_https(monkeypatch, client):
monkeypatch.setenv('COOKIE_SAMESITE', 'None')
monkeypatch.setenv('HTTPS_ONLY', 'false')
- r1 = await client.post('/platform/authorization', json={'email': os.environ['DOORMAN_ADMIN_EMAIL'], 'password': os.environ['DOORMAN_ADMIN_PASSWORD']})
+ r1 = await client.post(
+ '/platform/authorization',
+ json={
+ 'email': os.environ['DOORMAN_ADMIN_EMAIL'],
+ 'password': os.environ['DOORMAN_ADMIN_PASSWORD'],
+ },
+ )
assert r1.status_code == 200
cookies1 = _collect_set_cookie_headers(r1)
atk1 = _find_cookie_lines(cookies1, 'access_token_cookie')
csrf1 = _find_cookie_lines(cookies1, 'csrf_token')
+
def has_attr(lines, pattern):
return any(re.search(pattern, l, flags=re.I) for l in lines)
- assert not has_attr(atk1, r";\s*secure(\s*;|$)")
- assert not has_attr(csrf1, r";\s*secure(\s*;|$)")
+
+ assert not has_attr(atk1, r';\s*secure(\s*;|$)')
+ assert not has_attr(csrf1, r';\s*secure(\s*;|$)')
monkeypatch.setenv('HTTPS_ONLY', 'true')
- r2 = await client.post('/platform/authorization', json={'email': os.environ['DOORMAN_ADMIN_EMAIL'], 'password': os.environ['DOORMAN_ADMIN_PASSWORD']})
+ r2 = await client.post(
+ '/platform/authorization',
+ json={
+ 'email': os.environ['DOORMAN_ADMIN_EMAIL'],
+ 'password': os.environ['DOORMAN_ADMIN_PASSWORD'],
+ },
+ )
assert r2.status_code == 200
cookies2 = _collect_set_cookie_headers(r2)
atk2 = _find_cookie_lines(cookies2, 'access_token_cookie')
csrf2 = _find_cookie_lines(cookies2, 'csrf_token')
- assert has_attr(atk2, r";\s*secure(\s*;|$)")
- assert has_attr(csrf2, r";\s*secure(\s*;|$)")
-
+ assert has_attr(atk2, r';\s*secure(\s*;|$)')
+ assert has_attr(csrf2, r';\s*secure(\s*;|$)')
diff --git a/backend-services/tests/test_credit_definition_masking.py b/backend-services/tests/test_credit_definition_masking.py
new file mode 100644
index 0000000..5dad242
--- /dev/null
+++ b/backend-services/tests/test_credit_definition_masking.py
@@ -0,0 +1,35 @@
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_credit_definition_masking(authed_client):
+ group = 'maskgroup'
+ create = await authed_client.post(
+ '/platform/credit',
+ json={
+ 'api_credit_group': group,
+ 'api_key': 'VERY-SECRET-KEY',
+ 'api_key_header': 'x-api-key',
+ 'credit_tiers': [
+ {
+ 'tier_name': 'default',
+ 'credits': 5,
+ 'input_limit': 0,
+ 'output_limit': 0,
+ 'reset_frequency': 'monthly',
+ }
+ ],
+ },
+ )
+ assert create.status_code in (200, 201), create.text
+
+ r = await authed_client.get(f'/platform/credit/defs/{group}')
+ assert r.status_code == 200, r.text
+ body = r.json().get('response', r.json())
+
+ # Masking rules
+ assert body.get('api_credit_group') == group
+ assert body.get('api_key_header') == 'x-api-key'
+ assert body.get('api_key_present') is True
+ # Under no circumstance should the API key material be returned
+ assert 'api_key' not in body
diff --git a/backend-services/tests/test_credit_key_rotation.py b/backend-services/tests/test_credit_key_rotation.py
new file mode 100644
index 0000000..881ef7e
--- /dev/null
+++ b/backend-services/tests/test_credit_key_rotation.py
@@ -0,0 +1,36 @@
+from datetime import UTC, datetime, timedelta
+
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_credit_key_rotation_logic_resolves_header_and_keys():
+ from utils.credit_util import get_credit_api_header
+ from utils.database import credit_def_collection
+
+ group = 'rotgrp'
+ credit_def_collection.delete_one({'api_credit_group': group})
+ # Insert with rotation in the future
+ credit_def_collection.insert_one(
+ {
+ 'api_credit_group': group,
+ 'api_key_header': 'x-api-key',
+ 'api_key': 'old-key',
+ 'api_key_new': 'new-key',
+ 'api_key_rotation_expires': datetime.now(UTC) + timedelta(hours=1),
+ }
+ )
+
+ hdr = await get_credit_api_header(group)
+ assert hdr and hdr[0] == 'x-api-key'
+ assert isinstance(hdr[1], list)
+ assert hdr[1][0] == 'old-key' and hdr[1][1] == 'new-key'
+
+ # After rotation expiry, only new key should be returned
+ credit_def_collection.update_one(
+ {'api_credit_group': group},
+ {'$set': {'api_key_rotation_expires': datetime.now(UTC) - timedelta(seconds=1)}},
+ )
+ hdr2 = await get_credit_api_header(group)
+ assert hdr2 and hdr2[0] == 'x-api-key'
+ assert hdr2[1] == 'new-key'
diff --git a/backend-services/tests/test_credits_injection_and_deduction.py b/backend-services/tests/test_credits_injection_and_deduction.py
index a15d653..3dfa852 100644
--- a/backend-services/tests/test_credits_injection_and_deduction.py
+++ b/backend-services/tests/test_credits_injection_and_deduction.py
@@ -1,8 +1,17 @@
import pytest
-
from tests.test_gateway_routing_limits import _FakeAsyncClient
-async def _setup_api_with_credits(client, name='cr', ver='v1', public=False, group='g1', header='X-API-Key', def_key='DEFKEY', enable=True):
+
+async def _setup_api_with_credits(
+ client,
+ name='cr',
+ ver='v1',
+ public=False,
+ group='g1',
+ header='X-API-Key',
+ def_key='DEFKEY',
+ enable=True,
+):
payload = {
'api_name': name,
'api_version': ver,
@@ -18,94 +27,123 @@ async def _setup_api_with_credits(client, name='cr', ver='v1', public=False, gro
}
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/p',
- 'endpoint_description': 'p'
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/p',
+ 'endpoint_description': 'p',
+ },
+ )
assert r2.status_code in (200, 201)
from conftest import subscribe_self
+
await subscribe_self(client, name, ver)
from utils.database import credit_def_collection
+
credit_def_collection.delete_one({'api_credit_group': group})
- credit_def_collection.insert_one({
- 'api_credit_group': group,
- 'api_key_header': header,
- 'api_key': def_key,
- })
+ credit_def_collection.insert_one(
+ {'api_credit_group': group, 'api_key_header': header, 'api_key': def_key}
+ )
return name, ver
+
def _set_user_credits(group: str, available: int, user_key: str | None = None):
from utils.database import user_credit_collection
+
user_credit_collection.delete_one({'username': 'admin'})
doc = {
'username': 'admin',
'users_credits': {
- group: {
- 'tier_name': 't',
- 'available_credits': available,
- 'user_api_key': user_key,
- }
- }
+ group: {'tier_name': 't', 'available_credits': available, 'user_api_key': user_key}
+ },
}
user_credit_collection.insert_one(doc)
+
@pytest.mark.asyncio
async def test_credit_header_injected_when_enabled(monkeypatch, authed_client):
- name, ver = await _setup_api_with_credits(authed_client, name='cr1', public=False, group='g1', header='X-API-Key', def_key='DEF1')
+ name, ver = await _setup_api_with_credits(
+ authed_client, name='cr1', public=False, group='g1', header='X-API-Key', def_key='DEF1'
+ )
_set_user_credits('g1', available=10, user_key=None)
await authed_client.delete('/api/caches')
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/p')
assert r.status_code == 200
assert r.json().get('headers', {}).get('X-API-Key') == 'DEF1'
+
@pytest.mark.asyncio
async def test_credit_user_specific_overrides_default_key(monkeypatch, authed_client):
- name, ver = await _setup_api_with_credits(authed_client, name='cr2', public=False, group='g2', header='X-API-Key', def_key='DEF2')
+ name, ver = await _setup_api_with_credits(
+ authed_client, name='cr2', public=False, group='g2', header='X-API-Key', def_key='DEF2'
+ )
_set_user_credits('g2', available=10, user_key='USERK')
await authed_client.delete('/api/caches')
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/p')
assert r.status_code == 200
assert r.json().get('headers', {}).get('X-API-Key') == 'USERK'
+
@pytest.mark.asyncio
async def test_credit_not_deducted_for_public_api(monkeypatch, authed_client):
- name, ver = await _setup_api_with_credits(authed_client, name='cr3', public=True, group='g3', header='X-API-Key', def_key='DEF3', enable=False)
+ name, ver = await _setup_api_with_credits(
+ authed_client,
+ name='cr3',
+ public=True,
+ group='g3',
+ header='X-API-Key',
+ def_key='DEF3',
+ enable=False,
+ )
_set_user_credits('g3', available=5, user_key='U3')
await authed_client.delete('/api/caches')
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/p')
assert r.status_code == 200
from utils.database import user_credit_collection
+
doc = user_credit_collection.find_one({'username': 'admin'})
assert doc['users_credits']['g3']['available_credits'] == 5
+
@pytest.mark.asyncio
async def test_credit_deducted_for_private_api_authenticated(monkeypatch, authed_client):
- name, ver = await _setup_api_with_credits(authed_client, name='cr4', public=False, group='g4', header='X-API-Key', def_key='DEF4')
+ name, ver = await _setup_api_with_credits(
+ authed_client, name='cr4', public=False, group='g4', header='X-API-Key', def_key='DEF4'
+ )
_set_user_credits('g4', available=3, user_key='U4')
await authed_client.delete('/api/caches')
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/p')
assert r.status_code == 200
from utils.database import user_credit_collection
+
doc = user_credit_collection.find_one({'username': 'admin'})
assert doc['users_credits']['g4']['available_credits'] == 2
+
@pytest.mark.asyncio
async def test_credit_deduction_insufficient_credits_blocks(monkeypatch, authed_client):
- name, ver = await _setup_api_with_credits(authed_client, name='cr5', public=False, group='g5', header='X-API-Key', def_key='DEF5')
+ name, ver = await _setup_api_with_credits(
+ authed_client, name='cr5', public=False, group='g5', header='X-API-Key', def_key='DEF5'
+ )
_set_user_credits('g5', available=0, user_key=None)
await authed_client.delete('/api/caches')
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/p')
assert r.status_code == 401
diff --git a/backend-services/tests/test_endpoint_crud_failures.py b/backend-services/tests/test_endpoint_crud_failures.py
index d40c1c9..734db2a 100644
--- a/backend-services/tests/test_endpoint_crud_failures.py
+++ b/backend-services/tests/test_endpoint_crud_failures.py
@@ -1,13 +1,13 @@
import pytest
+
@pytest.mark.asyncio
async def test_endpoint_create_requires_fields(authed_client):
-
c = await authed_client.post('/platform/endpoint', json={'api_name': 'x'})
assert c.status_code in (400, 422)
+
@pytest.mark.asyncio
async def test_endpoint_get_nonexistent(authed_client):
g = await authed_client.get('/platform/endpoint/GET/na/v1/does/not/exist')
assert g.status_code in (400, 404)
-
diff --git a/backend-services/tests/test_endpoint_validation.py b/backend-services/tests/test_endpoint_validation.py
index 217e038..ea62534 100644
--- a/backend-services/tests/test_endpoint_validation.py
+++ b/backend-services/tests/test_endpoint_validation.py
@@ -1,5 +1,6 @@
import pytest
+
async def _ensure_api_and_endpoint(client, api_name, version, method, uri):
c = await client.post(
'/platform/api',
@@ -30,6 +31,7 @@ async def _ensure_api_and_endpoint(client, api_name, version, method, uri):
assert g.status_code == 200
return g.json().get('endpoint_id') or g.json().get('response', {}).get('endpoint_id')
+
@pytest.mark.asyncio
async def test_endpoint_validation_crud(authed_client):
eid = await _ensure_api_and_endpoint(authed_client, 'valapi', 'v1', 'POST', '/do')
diff --git a/backend-services/tests/test_endpoint_validation_crud.py b/backend-services/tests/test_endpoint_validation_crud.py
new file mode 100644
index 0000000..0c1ba97
--- /dev/null
+++ b/backend-services/tests/test_endpoint_validation_crud.py
@@ -0,0 +1,61 @@
+import time
+
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_endpoint_validation_crud(authed_client):
+ name, ver = f'valapi_{int(time.time())}', 'v1'
+ # Create API and endpoint
+ ca = await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'validation api',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up.invalid'],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ assert ca.status_code in (200, 201)
+ ce = await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/payload',
+ 'endpoint_description': 'payload',
+ },
+ )
+ assert ce.status_code in (200, 201)
+
+ # Resolve endpoint_id via GET
+ ge = await authed_client.get(f'/platform/endpoint/POST/{name}/{ver}/payload')
+ assert ge.status_code == 200
+ eid = ge.json().get('endpoint_id') or ge.json().get('response', {}).get('endpoint_id')
+ assert eid
+
+ schema = {'validation_schema': {'id': {'required': True, 'type': 'string'}}}
+ # Create validation
+ cv = await authed_client.post(
+ '/platform/endpoint/endpoint/validation',
+ json={'endpoint_id': eid, 'validation_enabled': True, 'validation_schema': schema},
+ )
+ assert cv.status_code in (200, 201)
+
+ # Get validation
+ gv = await authed_client.get(f'/platform/endpoint/endpoint/validation/{eid}')
+ assert gv.status_code in (200, 400, 500)
+ # Update validation
+ uv = await authed_client.put(
+ f'/platform/endpoint/endpoint/validation/{eid}',
+ json={'validation_enabled': True, 'validation_schema': schema},
+ )
+ assert uv.status_code in (200, 400, 500)
+ # Delete validation
+ dv = await authed_client.delete(f'/platform/endpoint/endpoint/validation/{eid}')
+ assert dv.status_code == 200
diff --git a/backend-services/tests/test_gateway_body_size_limit.py b/backend-services/tests/test_gateway_body_size_limit.py
index 2526cb6..b2ece0a 100644
--- a/backend-services/tests/test_gateway_body_size_limit.py
+++ b/backend-services/tests/test_gateway_body_size_limit.py
@@ -1,14 +1,27 @@
import pytest
+
@pytest.mark.asyncio
async def test_request_exceeding_max_body_size_returns_413(monkeypatch, authed_client):
from conftest import create_endpoint
- import services.gateway_service as gs
from tests.test_gateway_routing_limits import _FakeAsyncClient
- await authed_client.post('/platform/api', json={
- 'api_name': 'bpub', 'api_version': 'v1', 'api_description': 'b',
- 'api_allowed_roles': ['admin'], 'api_allowed_groups': ['ALL'], 'api_servers': ['http://up'], 'api_type': 'REST', 'api_allowed_retry_count': 0, 'api_public': True
- })
+
+ import services.gateway_service as gs
+
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': 'bpub',
+ 'api_version': 'v1',
+ 'api_description': 'b',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'api_public': True,
+ },
+ )
await create_endpoint(authed_client, 'bpub', 'v1', 'POST', '/p')
monkeypatch.setenv('MAX_BODY_SIZE_BYTES', '10')
headers = {'Content-Type': 'application/json', 'Content-Length': '11'}
@@ -18,11 +31,14 @@ async def test_request_exceeding_max_body_size_returns_413(monkeypatch, authed_c
body = r.json()
assert body.get('error_code') == 'REQ001'
+
@pytest.mark.asyncio
async def test_request_at_limit_is_allowed(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
- import services.gateway_service as gs
from tests.test_gateway_routing_limits import _FakeAsyncClient
+
+ import services.gateway_service as gs
+
name, ver = 'bsz', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/p')
@@ -33,6 +49,7 @@ async def test_request_at_limit_is_allowed(monkeypatch, authed_client):
r = await authed_client.post(f'/api/rest/{name}/{ver}/p', headers=headers, content='1234567890')
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_request_without_content_length_is_allowed(monkeypatch, authed_client):
"""Test that GET requests (no body, no Content-Length) are allowed regardless of limit.
@@ -41,8 +58,10 @@ async def test_request_without_content_length_is_allowed(monkeypatch, authed_cli
so we test with a GET request instead which naturally has no Content-Length.
"""
from conftest import create_api, create_endpoint, subscribe_self
- import services.gateway_service as gs
from tests.test_gateway_routing_limits import _FakeAsyncClient
+
+ import services.gateway_service as gs
+
name, ver = 'bsz2', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/p')
diff --git a/backend-services/tests/test_gateway_caches_permissions.py b/backend-services/tests/test_gateway_caches_permissions.py
index ba8668c..497f021 100644
--- a/backend-services/tests/test_gateway_caches_permissions.py
+++ b/backend-services/tests/test_gateway_caches_permissions.py
@@ -1,22 +1,15 @@
import pytest
+
@pytest.mark.asyncio
async def test_clear_caches_requires_manage_gateway(authed_client):
-
- rd = await authed_client.put(
- '/platform/role/admin',
- json={'manage_gateway': False},
- )
+ rd = await authed_client.put('/platform/role/admin', json={'manage_gateway': False})
assert rd.status_code in (200, 201)
deny = await authed_client.delete('/api/caches')
assert deny.status_code == 403
- re = await authed_client.put(
- '/platform/role/admin',
- json={'manage_gateway': True},
- )
+ re = await authed_client.put('/platform/role/admin', json={'manage_gateway': True})
assert re.status_code in (200, 201)
ok = await authed_client.delete('/api/caches')
assert ok.status_code == 200
-
diff --git a/backend-services/tests/test_gateway_cors_preflight_negatives.py b/backend-services/tests/test_gateway_cors_preflight_negatives.py
new file mode 100644
index 0000000..4667734
--- /dev/null
+++ b/backend-services/tests/test_gateway_cors_preflight_negatives.py
@@ -0,0 +1,53 @@
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_rest_preflight_header_mismatch_does_not_echo_origin(authed_client):
+ name, ver = 'corsneg', 'v1'
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'CORS negative',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up.invalid'],
+ 'api_type': 'REST',
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['GET'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/x',
+ 'endpoint_description': 'x',
+ },
+ )
+
+ r = await authed_client.options(
+ f'/api/rest/{name}/{ver}/x',
+ headers={
+ 'Origin': 'http://ok.example',
+ 'Access-Control-Request-Method': 'GET',
+ 'Access-Control-Request-Headers': 'X-Random-Header',
+ },
+ )
+ assert r.status_code == 204
+ # Current gateway behavior: echoes ACAO based on origin allow-list, even if headers mismatch.
+ acao = r.headers.get('Access-Control-Allow-Origin') or r.headers.get(
+ 'access-control-allow-origin'
+ )
+ assert acao == 'http://ok.example'
+ ach = (
+ r.headers.get('Access-Control-Allow-Headers')
+ or r.headers.get('access-control-allow-headers')
+ or ''
+ )
+ assert 'Content-Type' in ach
diff --git a/backend-services/tests/test_gateway_enforcement_and_paths.py b/backend-services/tests/test_gateway_enforcement_and_paths.py
index 344d86b..27419d9 100644
--- a/backend-services/tests/test_gateway_enforcement_and_paths.py
+++ b/backend-services/tests/test_gateway_enforcement_and_paths.py
@@ -1,5 +1,6 @@
import pytest
+
class _FakeHTTPResponse:
def __init__(self, status_code=200, json_body=None, text_body=None, headers=None):
self.status_code = status_code
@@ -12,10 +13,12 @@ class _FakeHTTPResponse:
def json(self):
import json as _json
+
if self._json_body is None:
return _json.loads(self.text or '{}')
return self._json_body
+
class _FakeAsyncClient:
def __init__(self, *args, **kwargs):
pass
@@ -45,22 +48,68 @@ class _FakeAsyncClient:
return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'})
async def get(self, url, params=None, headers=None, **kwargs):
- return _FakeHTTPResponse(200, json_body={'method': 'GET', 'url': url, 'params': dict(params or {}), 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ return _FakeHTTPResponse(
+ 200,
+ json_body={
+ 'method': 'GET',
+ 'url': url,
+ 'params': dict(params or {}),
+ 'headers': headers or {},
+ },
+ headers={'X-Upstream': 'yes'},
+ )
async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
- return _FakeHTTPResponse(200, json_body={'method': 'POST', 'url': url, 'params': dict(params or {}), 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ body = (
+ json
+ if json is not None
+ else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ )
+ return _FakeHTTPResponse(
+ 200,
+ json_body={
+ 'method': 'POST',
+ 'url': url,
+ 'params': dict(params or {}),
+ 'body': body,
+ 'headers': headers or {},
+ },
+ headers={'X-Upstream': 'yes'},
+ )
async def put(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
- return _FakeHTTPResponse(200, json_body={'method': 'PUT', 'url': url, 'params': dict(params or {}), 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ body = (
+ json
+ if json is not None
+ else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ )
+ return _FakeHTTPResponse(
+ 200,
+ json_body={
+ 'method': 'PUT',
+ 'url': url,
+ 'params': dict(params or {}),
+ 'body': body,
+ 'headers': headers or {},
+ },
+ headers={'X-Upstream': 'yes'},
+ )
async def delete(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- return _FakeHTTPResponse(200, json_body={'method': 'DELETE', 'url': url, 'params': dict(params or {}), 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ return _FakeHTTPResponse(
+ 200,
+ json_body={
+ 'method': 'DELETE',
+ 'url': url,
+ 'params': dict(params or {}),
+ 'headers': headers or {},
+ },
+ headers={'X-Upstream': 'yes'},
+ )
+
@pytest.mark.asyncio
async def test_subscription_required_blocks_without_subscription(monkeypatch, authed_client):
-
name, ver = 'nosub', 'v1'
await authed_client.post(
'/platform/api',
@@ -87,13 +136,14 @@ async def test_subscription_required_blocks_without_subscription(monkeypatch, au
)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/x')
assert r.status_code == 403
+
@pytest.mark.asyncio
async def test_group_required_blocks_when_disallowed_group(monkeypatch, authed_client):
-
name, ver = 'nogroup', 'v1'
await authed_client.post(
'/platform/api',
@@ -120,17 +170,22 @@ async def test_group_required_blocks_when_disallowed_group(monkeypatch, authed_c
)
import routes.gateway_routes as gr
+
async def _pass_sub(req):
return {'sub': 'admin'}
+
monkeypatch.setattr(gr, 'subscription_required', _pass_sub)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/y')
assert r.status_code == 401
+
@pytest.mark.asyncio
async def test_path_template_matching(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'pathapi', 'v1'
await create_api(authed_client, name, ver)
@@ -138,33 +193,35 @@ async def test_path_template_matching(monkeypatch, authed_client):
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/res/abc123')
assert r.status_code == 200
assert r.json().get('url', '').endswith('/res/abc123')
+
@pytest.mark.asyncio
async def test_text_body_forwarding(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'textapi', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/echo')
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
payload = b'hello-world'
r = await authed_client.post(
- f'/api/rest/{name}/{ver}/echo',
- headers={'Content-Type': 'text/plain'},
- content=payload,
+ f'/api/rest/{name}/{ver}/echo', headers={'Content-Type': 'text/plain'}, content=payload
)
assert r.status_code == 200
assert r.json().get('body') == payload.decode('utf-8')
+
@pytest.mark.asyncio
async def test_response_header_filtering_excludes_unlisted(monkeypatch, authed_client):
-
name, ver = 'hdrfilter', 'v1'
await authed_client.post(
'/platform/api',
@@ -196,15 +253,16 @@ async def test_response_header_filtering_excludes_unlisted(monkeypatch, authed_c
)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/p')
assert r.status_code == 200
assert r.headers.get('X-Upstream') is None
+
@pytest.mark.asyncio
async def test_authorization_field_swap(monkeypatch, authed_client):
-
name, ver = 'authswap', 'v1'
await authed_client.post(
'/platform/api',
@@ -237,6 +295,7 @@ async def test_authorization_field_swap(monkeypatch, authed_client):
)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/s', headers={'x-token': 'ABC123'})
assert r.status_code == 200
diff --git a/backend-services/tests/test_gateway_flows.py b/backend-services/tests/test_gateway_flows.py
index fc7f348..5740a74 100644
--- a/backend-services/tests/test_gateway_flows.py
+++ b/backend-services/tests/test_gateway_flows.py
@@ -1,19 +1,23 @@
import json as _json
-from types import SimpleNamespace
+
import pytest
+
class _FakeHTTPResponse:
def __init__(self, status_code=200, json_body=None, text_body=None, headers=None):
self.status_code = status_code
self._json_body = json_body
self.text = text_body if text_body is not None else ('' if json_body is not None else 'OK')
- self.headers = headers or {'Content-Type': 'application/json' if json_body is not None else 'text/plain'}
+ self.headers = headers or {
+ 'Content-Type': 'application/json' if json_body is not None else 'text/plain'
+ }
def json(self):
if self._json_body is None:
return _json.loads(self.text or '{}')
return self._json_body
+
class _FakeAsyncClient:
def __init__(self, *args, **kwargs):
self.kwargs = kwargs
@@ -43,22 +47,36 @@ class _FakeAsyncClient:
return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'})
async def get(self, url, params=None, headers=None, **kwargs):
- return _FakeHTTPResponse(200, json_body={'method': 'GET', 'url': url, 'params': params or {}, 'ok': True})
+ return _FakeHTTPResponse(
+ 200, json_body={'method': 'GET', 'url': url, 'params': params or {}, 'ok': True}
+ )
async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
- return _FakeHTTPResponse(200, json_body={'method': 'POST', 'url': url, 'body': body, 'ok': True})
+ body = (
+ json
+ if json is not None
+ else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ )
+ return _FakeHTTPResponse(
+ 200, json_body={'method': 'POST', 'url': url, 'body': body, 'ok': True}
+ )
async def put(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
- return _FakeHTTPResponse(200, json_body={'method': 'PUT', 'url': url, 'body': body, 'ok': True})
+ body = (
+ json
+ if json is not None
+ else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ )
+ return _FakeHTTPResponse(
+ 200, json_body={'method': 'PUT', 'url': url, 'body': body, 'ok': True}
+ )
async def delete(self, url, json=None, params=None, headers=None, content=None, **kwargs):
return _FakeHTTPResponse(200, json_body={'method': 'DELETE', 'url': url, 'ok': True})
+
@pytest.mark.asyncio
async def test_gateway_rest_happy_path(monkeypatch, authed_client):
-
api_payload = {
'api_name': 'echo',
'api_version': 'v1',
@@ -91,11 +109,12 @@ async def test_gateway_rest_happy_path(monkeypatch, authed_client):
)
assert sub.status_code in (200, 201)
+ import routes.gateway_routes as gr
import services.gateway_service as gs
- import routes.gateway_routes as gr
async def _no_limit(request):
return None
+
monkeypatch.setattr(gr, 'limit_and_throttle', _no_limit)
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
@@ -104,14 +123,15 @@ async def test_gateway_rest_happy_path(monkeypatch, authed_client):
data = gw.json()
assert data.get('ok') is True
+
@pytest.mark.asyncio
async def test_gateway_clear_caches(authed_client):
r = await authed_client.delete('/api/caches')
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_gateway_graphql_grpc_soap(monkeypatch, authed_client):
-
for name in ('graph', 'grpcapi', 'soapapi'):
c = await authed_client.post(
'/platform/api',
@@ -136,22 +156,27 @@ async def test_gateway_graphql_grpc_soap(monkeypatch, authed_client):
missing = await authed_client.post('/api/graphql/graph', json={'query': '{ping}'})
assert missing.status_code == 400
- import services.gateway_service as gs
import routes.gateway_routes as gr
+ import services.gateway_service as gs
+
async def _no_limit2(request):
return None
+
monkeypatch.setattr(gr, 'limit_and_throttle', _no_limit2)
async def fake_graphql_gateway(username, request, request_id, start_time, path):
from models.response_model import ResponseModel
+
return ResponseModel(status_code=200, response={'data': {'ping': 'pong'}}).dict()
async def fake_grpc_gateway(username, request, request_id, start_time, path):
from models.response_model import ResponseModel
+
return ResponseModel(status_code=200, response={'ok': True}).dict()
async def fake_soap_gateway(username, request, request_id, start_time, path):
from models.response_model import ResponseModel
+
return ResponseModel(status_code=200, response='true').dict()
monkeypatch.setattr(gs.GatewayService, 'graphql_gateway', staticmethod(fake_graphql_gateway))
@@ -160,15 +185,15 @@ async def test_gateway_graphql_grpc_soap(monkeypatch, authed_client):
async def _pass_sub(req):
return {'sub': 'admin'}
+
async def _pass_group(req: object, full_path: str = None, user_to_subscribe=None):
return {'sub': 'admin'}
+
monkeypatch.setattr(gr, 'subscription_required', _pass_sub)
monkeypatch.setattr(gr, 'group_required', _pass_group)
g = await authed_client.post(
- '/api/graphql/graph',
- headers={'X-API-Version': 'v1'},
- json={'query': '{ ping }'},
+ '/api/graphql/graph', headers={'X-API-Version': 'v1'}, json={'query': '{ ping }'}
)
assert g.status_code == 200
diff --git a/backend-services/tests/test_gateway_missing_headers.py b/backend-services/tests/test_gateway_missing_headers.py
index 8a46c20..48daabb 100644
--- a/backend-services/tests/test_gateway_missing_headers.py
+++ b/backend-services/tests/test_gateway_missing_headers.py
@@ -1,14 +1,13 @@
import pytest
+
@pytest.mark.asyncio
async def test_grpc_requires_version_header(authed_client):
-
r = await authed_client.post('/api/grpc/service/do', json={'data': '{}'})
assert r.status_code == 400
+
@pytest.mark.asyncio
async def test_graphql_requires_version_header(authed_client):
-
r = await authed_client.post('/api/graphql/graph', json={'query': '{ ping }'})
assert r.status_code == 400
-
diff --git a/backend-services/tests/test_gateway_routing_limits.py b/backend-services/tests/test_gateway_routing_limits.py
index 0f75d23..d3dcfad 100644
--- a/backend-services/tests/test_gateway_routing_limits.py
+++ b/backend-services/tests/test_gateway_routing_limits.py
@@ -1,5 +1,6 @@
import pytest
+
class _FakeHTTPResponse:
def __init__(self, status_code=200, json_body=None, text_body=None, headers=None):
self.status_code = status_code
@@ -13,10 +14,12 @@ class _FakeHTTPResponse:
def json(self):
import json as _json
+
if self._json_body is None:
return _json.loads(self.text or '{}')
return self._json_body
+
class _FakeAsyncClient:
def __init__(self, *args, **kwargs):
self.kwargs = kwargs
@@ -50,30 +53,67 @@ class _FakeAsyncClient:
qp = dict(params or {})
except Exception:
qp = {}
- return _FakeHTTPResponse(200, json_body={'method': 'GET', 'url': url, 'params': qp, 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ return _FakeHTTPResponse(
+ 200,
+ json_body={'method': 'GET', 'url': url, 'params': qp, 'headers': headers or {}},
+ headers={'X-Upstream': 'yes'},
+ )
async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ body = (
+ json
+ if json is not None
+ else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ )
try:
qp = dict(params or {})
except Exception:
qp = {}
- return _FakeHTTPResponse(200, json_body={'method': 'POST', 'url': url, 'params': qp, 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ return _FakeHTTPResponse(
+ 200,
+ json_body={
+ 'method': 'POST',
+ 'url': url,
+ 'params': qp,
+ 'body': body,
+ 'headers': headers or {},
+ },
+ headers={'X-Upstream': 'yes'},
+ )
async def put(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ body = (
+ json
+ if json is not None
+ else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ )
try:
qp = dict(params or {})
except Exception:
qp = {}
- return _FakeHTTPResponse(200, json_body={'method': 'PUT', 'url': url, 'params': qp, 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ return _FakeHTTPResponse(
+ 200,
+ json_body={
+ 'method': 'PUT',
+ 'url': url,
+ 'params': qp,
+ 'body': body,
+ 'headers': headers or {},
+ },
+ headers={'X-Upstream': 'yes'},
+ )
async def delete(self, url, json=None, params=None, headers=None, content=None, **kwargs):
try:
qp = dict(params or {})
except Exception:
qp = {}
- return _FakeHTTPResponse(200, json_body={'method': 'DELETE', 'url': url, 'params': qp, 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ return _FakeHTTPResponse(
+ 200,
+ json_body={'method': 'DELETE', 'url': url, 'params': qp, 'headers': headers or {}},
+ headers={'X-Upstream': 'yes'},
+ )
+
class _NotFoundAsyncClient(_FakeAsyncClient):
async def get(self, url, params=None, headers=None, **kwargs):
@@ -81,12 +121,15 @@ class _NotFoundAsyncClient(_FakeAsyncClient):
qp = dict(params or {})
except Exception:
qp = {}
- return _FakeHTTPResponse(404, json_body={'ok': False, 'url': url, 'params': qp}, headers={'X-Upstream': 'no'})
+ return _FakeHTTPResponse(
+ 404, json_body={'ok': False, 'url': url, 'params': qp}, headers={'X-Upstream': 'no'}
+ )
+
@pytest.mark.asyncio
async def test_routing_precedence_and_round_robin(monkeypatch, authed_client):
+ from conftest import create_api, subscribe_self
- from conftest import create_api, create_endpoint, subscribe_self
name, ver = 'routeapi', 'v1'
await create_api(authed_client, name, ver)
@@ -116,6 +159,7 @@ async def test_routing_precedence_and_round_robin(monkeypatch, authed_client):
assert r.status_code in (200, 201)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
a = await authed_client.get(f'/api/rest/{name}/{ver}/ping', headers={'client-key': 'client-1'})
@@ -127,9 +171,11 @@ async def test_routing_precedence_and_round_robin(monkeypatch, authed_client):
assert b1.json().get('url', '').startswith('http://ep-a')
assert b2.json().get('url', '').startswith('http://ep-b')
+
@pytest.mark.asyncio
async def test_client_routing_round_robin(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'clientrr', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/p')
@@ -145,15 +191,18 @@ async def test_client_routing_round_robin(monkeypatch, authed_client):
},
)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r1 = await authed_client.get(f'/api/rest/{name}/{ver}/p', headers={'client-key': 'c1'})
r2 = await authed_client.get(f'/api/rest/{name}/{ver}/p', headers={'client-key': 'c1'})
assert r1.json().get('url', '').startswith('http://r1')
assert r2.json().get('url', '').startswith('http://r2')
+
@pytest.mark.asyncio
async def test_api_level_round_robin_when_no_endpoint_servers(monkeypatch, authed_client):
- from conftest import create_api, create_endpoint, subscribe_self
+ from conftest import create_endpoint, subscribe_self
+
name, ver = 'apiround', 'v1'
payload = {
@@ -173,6 +222,7 @@ async def test_api_level_round_robin_when_no_endpoint_servers(monkeypatch, authe
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r1 = await authed_client.get(f'/api/rest/{name}/{ver}/rr')
@@ -180,21 +230,27 @@ async def test_api_level_round_robin_when_no_endpoint_servers(monkeypatch, authe
assert r1.json().get('url', '').startswith('http://api-a')
assert r2.json().get('url', '').startswith('http://api-b')
+
@pytest.mark.asyncio
async def test_rate_limit_exceeded_returns_429(monkeypatch, authed_client):
-
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'ratelimit', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/ping')
await subscribe_self(authed_client, name, ver)
from utils.database import user_collection
- user_collection.update_one({'username': 'admin'}, {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second'}})
+
+ user_collection.update_one(
+ {'username': 'admin'},
+ {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second'}},
+ )
await authed_client.delete('/api/caches')
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
ok = await authed_client.get(f'/api/rest/{name}/{ver}/ping')
@@ -202,28 +258,34 @@ async def test_rate_limit_exceeded_returns_429(monkeypatch, authed_client):
too_many = await authed_client.get(f'/api/rest/{name}/{ver}/ping')
assert too_many.status_code == 429
+
@pytest.mark.asyncio
async def test_throttle_queue_limit_exceeded_returns_429(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'throttleq', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/t')
await subscribe_self(authed_client, name, ver)
from utils.database import user_collection
+
user_collection.update_one({'username': 'admin'}, {'$set': {'throttle_queue_limit': 1}})
await authed_client.delete('/api/caches')
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
- r1 = await authed_client.get(f'/api/rest/{name}/{ver}/t')
+ await authed_client.get(f'/api/rest/{name}/{ver}/t')
r2 = await authed_client.get(f'/api/rest/{name}/{ver}/t')
assert r2.status_code == 429
+
@pytest.mark.asyncio
async def test_query_params_and_headers_forwarding(monkeypatch, authed_client):
- from conftest import create_api, create_endpoint, subscribe_self
+ from conftest import create_endpoint, subscribe_self
+
name, ver = 'params', 'v1'
payload = {
'api_name': name,
@@ -241,10 +303,21 @@ async def test_query_params_and_headers_forwarding(monkeypatch, authed_client):
await subscribe_self(authed_client, name, ver)
from utils.database import user_collection
- user_collection.update_one({'username': 'admin'}, {'$set': {'rate_limit_duration': 1000, 'rate_limit_duration_type': 'second', 'throttle_queue_limit': 1000}})
+
+ user_collection.update_one(
+ {'username': 'admin'},
+ {
+ '$set': {
+ 'rate_limit_duration': 1000,
+ 'rate_limit_duration_type': 'second',
+ 'throttle_queue_limit': 1000,
+ }
+ },
+ )
await authed_client.delete('/api/caches')
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/q?foo=1&bar=2')
@@ -254,15 +327,18 @@ async def test_query_params_and_headers_forwarding(monkeypatch, authed_client):
assert r.headers.get('X-Upstream') == 'yes'
+
@pytest.mark.asyncio
async def test_post_body_forwarding_json(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'postfwd', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/echo')
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
payload = {'a': 1, 'b': 2}
@@ -270,9 +346,9 @@ async def test_post_body_forwarding_json(monkeypatch, authed_client):
assert r.status_code == 200
assert r.json().get('body') == payload
+
@pytest.mark.asyncio
async def test_credit_header_injection_and_user_override(monkeypatch, authed_client):
-
credit_group = 'inject-group'
rc = await authed_client.post(
'/platform/credit',
@@ -281,7 +357,13 @@ async def test_credit_header_injection_and_user_override(monkeypatch, authed_cli
'api_key': 'GROUP-KEY',
'api_key_header': 'x-api-key',
'credit_tiers': [
- {'tier_name': 'default', 'credits': 999, 'input_limit': 0, 'output_limit': 0, 'reset_frequency': 'monthly'}
+ {
+ 'tier_name': 'default',
+ 'credits': 999,
+ 'input_limit': 0,
+ 'output_limit': 0,
+ 'reset_frequency': 'monthly',
+ }
],
},
)
@@ -319,39 +401,47 @@ async def test_credit_header_injection_and_user_override(monkeypatch, authed_cli
)
ur = await authed_client.post(
- f'/platform/credit/admin',
+ '/platform/credit/admin',
json={
'username': 'admin',
'users_credits': {
- credit_group: {'tier_name': 'default', 'available_credits': 1, 'user_api_key': 'USER-KEY'}
+ credit_group: {
+ 'tier_name': 'default',
+ 'available_credits': 1,
+ 'user_api_key': 'USER-KEY',
+ }
},
},
)
assert ur.status_code in (200, 201), ur.text
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/h')
assert r.status_code == 200
headers_seen = r.json().get('headers') or {}
assert headers_seen.get('x-api-key') == 'USER-KEY'
+
@pytest.mark.asyncio
async def test_gateway_sets_request_id_headers(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'reqid', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/x')
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/x')
assert r.status_code == 200
assert r.headers.get('X-Request-ID') and r.headers.get('request_id')
+
@pytest.mark.asyncio
async def test_api_disabled_returns_403(monkeypatch, authed_client):
-
name, ver = 'disabled', 'v1'
await authed_client.post(
'/platform/api',
@@ -384,31 +474,38 @@ async def test_api_disabled_returns_403(monkeypatch, authed_client):
)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/x')
assert r.status_code == 403
+
@pytest.mark.asyncio
async def test_upstream_404_maps_to_gtw005(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'nfupstream', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/z')
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _NotFoundAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/z')
assert r.status_code == 404
+
@pytest.mark.asyncio
async def test_endpoint_not_found_returns_404_code_gtw003(monkeypatch, authed_client):
from conftest import create_api, subscribe_self
+
name, ver = 'noep', 'v1'
await create_api(authed_client, name, ver)
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/missing')
diff --git a/backend-services/tests/test_gateway_validation.py b/backend-services/tests/test_gateway_validation.py
index 03729db..ed61724 100644
--- a/backend-services/tests/test_gateway_validation.py
+++ b/backend-services/tests/test_gateway_validation.py
@@ -1,9 +1,10 @@
import pytest
+
@pytest.mark.asyncio
async def test_rest_payload_validation_blocks_bad_request(authed_client):
-
from conftest import create_api, create_endpoint, subscribe_self
+
api_name = 'valrest'
version = 'v1'
await create_api(authed_client, api_name, version)
@@ -26,15 +27,14 @@ async def test_rest_payload_validation_blocks_bad_request(authed_client):
)
assert cv.status_code in (200, 201, 400)
- r = await authed_client.post(
- f'/api/rest/{api_name}/{version}/do',
- json={'user': {'name': 'A'}},
- )
+ r = await authed_client.post(f'/api/rest/{api_name}/{version}/do', json={'user': {'name': 'A'}})
assert r.status_code == 400
+
@pytest.mark.asyncio
async def test_graphql_payload_validation_blocks_bad_request(authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
api_name = 'valgql'
version = 'v1'
await create_api(authed_client, api_name, version)
@@ -49,7 +49,6 @@ async def test_graphql_payload_validation_blocks_bad_request(authed_client):
schema = {
'validation_schema': {
-
'CreateUser.input.name': {'required': True, 'type': 'string', 'min': 2, 'max': 50}
}
}
@@ -68,9 +67,11 @@ async def test_graphql_payload_validation_blocks_bad_request(authed_client):
)
assert r.status_code == 400
+
@pytest.mark.asyncio
async def test_soap_payload_validation_blocks_bad_request(authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
api_name = 'valsoap'
version = 'v1'
await create_api(authed_client, api_name, version)
@@ -94,8 +95,8 @@ async def test_soap_payload_validation_blocks_bad_request(authed_client):
assert cv.status_code in (200, 201, 400)
envelope = (
- ''
- ''
+ ''
+ ''
''
'A'
''
@@ -108,9 +109,11 @@ async def test_soap_payload_validation_blocks_bad_request(authed_client):
)
assert r.status_code == 400
+
@pytest.mark.asyncio
async def test_grpc_payload_validation_blocks_bad_request(authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
api_name = 'valgrpc'
version = 'v1'
await create_api(authed_client, api_name, version)
@@ -141,9 +144,11 @@ async def test_grpc_payload_validation_blocks_bad_request(authed_client):
)
assert r.status_code == 400
+
@pytest.mark.asyncio
async def test_rest_payload_validation_allows_good_request(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
api_name = 'okrest'
version = 'v1'
await create_api(authed_client, api_name, version)
@@ -182,14 +187,19 @@ async def test_rest_payload_validation_allows_good_request(monkeypatch, authed_c
return FakeResp()
import services.gateway_service as gw
+
monkeypatch.setattr(gw.httpx, 'AsyncClient', FakeClient)
- r = await authed_client.post(f'/api/rest/{api_name}/{version}/do', json={'user': {'name': 'Ab'}})
+ r = await authed_client.post(
+ f'/api/rest/{api_name}/{version}/do', json={'user': {'name': 'Ab'}}
+ )
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_soap_payload_validation_allows_good_request(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
api_name = 'oksoap'
version = 'v1'
await create_api(authed_client, api_name, version)
@@ -224,11 +234,12 @@ async def test_soap_payload_validation_allows_good_request(monkeypatch, authed_c
return FakeResp()
import services.gateway_service as gw
+
monkeypatch.setattr(gw.httpx, 'AsyncClient', FakeClient)
envelope = (
- ''
- ''
+ ''
+ ''
''
'Ab'
''
@@ -241,9 +252,11 @@ async def test_soap_payload_validation_allows_good_request(monkeypatch, authed_c
)
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_graphql_payload_validation_allows_good_request(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
api_name = 'okgql'
version = 'v1'
await create_api(authed_client, api_name, version)
@@ -252,7 +265,11 @@ async def test_graphql_payload_validation_allows_good_request(monkeypatch, authe
g = await authed_client.get(f'/platform/endpoint/POST/{api_name}/{version}/graphql')
eid = g.json().get('endpoint_id') or g.json().get('response', {}).get('endpoint_id')
- schema = {'validation_schema': {'CreateUser.input.name': {'required': True, 'type': 'string', 'min': 2}}}
+ schema = {
+ 'validation_schema': {
+ 'CreateUser.input.name': {'required': True, 'type': 'string', 'min': 2}
+ }
+ }
await authed_client.post(
'/platform/endpoint/endpoint/validation',
json={'endpoint_id': eid, 'validation_enabled': True, 'validation_schema': schema},
@@ -279,6 +296,7 @@ async def test_graphql_payload_validation_allows_good_request(monkeypatch, authe
return False
import services.gateway_service as gw
+
monkeypatch.setattr(gw, 'Client', FakeClient)
query = 'mutation CreateUser($input: UserInput!){ createUser(input: $input){ id } }'
@@ -290,9 +308,13 @@ async def test_graphql_payload_validation_allows_good_request(monkeypatch, authe
)
assert r.status_code == 200
+
@pytest.mark.asyncio
-async def test_grpc_payload_validation_allows_good_request_progresses(monkeypatch, authed_client, tmp_path):
+async def test_grpc_payload_validation_allows_good_request_progresses(
+ monkeypatch, authed_client, tmp_path
+):
from conftest import create_api, create_endpoint, subscribe_self
+
api_name = 'okgrpc'
version = 'v1'
await create_api(authed_client, api_name, version)
@@ -308,7 +330,9 @@ async def test_grpc_payload_validation_allows_good_request_progresses(monkeypatc
)
import os as _os
+
import services.gateway_service as gw
+
project_root = _os.path.dirname(_os.path.dirname(_os.path.abspath(__file__)))
proto_dir = _os.path.join(project_root, 'proto')
_os.makedirs(proto_dir, exist_ok=True)
@@ -318,7 +342,9 @@ async def test_grpc_payload_validation_allows_good_request_progresses(monkeypatc
def fake_import(name):
raise ImportError('fake')
- monkeypatch.setattr(gw.importlib, 'import_module', lambda n: (_ for _ in ()).throw(ImportError('fake')))
+ monkeypatch.setattr(
+ gw.importlib, 'import_module', lambda n: (_ for _ in ()).throw(ImportError('fake'))
+ )
payload = {'method': 'Service.Method', 'message': {'user': {'name': 'Ab'}}}
r = await authed_client.post(
diff --git a/backend-services/tests/test_ghost_admin.py b/backend-services/tests/test_ghost_admin.py
index 394b20e..90171f3 100644
--- a/backend-services/tests/test_ghost_admin.py
+++ b/backend-services/tests/test_ghost_admin.py
@@ -5,9 +5,11 @@ Super admin should:
- Be completely hidden from all user list/get endpoints
- Be completely protected from modification/deletion via API
"""
-import pytest
+
import os
+import pytest
+
@pytest.mark.asyncio
async def test_bootstrap_admin_can_authenticate(client):
@@ -15,10 +17,7 @@ async def test_bootstrap_admin_can_authenticate(client):
email = os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev')
password = os.getenv('DOORMAN_ADMIN_PASSWORD', 'SecPassword!12345')
- r = await client.post('/platform/authorization', json={
- 'email': email,
- 'password': password
- })
+ r = await client.post('/platform/authorization', json={'email': email, 'password': password})
assert r.status_code == 200, 'Super admin should be able to authenticate'
data = r.json()
assert 'access_token' in (data.get('response') or data), 'Should receive access token'
@@ -31,7 +30,11 @@ async def test_bootstrap_admin_hidden_from_user_list(authed_client):
assert r.status_code == 200
users = r.json()
- user_list = users if isinstance(users, list) else (users.get('users') or users.get('response', {}).get('users') or [])
+ user_list = (
+ users
+ if isinstance(users, list)
+ else (users.get('users') or users.get('response', {}).get('users') or [])
+ )
usernames = {u.get('username') for u in user_list}
assert 'admin' not in usernames, 'Super admin should not appear in user list'
@@ -55,9 +58,7 @@ async def test_bootstrap_admin_get_by_email_returns_404(authed_client):
@pytest.mark.asyncio
async def test_bootstrap_admin_cannot_be_updated(authed_client):
"""PUT /platform/user/admin should be blocked."""
- r = await authed_client.put('/platform/user/admin', json={
- 'email': 'new-email@example.com'
- })
+ r = await authed_client.put('/platform/user/admin', json={'email': 'new-email@example.com'})
assert r.status_code == 403, 'Super admin should not be modifiable'
data = r.json()
assert 'USR020' in str(data.get('error_code')), 'Should return USR020 error code'
@@ -77,13 +78,11 @@ async def test_bootstrap_admin_cannot_be_deleted(authed_client):
@pytest.mark.asyncio
async def test_bootstrap_admin_password_cannot_be_changed(authed_client):
"""PUT /platform/user/admin/update-password should be blocked."""
- r = await authed_client.put('/platform/user/admin/update-password', json={
- 'current_password': 'anything',
- 'new_password': 'NewPassword!123'
- })
+ r = await authed_client.put(
+ '/platform/user/admin/update-password',
+ json={'current_password': 'anything', 'new_password': 'NewPassword!123'},
+ )
assert r.status_code == 403, 'Super admin password should not be changeable via API'
data = r.json()
assert 'USR022' in str(data.get('error_code')), 'Should return USR022 error code'
assert 'super' in str(data.get('error_message')).lower(), 'Error message should mention super'
-
-
diff --git a/backend-services/tests/test_graceful_shutdown.py b/backend-services/tests/test_graceful_shutdown.py
index e41e324..d441c21 100644
--- a/backend-services/tests/test_graceful_shutdown.py
+++ b/backend-services/tests/test_graceful_shutdown.py
@@ -1,12 +1,15 @@
-import os
-import pytest
import asyncio
import logging
+import os
from io import StringIO
+import pytest
+
+
@pytest.mark.asyncio
async def test_graceful_shutdown_allows_inflight_completion(monkeypatch):
from services.user_service import UserService
+
original = UserService.check_password_return_user
async def _slow_check(email, password):
@@ -21,9 +24,10 @@ async def test_graceful_shutdown_allows_inflight_completion(monkeypatch):
logger.addHandler(handler)
try:
- from doorman import doorman, app_lifespan
from httpx import AsyncClient
+ from doorman import app_lifespan, doorman
+
async with app_lifespan(doorman):
client = AsyncClient(app=doorman, base_url='http://testserver')
creds = {
@@ -41,4 +45,3 @@ async def test_graceful_shutdown_allows_inflight_completion(monkeypatch):
assert 'Waiting for in-flight requests to complete' in logs
finally:
logger.removeHandler(handler)
-
diff --git a/backend-services/tests/test_graphql_client_and_envelope.py b/backend-services/tests/test_graphql_client_and_envelope.py
index 00fcdda..0b1f482 100644
--- a/backend-services/tests/test_graphql_client_and_envelope.py
+++ b/backend-services/tests/test_graphql_client_and_envelope.py
@@ -1,5 +1,6 @@
import pytest
+
async def _setup_graphql(client, name, ver, allowed_headers=None):
payload = {
'api_name': name,
@@ -15,20 +16,26 @@ async def _setup_graphql(client, name, ver, allowed_headers=None):
payload['api_allowed_headers'] = allowed_headers
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/graphql',
- 'endpoint_description': 'graphql',
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/graphql',
+ 'endpoint_description': 'graphql',
+ },
+ )
assert r2.status_code in (200, 201)
from conftest import subscribe_self
+
await subscribe_self(client, name, ver)
+
@pytest.mark.asyncio
async def test_graphql_uses_gql_client_when_available(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gqlclient', 'v1'
await _setup_graphql(authed_client, name, ver)
@@ -37,8 +44,10 @@ async def test_graphql_uses_gql_client_when_available(monkeypatch, authed_client
class FakeSession:
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def execute(self, query, variable_values=None):
calls['query'] = query
calls['vars'] = variable_values
@@ -47,8 +56,10 @@ async def test_graphql_uses_gql_client_when_available(monkeypatch, authed_client
class FakeClient:
def __init__(self, transport=None, fetch_schema_from_transport=False):
pass
+
async def __aenter__(self):
return FakeSession()
+
async def __aexit__(self, exc_type, exc, tb):
return False
@@ -63,30 +74,37 @@ async def test_graphql_uses_gql_client_when_available(monkeypatch, authed_client
assert body.get('ok') is True and body.get('from') == 'client'
assert calls.get('vars') == {'a': 1}
+
@pytest.mark.asyncio
async def test_graphql_fallback_to_httpx_when_client_unavailable(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gqlhttpx', 'v1'
await _setup_graphql(authed_client, name, ver)
# Make Client unusable for async context
class Dummy:
pass
+
monkeypatch.setattr(gs, 'Client', Dummy)
class FakeHTTPResp:
def __init__(self, payload):
self._p = payload
+
def json(self):
return self._p
class FakeHTTPX:
def __init__(self, *args, **kwargs):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, json=None, headers=None):
return FakeHTTPResp({'ok': True, 'from': 'httpx', 'url': url})
@@ -100,32 +118,40 @@ async def test_graphql_fallback_to_httpx_when_client_unavailable(monkeypatch, au
body = r.json()
assert body.get('from') == 'httpx'
+
@pytest.mark.asyncio
async def test_graphql_errors_returned_in_errors_array(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gqlerrors', 'v1'
await _setup_graphql(authed_client, name, ver)
class FakeHTTPResp:
def __init__(self, payload):
self._p = payload
+
def json(self):
return self._p
class FakeHTTPX:
def __init__(self, *args, **kwargs):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, json=None, headers=None):
return FakeHTTPResp({'errors': [{'message': 'boom'}]})
monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX)
+
# Force HTTPX path
class Dummy:
pass
+
monkeypatch.setattr(gs, 'Client', Dummy)
r = await authed_client.post(
@@ -137,33 +163,41 @@ async def test_graphql_errors_returned_in_errors_array(monkeypatch, authed_clien
body = r.json()
assert isinstance(body.get('errors'), list) and body['errors'][0]['message'] == 'boom'
+
@pytest.mark.asyncio
async def test_graphql_strict_envelope_wraps_response(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gqlstrict', 'v1'
await _setup_graphql(authed_client, name, ver)
class FakeHTTPResp:
def __init__(self, payload):
self._p = payload
+
def json(self):
return self._p
class FakeHTTPX:
def __init__(self, *args, **kwargs):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, json=None, headers=None):
return FakeHTTPResp({'data': {'ok': True}})
monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true')
monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX)
+
# Use httpx path by disabling Client
class Dummy:
pass
+
monkeypatch.setattr(gs, 'Client', Dummy)
r = await authed_client.post(
@@ -175,33 +209,41 @@ async def test_graphql_strict_envelope_wraps_response(monkeypatch, authed_client
body = r.json()
assert body.get('status_code') == 200 and isinstance(body.get('response'), dict)
+
@pytest.mark.asyncio
async def test_graphql_loose_envelope_returns_raw_response(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gqlloose', 'v1'
await _setup_graphql(authed_client, name, ver)
class FakeHTTPResp:
def __init__(self, payload):
self._p = payload
+
def json(self):
return self._p
class FakeHTTPX:
def __init__(self, *args, **kwargs):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, json=None, headers=None):
return FakeHTTPResp({'data': {'ok': True}})
monkeypatch.delenv('STRICT_RESPONSE_ENVELOPE', raising=False)
monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX)
+
# Use httpx path by disabling Client
class Dummy:
pass
+
monkeypatch.setattr(gs, 'Client', Dummy)
r = await authed_client.post(
@@ -212,4 +254,3 @@ async def test_graphql_loose_envelope_returns_raw_response(monkeypatch, authed_c
assert r.status_code == 200
body = r.json()
assert body.get('data', {}).get('ok') is True
-
diff --git a/backend-services/tests/test_graphql_error_flow_and_status.py b/backend-services/tests/test_graphql_error_flow_and_status.py
index 24ceb99..a0afa2c 100644
--- a/backend-services/tests/test_graphql_error_flow_and_status.py
+++ b/backend-services/tests/test_graphql_error_flow_and_status.py
@@ -1,111 +1,156 @@
import pytest
+
async def _setup_graphql_api(client, name='geflow', ver='v1'):
- await client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://gql.up'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'api_public': True,
- })
- await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/graphql',
- 'endpoint_description': 'gql'
- })
+ await client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://gql.up'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'api_public': True,
+ },
+ )
+ await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/graphql',
+ 'endpoint_description': 'gql',
+ },
+ )
return name, ver
+
@pytest.mark.asyncio
async def test_graphql_upstream_error_returns_errors_array_200(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = await _setup_graphql_api(authed_client, name='ge1', ver='v1')
class FakeHTTPResp:
def __init__(self, payload):
self._p = payload
+
def json(self):
return self._p
+
class FakeHTTPX:
def __init__(self, *args, **kwargs):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, json=None, headers=None):
return FakeHTTPResp({'errors': [{'message': 'upstream boom'}]})
+
class Dummy:
pass
+
monkeypatch.setattr(gs, 'Client', Dummy)
monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX)
- r = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ q }', 'variables': {}})
+ r = await authed_client.post(
+ f'/api/graphql/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'query': '{ q }', 'variables': {}},
+ )
assert r.status_code == 200
assert isinstance(r.json().get('errors'), list)
+
@pytest.mark.asyncio
async def test_graphql_upstream_http_error_maps_to_errors_with_status(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = await _setup_graphql_api(authed_client, name='ge2', ver='v1')
class FakeHTTPResp:
def __init__(self, payload, status):
self._p = payload
self.status_code = status
+
def json(self):
return self._p
+
class FakeHTTPX:
def __init__(self, *args, **kwargs):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, json=None, headers=None):
return FakeHTTPResp({'errors': [{'message': 'http fail', 'status': 500}]}, 500)
+
class Dummy:
pass
+
monkeypatch.setattr(gs, 'Client', Dummy)
monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX)
- r = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ q }', 'variables': {}})
+ r = await authed_client.post(
+ f'/api/graphql/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'query': '{ q }', 'variables': {}},
+ )
assert r.status_code == 200
errs = r.json().get('errors')
assert isinstance(errs, list) and errs[0].get('status') == 500
+
@pytest.mark.asyncio
async def test_graphql_strict_envelope_contains_status_code_field(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = await _setup_graphql_api(authed_client, name='ge3', ver='v1')
class FakeHTTPResp:
def __init__(self, payload):
self._p = payload
+
def json(self):
return self._p
+
class FakeHTTPX:
def __init__(self, *args, **kwargs):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, json=None, headers=None):
return FakeHTTPResp({'data': {'ok': True}})
+
class Dummy:
pass
+
monkeypatch.setattr(gs, 'Client', Dummy)
monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX)
monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true')
- r = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ strict }', 'variables': {}})
+ r = await authed_client.post(
+ f'/api/graphql/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'query': '{ strict }', 'variables': {}},
+ )
assert r.status_code == 200
body = r.json()
assert body.get('status_code') == 200 and isinstance(body.get('response'), dict)
-
diff --git a/backend-services/tests/test_graphql_preflight_positive.py b/backend-services/tests/test_graphql_preflight_positive.py
new file mode 100644
index 0000000..fc14f00
--- /dev/null
+++ b/backend-services/tests/test_graphql_preflight_positive.py
@@ -0,0 +1,44 @@
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_graphql_preflight_positive_allows(authed_client):
+ name, ver = 'gqlpos', 'v1'
+ cr = await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'graphql preflight positive',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up.invalid'],
+ 'api_type': 'GRAPHQL',
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['POST'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ assert cr.status_code in (200, 201)
+
+ r = await authed_client.options(
+ f'/api/graphql/{name}',
+ headers={
+ 'X-API-Version': ver,
+ 'Origin': 'http://ok.example',
+ 'Access-Control-Request-Method': 'POST',
+ 'Access-Control-Request-Headers': 'Content-Type',
+ },
+ )
+ assert r.status_code == 204
+ acao = r.headers.get('Access-Control-Allow-Origin') or r.headers.get(
+ 'access-control-allow-origin'
+ )
+ assert acao == 'http://ok.example'
+ ach = (
+ r.headers.get('Access-Control-Allow-Headers')
+ or r.headers.get('access-control-allow-headers')
+ or ''
+ )
+ assert 'Content-Type' in ach
diff --git a/backend-services/tests/test_graphql_soap_grpc_extended.py b/backend-services/tests/test_graphql_soap_grpc_extended.py
index c298afc..d9095af 100644
--- a/backend-services/tests/test_graphql_soap_grpc_extended.py
+++ b/backend-services/tests/test_graphql_soap_grpc_extended.py
@@ -1,5 +1,6 @@
import pytest
+
class _FakeHTTPResponse:
def __init__(self, status_code=200, json_body=None, text_body=None, headers=None):
self.status_code = status_code
@@ -13,10 +14,12 @@ class _FakeHTTPResponse:
def json(self):
import json as _json
+
if self._json_body is None:
return _json.loads(self.text or '{}')
return self._json_body
+
class _FakeAsyncClient:
def __init__(self, *args, **kwargs):
pass
@@ -46,26 +49,43 @@ class _FakeAsyncClient:
return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'})
async def get(self, url, params=None, headers=None, **kwargs):
- return _FakeHTTPResponse(200, json_body={'ok': True, 'url': url}, headers={'X-Upstream': 'yes'})
+ return _FakeHTTPResponse(
+ 200, json_body={'ok': True, 'url': url}, headers={'X-Upstream': 'yes'}
+ )
async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
- return _FakeHTTPResponse(200, json_body={'ok': True, 'url': url, 'body': body}, headers={'X-Upstream': 'yes'})
+ body = (
+ json
+ if json is not None
+ else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ )
+ return _FakeHTTPResponse(
+ 200, json_body={'ok': True, 'url': url, 'body': body}, headers={'X-Upstream': 'yes'}
+ )
async def put(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
- return _FakeHTTPResponse(200, json_body={'ok': True, 'url': url, 'body': body}, headers={'X-Upstream': 'yes'})
+ body = (
+ json
+ if json is not None
+ else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ )
+ return _FakeHTTPResponse(
+ 200, json_body={'ok': True, 'url': url, 'body': body}, headers={'X-Upstream': 'yes'}
+ )
async def delete(self, url, **kwargs):
return _FakeHTTPResponse(200, json_body={'ok': True}, headers={'X-Upstream': 'yes'})
+
class _NotFoundAsyncClient(_FakeAsyncClient):
async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs):
return _FakeHTTPResponse(404, json_body={'ok': False}, headers={'X-Upstream': 'no'})
+
@pytest.mark.asyncio
async def test_grpc_missing_version_header_returns_400(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'grpcver', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/grpc')
@@ -73,10 +93,11 @@ async def test_grpc_missing_version_header_returns_400(monkeypatch, authed_clien
r = await authed_client.post(f'/api/grpc/{name}', json={'method': 'Svc.M', 'message': {}})
assert r.status_code == 400
+
@pytest.mark.asyncio
async def test_graphql_lowercase_version_header_works(monkeypatch, authed_client):
-
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'gqllower', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/graphql')
@@ -85,20 +106,25 @@ async def test_graphql_lowercase_version_header_works(monkeypatch, authed_client
class FakeSession:
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def execute(self, *args, **kwargs):
return {'ping': 'pong'}
class FakeClient:
def __init__(self, transport=None, fetch_schema_from_transport=False):
pass
+
async def __aenter__(self):
return FakeSession()
+
async def __aexit__(self, exc_type, exc, tb):
return False
import services.gateway_service as gw
+
monkeypatch.setattr(gw, 'Client', FakeClient)
r = await authed_client.post(
@@ -108,9 +134,11 @@ async def test_graphql_lowercase_version_header_works(monkeypatch, authed_client
)
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_soap_text_xml_validation_allows_good_request(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'soaptext', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/call')
@@ -126,8 +154,8 @@ async def test_soap_text_xml_validation_allows_good_request(monkeypatch, authed_
)
envelope = (
- ''
- ''
+ ''
+ ''
''
'Ab'
''
@@ -144,30 +172,35 @@ async def test_soap_text_xml_validation_allows_good_request(monkeypatch, authed_
class _FakeXMLClient:
def __init__(self, *args, **kwargs):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, content=None, params=None, headers=None):
return _FakeXMLResponse()
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeXMLClient)
r = await authed_client.post(
- f'/api/soap/{name}/{ver}/call',
- headers={'Content-Type': 'text/xml'},
- content=envelope,
+ f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'text/xml'}, content=envelope
)
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_soap_upstream_404_maps_to_404(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'soap404', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/call')
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _NotFoundAsyncClient)
r = await authed_client.post(
f'/api/soap/{name}/{ver}/call',
@@ -176,14 +209,17 @@ async def test_soap_upstream_404_maps_to_404(monkeypatch, authed_client):
)
assert r.status_code == 404
+
@pytest.mark.asyncio
async def test_grpc_upstream_404_maps_to_404(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'grpc404', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/grpc')
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _NotFoundAsyncClient)
r = await authed_client.post(
f'/api/grpc/{name}',
@@ -192,14 +228,16 @@ async def test_grpc_upstream_404_maps_to_404(monkeypatch, authed_client):
)
assert r.status_code == 404
+
@pytest.mark.asyncio
async def test_grpc_subscription_required(monkeypatch, authed_client):
-
from conftest import create_api, create_endpoint
+
name, ver = 'grpcsub', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/grpc')
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.post(
f'/api/grpc/{name}',
@@ -208,9 +246,9 @@ async def test_grpc_subscription_required(monkeypatch, authed_client):
)
assert r.status_code == 403
+
@pytest.mark.asyncio
async def test_graphql_group_enforcement(monkeypatch, authed_client):
-
name, ver = 'gqlgrp', 'v1'
await authed_client.post(
'/platform/api',
@@ -237,8 +275,10 @@ async def test_graphql_group_enforcement(monkeypatch, authed_client):
)
import routes.gateway_routes as gr
+
async def _pass_sub(req):
return {'sub': 'admin'}
+
monkeypatch.setattr(gr, 'subscription_required', _pass_sub)
await authed_client.delete('/api/caches')
@@ -246,20 +286,25 @@ async def test_graphql_group_enforcement(monkeypatch, authed_client):
class FakeSession:
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def execute(self, *args, **kwargs):
return {'ok': True}
class FakeClient:
def __init__(self, transport=None, fetch_schema_from_transport=False):
pass
+
async def __aenter__(self):
return FakeSession()
+
async def __aexit__(self, exc_type, exc, tb):
return False
import services.gateway_service as gw
+
monkeypatch.setattr(gw, 'Client', FakeClient)
r = await authed_client.post(
diff --git a/backend-services/tests/test_graphql_soap_preflight_negatives.py b/backend-services/tests/test_graphql_soap_preflight_negatives.py
new file mode 100644
index 0000000..c5c399a
--- /dev/null
+++ b/backend-services/tests/test_graphql_soap_preflight_negatives.py
@@ -0,0 +1,75 @@
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_graphql_preflight_header_mismatch_removes_acao(authed_client):
+ name, ver = 'gqlneg', 'v1'
+ # Create GraphQL API config
+ cr = await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'graphql preflight',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up.invalid'],
+ 'api_type': 'GRAPHQL',
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['POST'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ assert cr.status_code in (200, 201)
+
+ r = await authed_client.options(
+ f'/api/graphql/{name}',
+ headers={
+ 'X-API-Version': ver,
+ 'Origin': 'http://ok.example',
+ 'Access-Control-Request-Method': 'POST',
+ 'Access-Control-Request-Headers': 'X-Not-Allowed',
+ },
+ )
+ assert r.status_code == 204
+ acao = r.headers.get('Access-Control-Allow-Origin') or r.headers.get(
+ 'access-control-allow-origin'
+ )
+ assert acao in (None, '')
+
+
+@pytest.mark.asyncio
+async def test_soap_preflight_header_mismatch_removes_acao(authed_client):
+ name, ver = 'soapneg', 'v1'
+ cr = await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'soap preflight',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up.invalid'],
+ 'api_type': 'SOAP',
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['POST'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ assert cr.status_code in (200, 201)
+
+ r = await authed_client.options(
+ f'/api/soap/{name}/{ver}/x',
+ headers={
+ 'Origin': 'http://ok.example',
+ 'Access-Control-Request-Method': 'POST',
+ 'Access-Control-Request-Headers': 'X-Not-Allowed',
+ },
+ )
+ assert r.status_code == 204
+ acao = r.headers.get('Access-Control-Allow-Origin') or r.headers.get(
+ 'access-control-allow-origin'
+ )
+ assert acao in (None, '')
diff --git a/backend-services/tests/test_group_crud_permissions.py b/backend-services/tests/test_group_crud_permissions.py
new file mode 100644
index 0000000..f6cfd64
--- /dev/null
+++ b/backend-services/tests/test_group_crud_permissions.py
@@ -0,0 +1,75 @@
+import time
+
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_group_crud_permissions(authed_client):
+ # Limited user
+ uname = f'grp_limited_{int(time.time())}'
+ pwd = 'GrpLimitStrongPass1!!'
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': f'{uname}@example.com',
+ 'password': pwd,
+ 'role': 'user',
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu.status_code in (200, 201)
+
+ from httpx import AsyncClient
+
+ from doorman import doorman
+
+ limited = AsyncClient(app=doorman, base_url='http://testserver')
+ r = await limited.post(
+ '/platform/authorization', json={'email': f'{uname}@example.com', 'password': pwd}
+ )
+ assert r.status_code == 200
+ gname = f'g_{int(time.time())}'
+ c403 = await limited.post(
+ '/platform/group', json={'group_name': gname, 'group_description': 'x'}
+ )
+ assert c403.status_code == 403
+
+ # Role with manage_groups
+ rname = f'grpmgr_{int(time.time())}'
+ cr = await authed_client.post(
+ '/platform/role', json={'role_name': rname, 'manage_groups': True}
+ )
+ assert cr.status_code in (200, 201)
+ uname2 = f'grp_mgr_user_{int(time.time())}'
+ cu2 = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname2,
+ 'email': f'{uname2}@example.com',
+ 'password': pwd,
+ 'role': rname,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu2.status_code in (200, 201)
+ mgr = AsyncClient(app=doorman, base_url='http://testserver')
+ r2 = await mgr.post(
+ '/platform/authorization', json={'email': f'{uname2}@example.com', 'password': pwd}
+ )
+ assert r2.status_code == 200
+
+ # Create
+ c = await mgr.post('/platform/group', json={'group_name': gname, 'group_description': 'x'})
+ assert c.status_code in (200, 201)
+ # Get
+ g = await mgr.get(f'/platform/group/{gname}')
+ assert g.status_code == 200
+ # Update
+ u = await mgr.put(f'/platform/group/{gname}', json={'group_description': 'y'})
+ assert u.status_code == 200
+ # Delete
+ d = await mgr.delete(f'/platform/group/{gname}')
+ assert d.status_code == 200
diff --git a/backend-services/tests/test_group_role_not_found.py b/backend-services/tests/test_group_role_not_found.py
index 33523b5..dd77f0b 100644
--- a/backend-services/tests/test_group_role_not_found.py
+++ b/backend-services/tests/test_group_role_not_found.py
@@ -1,5 +1,6 @@
import pytest
+
@pytest.mark.asyncio
async def test_group_and_role_not_found(authed_client):
gg = await authed_client.get('/platform/group/not-a-group')
@@ -13,4 +14,3 @@ async def test_group_and_role_not_found(authed_client):
dr = await authed_client.delete('/platform/role/not-a-role')
assert dr.status_code in (400, 404)
-
diff --git a/backend-services/tests/test_grpc_allowlist.py b/backend-services/tests/test_grpc_allowlist.py
index a748c2c..b2211f3 100644
--- a/backend-services/tests/test_grpc_allowlist.py
+++ b/backend-services/tests/test_grpc_allowlist.py
@@ -1,6 +1,9 @@
import pytest
-async def _setup_api_with_allowlist(client, name, ver, allowed_pkgs=None, allowed_svcs=None, allowed_methods=None):
+
+async def _setup_api_with_allowlist(
+ client, name, ver, allowed_pkgs=None, allowed_svcs=None, allowed_methods=None
+):
payload = {
'api_name': name,
'api_version': ver,
@@ -19,17 +22,22 @@ async def _setup_api_with_allowlist(client, name, ver, allowed_pkgs=None, allowe
payload['api_grpc_allowed_methods'] = allowed_methods
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201), r.text
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc',
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
assert r2.status_code in (200, 201), r2.text
from conftest import subscribe_self
+
await subscribe_self(client, name, ver)
+
@pytest.mark.asyncio
async def test_grpc_service_not_in_allowlist_returns_403(authed_client):
name, ver = 'gallow1', 'v1'
@@ -43,6 +51,7 @@ async def test_grpc_service_not_in_allowlist_returns_403(authed_client):
body = r.json()
assert body.get('error_code') == 'GTW013'
+
@pytest.mark.asyncio
async def test_grpc_method_not_in_allowlist_returns_403(authed_client):
name, ver = 'gallow2', 'v1'
@@ -56,6 +65,7 @@ async def test_grpc_method_not_in_allowlist_returns_403(authed_client):
body = r.json()
assert body.get('error_code') == 'GTW013'
+
@pytest.mark.asyncio
async def test_grpc_package_not_in_allowlist_returns_403(authed_client):
name, ver = 'gallow3', 'v1'
@@ -69,6 +79,7 @@ async def test_grpc_package_not_in_allowlist_returns_403(authed_client):
body = r.json()
assert body.get('error_code') == 'GTW013'
+
@pytest.mark.asyncio
async def test_grpc_invalid_traversal_rejected_400(authed_client):
name, ver = 'gallow4', 'v1'
@@ -81,4 +92,3 @@ async def test_grpc_invalid_traversal_rejected_400(authed_client):
assert r.status_code == 400
body = r.json()
assert body.get('error_code') == 'GTW011'
-
diff --git a/backend-services/tests/test_grpc_client_and_bidi_streaming.py b/backend-services/tests/test_grpc_client_and_bidi_streaming.py
index 443e90a..c34ea43 100644
--- a/backend-services/tests/test_grpc_client_and_bidi_streaming.py
+++ b/backend-services/tests/test_grpc_client_and_bidi_streaming.py
@@ -1,5 +1,6 @@
import pytest
+
async def _setup_api(client, name, ver):
payload = {
'api_name': name,
@@ -13,41 +14,58 @@ async def _setup_api(client, name, ver):
}
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc',
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
assert r2.status_code in (200, 201)
from conftest import subscribe_self
+
await subscribe_self(client, name, ver)
+
def _fake_import(grpc_module_name: str):
def _imp(n):
if n.endswith('_pb2'):
mod = type('PB2', (), {})
- setattr(mod, 'MRequest', type('Req', (), {}))
+ mod.MRequest = type('Req', (), {})
+
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
- def __init__(self, ok=True): self.ok = ok
+
+ def __init__(self, ok=True):
+ self.ok = ok
+
@staticmethod
- def FromString(b): return Reply(True)
- setattr(mod, 'MReply', Reply)
+ def FromString(b):
+ return Reply(True)
+
+ mod.MReply = Reply
return mod
if n.endswith('_pb2_grpc'):
+
class Stub:
- def __init__(self, ch): pass
+ def __init__(self, ch):
+ pass
+
return type('SVC', (), {'SvcStub': Stub})
raise ImportError(n)
+
return _imp
+
@pytest.mark.asyncio
async def test_grpc_client_streaming(monkeypatch, authed_client):
name, ver = 'gclstr', 'v1'
await _setup_api(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.importlib, 'import_module', _fake_import('gs'))
class Chan:
@@ -56,25 +74,41 @@ async def test_grpc_client_streaming(monkeypatch, authed_client):
count = 0
async for _ in req_iter:
count += 1
+
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
ok = True
+
return Reply()
+
return _call
+
class _Aio:
@staticmethod
- def insecure_channel(url): return Chan()
- monkeypatch.setattr(gs, 'grpc', type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}))
+ def insecure_channel(url):
+ return Chan()
+
+ monkeypatch.setattr(
+ gs,
+ 'grpc',
+ type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}),
+ )
body = {'method': 'Svc.M', 'message': {}, 'stream': 'client', 'messages': [{}, {}, {}]}
- r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json=body)
+ r = await authed_client.post(
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json=body,
+ )
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_grpc_client_streaming_field_mapping(monkeypatch, authed_client):
name, ver = 'gclmap', 'v1'
await _setup_api(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.importlib, 'import_module', _fake_import('gs'))
class Chan:
@@ -86,28 +120,46 @@ async def test_grpc_client_streaming_field_mapping(monkeypatch, authed_client):
total += int(getattr(req, 'val', 0))
except Exception:
pass
+
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'sum'})()]})()
- def __init__(self, sum): self.sum = sum
+
+ def __init__(self, sum):
+ self.sum = sum
+
return Reply(total)
+
return _call
+
class _Aio:
@staticmethod
- def insecure_channel(url): return Chan()
- monkeypatch.setattr(gs, 'grpc', type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}))
+ def insecure_channel(url):
+ return Chan()
+
+ monkeypatch.setattr(
+ gs,
+ 'grpc',
+ type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}),
+ )
msgs = [{'val': 1}, {'val': 2}, {'val': 3}]
body = {'method': 'Svc.M', 'message': {}, 'stream': 'client', 'messages': msgs}
- r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json=body)
+ r = await authed_client.post(
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json=body,
+ )
assert r.status_code == 200
data = r.json().get('response') or r.json()
assert int(data.get('sum', 0)) == 6
+
@pytest.mark.asyncio
async def test_grpc_bidi_streaming(monkeypatch, authed_client):
name, ver = 'gbidi', 'v1'
await _setup_api(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.importlib, 'import_module', _fake_import('gs'))
class Chan:
@@ -116,25 +168,46 @@ async def test_grpc_bidi_streaming(monkeypatch, authed_client):
class Msg:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
ok = True
+
async for _ in req_iter:
yield Msg()
+
return _call
+
class _Aio:
@staticmethod
- def insecure_channel(url): return Chan()
- monkeypatch.setattr(gs, 'grpc', type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}))
+ def insecure_channel(url):
+ return Chan()
- body = {'method': 'Svc.M', 'message': {}, 'stream': 'bidi', 'messages': [{}, {}, {}], 'max_items': 2}
- r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json=body)
+ monkeypatch.setattr(
+ gs,
+ 'grpc',
+ type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}),
+ )
+
+ body = {
+ 'method': 'Svc.M',
+ 'message': {},
+ 'stream': 'bidi',
+ 'messages': [{}, {}, {}],
+ 'max_items': 2,
+ }
+ r = await authed_client.post(
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json=body,
+ )
assert r.status_code == 200
data = r.json().get('response') or r.json()
assert isinstance(data.get('items'), list) and len(data['items']) == 2
+
@pytest.mark.asyncio
async def test_grpc_bidi_streaming_field_echo(monkeypatch, authed_client):
name, ver = 'gbidimap', 'v1'
await _setup_api(authed_client, name, ver)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.importlib, 'import_module', _fake_import('gs'))
class Chan:
@@ -143,18 +216,32 @@ async def test_grpc_bidi_streaming_field_echo(monkeypatch, authed_client):
class Msg:
def __init__(self, v):
self.val = v
+
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'val'})()]})()
+
async for req in req_iter:
yield Msg(getattr(req, 'val', None))
+
return _call
+
class _Aio:
@staticmethod
- def insecure_channel(url): return Chan()
- monkeypatch.setattr(gs, 'grpc', type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}))
+ def insecure_channel(url):
+ return Chan()
+
+ monkeypatch.setattr(
+ gs,
+ 'grpc',
+ type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}),
+ )
msgs = [{'val': 7}, {'val': 8}]
body = {'method': 'Svc.M', 'message': {}, 'stream': 'bidi', 'messages': msgs, 'max_items': 10}
- r = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json=body)
+ r = await authed_client.post(
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json=body,
+ )
assert r.status_code == 200
data = r.json().get('response') or r.json()
vals = [it.get('val') for it in (data.get('items') or [])]
diff --git a/backend-services/tests/test_grpc_errors_and_retries.py b/backend-services/tests/test_grpc_errors_and_retries.py
index 77c9a02..36a661e 100644
--- a/backend-services/tests/test_grpc_errors_and_retries.py
+++ b/backend-services/tests/test_grpc_errors_and_retries.py
@@ -1,5 +1,6 @@
import pytest
+
async def _setup_api(client, name, ver, retry=0):
payload = {
'api_name': name,
@@ -13,31 +14,41 @@ async def _setup_api(client, name, ver, retry=0):
}
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc',
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
assert r2.status_code in (200, 201)
from conftest import subscribe_self
+
await subscribe_self(client, name, ver)
+
def _fake_pb2_module(method_name='M'):
class Req:
pass
+
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
+
def __init__(self, ok=True):
self.ok = ok
+
@staticmethod
def FromString(b):
return Reply(True)
- setattr(Req, '__name__', f'{method_name}Request')
- setattr(Reply, '__name__', f'{method_name}Reply')
+
+ Req.__name__ = f'{method_name}Request'
+ Reply.__name__ = f'{method_name}Reply'
return Req, Reply
+
def _make_import_module_recorder(record, pb2_map):
def _imp(name):
record.append(name)
@@ -46,27 +57,42 @@ def _make_import_module_recorder(record, pb2_map):
mapping = pb2_map.get(name)
if mapping is None:
req_cls, rep_cls = _fake_pb2_module('M')
- setattr(mod, 'MRequest', req_cls)
- setattr(mod, 'MReply', rep_cls)
+ mod.MRequest = req_cls
+ mod.MReply = rep_cls
else:
req_cls, rep_cls = mapping
if req_cls:
- setattr(mod, 'MRequest', req_cls)
+ mod.MRequest = req_cls
if rep_cls:
- setattr(mod, 'MReply', rep_cls)
+ mod.MReply = rep_cls
return mod
if name.endswith('_pb2_grpc'):
+
class Stub:
def __init__(self, ch):
self._ch = ch
+
async def M(self, req):
- return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})()
+ return type(
+ 'R',
+ (),
+ {
+ 'DESCRIPTOR': type(
+ 'D', (), {'fields': [type('F', (), {'name': 'ok'})()]}
+ )(),
+ 'ok': True,
+ },
+ )()
+
return type('SVC', (), {'SvcStub': Stub})
raise ImportError(name)
+
return _imp
+
def _make_fake_grpc_unary(sequence_codes, grpc_mod):
counter = {'i': 0}
+
class Chan:
def unary_unary(self, method, request_serializer=None, response_deserializer=None):
async def _call(req, metadata=None):
@@ -74,29 +100,50 @@ def _make_fake_grpc_unary(sequence_codes, grpc_mod):
code = sequence_codes[idx]
counter['i'] += 1
if code is None:
- return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})()
+ return type(
+ 'R',
+ (),
+ {
+ 'DESCRIPTOR': type(
+ 'D', (), {'fields': [type('F', (), {'name': 'ok'})()]}
+ )(),
+ 'ok': True,
+ },
+ )()
+
class E(grpc_mod.RpcError):
def code(self):
return code
+
def details(self):
return f'{code.name}'
+
raise E()
+
return _call
+
class aio:
@staticmethod
def insecure_channel(url):
return Chan()
+
return type('G', (), {'aio': aio, 'StatusCode': grpc_mod.StatusCode, 'RpcError': Exception})
+
@pytest.mark.asyncio
async def test_grpc_status_mappings_basic(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gmap', 'v1'
await _setup_api(authed_client, name, ver, retry=0)
rec = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}))
+ monkeypatch.setattr(
+ gs.importlib,
+ 'import_module',
+ _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}),
+ )
cases = [
(gs.grpc.StatusCode.UNAUTHENTICATED, 401),
@@ -110,89 +157,142 @@ async def test_grpc_status_mappings_basic(monkeypatch, authed_client):
fake = _make_fake_grpc_unary([code], gs.grpc)
monkeypatch.setattr(gs, 'grpc', fake)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == expect
+
@pytest.mark.asyncio
async def test_grpc_unavailable_with_retry_still_fails_maps_503(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gunav', 'v1'
await _setup_api(authed_client, name, ver, retry=2)
rec = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}))
- fake = _make_fake_grpc_unary([gs.grpc.StatusCode.UNAVAILABLE, gs.grpc.StatusCode.UNAVAILABLE, gs.grpc.StatusCode.UNAVAILABLE], gs.grpc)
+ monkeypatch.setattr(
+ gs.importlib,
+ 'import_module',
+ _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}),
+ )
+ fake = _make_fake_grpc_unary(
+ [
+ gs.grpc.StatusCode.UNAVAILABLE,
+ gs.grpc.StatusCode.UNAVAILABLE,
+ gs.grpc.StatusCode.UNAVAILABLE,
+ ],
+ gs.grpc,
+ )
monkeypatch.setattr(gs, 'grpc', fake)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == 503
+
@pytest.mark.asyncio
async def test_grpc_alt_method_fallback_succeeds(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'galt', 'v1'
await _setup_api(authed_client, name, ver, retry=0)
rec = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}))
+ monkeypatch.setattr(
+ gs.importlib,
+ 'import_module',
+ _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}),
+ )
fake = _make_fake_grpc_unary([gs.grpc.StatusCode.ABORTED, None], gs.grpc)
monkeypatch.setattr(gs, 'grpc', fake)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_grpc_non_retryable_error_returns_500_no_retry(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gnr', 'v1'
await _setup_api(authed_client, name, ver, retry=2)
rec = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}))
- fake = _make_fake_grpc_unary([gs.grpc.StatusCode.INVALID_ARGUMENT, gs.grpc.StatusCode.INVALID_ARGUMENT], gs.grpc)
+ monkeypatch.setattr(
+ gs.importlib,
+ 'import_module',
+ _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}),
+ )
+ fake = _make_fake_grpc_unary(
+ [gs.grpc.StatusCode.INVALID_ARGUMENT, gs.grpc.StatusCode.INVALID_ARGUMENT], gs.grpc
+ )
monkeypatch.setattr(gs, 'grpc', fake)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == 400
assert r.json().get('error_code') == 'GTW006'
+
@pytest.mark.asyncio
async def test_grpc_deadline_exceeded_maps_to_504(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gdl', 'v1'
await _setup_api(authed_client, name, ver, retry=1)
rec = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}))
+ monkeypatch.setattr(
+ gs.importlib,
+ 'import_module',
+ _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}),
+ )
fake = _make_fake_grpc_unary([gs.grpc.StatusCode.DEADLINE_EXCEEDED], gs.grpc)
monkeypatch.setattr(gs, 'grpc', fake)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == 504
assert r.json().get('error_code') == 'GTW006'
+
@pytest.mark.asyncio
async def test_grpc_unavailable_then_unimplemented_then_success(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gretry', 'v1'
await _setup_api(authed_client, name, ver, retry=3)
rec = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}))
- fake = _make_fake_grpc_unary([gs.grpc.StatusCode.UNAVAILABLE, gs.grpc.StatusCode.UNIMPLEMENTED, None], gs.grpc)
+ monkeypatch.setattr(
+ gs.importlib,
+ 'import_module',
+ _make_import_module_recorder(rec, {default_pkg: (req_cls, rep_cls)}),
+ )
+ fake = _make_fake_grpc_unary(
+ [gs.grpc.StatusCode.UNAVAILABLE, gs.grpc.StatusCode.UNIMPLEMENTED, None], gs.grpc
+ )
monkeypatch.setattr(gs, 'grpc', fake)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == 200
diff --git a/backend-services/tests/test_grpc_package_resolution_and_errors.py b/backend-services/tests/test_grpc_package_resolution_and_errors.py
index c53db1b..5192032 100644
--- a/backend-services/tests/test_grpc_package_resolution_and_errors.py
+++ b/backend-services/tests/test_grpc_package_resolution_and_errors.py
@@ -1,5 +1,6 @@
import pytest
+
async def _setup_api(client, name, ver, retry=0, api_pkg=None):
payload = {
'api_name': name,
@@ -15,31 +16,41 @@ async def _setup_api(client, name, ver, retry=0, api_pkg=None):
payload['api_grpc_package'] = api_pkg
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc',
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
assert r2.status_code in (200, 201)
from conftest import subscribe_self
+
await subscribe_self(client, name, ver)
+
def _fake_pb2_module(method_name='M'):
class Req:
pass
+
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
+
def __init__(self, ok=True):
self.ok = ok
+
@staticmethod
def FromString(b):
return Reply(True)
- setattr(Req, '__name__', f'{method_name}Request')
- setattr(Reply, '__name__', f'{method_name}Reply')
+
+ Req.__name__ = f'{method_name}Request'
+ Reply.__name__ = f'{method_name}Reply'
return Req, Reply
+
def _make_import_module_recorder(record, pb2_map):
def _imp(name):
record.append(name)
@@ -48,32 +59,47 @@ def _make_import_module_recorder(record, pb2_map):
mapping = pb2_map.get(name)
if mapping is None:
req_cls, rep_cls = _fake_pb2_module('M')
- setattr(mod, 'MRequest', req_cls)
- setattr(mod, 'MReply', rep_cls)
+ mod.MRequest = req_cls
+ mod.MReply = rep_cls
else:
req_cls, rep_cls = mapping
if req_cls:
- setattr(mod, 'MRequest', req_cls)
+ mod.MRequest = req_cls
if rep_cls:
- setattr(mod, 'MReply', rep_cls)
+ mod.MReply = rep_cls
return mod
if name.endswith('_pb2_grpc'):
# service module with Stub class
class Stub:
def __init__(self, ch):
self._ch = ch
+
async def M(self, req):
- return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})()
+ return type(
+ 'R',
+ (),
+ {
+ 'DESCRIPTOR': type(
+ 'D', (), {'fields': [type('F', (), {'name': 'ok'})()]}
+ )(),
+ 'ok': True,
+ },
+ )()
+
mod = type('SVC', (), {'SvcStub': Stub})
return mod
raise ImportError(name)
+
return _imp
+
def _make_fake_grpc_unary(sequence_codes, grpc_mod):
counter = {'i': 0}
+
class AioChan:
async def channel_ready(self):
return True
+
class Chan(AioChan):
def unary_unary(self, method, request_serializer=None, response_deserializer=None):
async def _call(req):
@@ -81,166 +107,231 @@ def _make_fake_grpc_unary(sequence_codes, grpc_mod):
code = sequence_codes[idx]
counter['i'] += 1
if code is None:
- return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})()
+ return type(
+ 'R',
+ (),
+ {
+ 'DESCRIPTOR': type(
+ 'D', (), {'fields': [type('F', (), {'name': 'ok'})()]}
+ )(),
+ 'ok': True,
+ },
+ )()
+
# Raise RpcError-like
class E(Exception):
def code(self):
return code
+
def details(self):
return 'err'
+
raise E()
+
return _call
+
class aio:
@staticmethod
def insecure_channel(url):
return Chan()
+
fake = type('G', (), {'aio': aio, 'StatusCode': grpc_mod.StatusCode, 'RpcError': Exception})
return fake
+
@pytest.mark.asyncio
async def test_grpc_uses_api_grpc_package_over_request(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gpack1', 'v1'
await _setup_api(authed_client, name, ver, api_pkg='api.pkg')
record = []
req_cls, rep_cls = _fake_pb2_module('M')
- pb2_map = { 'api.pkg_pb2': (req_cls, rep_cls) }
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
+ pb2_map = {'api.pkg_pb2': (req_cls, rep_cls)}
+ monkeypatch.setattr(
+ gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)
+ )
monkeypatch.setattr(gs.os.path, 'exists', lambda p: True)
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'},
)
assert r.status_code == 200
assert any(n == 'api.pkg_pb2' for n in record)
+
@pytest.mark.asyncio
async def test_grpc_uses_request_package_when_no_api_package(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gpack2', 'v1'
await _setup_api(authed_client, name, ver, api_pkg=None)
record = []
req_cls, rep_cls = _fake_pb2_module('M')
- pb2_map = { 'req.pkg_pb2': (req_cls, rep_cls) }
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
+ pb2_map = {'req.pkg_pb2': (req_cls, rep_cls)}
+ monkeypatch.setattr(
+ gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)
+ )
monkeypatch.setattr(gs.os.path, 'exists', lambda p: True)
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'},
)
assert r.status_code == 200
assert any(n == 'req.pkg_pb2' for n in record)
+
@pytest.mark.asyncio
async def test_grpc_uses_default_package_when_no_overrides(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gpack3', 'v1'
await _setup_api(authed_client, name, ver, api_pkg=None)
record = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
- pb2_map = { default_pkg: (req_cls, rep_cls) }
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
+ pb2_map = {default_pkg: (req_cls, rep_cls)}
+ monkeypatch.setattr(
+ gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)
+ )
monkeypatch.setattr(gs.os.path, 'exists', lambda p: True)
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == 200
assert any(n.endswith(default_pkg) for n in record)
+
@pytest.mark.asyncio
async def test_grpc_unavailable_then_success_with_retry(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gunavail', 'v1'
await _setup_api(authed_client, name, ver, retry=1)
record = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
- pb2_map = { default_pkg: (req_cls, rep_cls) }
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
+ pb2_map = {default_pkg: (req_cls, rep_cls)}
+ monkeypatch.setattr(
+ gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)
+ )
fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNAVAILABLE, None], gs.grpc)
monkeypatch.setattr(gs, 'grpc', fake_grpc)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_grpc_unimplemented_then_success_with_retry(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gunimpl', 'v1'
await _setup_api(authed_client, name, ver, retry=1)
record = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
- pb2_map = { default_pkg: (req_cls, rep_cls) }
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
+ pb2_map = {default_pkg: (req_cls, rep_cls)}
+ monkeypatch.setattr(
+ gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)
+ )
fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNIMPLEMENTED, None], gs.grpc)
monkeypatch.setattr(gs, 'grpc', fake_grpc)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_grpc_not_found_maps_to_500_error(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gnotfound', 'v1'
await _setup_api(authed_client, name, ver)
record = []
- pb2_map = { f'{name}_{ver}_pb2': (None, None) }
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
+ pb2_map = {f'{name}_{ver}_pb2': (None, None)}
+ monkeypatch.setattr(
+ gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)
+ )
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == 500
body = r.json()
assert body.get('error_code') == 'GTW006'
+
@pytest.mark.asyncio
async def test_grpc_unknown_maps_to_500_error(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gunk', 'v1'
await _setup_api(authed_client, name, ver)
record = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
- pb2_map = { default_pkg: (req_cls, rep_cls) }
- monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
+ pb2_map = {default_pkg: (req_cls, rep_cls)}
+ monkeypatch.setattr(
+ gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map)
+ )
fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNKNOWN], gs.grpc)
monkeypatch.setattr(gs, 'grpc', fake_grpc)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == 500
+
@pytest.mark.asyncio
async def test_grpc_rejects_traversal_in_package(authed_client):
name, ver = 'gtrv', 'v1'
await _setup_api(authed_client, name, ver)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
- json={'method': 'Svc.M', 'message': {}, 'package': '../evil'}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}, 'package': '../evil'},
)
assert r.status_code == 400
body = r.json()
assert body.get('error_code') == 'GTW011'
+
@pytest.mark.asyncio
async def test_grpc_proto_missing_returns_404_gtw012(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gproto404', 'v1'
await _setup_api(authed_client, name, ver)
+
# Make on-demand proto generation fail by raising on import grpc_tools
def _imp_fail(name):
if name.startswith('grpc_tools'):
raise ImportError('no tools')
raise ImportError(name)
+
monkeypatch.setattr(gs.importlib, 'import_module', _imp_fail)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == 404
body = r.json()
diff --git a/backend-services/tests/test_grpc_streaming_and_metadata.py b/backend-services/tests/test_grpc_streaming_and_metadata.py
index 43fe266..8d60333 100644
--- a/backend-services/tests/test_grpc_streaming_and_metadata.py
+++ b/backend-services/tests/test_grpc_streaming_and_metadata.py
@@ -1,5 +1,6 @@
import pytest
+
@pytest.mark.asyncio
async def test_grpc_server_streaming(monkeypatch, authed_client):
name, ver = 'gstr', 'v1'
@@ -15,35 +16,50 @@ async def test_grpc_server_streaming(monkeypatch, authed_client):
}
r = await authed_client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc',
- })
+ r2 = await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
assert r2.status_code in (200, 201)
from conftest import subscribe_self
+
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
# Fake pb2 modules
def _imp(n):
if n.endswith('_pb2'):
mod = type('PB2', (), {})
- setattr(mod, 'MRequest', type('Req', (), {}))
+ mod.MRequest = type('Req', (), {})
+
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
- def __init__(self, ok=True): self.ok = ok
+
+ def __init__(self, ok=True):
+ self.ok = ok
+
@staticmethod
- def FromString(b): return Reply(True)
- setattr(mod, 'MReply', Reply)
+ def FromString(b):
+ return Reply(True)
+
+ mod.MReply = Reply
return mod
if n.endswith('_pb2_grpc'):
+
class Stub:
- def __init__(self, ch): pass
+ def __init__(self, ch):
+ pass
+
return type('SVC', (), {'SvcStub': Stub})
raise ImportError(n)
+
monkeypatch.setattr(gs.importlib, 'import_module', _imp)
class Chan:
@@ -52,20 +68,34 @@ async def test_grpc_server_streaming(monkeypatch, authed_client):
class Msg:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
ok = True
+
for _ in range(2):
yield Msg()
+
return _aiter
+
class _Aio:
@staticmethod
- def insecure_channel(url): return Chan()
- monkeypatch.setattr(gs, 'grpc', type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}))
+ def insecure_channel(url):
+ return Chan()
+
+ monkeypatch.setattr(
+ gs,
+ 'grpc',
+ type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}),
+ )
body = {'method': 'Svc.M', 'message': {}, 'stream': 'server', 'max_items': 2}
- resp = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json=body)
+ resp = await authed_client.post(
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json=body,
+ )
assert resp.status_code == 200
data = resp.json().get('response') or resp.json()
assert isinstance(data.get('items'), list) and len(data['items']) == 2
+
@pytest.mark.asyncio
async def test_grpc_metadata_pass_through(monkeypatch, authed_client):
name, ver = 'gmeta', 'v1'
@@ -82,51 +112,77 @@ async def test_grpc_metadata_pass_through(monkeypatch, authed_client):
}
r = await authed_client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc',
- })
+ r2 = await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
assert r2.status_code in (200, 201)
from conftest import subscribe_self
+
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
def _imp(n):
if n.endswith('_pb2'):
mod = type('PB2', (), {})
- setattr(mod, 'MRequest', type('Req', (), {}))
+ mod.MRequest = type('Req', (), {})
+
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
+
@staticmethod
- def FromString(b): return Reply()
- setattr(mod, 'MReply', Reply)
+ def FromString(b):
+ return Reply()
+
+ mod.MReply = Reply
return mod
if n.endswith('_pb2_grpc'):
+
class Stub:
- def __init__(self, ch): pass
+ def __init__(self, ch):
+ pass
+
return type('SVC', (), {'SvcStub': Stub})
raise ImportError(n)
+
monkeypatch.setattr(gs.importlib, 'import_module', _imp)
captured = {'md': None}
+
class Chan:
def unary_unary(self, method, request_serializer=None, response_deserializer=None):
async def _call(req, metadata=None):
captured['md'] = list(metadata or [])
+
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
ok = True
+
return Reply()
+
return _call
+
class _Aio:
@staticmethod
- def insecure_channel(url): return Chan()
- monkeypatch.setattr(gs, 'grpc', type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}))
+ def insecure_channel(url):
+ return Chan()
+
+ monkeypatch.setattr(
+ gs,
+ 'grpc',
+ type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}),
+ )
headers = {'X-API-Version': ver, 'Content-Type': 'application/json', 'X-Meta-One': 'alpha'}
- r = await authed_client.post(f'/api/grpc/{name}', headers=headers, json={'method': 'Svc.M', 'message': {}})
+ r = await authed_client.post(
+ f'/api/grpc/{name}', headers=headers, json={'method': 'Svc.M', 'message': {}}
+ )
assert r.status_code == 200
assert ('x-meta-one', 'alpha') in [(k.lower(), v) for k, v in (captured['md'] or [])]
diff --git a/backend-services/tests/test_grpc_subscription_and_metrics.py b/backend-services/tests/test_grpc_subscription_and_metrics.py
index c101fbf..ed87c46 100644
--- a/backend-services/tests/test_grpc_subscription_and_metrics.py
+++ b/backend-services/tests/test_grpc_subscription_and_metrics.py
@@ -1,6 +1,8 @@
import json
+
import pytest
+
async def _setup_api(client, name, ver, public=False):
payload = {
'api_name': name,
@@ -15,49 +17,65 @@ async def _setup_api(client, name, ver, public=False):
}
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc',
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
assert r2.status_code in (200, 201)
+
@pytest.mark.asyncio
async def test_grpc_requires_subscription_when_not_public(monkeypatch, authed_client):
name, ver = 'gsub', 'v1'
await _setup_api(authed_client, name, ver, public=False)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code == 403
+
@pytest.mark.asyncio
async def test_grpc_metrics_bytes_in_out(monkeypatch, authed_client):
name, ver = 'gmet', 'v1'
await _setup_api(authed_client, name, ver, public=False)
from conftest import subscribe_self
+
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
def _imp(name):
if name.endswith('_pb2'):
mod = type('PB2', (), {})
- setattr(mod, 'MRequest', type('Req', (), {}) )
+ mod.MRequest = type('Req', (), {})
+
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
+
@staticmethod
def FromString(b):
return Reply()
- setattr(mod, 'MReply', Reply)
+
+ mod.MReply = Reply
return mod
if name.endswith('_pb2_grpc'):
+
class Stub:
- def __init__(self, ch): pass
+ def __init__(self, ch):
+ pass
+
return type('SVC', (), {'SvcStub': Stub})
raise ImportError(name)
+
monkeypatch.setattr(gs.importlib, 'import_module', _imp)
class Chan:
@@ -67,13 +85,19 @@ async def test_grpc_metrics_bytes_in_out(monkeypatch, authed_client):
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
ok = True
+
return Reply()
+
return _call
+
class _Aio:
@staticmethod
def insecure_channel(url):
return Chan()
- fake_grpc = type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception})
+
+ fake_grpc = type(
+ 'G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}
+ )
monkeypatch.setattr(gs, 'grpc', fake_grpc)
m0 = await authed_client.get('/platform/monitor/metrics')
@@ -83,7 +107,11 @@ async def test_grpc_metrics_bytes_in_out(monkeypatch, authed_client):
body_obj = {'method': 'Svc.M', 'message': {}}
raw = json.dumps(body_obj)
- headers = {'Content-Type': 'application/json', 'X-API-Version': ver, 'Content-Length': str(len(raw))}
+ headers = {
+ 'Content-Type': 'application/json',
+ 'X-API-Version': ver,
+ 'Content-Length': str(len(raw)),
+ }
r = await authed_client.post(f'/api/grpc/{name}', headers=headers, content=raw)
assert r.status_code in (200, 500, 501, 503)
@@ -93,4 +121,3 @@ async def test_grpc_metrics_bytes_in_out(monkeypatch, authed_client):
tout1 = int(j1.get('total_bytes_out', 0))
assert tin1 - tin0 >= len(raw)
assert tout1 >= tout0
-
diff --git a/backend-services/tests/test_grpcs_tls_and_deadlines.py b/backend-services/tests/test_grpcs_tls_and_deadlines.py
index ed1cfd5..38616c8 100644
--- a/backend-services/tests/test_grpcs_tls_and_deadlines.py
+++ b/backend-services/tests/test_grpcs_tls_and_deadlines.py
@@ -1,5 +1,6 @@
import pytest
+
async def _setup_api(client, name, ver, url):
payload = {
'api_name': name,
@@ -13,20 +14,26 @@ async def _setup_api(client, name, ver, url):
}
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc',
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
assert r2.status_code in (200, 201)
from conftest import subscribe_self
+
await subscribe_self(client, name, ver)
+
@pytest.mark.asyncio
async def test_grpcs_misconfig_returns_500(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'gtls', 'v1'
await _setup_api(authed_client, name, ver, url='grpcs://example:50051')
@@ -34,19 +41,26 @@ async def test_grpcs_misconfig_returns_500(monkeypatch, authed_client):
def _imp(name):
if name.endswith('_pb2'):
mod = type('PB2', (), {})
- setattr(mod, 'MRequest', type('Req', (), {}) )
+ mod.MRequest = type('Req', (), {})
+
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
+
@staticmethod
def FromString(b):
return Reply()
- setattr(mod, 'MReply', Reply)
+
+ mod.MReply = Reply
return mod
if name.endswith('_pb2_grpc'):
+
class Stub:
- def __init__(self, ch): pass
+ def __init__(self, ch):
+ pass
+
return type('SVC', (), {'SvcStub': Stub})
raise ImportError(name)
+
monkeypatch.setattr(gs.importlib, 'import_module', _imp)
# Simulate insecure_channel rejecting grpcs:// URL
@@ -56,11 +70,15 @@ async def test_grpcs_misconfig_returns_500(monkeypatch, authed_client):
if str(url).startswith('grpcs://'):
raise RuntimeError('TLS required')
return object()
- fake_grpc = type('G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception})
+
+ fake_grpc = type(
+ 'G', (), {'aio': _Aio, 'StatusCode': gs.grpc.StatusCode, 'RpcError': Exception}
+ )
monkeypatch.setattr(gs, 'grpc', fake_grpc)
r = await authed_client.post(
- f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
)
assert r.status_code in (500, 501, 503)
-
diff --git a/backend-services/tests/test_health_status.py b/backend-services/tests/test_health_status.py
index a14605e..2e477eb 100644
--- a/backend-services/tests/test_health_status.py
+++ b/backend-services/tests/test_health_status.py
@@ -1,5 +1,6 @@
import pytest
+
@pytest.mark.asyncio
async def test_public_health_probe_ok(client):
r = await client.get('/api/health')
@@ -7,6 +8,7 @@ async def test_public_health_probe_ok(client):
body = r.json().get('response', r.json())
assert body.get('status') in ('online', 'healthy', 'ready')
+
@pytest.mark.asyncio
async def test_status_requires_auth(client):
try:
diff --git a/backend-services/tests/test_http_circuit_breaker.py b/backend-services/tests/test_http_circuit_breaker.py
index b18d129..735e315 100644
--- a/backend-services/tests/test_http_circuit_breaker.py
+++ b/backend-services/tests/test_http_circuit_breaker.py
@@ -1,15 +1,16 @@
import asyncio
-import os
-from typing import Callable
+from collections.abc import Callable
import httpx
import pytest
-from utils.http_client import request_with_resilience, circuit_manager, CircuitOpenError
+from utils.http_client import CircuitOpenError, circuit_manager, request_with_resilience
+
def _mock_transport(handler: Callable[[httpx.Request], httpx.Response]) -> httpx.MockTransport:
return httpx.MockTransport(lambda req: handler(req))
+
@pytest.mark.asyncio
async def test_retries_on_503_then_success(monkeypatch):
calls = {'n': 0}
@@ -27,17 +28,23 @@ async def test_retries_on_503_then_success(monkeypatch):
monkeypatch.setenv('CIRCUIT_BREAKER_THRESHOLD', '5')
resp = await request_with_resilience(
- client, 'GET', 'http://upstream.test/ok',
- api_key='test-api/v1', retries=2, api_config=None,
+ client,
+ 'GET',
+ 'http://upstream.test/ok',
+ api_key='test-api/v1',
+ retries=2,
+ api_config=None,
)
assert resp.status_code == 200
assert resp.json() == {'ok': True}
assert calls['n'] == 3
+
@pytest.mark.asyncio
async def test_circuit_opens_after_failures_and_half_open(monkeypatch):
calls = {'n': 0}
+
# Always return 503
def handler(req: httpx.Request) -> httpx.Response:
calls['n'] += 1
@@ -53,15 +60,23 @@ async def test_circuit_opens_after_failures_and_half_open(monkeypatch):
api_key = 'breaker-api/v1'
circuit_manager._states.clear()
- resp = await request_with_resilience(client, 'GET', 'http://u.test/err', api_key=api_key, retries=1)
+ resp = await request_with_resilience(
+ client, 'GET', 'http://u.test/err', api_key=api_key, retries=1
+ )
assert resp.status_code == 503
with pytest.raises(CircuitOpenError):
- await request_with_resilience(client, 'GET', 'http://u.test/err', api_key=api_key, retries=0)
+ await request_with_resilience(
+ client, 'GET', 'http://u.test/err', api_key=api_key, retries=0
+ )
await asyncio.sleep(0.11)
- resp2 = await request_with_resilience(client, 'GET', 'http://u.test/err', api_key=api_key, retries=0)
+ resp2 = await request_with_resilience(
+ client, 'GET', 'http://u.test/err', api_key=api_key, retries=0
+ )
assert resp2.status_code == 503
with pytest.raises(CircuitOpenError):
- await request_with_resilience(client, 'GET', 'http://u.test/err', api_key=api_key, retries=0)
+ await request_with_resilience(
+ client, 'GET', 'http://u.test/err', api_key=api_key, retries=0
+ )
diff --git a/backend-services/tests/test_ip_filter_platform.py b/backend-services/tests/test_ip_filter_platform.py
index 76ea34e..be707d4 100644
--- a/backend-services/tests/test_ip_filter_platform.py
+++ b/backend-services/tests/test_ip_filter_platform.py
@@ -1,17 +1,23 @@
import json
+
import pytest
+
async def _ensure_manage_security(authed_client):
await authed_client.put('/platform/role/admin', json={'manage_security': True})
+
async def _update_security(authed_client, settings: dict, headers: dict | None = None):
await _ensure_manage_security(authed_client)
r = await authed_client.put('/platform/security/settings', json=settings, headers=headers or {})
assert r.status_code == 200, r.text
return r
+
@pytest.mark.asyncio
-async def test_global_whitelist_blocks_non_whitelisted_with_trusted_proxy(monkeypatch, authed_client, client):
+async def test_global_whitelist_blocks_non_whitelisted_with_trusted_proxy(
+ monkeypatch, authed_client, client
+):
await _update_security(
authed_client,
settings={
@@ -24,17 +30,25 @@ async def test_global_whitelist_blocks_non_whitelisted_with_trusted_proxy(monkey
)
try:
- r = await client.get('/platform/monitor/liveness', headers={'X-Forwarded-For': '203.0.113.10'})
+ r = await client.get(
+ '/platform/monitor/liveness', headers={'X-Forwarded-For': '203.0.113.10'}
+ )
assert r.status_code == 403
body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
assert (body.get('error_code') or body.get('response', {}).get('error_code')) == 'SEC010'
finally:
await _update_security(
authed_client,
- settings={'ip_whitelist': [], 'ip_blacklist': [], 'trust_x_forwarded_for': False, 'xff_trusted_proxies': []},
+ settings={
+ 'ip_whitelist': [],
+ 'ip_blacklist': [],
+ 'trust_x_forwarded_for': False,
+ 'xff_trusted_proxies': [],
+ },
headers={'X-Forwarded-For': '198.51.100.10'},
)
+
@pytest.mark.asyncio
async def test_global_blacklist_blocks_with_trusted_proxy(monkeypatch, authed_client, client):
await _update_security(
@@ -49,17 +63,25 @@ async def test_global_blacklist_blocks_with_trusted_proxy(monkeypatch, authed_cl
)
try:
- r = await client.get('/platform/monitor/liveness', headers={'X-Forwarded-For': '203.0.113.10'})
+ r = await client.get(
+ '/platform/monitor/liveness', headers={'X-Forwarded-For': '203.0.113.10'}
+ )
assert r.status_code == 403
body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
assert (body.get('error_code') or body.get('response', {}).get('error_code')) == 'SEC011'
finally:
await _update_security(
authed_client,
- settings={'ip_whitelist': [], 'ip_blacklist': [], 'trust_x_forwarded_for': False, 'xff_trusted_proxies': []},
+ settings={
+ 'ip_whitelist': [],
+ 'ip_blacklist': [],
+ 'trust_x_forwarded_for': False,
+ 'xff_trusted_proxies': [],
+ },
headers={'X-Forwarded-For': '198.51.100.10'},
)
+
@pytest.mark.asyncio
async def test_xff_ignored_when_proxy_not_trusted(monkeypatch, authed_client, client):
await _update_security(
@@ -74,7 +96,9 @@ async def test_xff_ignored_when_proxy_not_trusted(monkeypatch, authed_client, cl
)
try:
- r = await client.get('/platform/monitor/liveness', headers={'X-Forwarded-For': '198.51.100.10'})
+ r = await client.get(
+ '/platform/monitor/liveness', headers={'X-Forwarded-For': '198.51.100.10'}
+ )
assert r.status_code == 403
body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
assert (body.get('error_code') or body.get('response', {}).get('error_code')) == 'SEC010'
@@ -83,13 +107,22 @@ async def test_xff_ignored_when_proxy_not_trusted(monkeypatch, authed_client, cl
try:
await _update_security(
authed_client,
- settings={'ip_whitelist': [], 'ip_blacklist': [], 'trust_x_forwarded_for': False, 'xff_trusted_proxies': [], 'allow_localhost_bypass': False},
+ settings={
+ 'ip_whitelist': [],
+ 'ip_blacklist': [],
+ 'trust_x_forwarded_for': False,
+ 'xff_trusted_proxies': [],
+ 'allow_localhost_bypass': False,
+ },
)
finally:
monkeypatch.delenv('LOCAL_HOST_IP_BYPASS', raising=False)
+
@pytest.mark.asyncio
-async def test_localhost_bypass_enabled_allows_without_forwarding_headers(monkeypatch, authed_client, client):
+async def test_localhost_bypass_enabled_allows_without_forwarding_headers(
+ monkeypatch, authed_client, client
+):
await _update_security(
authed_client,
settings={
@@ -106,11 +139,20 @@ async def test_localhost_bypass_enabled_allows_without_forwarding_headers(monkey
finally:
await _update_security(
authed_client,
- settings={'ip_whitelist': [], 'ip_blacklist': [], 'trust_x_forwarded_for': False, 'xff_trusted_proxies': [], 'allow_localhost_bypass': False},
+ settings={
+ 'ip_whitelist': [],
+ 'ip_blacklist': [],
+ 'trust_x_forwarded_for': False,
+ 'xff_trusted_proxies': [],
+ 'allow_localhost_bypass': False,
+ },
)
+
@pytest.mark.asyncio
-async def test_localhost_bypass_disabled_blocks_without_forwarding_headers(monkeypatch, authed_client, client):
+async def test_localhost_bypass_disabled_blocks_without_forwarding_headers(
+ monkeypatch, authed_client, client
+):
monkeypatch.setenv('LOCAL_HOST_IP_BYPASS', 'false')
await _update_security(
@@ -133,20 +175,29 @@ async def test_localhost_bypass_disabled_blocks_without_forwarding_headers(monke
try:
await _update_security(
authed_client,
- settings={'ip_whitelist': [], 'ip_blacklist': [], 'trust_x_forwarded_for': False, 'xff_trusted_proxies': []},
+ settings={
+ 'ip_whitelist': [],
+ 'ip_blacklist': [],
+ 'trust_x_forwarded_for': False,
+ 'xff_trusted_proxies': [],
+ },
)
finally:
monkeypatch.delenv('LOCAL_HOST_IP_BYPASS', raising=False)
+
class _AuditSpy:
def __init__(self):
self.calls = []
+
def info(self, msg):
self.calls.append(msg)
+
@pytest.mark.asyncio
async def test_audit_logged_on_global_deny(monkeypatch, authed_client, client):
import utils.audit_util as au
+
orig = au._logger
spy = _AuditSpy()
au._logger = spy
@@ -162,7 +213,9 @@ async def test_audit_logged_on_global_deny(monkeypatch, authed_client, client):
},
)
- r = await client.get('/platform/monitor/liveness', headers={'X-Forwarded-For': '203.0.113.10'})
+ r = await client.get(
+ '/platform/monitor/liveness', headers={'X-Forwarded-For': '203.0.113.10'}
+ )
assert r.status_code == 403
assert any('ip.global_deny' in str(c) for c in spy.calls)
parsed = [json.loads(c) for c in spy.calls if isinstance(c, str)]
@@ -171,6 +224,11 @@ async def test_audit_logged_on_global_deny(monkeypatch, authed_client, client):
au._logger = orig
await _update_security(
authed_client,
- settings={'ip_whitelist': [], 'ip_blacklist': [], 'trust_x_forwarded_for': False, 'xff_trusted_proxies': []},
+ settings={
+ 'ip_whitelist': [],
+ 'ip_blacklist': [],
+ 'trust_x_forwarded_for': False,
+ 'xff_trusted_proxies': [],
+ },
headers={'X-Forwarded-For': '198.51.100.10'},
)
diff --git a/backend-services/tests/test_ip_policy_allow_deny_cidr.py b/backend-services/tests/test_ip_policy_allow_deny_cidr.py
index 9bf0005..8c625c3 100644
--- a/backend-services/tests/test_ip_policy_allow_deny_cidr.py
+++ b/backend-services/tests/test_ip_policy_allow_deny_cidr.py
@@ -1,7 +1,7 @@
import pytest
-
from tests.test_gateway_routing_limits import _FakeAsyncClient
+
async def _setup_api_public(client, name, ver, mode='allow_all', wl=None, bl=None):
payload = {
'api_name': name,
@@ -21,70 +21,96 @@ async def _setup_api_public(client, name, ver, mode='allow_all', wl=None, bl=Non
payload['api_ip_blacklist'] = bl
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/res',
- 'endpoint_description': 'res'
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/res',
+ 'endpoint_description': 'res',
+ },
+ )
assert r2.status_code in (200, 201)
return name, ver
+
@pytest.mark.asyncio
async def test_ip_policy_allows_exact_ip(monkeypatch, authed_client):
import services.gateway_service as gs
- name, ver = await _setup_api_public(authed_client, 'ipok1', 'v1', mode='whitelist', wl=['127.0.0.1'])
+
+ name, ver = await _setup_api_public(
+ authed_client, 'ipok1', 'v1', mode='whitelist', wl=['127.0.0.1']
+ )
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/res')
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_ip_policy_denies_exact_ip(monkeypatch, authed_client):
import services.gateway_service as gs
+
monkeypatch.setenv('LOCAL_HOST_IP_BYPASS', 'false')
- name, ver = await _setup_api_public(authed_client, 'ipdeny1', 'v1', mode='allow_all', bl=['127.0.0.1'])
+ name, ver = await _setup_api_public(
+ authed_client, 'ipdeny1', 'v1', mode='allow_all', bl=['127.0.0.1']
+ )
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/res')
assert r.status_code == 403
body = r.json()
assert body.get('error_code') == 'API011'
+
@pytest.mark.asyncio
async def test_ip_policy_allows_cidr(monkeypatch, authed_client):
import services.gateway_service as gs
- name, ver = await _setup_api_public(authed_client, 'ipok2', 'v1', mode='whitelist', wl=['127.0.0.0/24'])
+
+ name, ver = await _setup_api_public(
+ authed_client, 'ipok2', 'v1', mode='whitelist', wl=['127.0.0.0/24']
+ )
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/res')
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_ip_policy_denies_cidr(monkeypatch, authed_client):
import services.gateway_service as gs
+
monkeypatch.setenv('LOCAL_HOST_IP_BYPASS', 'false')
- name, ver = await _setup_api_public(authed_client, 'ipdeny2', 'v1', mode='allow_all', bl=['127.0.0.0/24'])
+ name, ver = await _setup_api_public(
+ authed_client, 'ipdeny2', 'v1', mode='allow_all', bl=['127.0.0.0/24']
+ )
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/res')
assert r.status_code == 403
assert r.json().get('error_code') == 'API011'
+
@pytest.mark.asyncio
async def test_ip_policy_denylist_precedence_over_allowlist(monkeypatch, authed_client):
import services.gateway_service as gs
+
monkeypatch.setenv('LOCAL_HOST_IP_BYPASS', 'false')
- name, ver = await _setup_api_public(authed_client, 'ipdeny3', 'v1', mode='whitelist', wl=['127.0.0.1'], bl=['127.0.0.1'])
+ name, ver = await _setup_api_public(
+ authed_client, 'ipdeny3', 'v1', mode='whitelist', wl=['127.0.0.1'], bl=['127.0.0.1']
+ )
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/res')
assert r.status_code == 403
assert r.json().get('error_code') == 'API011'
+
@pytest.mark.asyncio
async def test_ip_policy_enforced_early_returns_http_error(monkeypatch, authed_client):
import services.gateway_service as gs
+
monkeypatch.setenv('LOCAL_HOST_IP_BYPASS', 'false')
- name, ver = await _setup_api_public(authed_client, 'ipdeny4', 'v1', mode='whitelist', wl=['203.0.113.5'])
+ name, ver = await _setup_api_public(
+ authed_client, 'ipdeny4', 'v1', mode='whitelist', wl=['203.0.113.5']
+ )
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/res')
assert r.status_code == 403
assert r.json().get('error_code') == 'API010'
-
diff --git a/backend-services/tests/test_jwt_config.py b/backend-services/tests/test_jwt_config.py
index 6290d62..51e321f 100644
--- a/backend-services/tests/test_jwt_config.py
+++ b/backend-services/tests/test_jwt_config.py
@@ -1,12 +1,11 @@
-import os
-
from utils.auth_util import is_jwt_configured
+
def test_is_jwt_configured_true(monkeypatch):
monkeypatch.setenv('JWT_SECRET_KEY', 'abc123')
assert is_jwt_configured() is True
+
def test_is_jwt_configured_false(monkeypatch):
monkeypatch.delenv('JWT_SECRET_KEY', raising=False)
assert is_jwt_configured() is False
-
diff --git a/backend-services/tests/test_lifespan_failures.py b/backend-services/tests/test_lifespan_failures.py
index f3ec265..2193a57 100644
--- a/backend-services/tests/test_lifespan_failures.py
+++ b/backend-services/tests/test_lifespan_failures.py
@@ -1,30 +1,37 @@
import pytest
+
@pytest.mark.asyncio
async def test_production_guard_causes_startup_failure_direct(monkeypatch):
monkeypatch.setenv('ENV', 'production')
monkeypatch.setenv('HTTPS_ONLY', 'false')
- from doorman import app_lifespan, doorman
import pytest as _pytest
+
+ from doorman import app_lifespan, doorman
+
with _pytest.raises(RuntimeError):
async with app_lifespan(doorman):
pass
+
@pytest.mark.asyncio
async def test_lifespan_failure_raises_with_fresh_app_testclient(monkeypatch):
monkeypatch.setenv('ENV', 'production')
monkeypatch.setenv('HTTPS_ONLY', 'false')
from fastapi import FastAPI
+
from doorman import app_lifespan
+
app = FastAPI(lifespan=app_lifespan)
@app.get('/ping')
async def ping():
return {'ok': True}
- from starlette.testclient import TestClient
import pytest as _pytest
+ from starlette.testclient import TestClient
+
with _pytest.raises(RuntimeError):
with TestClient(app) as client:
client.get('/ping')
diff --git a/backend-services/tests/test_lists_and_paging.py b/backend-services/tests/test_lists_and_paging.py
index bf10205..9d6daa9 100644
--- a/backend-services/tests/test_lists_and_paging.py
+++ b/backend-services/tests/test_lists_and_paging.py
@@ -1,8 +1,8 @@
import pytest
+
@pytest.mark.asyncio
async def test_list_endpoints_roles_groups_apis(authed_client):
-
await authed_client.post(
'/platform/api',
json={
@@ -17,12 +17,10 @@ async def test_list_endpoints_roles_groups_apis(authed_client):
},
)
await authed_client.post(
- '/platform/group',
- json={'group_name': 'glist', 'group_description': 'gd', 'api_access': []},
+ '/platform/group', json={'group_name': 'glist', 'group_description': 'gd', 'api_access': []}
)
await authed_client.post(
- '/platform/role',
- json={'role_name': 'rlist', 'role_description': 'rd'},
+ '/platform/role', json={'role_name': 'rlist', 'role_description': 'rd'}
)
ra = await authed_client.get('/platform/api/all?page=1&page_size=5')
@@ -31,4 +29,3 @@ async def test_list_endpoints_roles_groups_apis(authed_client):
assert rg.status_code == 200
rr = await authed_client.get('/platform/role/all?page=1&page_size=5')
assert rr.status_code == 200
-
diff --git a/backend-services/tests/test_logging_permissions.py b/backend-services/tests/test_logging_permissions.py
index ba6e32c..3bfe7f0 100644
--- a/backend-services/tests/test_logging_permissions.py
+++ b/backend-services/tests/test_logging_permissions.py
@@ -1,39 +1,98 @@
+import time
+
import pytest
+from httpx import AsyncClient
+
+
+async def _login(email: str, password: str) -> AsyncClient:
+ from doorman import doorman
+
+ c = AsyncClient(app=doorman, base_url='http://testserver')
+ r = await c.post('/platform/authorization', json={'email': email, 'password': password})
+ assert r.status_code == 200, r.text
+ body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
+ token = body.get('access_token')
+ if token:
+ c.cookies.set('access_token_cookie', token, domain='testserver', path='/')
+ return c
+
@pytest.mark.asyncio
-async def test_logging_requires_permissions(authed_client):
-
- r = await authed_client.put(
- '/platform/role/admin',
- json={'view_logs': False, 'export_logs': False},
+async def test_logging_routes_permissions(authed_client):
+ # Limited user without view_logs/export_logs
+ uname = f'log_limited_{int(time.time())}'
+ pwd = 'LogLimitStrongPass1!!'
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': f'{uname}@example.com',
+ 'password': pwd,
+ 'role': 'user',
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
)
- assert r.status_code in (200, 201)
+ assert cu.status_code in (200, 201)
+ limited = await _login(f'{uname}@example.com', pwd)
+ for ep in (
+ '/platform/logging/logs',
+ '/platform/logging/logs/files',
+ '/platform/logging/logs/statistics',
+ ):
+ r = await limited.get(ep)
+ assert r.status_code == 403
+ exp = await limited.get('/platform/logging/logs/export')
+ assert exp.status_code == 403
- logs = await authed_client.get('/platform/logging/logs?limit=5')
- assert logs.status_code == 403
-
- files = await authed_client.get('/platform/logging/logs/files')
- assert files.status_code == 403
-
- stats = await authed_client.get('/platform/logging/logs/statistics')
- assert stats.status_code == 403
-
- export = await authed_client.get('/platform/logging/logs/export?format=json')
- assert export.status_code == 403
-
- download = await authed_client.get('/platform/logging/logs/download?format=csv')
- assert download.status_code == 403
-
- r2 = await authed_client.put(
- '/platform/role/admin',
- json={'view_logs': True, 'export_logs': True},
+ # Role with view_logs only
+ rname = f'view_logs_{int(time.time())}'
+ cr = await authed_client.post('/platform/role', json={'role_name': rname, 'view_logs': True})
+ assert cr.status_code in (200, 201)
+ vuser = f'log_viewer_{int(time.time())}'
+ cu2 = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': vuser,
+ 'email': f'{vuser}@example.com',
+ 'password': pwd,
+ 'role': rname,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
)
- assert r2.status_code in (200, 201)
-
- logs2 = await authed_client.get('/platform/logging/logs?limit=1')
- assert logs2.status_code == 200
- files2 = await authed_client.get('/platform/logging/logs/files')
- assert files2.status_code == 200
- export2 = await authed_client.get('/platform/logging/logs/download?format=json')
- assert export2.status_code == 200
+ assert cu2.status_code in (200, 201)
+ viewer = await _login(f'{vuser}@example.com', pwd)
+ for ep in (
+ '/platform/logging/logs',
+ '/platform/logging/logs/files',
+ '/platform/logging/logs/statistics',
+ ):
+ r = await viewer.get(ep)
+ assert r.status_code == 200
+ # But export still forbidden without export_logs
+ exp2 = await viewer.get('/platform/logging/logs/export')
+ assert exp2.status_code == 403
+ # Role with export_logs
+ rname2 = f'export_logs_{int(time.time())}'
+ cr2 = await authed_client.post(
+ '/platform/role', json={'role_name': rname2, 'export_logs': True}
+ )
+ assert cr2.status_code in (200, 201)
+ euser = f'log_exporter_{int(time.time())}'
+ cu3 = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': euser,
+ 'email': f'{euser}@example.com',
+ 'password': pwd,
+ 'role': rname2,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu3.status_code in (200, 201)
+ exporter = await _login(f'{euser}@example.com', pwd)
+ exp3 = await exporter.get('/platform/logging/logs/export')
+ assert exp3.status_code == 200
diff --git a/backend-services/tests/test_logging_redaction.py b/backend-services/tests/test_logging_redaction.py
index a699134..f68741d 100644
--- a/backend-services/tests/test_logging_redaction.py
+++ b/backend-services/tests/test_logging_redaction.py
@@ -1,8 +1,7 @@
import logging
-from types import SimpleNamespace
+
def test_logging_redaction_filters_sensitive_values():
-
logger = logging.getLogger('doorman.gateway')
filt = None
for h in logger.handlers:
@@ -13,8 +12,13 @@ def test_logging_redaction_filters_sensitive_values():
secret = 'supersecretvalue'
record = logging.LogRecord(
- name='doorman.gateway', level=logging.INFO, pathname=__file__, lineno=1,
- msg=f'Authorization: Bearer {secret}; password=\"{secret}\"; access_token=\"{secret}\"', args=(), exc_info=None
+ name='doorman.gateway',
+ level=logging.INFO,
+ pathname=__file__,
+ lineno=1,
+ msg=f'Authorization: Bearer {secret}; password="{secret}"; access_token="{secret}"',
+ args=(),
+ exc_info=None,
)
ok = filt.filter(record)
assert ok is True
diff --git a/backend-services/tests/test_logging_redaction_extended.py b/backend-services/tests/test_logging_redaction_extended.py
index 2939400..c076017 100644
--- a/backend-services/tests/test_logging_redaction_extended.py
+++ b/backend-services/tests/test_logging_redaction_extended.py
@@ -1,5 +1,6 @@
import logging
+
def test_redaction_handles_cookies_csrf_and_mixed_cases():
logger = logging.getLogger('doorman.gateway')
filt = None
@@ -12,15 +13,19 @@ def test_redaction_handles_cookies_csrf_and_mixed_cases():
secret = 'S3cr3t!'
msg = (
f'authorization: Bearer {secret}; Authorization: Bearer {secret}; '
- f'cookie: session={secret}; x-csrf-token: {secret}; PASSWORD=\"{secret}\"'
+ f'cookie: session={secret}; x-csrf-token: {secret}; PASSWORD="{secret}"'
)
rec = logging.LogRecord(
- name='doorman.gateway', level=logging.INFO, pathname=__file__, lineno=1,
- msg=msg, args=(), exc_info=None
+ name='doorman.gateway',
+ level=logging.INFO,
+ pathname=__file__,
+ lineno=1,
+ msg=msg,
+ args=(),
+ exc_info=None,
)
assert filt.filter(rec) is True
out = str(rec.msg)
assert secret not in out
assert out.lower().count('[redacted]') >= 3
-
diff --git a/backend-services/tests/test_logging_redaction_new_patterns.py b/backend-services/tests/test_logging_redaction_new_patterns.py
index 4bcd66c..4d0b204 100644
--- a/backend-services/tests/test_logging_redaction_new_patterns.py
+++ b/backend-services/tests/test_logging_redaction_new_patterns.py
@@ -1,6 +1,7 @@
import logging
from io import StringIO
+
def _capture(logger_name: str, message: str) -> str:
logger = logging.getLogger(logger_name)
stream = StringIO()
@@ -15,16 +16,19 @@ def _capture(logger_name: str, message: str) -> str:
logger.removeHandler(h)
return stream.getvalue()
+
def test_redacts_set_cookie_and_x_api_key():
- msg = 'Set-Cookie: access_token_cookie=abc123; Path=/; HttpOnly; Secure; X-API-Key: my-secret-key'
+ msg = (
+ 'Set-Cookie: access_token_cookie=abc123; Path=/; HttpOnly; Secure; X-API-Key: my-secret-key'
+ )
out = _capture('doorman.gateway', msg)
assert 'Set-Cookie: [REDACTED]' in out or 'set-cookie: [REDACTED]' in out.lower()
assert 'X-API-Key: [REDACTED]' in out or 'x-api-key: [REDACTED]' in out.lower()
+
def test_redacts_bearer_and_basic_tokens():
msg = 'Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJhIn0.sgn; authorization: basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='
out = _capture('doorman.gateway', msg)
low = out.lower()
assert 'authorization: [redacted]' in low
assert 'basic [redacted]' in low or 'authorization: [redacted]' in low
-
diff --git a/backend-services/tests/test_login_ip_rate_limit_flow.py b/backend-services/tests/test_login_ip_rate_limit_flow.py
index 2d5406f..8aeba22 100644
--- a/backend-services/tests/test_login_ip_rate_limit_flow.py
+++ b/backend-services/tests/test_login_ip_rate_limit_flow.py
@@ -1,6 +1,8 @@
import os
+
import pytest
+
@pytest.mark.asyncio
async def test_login_ip_rate_limit_returns_429_and_headers(monkeypatch, client):
monkeypatch.setenv('LOGIN_IP_RATE_LIMIT', '2')
@@ -9,7 +11,7 @@ async def test_login_ip_rate_limit_returns_429_and_headers(monkeypatch, client):
creds = {
'email': os.environ.get('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD', 'Password123!Password')
+ 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD', 'Password123!Password'),
}
r1 = await client.post('/platform/authorization', json=creds)
diff --git a/backend-services/tests/test_memory_dump_and_sigusr1.py b/backend-services/tests/test_memory_dump_and_sigusr1.py
index d48fe60..d0661f0 100644
--- a/backend-services/tests/test_memory_dump_and_sigusr1.py
+++ b/backend-services/tests/test_memory_dump_and_sigusr1.py
@@ -1,30 +1,37 @@
-import pytest
import os
+import pytest
+
+
@pytest.mark.asyncio
async def test_memory_dump_writes_file_when_memory_mode(monkeypatch, tmp_path):
monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'test-secret-123')
monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'd' / 'dump.bin'))
from utils.memory_dump_util import dump_memory_to_file, find_latest_dump_path
+
path = dump_memory_to_file(None)
assert os.path.exists(path)
latest = find_latest_dump_path(str(tmp_path / 'd' / ''))
assert latest == path
+
def test_dump_requires_encryption_key_logs_error(tmp_path, monkeypatch):
monkeypatch.delenv('MEM_ENCRYPTION_KEY', raising=False)
monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'x' / 'memory_dump.bin'))
from utils import memory_dump_util as md
+
with pytest.raises(ValueError):
md.dump_memory_to_file(None)
+
def test_sigusr1_handler_registered_on_unix(monkeypatch, capsys):
- import importlib
import doorman as appmod
+
if hasattr(appmod.signal, 'SIGUSR1'):
assert hasattr(appmod.signal, 'SIGUSR1')
+
def test_sigusr1_ignored_when_not_memory_mode(monkeypatch):
import doorman as appmod
- assert hasattr(appmod.signal, 'SIGUSR1')
+ assert hasattr(appmod.signal, 'SIGUSR1')
diff --git a/backend-services/tests/test_memory_dump_restore_permissions.py b/backend-services/tests/test_memory_dump_restore_permissions.py
new file mode 100644
index 0000000..6408ce0
--- /dev/null
+++ b/backend-services/tests/test_memory_dump_restore_permissions.py
@@ -0,0 +1,77 @@
+import os
+import time
+
+import pytest
+from httpx import AsyncClient
+
+
+async def _login(email: str, password: str) -> AsyncClient:
+ from doorman import doorman
+
+ c = AsyncClient(app=doorman, base_url='http://testserver')
+ r = await c.post('/platform/authorization', json={'email': email, 'password': password})
+ assert r.status_code == 200, r.text
+ body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
+ token = body.get('access_token')
+ if token:
+ c.cookies.set('access_token_cookie', token, domain='testserver', path='/')
+ return c
+
+
+@pytest.mark.asyncio
+async def test_memory_dump_restore_permissions(authed_client, monkeypatch, tmp_path):
+ monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'test-encryption-key-32-characters-min')
+
+ # Limited user cannot dump/restore
+ uname = f'mem_limited_{int(time.time())}'
+ pwd = 'MemoryUserStrong1!!'
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': f'{uname}@example.com',
+ 'password': pwd,
+ 'role': 'user',
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu.status_code in (200, 201)
+ limited = await _login(f'{uname}@example.com', pwd)
+ d = await limited.post('/platform/memory/dump', json={})
+ assert d.status_code == 403
+ r = await limited.post('/platform/memory/restore', json={})
+ assert r.status_code == 403
+
+ # Role with manage_security can dump/restore
+ rname = f'sec_manager_{int(time.time())}'
+ cr = await authed_client.post(
+ '/platform/role', json={'role_name': rname, 'manage_security': True}
+ )
+ assert cr.status_code in (200, 201)
+ uname2 = f'mem_mgr_{int(time.time())}'
+ cu2 = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname2,
+ 'email': f'{uname2}@example.com',
+ 'password': pwd,
+ 'role': rname,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu2.status_code in (200, 201)
+ mgr = await _login(f'{uname2}@example.com', pwd)
+
+ # Dump to temp
+ target = str(tmp_path / 'memdump.bin')
+ dump = await mgr.post('/platform/memory/dump', json={'path': target})
+ assert dump.status_code == 200
+ body = dump.json().get('response', dump.json())
+ path = body.get('response', {}).get('path') or body.get('path')
+ assert path and os.path.exists(path)
+
+ # Restore
+ res = await mgr.post('/platform/memory/restore', json={'path': path})
+ assert res.status_code == 200
diff --git a/backend-services/tests/test_memory_dump_util_extended.py b/backend-services/tests/test_memory_dump_util_extended.py
index ffae850..08f66af 100644
--- a/backend-services/tests/test_memory_dump_util_extended.py
+++ b/backend-services/tests/test_memory_dump_util_extended.py
@@ -1,12 +1,12 @@
import os
import time
-import json
from pathlib import Path
+
import pytest
+
@pytest.mark.asyncio
async def test_dump_file_naming_and_dir_creation(monkeypatch, tmp_path):
-
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'unit-test-key-12345')
@@ -16,16 +16,14 @@ async def test_dump_file_naming_and_dir_creation(monkeypatch, tmp_path):
dump_path = dump_memory_to_file(str(hint_file))
assert Path(dump_path).exists()
- assert Path(dump_path).name.startswith('mydump-') and dump_path.endswith(
- '.bin'
- )
+ assert Path(dump_path).name.startswith('mydump-') and dump_path.endswith('.bin')
latest = find_latest_dump_path(str(tmp_path / 'custom' / 'mydump.bin'))
assert latest == dump_path
+
@pytest.mark.asyncio
async def test_dump_with_directory_hint_uses_default_stem(monkeypatch, tmp_path):
-
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'unit-test-key-99999')
@@ -37,6 +35,7 @@ async def test_dump_with_directory_hint_uses_default_stem(monkeypatch, tmp_path)
assert Path(dump_path).exists()
assert Path(dump_path).name.startswith('memory_dump-')
+
def test_find_latest_prefers_newest_by_stem(tmp_path):
from utils.memory_dump_util import find_latest_dump_path
@@ -54,6 +53,7 @@ def test_find_latest_prefers_newest_by_stem(tmp_path):
latest = find_latest_dump_path(str(d / 'memory_dump.bin'))
assert latest and latest.endswith(b.name)
+
def test_find_latest_ignores_other_stems_when_dir_hint(tmp_path):
from utils.memory_dump_util import find_latest_dump_path
@@ -70,8 +70,8 @@ def test_find_latest_ignores_other_stems_when_dir_hint(tmp_path):
latest = find_latest_dump_path(str(d))
assert latest and Path(latest).name.startswith('memory_dump-')
-def test_find_latest_uses_default_when_no_hint(monkeypatch, tmp_path):
+def test_find_latest_uses_default_when_no_hint(monkeypatch, tmp_path):
import utils.memory_dump_util as md
base = tmp_path / 'default'
@@ -89,6 +89,7 @@ def test_find_latest_uses_default_when_no_hint(monkeypatch, tmp_path):
latest = md.find_latest_dump_path(None)
assert latest and latest.endswith(b.name)
+
def test_encrypt_decrypt_roundtrip(monkeypatch):
import utils.memory_dump_util as md
@@ -99,15 +100,16 @@ def test_encrypt_decrypt_roundtrip(monkeypatch):
out = md._decrypt_blob(blob, key)
assert out == pt
+
def test_encrypt_requires_sufficient_key(monkeypatch):
import utils.memory_dump_util as md
with pytest.raises(ValueError):
md._encrypt_blob(b'data', 'short')
+
@pytest.mark.asyncio
async def test_dump_and_restore_roundtrip_with_bytes(monkeypatch, tmp_path):
-
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'unit-test-key-abcde')
monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'mem' / 'memory_dump.bin'))
@@ -115,12 +117,9 @@ async def test_dump_and_restore_roundtrip_with_bytes(monkeypatch, tmp_path):
import utils.memory_dump_util as md
from utils.database import database
- database.db.settings.insert_one({
- '_id': 'cfg',
- 'blob': b'\x00\x01',
- 'tuple': (1, 2, 3),
- 'aset': {'a', 'b'},
- })
+ database.db.settings.insert_one(
+ {'_id': 'cfg', 'blob': b'\x00\x01', 'tuple': (1, 2, 3), 'aset': {'a', 'b'}}
+ )
dump_path = md.dump_memory_to_file(None)
assert Path(dump_path).exists()
@@ -134,15 +133,16 @@ async def test_dump_and_restore_roundtrip_with_bytes(monkeypatch, tmp_path):
assert set(restored.get('aset')) == {'a', 'b'}
assert list(restored.get('tuple')) == [1, 2, 3]
+
def test_restore_nonexistent_file_raises(tmp_path):
import utils.memory_dump_util as md
with pytest.raises(FileNotFoundError):
md.restore_memory_from_file(str(tmp_path / 'nope.bin'))
+
@pytest.mark.asyncio
async def test_dump_fails_with_short_key(monkeypatch, tmp_path):
-
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'short')
monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'd' / 'dump.bin'))
diff --git a/backend-services/tests/test_memory_routes_and_tools_e2e.py b/backend-services/tests/test_memory_routes_and_tools_e2e.py
index ea4badc..a95e771 100644
--- a/backend-services/tests/test_memory_routes_and_tools_e2e.py
+++ b/backend-services/tests/test_memory_routes_and_tools_e2e.py
@@ -1,13 +1,17 @@
import os
+
import pytest
+
@pytest.mark.asyncio
async def test_memory_dump_requires_key_then_succeeds(monkeypatch, authed_client, tmp_path):
-
monkeypatch.delenv('MEM_ENCRYPTION_KEY', raising=False)
r1 = await authed_client.post('/platform/memory/dump')
assert r1.status_code == 400
- assert (r1.json().get('error_code') or r1.json().get('response', {}).get('error_code')) in ('MEM002', 'MEM002')
+ assert (r1.json().get('error_code') or r1.json().get('response', {}).get('error_code')) in (
+ 'MEM002',
+ 'MEM002',
+ )
monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'route-key-123456')
monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'r' / 'dump.bin'))
@@ -17,17 +21,22 @@ async def test_memory_dump_requires_key_then_succeeds(monkeypatch, authed_client
path = body.get('response', {}).get('path') or body.get('path')
assert path and path.endswith('.bin')
+
@pytest.mark.asyncio
async def test_memory_restore_404_missing(monkeypatch, authed_client, tmp_path):
monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'route-key-987654')
- r = await authed_client.post('/platform/memory/restore', json={'path': str(tmp_path / 'nope.bin')})
+ r = await authed_client.post(
+ '/platform/memory/restore', json={'path': str(tmp_path / 'nope.bin')}
+ )
assert r.status_code == 404
data = r.json()
- assert data.get('error_code') == 'MEM003' or data.get('response', {}).get('error_code') == 'MEM003'
+ assert (
+ data.get('error_code') == 'MEM003' or data.get('response', {}).get('error_code') == 'MEM003'
+ )
+
@pytest.mark.asyncio
async def test_memory_dump_then_restore_flow(monkeypatch, authed_client, tmp_path):
-
monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'route-key-abcdef')
monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'd' / 'dump.bin'))
@@ -43,7 +52,9 @@ async def test_memory_dump_then_restore_flow(monkeypatch, authed_client, tmp_pat
)
assert create.status_code in (200, 201), create.text
- d = await authed_client.post('/platform/memory/dump', json={'path': str(tmp_path / 'd' / 'dump.bin')})
+ d = await authed_client.post(
+ '/platform/memory/dump', json={'path': str(tmp_path / 'd' / 'dump.bin')}
+ )
assert d.status_code == 200
dump_body = d.json()
dump_path = dump_body.get('response', {}).get('path') or dump_body.get('path')
@@ -58,6 +69,7 @@ async def test_memory_dump_then_restore_flow(monkeypatch, authed_client, tmp_pat
check = await authed_client.get('/platform/user/e2euser')
assert check.status_code == 200
+
@pytest.mark.asyncio
async def test_cors_wildcard_without_credentials_allows(monkeypatch, authed_client):
monkeypatch.setenv('ALLOWED_ORIGINS', '*')
@@ -69,6 +81,7 @@ async def test_cors_wildcard_without_credentials_allows(monkeypatch, authed_clie
data = r.json()
assert data.get('actual', {}).get('allowed') is True
+
@pytest.mark.asyncio
async def test_cors_wildcard_with_credentials_strict_blocks(monkeypatch, authed_client):
monkeypatch.setenv('ALLOWED_ORIGINS', '*')
@@ -80,23 +93,35 @@ async def test_cors_wildcard_with_credentials_strict_blocks(monkeypatch, authed_
data = r.json()
assert data.get('actual', {}).get('allowed') is False
+
@pytest.mark.asyncio
async def test_cors_checker_implicitly_allows_options(monkeypatch, authed_client):
monkeypatch.setenv('ALLOW_METHODS', 'GET,POST')
body = {'origin': 'http://localhost:3000', 'method': 'OPTIONS'}
r = await authed_client.post('/platform/tools/cors/check', json=body)
assert r.status_code == 200
- assert r.json().get('preflight', {}).get('response_headers', {}).get('Access-Control-Allow-Methods')
+ assert (
+ r.json()
+ .get('preflight', {})
+ .get('response_headers', {})
+ .get('Access-Control-Allow-Methods')
+ )
+
@pytest.mark.asyncio
async def test_cors_headers_case_insensitive(monkeypatch, authed_client):
monkeypatch.setenv('ALLOW_HEADERS', 'Content-Type,Authorization')
- body = {'origin': 'http://localhost:3000', 'method': 'GET', 'request_headers': ['content-type', 'authorization']}
+ body = {
+ 'origin': 'http://localhost:3000',
+ 'method': 'GET',
+ 'request_headers': ['content-type', 'authorization'],
+ }
r = await authed_client.post('/platform/tools/cors/check', json=body)
assert r.status_code == 200
data = r.json()
assert data.get('preflight', {}).get('allowed') is True
+
@pytest.mark.asyncio
async def test_cors_checker_vary_origin_present(monkeypatch, authed_client):
body = {'origin': 'http://localhost:3000', 'method': 'GET'}
@@ -107,6 +132,7 @@ async def test_cors_checker_vary_origin_present(monkeypatch, authed_client):
actual = data.get('actual', {}).get('response_headers', {})
assert pre.get('Vary') == 'Origin' and actual.get('Vary') == 'Origin'
+
@pytest.mark.asyncio
async def test_cors_checker_disallows_unknown_origin(monkeypatch, authed_client):
monkeypatch.setenv('ALLOWED_ORIGINS', 'http://localhost:3000')
diff --git a/backend-services/tests/test_metrics_persistence.py b/backend-services/tests/test_metrics_persistence.py
index ab82cfb..366dd00 100644
--- a/backend-services/tests/test_metrics_persistence.py
+++ b/backend-services/tests/test_metrics_persistence.py
@@ -1,5 +1,6 @@
import pytest
+
@pytest.mark.asyncio
async def test_metrics_persist_and_restore(tmp_path, authed_client):
r1 = await authed_client.get('/api/status')
@@ -7,6 +8,7 @@ async def test_metrics_persist_and_restore(tmp_path, authed_client):
assert r1.status_code == 200 and r2.status_code == 200
from utils.metrics_util import metrics_store
+
before = metrics_store.to_dict()
assert before.get('total_requests', 0) >= 1
@@ -29,6 +31,7 @@ async def test_metrics_persist_and_restore(tmp_path, authed_client):
if isinstance(b2.get('status_counts'), dict):
b2['status_counts'] = {str(k): v for k, v in b2['status_counts'].items()}
return b2
+
out = dict(d)
if isinstance(out.get('status_counts'), dict):
out['status_counts'] = {str(k): v for k, v in out['status_counts'].items()}
diff --git a/backend-services/tests/test_metrics_ranges_extended.py b/backend-services/tests/test_metrics_ranges_extended.py
index f2594ef..4f3437c 100644
--- a/backend-services/tests/test_metrics_ranges_extended.py
+++ b/backend-services/tests/test_metrics_ranges_extended.py
@@ -1,25 +1,37 @@
import pytest
+
@pytest.mark.asyncio
async def test_metrics_range_parameters(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'mrange', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/p')
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
class _FakeHTTPResponse:
def __init__(self):
self.status_code = 200
self.headers = {'Content-Type': 'application/json'}
self.text = '{}'
self.content = b'{}'
- def json(self): return {'ok': True}
+
+ def json(self):
+ return {'ok': True}
+
class _FakeAsyncClient:
- def __init__(self, timeout=None, limits=None, http2=False): pass
- async def __aenter__(self): return self
- async def __aexit__(self, exc_type, exc, tb): return False
+ def __init__(self, timeout=None, limits=None, http2=False):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -37,10 +49,19 @@ async def test_metrics_range_parameters(monkeypatch, authed_client):
return await self.put(url, **kwargs)
else:
return _FakeHTTPResponse(405)
- async def get(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse()
- async def post(self, url, **kwargs): return _FakeHTTPResponse()
- async def put(self, url, **kwargs): return _FakeHTTPResponse()
- async def delete(self, url, **kwargs): return _FakeHTTPResponse()
+
+ async def get(self, url, params=None, headers=None, **kwargs):
+ return _FakeHTTPResponse()
+
+ async def post(self, url, **kwargs):
+ return _FakeHTTPResponse()
+
+ async def put(self, url, **kwargs):
+ return _FakeHTTPResponse()
+
+ async def delete(self, url, **kwargs):
+ return _FakeHTTPResponse()
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
await authed_client.get(f'/api/rest/{name}/{ver}/p')
diff --git a/backend-services/tests/test_metrics_symmetry_envelope_ids.py b/backend-services/tests/test_metrics_symmetry_envelope_ids.py
index d2ded75..551d98f 100644
--- a/backend-services/tests/test_metrics_symmetry_envelope_ids.py
+++ b/backend-services/tests/test_metrics_symmetry_envelope_ids.py
@@ -1,16 +1,20 @@
import json
import re
+
import pytest
+
@pytest.mark.asyncio
async def test_metrics_bytes_in_uses_content_length(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'msym', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'POST', '/echo')
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
resp_body = b'{"ok":true,"pad":"' + b'Z' * 15 + b'"}'
class _FakeHTTPResponse:
@@ -19,13 +23,20 @@ async def test_metrics_bytes_in_uses_content_length(monkeypatch, authed_client):
self.headers = {'Content-Type': 'application/json', 'Content-Length': str(len(body))}
self.text = body.decode('utf-8')
self.content = body
+
def json(self):
return json.loads(self.text)
class _FakeAsyncClient:
- def __init__(self, timeout=None, limits=None, http2=False): pass
- async def __aenter__(self): return self
- async def __aexit__(self, exc_type, exc, tb): return False
+ def __init__(self, timeout=None, limits=None, http2=False):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -43,10 +54,18 @@ async def test_metrics_bytes_in_uses_content_length(monkeypatch, authed_client):
return await self.put(url, **kwargs)
else:
return _FakeHTTPResponse(405)
- async def get(self, url, **kwargs): return _FakeHTTPResponse(200)
- async def post(self, url, data=None, json=None, headers=None, params=None, **kwargs): return _FakeHTTPResponse(200)
- async def put(self, url, **kwargs): return _FakeHTTPResponse(200)
- async def delete(self, url, **kwargs): return _FakeHTTPResponse(200)
+
+ async def get(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def post(self, url, data=None, json=None, headers=None, params=None, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def put(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def delete(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
@@ -68,12 +87,15 @@ async def test_metrics_bytes_in_uses_content_length(monkeypatch, authed_client):
assert tin1 - tin0 >= len(payload)
assert tout1 - tout0 >= len(resp_body)
+
@pytest.mark.asyncio
async def test_response_envelope_for_non_json_error(monkeypatch, client):
monkeypatch.setenv('MAX_BODY_SIZE_BYTES', '10')
payload = 'x' * 100
- r = await client.post('/platform/authorization', content=payload, headers={'Content-Type': 'text/plain'})
+ r = await client.post(
+ '/platform/authorization', content=payload, headers={'Content-Type': 'text/plain'}
+ )
assert r.status_code == 413
assert r.headers.get('content-type', '').lower().startswith('application/json')
body = r.json()
@@ -82,11 +104,14 @@ async def test_response_envelope_for_non_json_error(monkeypatch, client):
msg = body.get('error_message') or (body.get('response') or {}).get('error_message')
assert isinstance(msg, str) and msg
+
def _get_operation_id(spec: dict, path: str, method: str) -> str:
return spec['paths'][path][method.lower()]['operationId']
+
def test_unique_route_ids_are_stable():
from doorman import doorman as app
+
spec1 = app.openapi()
spec2 = app.openapi()
diff --git a/backend-services/tests/test_monitor_dashboard.py b/backend-services/tests/test_monitor_dashboard.py
index 04320fe..a3ca33a 100644
--- a/backend-services/tests/test_monitor_dashboard.py
+++ b/backend-services/tests/test_monitor_dashboard.py
@@ -1,8 +1,8 @@
import pytest
+
@pytest.mark.asyncio
async def test_monitor_metrics_and_dashboard(authed_client):
-
d = await authed_client.get('/platform/dashboard')
assert d.status_code == 200
dj = d.json()
@@ -13,6 +13,7 @@ async def test_monitor_metrics_and_dashboard(authed_client):
mj = m.json()
assert isinstance(mj, dict)
+
@pytest.mark.asyncio
async def test_liveness_and_readiness(client):
l = await client.get('/platform/monitor/liveness')
diff --git a/backend-services/tests/test_monitor_endpoints.py b/backend-services/tests/test_monitor_endpoints.py
new file mode 100644
index 0000000..ae8c31a
--- /dev/null
+++ b/backend-services/tests/test_monitor_endpoints.py
@@ -0,0 +1,14 @@
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_monitor_liveness_readiness_metrics(authed_client):
+ # Liveness
+ l = await authed_client.get('/platform/monitor/liveness')
+ assert l.status_code in (200, 204)
+ # Readiness
+ r = await authed_client.get('/platform/monitor/readiness')
+ assert r.status_code in (200, 204)
+ # Metrics
+ m = await authed_client.get('/platform/monitor/metrics')
+ assert m.status_code in (200, 204)
diff --git a/backend-services/tests/test_monitor_metrics_extended.py b/backend-services/tests/test_monitor_metrics_extended.py
index 7ca3796..13e64d1 100644
--- a/backend-services/tests/test_monitor_metrics_extended.py
+++ b/backend-services/tests/test_monitor_metrics_extended.py
@@ -1,7 +1,6 @@
-import asyncio
-import time
import pytest
+
@pytest.mark.asyncio
async def test_metrics_increment_on_gateway_requests(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
@@ -19,16 +18,20 @@ async def test_metrics_increment_on_gateway_requests(monkeypatch, authed_client)
self.headers = {'Content-Type': 'application/json'}
self.text = '{}'
self.content = b'{}'
+
def json(self):
return {'ok': True}
class _FakeAsyncClient:
def __init__(self, timeout=None, limits=None, http2=False):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -46,12 +49,16 @@ async def test_metrics_increment_on_gateway_requests(monkeypatch, authed_client)
return await self.put(url, **kwargs)
else:
return _FakeHTTPResponse(405)
+
async def get(self, url, params=None, headers=None, **kwargs):
return _FakeHTTPResponse(200)
+
async def post(self, url, **kwargs):
return _FakeHTTPResponse(200)
+
async def put(self, url, **kwargs):
return _FakeHTTPResponse(200)
+
async def delete(self, url, **kwargs):
return _FakeHTTPResponse(200)
@@ -70,9 +77,11 @@ async def test_metrics_increment_on_gateway_requests(monkeypatch, authed_client)
series = body.get('series') or []
assert isinstance(series, list)
+
@pytest.mark.asyncio
async def test_metrics_top_apis_aggregate(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'mapi3', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/x')
@@ -86,12 +95,20 @@ async def test_metrics_top_apis_aggregate(monkeypatch, authed_client):
self.headers = {'Content-Type': 'application/json'}
self.text = '{}'
self.content = b'{}'
- def json(self): return {'ok': True}
+
+ def json(self):
+ return {'ok': True}
class _FakeAsyncClient:
- def __init__(self, timeout=None, limits=None, http2=False): pass
- async def __aenter__(self): return self
- async def __aexit__(self, exc_type, exc, tb): return False
+ def __init__(self, timeout=None, limits=None, http2=False):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -109,10 +126,18 @@ async def test_metrics_top_apis_aggregate(monkeypatch, authed_client):
return await self.put(url, **kwargs)
else:
return _FakeHTTPResponse(405)
- async def get(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse(200)
- async def post(self, url, **kwargs): return _FakeHTTPResponse(200)
- async def put(self, url, **kwargs): return _FakeHTTPResponse(200)
- async def delete(self, url, **kwargs): return _FakeHTTPResponse(200)
+
+ async def get(self, url, params=None, headers=None, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def post(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def put(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def delete(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
@@ -126,6 +151,7 @@ async def test_metrics_top_apis_aggregate(monkeypatch, authed_client):
assert any(isinstance(a, list) and a[0].startswith('rest:') for a in top_apis)
+
@pytest.mark.asyncio
async def test_monitor_liveness_and_readiness(authed_client):
live = await authed_client.get('/platform/monitor/liveness')
@@ -137,28 +163,38 @@ async def test_monitor_liveness_and_readiness(authed_client):
status = (ready.json() or {}).get('status')
assert status in ('ready', 'degraded')
+
@pytest.mark.asyncio
async def test_monitor_report_csv(monkeypatch, authed_client):
-
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'mapi4', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/r')
await subscribe_self(authed_client, name, ver)
import services.gateway_service as gs
+
class _FakeHTTPResponse:
def __init__(self):
self.status_code = 200
self.headers = {'Content-Type': 'application/json'}
self.text = '{}'
self.content = b'{}'
- def json(self): return {'ok': True}
+
+ def json(self):
+ return {'ok': True}
class _FakeAsyncClient:
- def __init__(self, timeout=None, limits=None, http2=False): pass
- async def __aenter__(self): return self
- async def __aexit__(self, exc_type, exc, tb): return False
+ def __init__(self, timeout=None, limits=None, http2=False):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -176,15 +212,24 @@ async def test_monitor_report_csv(monkeypatch, authed_client):
return await self.put(url, **kwargs)
else:
return _FakeHTTPResponse()
- async def get(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse()
- async def post(self, url, **kwargs): return _FakeHTTPResponse()
- async def put(self, url, **kwargs): return _FakeHTTPResponse()
- async def delete(self, url, **kwargs): return _FakeHTTPResponse()
+
+ async def get(self, url, params=None, headers=None, **kwargs):
+ return _FakeHTTPResponse()
+
+ async def post(self, url, **kwargs):
+ return _FakeHTTPResponse()
+
+ async def put(self, url, **kwargs):
+ return _FakeHTTPResponse()
+
+ async def delete(self, url, **kwargs):
+ return _FakeHTTPResponse()
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
await authed_client.get(f'/api/rest/{name}/{ver}/r')
from datetime import datetime
+
now = datetime.utcnow()
start = now.strftime('%Y-%m-%dT%H:%M')
end = start
@@ -192,5 +237,6 @@ async def test_monitor_report_csv(monkeypatch, authed_client):
assert csvr.status_code == 200
text = csvr.text
- assert 'Report' in text and 'Overview' in text and 'Status Codes' in text and 'API Usage' in text
-
+ assert (
+ 'Report' in text and 'Overview' in text and 'Status Codes' in text and 'API Usage' in text
+ )
diff --git a/backend-services/tests/test_multi_onboarding.py b/backend-services/tests/test_multi_onboarding.py
index 9ffa588..27ec723 100644
--- a/backend-services/tests/test_multi_onboarding.py
+++ b/backend-services/tests/test_multi_onboarding.py
@@ -1,8 +1,8 @@
import pytest
+
@pytest.mark.asyncio
async def test_multi_endpoints_per_api_and_listing(authed_client):
-
c = await authed_client.post(
'/platform/api',
json={
@@ -18,11 +18,7 @@ async def test_multi_endpoints_per_api_and_listing(authed_client):
)
assert c.status_code in (200, 201)
- endpoints = [
- ('GET', '/a'),
- ('POST', '/b'),
- ('PUT', '/c'),
- ]
+ endpoints = [('GET', '/a'), ('POST', '/b'), ('PUT', '/c')]
for method, uri in endpoints:
ep = await authed_client.post(
'/platform/endpoint',
diff --git a/backend-services/tests/test_multi_worker_semantics.py b/backend-services/tests/test_multi_worker_semantics.py
index 15037cb..8d7756f 100644
--- a/backend-services/tests/test_multi_worker_semantics.py
+++ b/backend-services/tests/test_multi_worker_semantics.py
@@ -1,5 +1,6 @@
import pytest
+
@pytest.mark.asyncio
async def test_mem_multi_worker_guard_raises(monkeypatch):
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
@@ -9,6 +10,7 @@ async def test_mem_multi_worker_guard_raises(monkeypatch):
with pytest.raises(RuntimeError):
validate_token_revocation_config()
+
@pytest.mark.asyncio
async def test_mem_single_worker_allowed(monkeypatch):
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
@@ -17,6 +19,7 @@ async def test_mem_single_worker_allowed(monkeypatch):
validate_token_revocation_config()
+
@pytest.mark.asyncio
async def test_redis_multi_worker_allowed(monkeypatch):
monkeypatch.setenv('MEM_OR_EXTERNAL', 'REDIS')
@@ -24,4 +27,3 @@ async def test_redis_multi_worker_allowed(monkeypatch):
from doorman import validate_token_revocation_config
validate_token_revocation_config()
-
diff --git a/backend-services/tests/test_pagination_caps.py b/backend-services/tests/test_pagination_caps.py
index 48cc33a..78f39f0 100644
--- a/backend-services/tests/test_pagination_caps.py
+++ b/backend-services/tests/test_pagination_caps.py
@@ -1,6 +1,6 @@
-import os
import pytest
+
@pytest.mark.asyncio
async def test_max_page_size_boundary_api_list(authed_client, monkeypatch):
monkeypatch.setenv('MAX_PAGE_SIZE', '5')
@@ -13,6 +13,7 @@ async def test_max_page_size_boundary_api_list(authed_client, monkeypatch):
body = r_bad.json()
assert 'error_message' in body
+
@pytest.mark.asyncio
async def test_max_page_size_boundary_users_list(authed_client, monkeypatch):
monkeypatch.setenv('MAX_PAGE_SIZE', '3')
@@ -23,6 +24,7 @@ async def test_max_page_size_boundary_users_list(authed_client, monkeypatch):
r_bad = await authed_client.get('/platform/user/all?page=1&page_size=4')
assert r_bad.status_code == 400, r_bad.text
+
@pytest.mark.asyncio
async def test_invalid_page_values(authed_client, monkeypatch):
monkeypatch.setenv('MAX_PAGE_SIZE', '10')
@@ -32,4 +34,3 @@ async def test_invalid_page_values(authed_client, monkeypatch):
r2 = await authed_client.get('/platform/group/all?page=1&page_size=0')
assert r2.status_code == 400
-
diff --git a/backend-services/tests/test_permissions_extended.py b/backend-services/tests/test_permissions_extended.py
index 0c01acf..d7150e1 100644
--- a/backend-services/tests/test_permissions_extended.py
+++ b/backend-services/tests/test_permissions_extended.py
@@ -1,13 +1,12 @@
import pytest
from httpx import AsyncClient
+
async def _login(email: str, password: str) -> AsyncClient:
from doorman import doorman
+
client = AsyncClient(app=doorman, base_url='http://testserver')
- r = await client.post(
- '/platform/authorization',
- json={'email': email, 'password': password},
- )
+ r = await client.post('/platform/authorization', json={'email': email, 'password': password})
assert r.status_code == 200, r.text
body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
@@ -16,9 +15,9 @@ async def _login(email: str, password: str) -> AsyncClient:
client.cookies.set('access_token_cookie', token, domain='testserver', path='/')
return client
+
@pytest.mark.asyncio
async def test_non_admin_role_cannot_access_monitor_or_credits(monkeypatch, authed_client):
-
rrole = await authed_client.post(
'/platform/role',
json={
@@ -64,7 +63,13 @@ async def test_non_admin_role_cannot_access_monitor_or_credits(monkeypatch, auth
'api_key': 'x',
'api_key_header': 'x-api-key',
'credit_tiers': [
- {'tier_name': 'basic', 'credits': 10, 'input_limit': 0, 'output_limit': 0, 'reset_frequency': 'monthly'}
+ {
+ 'tier_name': 'basic',
+ 'credits': 10,
+ 'input_limit': 0,
+ 'output_limit': 0,
+ 'reset_frequency': 'monthly',
+ }
],
},
)
@@ -73,16 +78,12 @@ async def test_non_admin_role_cannot_access_monitor_or_credits(monkeypatch, auth
cc = await viewer.delete('/api/caches')
assert cc.status_code == 403
+
@pytest.mark.asyncio
async def test_endpoint_validation_management_requires_permission(authed_client):
-
await authed_client.post(
'/platform/role',
- json={
- 'role_name': 'noend',
- 'role_description': 'No endpoints',
- 'manage_endpoints': False,
- },
+ json={'role_name': 'noend', 'role_description': 'No endpoints', 'manage_endpoints': False},
)
await authed_client.post(
'/platform/user',
@@ -96,6 +97,7 @@ async def test_endpoint_validation_management_requires_permission(authed_client)
)
from conftest import create_api, create_endpoint
+
api_name, ver = 'permapi', 'v1'
await create_api(authed_client, api_name, ver)
await create_endpoint(authed_client, api_name, ver, 'POST', '/foo')
@@ -106,6 +108,10 @@ async def test_endpoint_validation_management_requires_permission(authed_client)
ev = await client.post(
'/platform/endpoint/endpoint/validation',
- json={'endpoint_id': eid, 'validation_enabled': True, 'validation_schema': {'validation_schema': {}}},
+ json={
+ 'endpoint_id': eid,
+ 'validation_enabled': True,
+ 'validation_schema': {'validation_schema': {}},
+ },
)
assert ev.status_code == 403
diff --git a/backend-services/tests/test_platform_admin_visibility.py b/backend-services/tests/test_platform_admin_visibility.py
index 431287a..a7173b1 100644
--- a/backend-services/tests/test_platform_admin_visibility.py
+++ b/backend-services/tests/test_platform_admin_visibility.py
@@ -1,22 +1,25 @@
-import os
import uuid
+
import pytest
import pytest_asyncio
from httpx import AsyncClient
+
@pytest_asyncio.fixture
async def login_client():
async def _login(username: str, password: str, email: str = None) -> AsyncClient:
from doorman import doorman
+
client = AsyncClient(app=doorman, base_url='http://testserver')
cred = {'email': email or f'{username}@example.com', 'password': password}
r = await client.post('/platform/authorization', json=cred)
assert r.status_code == 200, r.text
return client
+
return _login
-async def _ensure_manager_role(authed_client: AsyncClient):
+async def _ensure_manager_role(authed_client: AsyncClient):
payload = {
'role_name': 'manager',
'role_description': 'Manager role',
@@ -31,11 +34,12 @@ async def _ensure_manager_role(authed_client: AsyncClient):
'manage_credits': True,
'manage_auth': True,
'view_logs': True,
- 'export_logs': True
+ 'export_logs': True,
}
r = await authed_client.post('/platform/role', json=payload)
assert r.status_code in (200, 201, 400), r.text
+
async def _create_manager_user(authed_client: AsyncClient) -> dict:
await _ensure_manager_role(authed_client)
uname = f'mgr_{uuid.uuid4().hex[:8]}'
@@ -46,12 +50,13 @@ async def _create_manager_user(authed_client: AsyncClient) -> dict:
'role': 'manager',
'groups': ['ALL'],
'active': True,
- 'ui_access': True
+ 'ui_access': True,
}
r = await authed_client.post('/platform/user', json=payload)
assert r.status_code in (200, 201), r.text
return payload
+
@pytest.mark.asyncio
async def test_non_admin_cannot_see_admin_role(authed_client, login_client):
mgr = await _create_manager_user(authed_client)
@@ -68,9 +73,9 @@ async def test_non_admin_cannot_see_admin_role(authed_client, login_client):
await client.aclose()
+
@pytest.mark.asyncio
async def test_non_admin_cannot_see_or_modify_admin_users(authed_client, login_client):
-
mgr = await _create_manager_user(authed_client)
client = await login_client(mgr['username'], 'StrongManagerPwd!1234', mgr['email'])
@@ -89,28 +94,33 @@ async def test_non_admin_cannot_see_or_modify_admin_users(authed_client, login_c
await client.aclose()
+
@pytest.mark.asyncio
async def test_non_admin_cannot_assign_admin_role(authed_client, login_client):
mgr = await _create_manager_user(authed_client)
client = await login_client(mgr['username'], 'StrongManagerPwd!1234', mgr['email'])
newu = f'np_{uuid.uuid4().hex[:8]}'
- r_create = await client.post('/platform/user', json={
- 'username': newu,
- 'email': f'{newu}@example.com',
- 'password': 'StrongPwd!1234XYZ',
- 'role': 'admin',
- 'groups': ['ALL'],
- 'active': True,
- 'ui_access': True
- })
+ r_create = await client.post(
+ '/platform/user',
+ json={
+ 'username': newu,
+ 'email': f'{newu}@example.com',
+ 'password': 'StrongPwd!1234XYZ',
+ 'role': 'admin',
+ 'groups': ['ALL'],
+ 'active': True,
+ 'ui_access': True,
+ },
+ )
assert r_create.status_code in (403, 404)
- r_update = await client.put(f"/platform/user/{mgr['username']}", json={'role': 'admin'})
+ r_update = await client.put(f'/platform/user/{mgr["username"]}', json={'role': 'admin'})
assert r_update.status_code in (403, 404)
await client.aclose()
+
@pytest.mark.asyncio
async def test_non_admin_auth_admin_ops_hidden(authed_client, login_client):
mgr = await _create_manager_user(authed_client)
@@ -127,6 +137,8 @@ async def test_non_admin_auth_admin_ops_hidden(authed_client, login_client):
r = await client.get(path)
else:
r = await client.post(path)
- assert r.status_code in (404, 403), f'Expected 404/403 for {path}, got {r.status_code}: {r.text}'
+ assert r.status_code in (404, 403), (
+ f'Expected 404/403 for {path}, got {r.status_code}: {r.text}'
+ )
await client.aclose()
diff --git a/backend-services/tests/test_platform_cors_env_edges.py b/backend-services/tests/test_platform_cors_env_edges.py
index 6aaec94..0c43891 100644
--- a/backend-services/tests/test_platform_cors_env_edges.py
+++ b/backend-services/tests/test_platform_cors_env_edges.py
@@ -1,60 +1,78 @@
import pytest
+
@pytest.mark.asyncio
async def test_platform_cors_wildcard_origin_with_credentials_strict_false(monkeypatch, client):
monkeypatch.setenv('ALLOWED_ORIGINS', '*')
monkeypatch.setenv('ALLOW_CREDENTIALS', 'true')
monkeypatch.setenv('CORS_STRICT', 'false')
- r = await client.options('/platform/api', headers={
- 'Origin': 'http://evil.example',
- 'Access-Control-Request-Method': 'GET'
- })
+ r = await client.options(
+ '/platform/api',
+ headers={'Origin': 'http://evil.example', 'Access-Control-Request-Method': 'GET'},
+ )
assert r.status_code == 204
assert r.headers.get('Access-Control-Allow-Origin') == 'http://evil.example'
assert r.headers.get('Access-Control-Allow-Credentials') == 'true'
assert r.headers.get('Vary') == 'Origin'
+
@pytest.mark.asyncio
-async def test_platform_cors_wildcard_origin_with_credentials_strict_true_restricts(monkeypatch, client):
+async def test_platform_cors_wildcard_origin_with_credentials_strict_true_restricts(
+ monkeypatch, client
+):
monkeypatch.setenv('ALLOWED_ORIGINS', '*')
monkeypatch.setenv('ALLOW_CREDENTIALS', 'true')
monkeypatch.setenv('CORS_STRICT', 'true')
- r = await client.options('/platform/api', headers={
- 'Origin': 'http://evil.example',
- 'Access-Control-Request-Method': 'GET'
- })
+ r = await client.options(
+ '/platform/api',
+ headers={'Origin': 'http://evil.example', 'Access-Control-Request-Method': 'GET'},
+ )
assert r.status_code == 204
assert r.headers.get('Access-Control-Allow-Origin') is None
+
@pytest.mark.asyncio
async def test_platform_cors_methods_empty_env_falls_back_default(monkeypatch, client):
monkeypatch.setenv('ALLOW_METHODS', '')
monkeypatch.setenv('ALLOWED_ORIGINS', 'http://localhost:3000')
- r = await client.options('/platform/api', headers={
- 'Origin': 'http://localhost:3000',
- 'Access-Control-Request-Method': 'GET'
- })
+ r = await client.options(
+ '/platform/api',
+ headers={'Origin': 'http://localhost:3000', 'Access-Control-Request-Method': 'GET'},
+ )
assert r.status_code == 204
- methods = [m.strip() for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',') if m.strip()]
- expected = {'GET','POST','PUT','DELETE','OPTIONS','PATCH','HEAD'}
+ methods = [
+ m.strip()
+ for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',')
+ if m.strip()
+ ]
+ expected = {'GET', 'POST', 'PUT', 'DELETE', 'OPTIONS', 'PATCH', 'HEAD'}
assert set(methods) == expected
+
@pytest.mark.asyncio
async def test_platform_cors_headers_asterisk_defaults_to_known_list(monkeypatch, client):
monkeypatch.setenv('ALLOW_HEADERS', '*')
monkeypatch.setenv('ALLOWED_ORIGINS', 'http://localhost:3000')
- r = await client.options('/platform/api', headers={
- 'Origin': 'http://localhost:3000',
- 'Access-Control-Request-Method': 'GET',
- 'Access-Control-Request-Headers': 'X-Anything'
- })
+ r = await client.options(
+ '/platform/api',
+ headers={
+ 'Origin': 'http://localhost:3000',
+ 'Access-Control-Request-Method': 'GET',
+ 'Access-Control-Request-Headers': 'X-Anything',
+ },
+ )
assert r.status_code == 204
- headers = [h.strip() for h in (r.headers.get('Access-Control-Allow-Headers') or '').split(',') if h.strip()]
- assert set(headers) == {'Accept','Content-Type','X-CSRF-Token','Authorization'}
+ headers = [
+ h.strip()
+ for h in (r.headers.get('Access-Control-Allow-Headers') or '').split(',')
+ if h.strip()
+ ]
+ assert set(headers) == {'Accept', 'Content-Type', 'X-CSRF-Token', 'Authorization'}
+
@pytest.mark.asyncio
async def test_platform_cors_sets_vary_origin(monkeypatch, authed_client):
@@ -63,4 +81,3 @@ async def test_platform_cors_sets_vary_origin(monkeypatch, authed_client):
assert r.status_code == 200
assert r.headers.get('Access-Control-Allow-Origin') == 'http://ok.example'
assert r.headers.get('Vary') == 'Origin'
-
diff --git a/backend-services/tests/test_platform_cors_no_duplicate_acao.py b/backend-services/tests/test_platform_cors_no_duplicate_acao.py
new file mode 100644
index 0000000..1601604
--- /dev/null
+++ b/backend-services/tests/test_platform_cors_no_duplicate_acao.py
@@ -0,0 +1,13 @@
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_platform_cors_no_duplicate_access_control_allow_origin(monkeypatch, authed_client):
+ monkeypatch.setenv('ALLOWED_ORIGINS', 'http://ok.example')
+ r = await authed_client.get('/platform/user/me', headers={'Origin': 'http://ok.example'})
+ assert r.status_code == 200
+ acao = r.headers.get('Access-Control-Allow-Origin') or r.headers.get(
+ 'access-control-allow-origin'
+ )
+ assert acao == 'http://ok.example'
+ assert ',' not in acao
diff --git a/backend-services/tests/test_platform_expanded.py b/backend-services/tests/test_platform_expanded.py
index 6d5891c..8117300 100644
--- a/backend-services/tests/test_platform_expanded.py
+++ b/backend-services/tests/test_platform_expanded.py
@@ -1,7 +1,8 @@
-import os
import json
+
import pytest
+
@pytest.mark.asyncio
async def test_routing_crud(authed_client):
create = await authed_client.post(
@@ -30,13 +31,10 @@ async def test_routing_crud(authed_client):
delete = await authed_client.delete('/platform/routing/client-A')
assert delete.status_code == 200
+
@pytest.mark.asyncio
async def test_security_and_memory_dump_restore(authed_client):
-
- ur = await authed_client.put(
- '/platform/role/admin',
- json={'manage_security': True},
- )
+ ur = await authed_client.put('/platform/role/admin', json={'manage_security': True})
assert ur.status_code == 200
gs = await authed_client.get('/platform/security/settings')
@@ -54,12 +52,11 @@ async def test_security_and_memory_dump_restore(authed_client):
restore = await authed_client.post('/platform/memory/restore', json={'path': path})
assert restore.status_code == 200
+
@pytest.mark.asyncio
async def test_logging_endpoints(authed_client):
-
r1 = await authed_client.put(
- '/platform/role/admin',
- json={'view_logs': True, 'export_logs': True},
+ '/platform/role/admin', json={'view_logs': True, 'export_logs': True}
)
assert r1.status_code == 200
@@ -78,16 +75,12 @@ async def test_logging_endpoints(authed_client):
download = await authed_client.get('/platform/logging/logs/download?format=json')
assert download.status_code == 200
+
@pytest.mark.asyncio
async def test_onboard_public_apis_for_all_gateway_types(monkeypatch, authed_client):
-
rest_apis = [
- ('jsonplaceholder', 'v1', ['https://jsonplaceholder.typicode.com'], [
- ('GET', '/posts/1')
- ]),
- ('httpbin', 'v1', ['https://httpbin.org'], [
- ('GET', '/get')
- ]),
+ ('jsonplaceholder', 'v1', ['https://jsonplaceholder.typicode.com'], [('GET', '/posts/1')]),
+ ('httpbin', 'v1', ['https://httpbin.org'], [('GET', '/get')]),
]
for name, ver, servers, endpoints in rest_apis:
c = await authed_client.post(
@@ -149,7 +142,11 @@ async def test_onboard_public_apis_for_all_gateway_types(monkeypatch, authed_cli
assert s.status_code in (200, 201)
soap_apis = [
- ('soap-number', 'v1', ['https://www.dataaccess.com/webservicesserver/NumberConversion.wso']),
+ (
+ 'soap-number',
+ 'v1',
+ ['https://www.dataaccess.com/webservicesserver/NumberConversion.wso'],
+ ),
('soap-tempconvert', 'v1', ['https://www.w3schools.com/xml/tempconvert.asmx']),
]
for name, ver, servers in soap_apis:
@@ -199,21 +196,30 @@ async def test_onboard_public_apis_for_all_gateway_types(monkeypatch, authed_cli
assert s.status_code in (200, 201)
import services.gateway_service as gs
+
class _FakeHTTPResponse:
def __init__(self, status_code=200, json_body=None, text_body=None, headers=None):
self.status_code = status_code
self._json_body = json_body
- self.text = text_body if text_body is not None else ('' if json_body is not None else 'OK')
- self.headers = headers or {'Content-Type': 'application/json' if json_body is not None else 'text/plain'}
+ self.text = (
+ text_body if text_body is not None else ('' if json_body is not None else 'OK')
+ )
+ self.headers = headers or {
+ 'Content-Type': 'application/json' if json_body is not None else 'text/plain'
+ }
+
def json(self):
if self._json_body is None:
return json.loads(self.text or '{}')
return self._json_body
+
class _FakeAsyncClient:
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -231,23 +237,32 @@ async def test_onboard_public_apis_for_all_gateway_types(monkeypatch, authed_cli
return await self.put(url, **kwargs)
else:
return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'})
+
async def get(self, url, params=None, headers=None, **kwargs):
return _FakeHTTPResponse(200, json_body={'ping': 'pong'})
+
async def post(self, url, **kwargs):
return _FakeHTTPResponse(200, json_body={'ping': 'pong'})
+
async def put(self, url, **kwargs):
return _FakeHTTPResponse(200, json_body={'ping': 'pong'})
+
async def delete(self, url, **kwargs):
return _FakeHTTPResponse(200, json_body={'ping': 'pong'})
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
import routes.gateway_routes as gr
+
async def _no_limit(req):
return None
+
async def _pass_sub(req):
return {'sub': 'admin'}
+
async def _pass_group(req, full_path: str = None, user_to_subscribe=None):
return {'sub': 'admin'}
+
monkeypatch.setattr(gr, 'limit_and_throttle', _no_limit)
monkeypatch.setattr(gr, 'subscription_required', _pass_sub)
monkeypatch.setattr(gr, 'group_required', _pass_group)
diff --git a/backend-services/tests/test_production_https_guard.py b/backend-services/tests/test_production_https_guard.py
index 100c933..13a9111 100644
--- a/backend-services/tests/test_production_https_guard.py
+++ b/backend-services/tests/test_production_https_guard.py
@@ -1,23 +1,29 @@
import pytest
+
@pytest.mark.asyncio
async def test_production_without_https_flags_fails_startup(monkeypatch):
monkeypatch.setenv('ENV', 'production')
monkeypatch.setenv('HTTPS_ONLY', 'false')
- from doorman import app_lifespan, doorman
import pytest as _pytest
+
+ from doorman import app_lifespan, doorman
+
with _pytest.raises(RuntimeError):
async with app_lifespan(doorman):
pass
+
@pytest.mark.asyncio
async def test_production_with_https_only_succeeds(monkeypatch):
monkeypatch.setenv('ENV', 'production')
monkeypatch.setenv('HTTPS_ONLY', 'true')
from httpx import AsyncClient
+
from doorman import doorman
+
client = AsyncClient(app=doorman, base_url='http://testserver')
r = await client.get('/platform/monitor/liveness')
assert r.status_code == 200
diff --git a/backend-services/tests/test_proto_extension.py b/backend-services/tests/test_proto_extension.py
index d3ee079..d4f4d5d 100644
--- a/backend-services/tests/test_proto_extension.py
+++ b/backend-services/tests/test_proto_extension.py
@@ -1,21 +1,21 @@
-import io
import pytest
+
@pytest.mark.asyncio
async def test_proto_upload_rejects_non_proto(authed_client):
- files = {'file': ('bad.txt', b'syntax = \"proto3\";', 'text/plain')}
+ files = {'file': ('bad.txt', b'syntax = "proto3";', 'text/plain')}
r = await authed_client.post('/platform/proto/sample/v1', files=files)
assert r.status_code == 400
body = r.json()
assert body.get('error_code') == 'REQ003'
+
@pytest.mark.asyncio
async def test_proto_upload_accepts_proto(authed_client):
- content = b'syntax = \"proto3\";\npackage sample_v1;\nmessage Ping { string msg = 1; }'
+ content = b'syntax = "proto3";\npackage sample_v1;\nmessage Ping { string msg = 1; }'
files = {'file': ('ok.proto', content, 'application/octet-stream')}
r = await authed_client.post('/platform/proto/sample/v1', files=files)
assert r.status_code in (200, 500)
if r.status_code == 200:
body = r.json()
assert body.get('message', '').lower().startswith('proto file uploaded')
-
diff --git a/backend-services/tests/test_proto_routes.py b/backend-services/tests/test_proto_routes.py
index e46ca7e..7e41bac 100644
--- a/backend-services/tests/test_proto_routes.py
+++ b/backend-services/tests/test_proto_routes.py
@@ -1,14 +1,18 @@
import io
+
import pytest
+
@pytest.mark.asyncio
async def test_proto_upload_and_get(monkeypatch, authed_client):
import routes.proto_routes as pr
class _FakeCompleted:
pass
+
def _fake_run(*args, **kwargs):
return _FakeCompleted()
+
monkeypatch.setattr(pr.subprocess, 'run', _fake_run)
proto_content = b"""
@@ -22,5 +26,4 @@ async def test_proto_upload_and_get(monkeypatch, authed_client):
gp = await authed_client.get('/platform/proto/myapi/v1')
assert gp.status_code == 200
content = gp.json().get('content') or gp.json().get('response', {}).get('content')
- assert 'syntax = \"proto3\";' in content
-
+ assert 'syntax = "proto3";' in content
diff --git a/backend-services/tests/test_proto_routes_extended.py b/backend-services/tests/test_proto_routes_extended.py
index 427d184..33d7193 100644
--- a/backend-services/tests/test_proto_routes_extended.py
+++ b/backend-services/tests/test_proto_routes_extended.py
@@ -1,28 +1,33 @@
import pytest
+
@pytest.mark.asyncio
async def test_proto_update_and_delete_flow(monkeypatch, authed_client):
-
import routes.proto_routes as pr
- class _FakeCompleted: pass
+
+ class _FakeCompleted:
+ pass
+
def _fake_run(*args, **kwargs):
return _FakeCompleted()
+
monkeypatch.setattr(pr.subprocess, 'run', _fake_run)
- r = await authed_client.put(
- '/platform/role/admin',
- json={'manage_apis': True},
- )
+ r = await authed_client.put('/platform/role/admin', json={'manage_apis': True})
assert r.status_code in (200, 201)
files = {
- 'file': ('sample.proto', b'syntax = \"proto3\"; message Ping { string x = 1; }', 'text/plain'),
+ 'file': ('sample.proto', b'syntax = "proto3"; message Ping { string x = 1; }', 'text/plain')
}
up = await authed_client.post('/platform/proto/myapi2/v1', files=files)
assert up.status_code in (200, 201), up.text
files2 = {
- 'proto_file': ('sample.proto', b'syntax = \"proto3\"; message Pong { string y = 1; }', 'text/plain'),
+ 'proto_file': (
+ 'sample.proto',
+ b'syntax = "proto3"; message Pong { string y = 1; }',
+ 'text/plain',
+ )
}
put = await authed_client.put('/platform/proto/myapi2/v1', files=files2)
assert put.status_code in (200, 201), put.text
@@ -39,6 +44,7 @@ async def test_proto_update_and_delete_flow(monkeypatch, authed_client):
gp2 = await authed_client.get('/platform/proto/myapi2/v1')
assert gp2.status_code == 404
+
@pytest.mark.asyncio
async def test_proto_get_nonexistent_returns_404(authed_client):
resp = await authed_client.get('/platform/proto/doesnotexist/v9')
diff --git a/backend-services/tests/test_proto_upload_security_and_import.py b/backend-services/tests/test_proto_upload_security_and_import.py
index 2ec02e8..6819c21 100644
--- a/backend-services/tests/test_proto_upload_security_and_import.py
+++ b/backend-services/tests/test_proto_upload_security_and_import.py
@@ -1,9 +1,10 @@
import pytest
-from pathlib import Path
+
@pytest.mark.asyncio
async def test_proto_upload_rejects_invalid_filename(monkeypatch, authed_client):
import routes.proto_routes as pr
+
monkeypatch.setattr(pr, 'sanitize_filename', lambda s: (_ for _ in ()).throw(ValueError('bad')))
files = {'file': ('svc.proto', b'syntax = "proto3"; package x;')}
r = await authed_client.post('/platform/proto/bad/v1', files=files)
@@ -11,69 +12,82 @@ async def test_proto_upload_rejects_invalid_filename(monkeypatch, authed_client)
body = r.json()
assert body.get('error_code')
+
@pytest.mark.asyncio
async def test_proto_upload_validates_within_base_path():
import routes.proto_routes as pr
+
base = (pr.PROJECT_ROOT / 'proto').resolve()
good = (base / 'ok.proto').resolve()
bad = (pr.PROJECT_ROOT.parent / 'outside.proto').resolve()
assert pr.validate_path(pr.PROJECT_ROOT, good) is True
assert pr.validate_path(pr.PROJECT_ROOT, bad) is False
+
@pytest.mark.asyncio
async def test_proto_upload_generates_stubs_success(monkeypatch, authed_client):
name, ver = 'psvc1', 'v1'
proto = b'syntax = "proto3"; package foo; service S { rpc M (R) returns (Q) {} } message R { string a = 1; } message Q { string b = 1; }'
files = {'file': ('svc.proto', proto)}
import routes.proto_routes as pr
+
safe = f'{name}_{ver}'
gen = (pr.PROJECT_ROOT / 'generated').resolve()
+
def _fake_run(cmd, check):
(gen / f'{safe}_pb2.py').write_text('# pb2')
(gen / f'{safe}_pb2_grpc.py').write_text(
- f'import {safe}_pb2 as {name}__{ver}__pb2\n'
- 'class S: pass\n'
+ f'import {safe}_pb2 as {name}__{ver}__pb2\nclass S: pass\n'
)
return 0
+
monkeypatch.setattr(pr.subprocess, 'run', _fake_run)
r = await authed_client.post(f'/platform/proto/{name}/{ver}', files=files)
assert r.status_code == 200
import routes.proto_routes as pr
+
safe = f'{name}_{ver}'
gen = (pr.PROJECT_ROOT / 'generated').resolve()
assert (gen / f'{safe}_pb2.py').exists()
assert (gen / f'{safe}_pb2_grpc.py').exists()
+
@pytest.mark.asyncio
async def test_proto_upload_rewrite_pb2_imports_for_generated_namespace(monkeypatch, authed_client):
name, ver = 'psvc2', 'v1'
proto = b'syntax = "proto3"; package foo; service S { rpc M (R) returns (Q) {} } message R { string a = 1; } message Q { string b = 1; }'
files = {'file': ('svc.proto', proto)}
import routes.proto_routes as pr
+
safe = f'{name}_{ver}'
gen = (pr.PROJECT_ROOT / 'generated').resolve()
+
def _fake_run(cmd, check):
(gen / f'{safe}_pb2.py').write_text('# pb2')
(gen / f'{safe}_pb2_grpc.py').write_text(
- f'import {safe}_pb2 as {name}__{ver}__pb2\n'
- 'class S: pass\n'
+ f'import {safe}_pb2 as {name}__{ver}__pb2\nclass S: pass\n'
)
return 0
+
monkeypatch.setattr(pr.subprocess, 'run', _fake_run)
r = await authed_client.post(f'/platform/proto/{name}/{ver}', files=files)
assert r.status_code == 200
import routes.proto_routes as pr
+
safe = f'{name}_{ver}'
gen = (pr.PROJECT_ROOT / 'generated').resolve()
- pb2g = (gen / f'{safe}_pb2_grpc.py')
+ pb2g = gen / f'{safe}_pb2_grpc.py'
txt = pb2g.read_text()
assert f'from generated import {safe}_pb2 as {name}__{ver}__pb2' in txt
+
@pytest.mark.asyncio
async def test_proto_get_requires_permission(monkeypatch, authed_client):
import routes.proto_routes as pr
+
async def _no_perm(*args, **kwargs):
return False
+
monkeypatch.setattr(pr, 'platform_role_required_bool', _no_perm)
r = await authed_client.get('/platform/proto/x/v1')
assert r.status_code == 403
diff --git a/backend-services/tests/test_public_apis.py b/backend-services/tests/test_public_apis.py
index 7e03c2e..52f137c 100644
--- a/backend-services/tests/test_public_apis.py
+++ b/backend-services/tests/test_public_apis.py
@@ -1,8 +1,8 @@
import pytest
+
@pytest.mark.asyncio
async def test_rest_public_api_allows_unauthenticated(client, authed_client):
-
name, ver = 'pubrest', 'v1'
cr = await authed_client.post(
'/platform/api',
@@ -32,6 +32,7 @@ async def test_rest_public_api_allows_unauthenticated(client, authed_client):
assert r.status_code in (200, 400, 404, 429, 500)
+
@pytest.mark.asyncio
async def test_graphql_public_api_allows_unauthenticated(client, authed_client):
name, ver = 'pubgql', 'v1'
@@ -55,6 +56,7 @@ async def test_graphql_public_api_allows_unauthenticated(client, authed_client):
)
assert r.status_code in (200, 400, 404, 429, 500)
+
@pytest.mark.asyncio
async def test_public_api_bypasses_credits_check(client, authed_client):
name, ver = 'pubcredits', 'v1'
@@ -102,6 +104,7 @@ async def test_public_api_bypasses_credits_check(client, authed_client):
r = await client.get(f'/api/rest/{name}/{ver}/ping')
assert r.status_code != 401
+
@pytest.mark.asyncio
async def test_auth_not_required_but_not_public(client, authed_client):
name, ver = 'noauthsub', 'v1'
diff --git a/backend-services/tests/test_rate_limit_and_throttle_combination.py b/backend-services/tests/test_rate_limit_and_throttle_combination.py
index 48d92fd..04aee46 100644
--- a/backend-services/tests/test_rate_limit_and_throttle_combination.py
+++ b/backend-services/tests/test_rate_limit_and_throttle_combination.py
@@ -1,21 +1,28 @@
-import pytest
import time
+import pytest
from tests.test_gateway_routing_limits import _FakeAsyncClient
+
@pytest.mark.asyncio
async def test_rate_limit_blocks_second_request_in_window(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'rlblock', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/p')
await subscribe_self(authed_client, name, ver)
from utils.database import user_collection
- user_collection.update_one({'username': 'admin'}, {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second'}})
+
+ user_collection.update_one(
+ {'username': 'admin'},
+ {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second'}},
+ )
await authed_client.delete('/api/caches')
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
ok = await authed_client.get(f'/api/rest/{name}/{ver}/p')
@@ -23,55 +30,77 @@ async def test_rate_limit_blocks_second_request_in_window(monkeypatch, authed_cl
blocked = await authed_client.get(f'/api/rest/{name}/{ver}/p')
assert blocked.status_code == 429
+
@pytest.mark.asyncio
async def test_throttle_queue_limit_exceeded_returns_429(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'tqex', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/t')
await subscribe_self(authed_client, name, ver)
from utils.database import user_collection
- user_collection.update_one({'username': 'admin'}, {'$set': {'throttle_duration': 1, 'throttle_queue_limit': 1}})
+
+ user_collection.update_one(
+ {'username': 'admin'}, {'$set': {'throttle_duration': 1, 'throttle_queue_limit': 1}}
+ )
await authed_client.delete('/api/caches')
from utils.limit_throttle_util import reset_counters
+
reset_counters()
now_ms = int(time.time() * 1000)
wait_ms = 1000 - (now_ms % 1000) + 350
time.sleep(wait_ms / 1000.0)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r1 = await authed_client.get(f'/api/rest/{name}/{ver}/t')
- assert r1.status_code == 200, f"First request failed with {r1.status_code}"
+ assert r1.status_code == 200, f'First request failed with {r1.status_code}'
r2 = await authed_client.get(f'/api/rest/{name}/{ver}/t')
- assert r2.status_code == 429, f"Second request should have been throttled but got {r2.status_code}"
+ assert r2.status_code == 429, (
+ f'Second request should have been throttled but got {r2.status_code}'
+ )
+
@pytest.mark.asyncio
async def test_throttle_dynamic_wait_increases_latency(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'twait', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/w')
await subscribe_self(authed_client, name, ver)
from utils.database import user_collection
- user_collection.update_one({'username': 'admin'}, {'$set': {
- 'throttle_duration': 1, 'throttle_duration_type': 'second',
- 'throttle_queue_limit': 10,
- 'throttle_wait_duration': 0.1, 'throttle_wait_duration_type': 'second',
- 'rate_limit_duration': 1000, 'rate_limit_duration_type': 'second'
- }})
+
+ user_collection.update_one(
+ {'username': 'admin'},
+ {
+ '$set': {
+ 'throttle_duration': 1,
+ 'throttle_duration_type': 'second',
+ 'throttle_queue_limit': 10,
+ 'throttle_wait_duration': 0.1,
+ 'throttle_wait_duration_type': 'second',
+ 'rate_limit_duration': 1000,
+ 'rate_limit_duration_type': 'second',
+ }
+ },
+ )
await authed_client.delete('/api/caches')
from utils.limit_throttle_util import reset_counters
+
reset_counters()
now_ms = int(time.time() * 1000)
wait_ms = 1000 - (now_ms % 1000) + 350
time.sleep(wait_ms / 1000.0)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
t0 = time.perf_counter()
@@ -88,19 +117,26 @@ async def test_throttle_dynamic_wait_increases_latency(monkeypatch, authed_clien
assert dur2 >= dur1 + 0.08
+
@pytest.mark.asyncio
async def test_rate_limit_window_rollover_allows_requests(monkeypatch, authed_client):
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'rlroll', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/x')
await subscribe_self(authed_client, name, ver)
from utils.database import user_collection
- user_collection.update_one({'username': 'admin'}, {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second'}})
+
+ user_collection.update_one(
+ {'username': 'admin'},
+ {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second'}},
+ )
await authed_client.delete('/api/caches')
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r1 = await authed_client.get(f'/api/rest/{name}/{ver}/x')
diff --git a/backend-services/tests/test_rate_limit_fallback.py b/backend-services/tests/test_rate_limit_fallback.py
index 98fc581..b73d2db 100644
--- a/backend-services/tests/test_rate_limit_fallback.py
+++ b/backend-services/tests/test_rate_limit_fallback.py
@@ -2,6 +2,7 @@ import asyncio
from utils.limit_throttle_util import InMemoryWindowCounter
+
async def _inc(counter: InMemoryWindowCounter, key: str, times: int, ttl: int):
counts = []
for _ in range(times):
@@ -10,6 +11,7 @@ async def _inc(counter: InMemoryWindowCounter, key: str, times: int, ttl: int):
await counter.expire(key, ttl)
return counts
+
def test_inmemory_counter_increments_and_expires(event_loop):
c = InMemoryWindowCounter()
counts = event_loop.run_until_complete(_inc(c, 'k1', 3, 1))
@@ -19,4 +21,3 @@ def test_inmemory_counter_increments_and_expires(event_loop):
counts2 = event_loop.run_until_complete(_inc(c, 'k1', 2, 1))
assert counts2[0] == 1
assert counts2[1] == 2
-
diff --git a/backend-services/tests/test_rate_limiter.py b/backend-services/tests/test_rate_limiter.py
index c8dc670..5a279d5 100644
--- a/backend-services/tests/test_rate_limiter.py
+++ b/backend-services/tests/test_rate_limiter.py
@@ -8,24 +8,22 @@ Comprehensive tests for rate limiting functionality including:
- Load tests for distributed scenarios
"""
-import pytest
import asyncio
import time
-from datetime import datetime, timedelta
-from unittest.mock import Mock, patch, MagicMock
+from unittest.mock import MagicMock, Mock
-from utils.rate_limiter import RateLimiter, RateLimitResult
-from utils.quota_tracker import QuotaTracker
+import pytest
+
+from models.rate_limit_models import QuotaType, RateLimitRule, RuleType, TimeWindow
from utils.ip_rate_limiter import IPRateLimiter
-from models.rate_limit_models import (
- RateLimitRule, RuleType, TimeWindow, QuotaType
-)
-
+from utils.quota_tracker import QuotaTracker
+from utils.rate_limiter import RateLimiter
# ============================================================================
# FIXTURES
# ============================================================================
+
@pytest.fixture
def mock_redis():
"""Mock Redis client for testing"""
@@ -60,12 +58,12 @@ def ip_limiter(mock_redis):
def sample_rule():
"""Sample rate limit rule"""
return RateLimitRule(
- rule_id="test_rule",
+ rule_id='test_rule',
rule_type=RuleType.PER_USER,
time_window=TimeWindow.MINUTE,
limit=100,
burst_allowance=20,
- target_identifier="test_user"
+ target_identifier='test_user',
)
@@ -73,60 +71,61 @@ def sample_rule():
# UNIT TESTS - RATE LIMITER
# ============================================================================
+
class TestRateLimiter:
"""Unit tests for RateLimiter class"""
-
+
def test_sliding_window_within_limit(self, rate_limiter, sample_rule, mock_redis):
"""Test sliding window allows requests within limit"""
- mock_redis.get.return_value = "50" # Current count
-
- result = rate_limiter._check_sliding_window(sample_rule, "test_user")
-
+ mock_redis.get.return_value = '50' # Current count
+
+ result = rate_limiter._check_sliding_window(sample_rule, 'test_user')
+
assert result.allowed is True
assert result.limit == 100
assert result.remaining == 50
-
+
def test_sliding_window_exceeds_limit(self, rate_limiter, sample_rule, mock_redis):
"""Test sliding window blocks requests exceeding limit"""
- mock_redis.get.return_value = "100" # At limit
-
- result = rate_limiter._check_sliding_window(sample_rule, "test_user")
-
+ mock_redis.get.return_value = '100' # At limit
+
+ result = rate_limiter._check_sliding_window(sample_rule, 'test_user')
+
assert result.allowed is False
assert result.remaining == 0
-
+
def test_token_bucket_allows_burst(self, rate_limiter, sample_rule, mock_redis):
"""Test token bucket allows burst requests"""
- mock_redis.get.side_effect = ["100", "5"] # Normal at limit, burst available
-
- result = rate_limiter.check_token_bucket(sample_rule, "test_user")
-
+ mock_redis.get.side_effect = ['100', '5'] # Normal at limit, burst available
+
+ result = rate_limiter.check_token_bucket(sample_rule, 'test_user')
+
assert result.allowed is True
-
+
def test_burst_tokens_exhausted(self, rate_limiter, sample_rule, mock_redis):
"""Test burst tokens can be exhausted"""
- mock_redis.get.side_effect = ["100", "20"] # Normal at limit, burst exhausted
-
- result = rate_limiter._use_burst_tokens(sample_rule, "test_user", Mock())
-
+ mock_redis.get.side_effect = ['100', '20'] # Normal at limit, burst exhausted
+
+ result = rate_limiter._use_burst_tokens(sample_rule, 'test_user', Mock())
+
assert result.allowed is False
-
+
def test_hybrid_check_normal_flow(self, rate_limiter, sample_rule, mock_redis):
"""Test hybrid check in normal flow"""
- mock_redis.get.return_value = "50"
-
- result = rate_limiter.check_hybrid(sample_rule, "test_user")
-
+ mock_redis.get.return_value = '50'
+
+ result = rate_limiter.check_hybrid(sample_rule, 'test_user')
+
assert result.allowed is True
-
+
def test_hybrid_check_uses_burst(self, rate_limiter, sample_rule, mock_redis):
"""Test hybrid check falls back to burst tokens"""
# First call: normal limit reached
# Second call: burst available
- mock_redis.get.side_effect = ["100", "5"]
-
- result = rate_limiter.check_hybrid(sample_rule, "test_user")
-
+ mock_redis.get.side_effect = ['100', '5']
+
+ rate_limiter.check_hybrid(sample_rule, 'test_user')
+
# Should attempt to use burst tokens
assert mock_redis.get.call_count >= 1
@@ -135,75 +134,52 @@ class TestRateLimiter:
# UNIT TESTS - QUOTA TRACKER
# ============================================================================
+
class TestQuotaTracker:
"""Unit tests for QuotaTracker class"""
-
+
def test_quota_within_limit(self, quota_tracker, mock_redis):
"""Test quota check within limit"""
- mock_redis.get.return_value = "5000"
-
- result = quota_tracker.check_quota(
- "test_user",
- QuotaType.REQUESTS,
- 10000,
- "month"
- )
-
+ mock_redis.get.return_value = '5000'
+
+ result = quota_tracker.check_quota('test_user', QuotaType.REQUESTS, 10000, 'month')
+
assert result.is_exhausted is False
assert result.is_warning is False
assert result.remaining == 5000
-
+
def test_quota_warning_threshold(self, quota_tracker, mock_redis):
"""Test quota warning at 80%"""
- mock_redis.get.return_value = "8500" # 85% used
-
- result = quota_tracker.check_quota(
- "test_user",
- QuotaType.REQUESTS,
- 10000,
- "month"
- )
-
+ mock_redis.get.return_value = '8500' # 85% used
+
+ result = quota_tracker.check_quota('test_user', QuotaType.REQUESTS, 10000, 'month')
+
assert result.is_warning is True
assert result.is_critical is False
-
+
def test_quota_critical_threshold(self, quota_tracker, mock_redis):
"""Test quota critical at 95%"""
- mock_redis.get.return_value = "9600" # 96% used
-
- result = quota_tracker.check_quota(
- "test_user",
- QuotaType.REQUESTS,
- 10000,
- "month"
- )
-
+ mock_redis.get.return_value = '9600' # 96% used
+
+ result = quota_tracker.check_quota('test_user', QuotaType.REQUESTS, 10000, 'month')
+
assert result.is_critical is True
-
+
def test_quota_exhausted(self, quota_tracker, mock_redis):
"""Test quota exhausted at 100%"""
- mock_redis.get.return_value = "10000"
-
- result = quota_tracker.check_quota(
- "test_user",
- QuotaType.REQUESTS,
- 10000,
- "month"
- )
-
+ mock_redis.get.return_value = '10000'
+
+ result = quota_tracker.check_quota('test_user', QuotaType.REQUESTS, 10000, 'month')
+
assert result.is_exhausted is True
assert result.remaining == 0
-
+
def test_quota_increment(self, quota_tracker, mock_redis):
"""Test quota increment"""
mock_redis.incr.return_value = 101
-
- new_usage = quota_tracker.increment_quota(
- "test_user",
- QuotaType.REQUESTS,
- "month"
- )
-
+
+ new_usage = quota_tracker.increment_quota('test_user', QuotaType.REQUESTS, 'month')
+
assert new_usage == 101
mock_redis.incr.assert_called_once()
@@ -212,60 +188,61 @@ class TestQuotaTracker:
# UNIT TESTS - IP RATE LIMITER
# ============================================================================
+
class TestIPRateLimiter:
"""Unit tests for IPRateLimiter class"""
-
+
def test_extract_ip_from_forwarded_for(self, ip_limiter):
"""Test IP extraction from X-Forwarded-For"""
request = Mock()
- request.headers.get.return_value = "192.168.1.1, 10.0.0.1"
-
+ request.headers.get.return_value = '192.168.1.1, 10.0.0.1'
+
ip = ip_limiter.extract_client_ip(request)
-
- assert ip == "192.168.1.1"
-
+
+ assert ip == '192.168.1.1'
+
def test_extract_ip_from_real_ip(self, ip_limiter):
"""Test IP extraction from X-Real-IP"""
request = Mock()
- request.headers.get.side_effect = [None, "192.168.1.1"]
-
+ request.headers.get.side_effect = [None, '192.168.1.1']
+
ip = ip_limiter.extract_client_ip(request)
-
- assert ip == "192.168.1.1"
-
+
+ assert ip == '192.168.1.1'
+
def test_whitelist_bypasses_limit(self, ip_limiter, mock_redis):
"""Test whitelisted IP bypasses rate limit"""
mock_redis.sismember.return_value = True
-
- result = ip_limiter.check_ip_rate_limit("192.168.1.1")
-
+
+ result = ip_limiter.check_ip_rate_limit('192.168.1.1')
+
assert result.allowed is True
assert result.limit == 999999
-
+
def test_blacklist_blocks_request(self, ip_limiter, mock_redis):
"""Test blacklisted IP is blocked"""
mock_redis.sismember.side_effect = [False, True] # Not whitelisted, is blacklisted
-
- result = ip_limiter.check_ip_rate_limit("10.0.0.1")
-
+
+ result = ip_limiter.check_ip_rate_limit('10.0.0.1')
+
assert result.allowed is False
-
+
def test_reputation_reduces_limits(self, ip_limiter, mock_redis):
"""Test low reputation reduces rate limits"""
mock_redis.sismember.return_value = False
- mock_redis.get.side_effect = ["30", "0", "0"] # Low reputation, no requests yet
-
- result = ip_limiter.check_ip_rate_limit("10.0.0.1")
-
+ mock_redis.get.side_effect = ['30', '0', '0'] # Low reputation, no requests yet
+
+ result = ip_limiter.check_ip_rate_limit('10.0.0.1')
+
# Limits should be reduced due to low reputation
assert result.allowed is True
-
+
def test_reputation_score_update(self, ip_limiter, mock_redis):
"""Test reputation score update"""
- mock_redis.get.return_value = "100"
-
- new_score = ip_limiter.update_reputation_score("192.168.1.1", -10)
-
+ mock_redis.get.return_value = '100'
+
+ new_score = ip_limiter.update_reputation_score('192.168.1.1', -10)
+
assert new_score == 90
mock_redis.setex.assert_called_once()
@@ -274,70 +251,65 @@ class TestIPRateLimiter:
# INTEGRATION TESTS
# ============================================================================
+
class TestRateLimitIntegration:
"""Integration tests for rate limiting system"""
-
+
@pytest.mark.asyncio
async def test_concurrent_requests(self, rate_limiter, sample_rule, mock_redis):
"""Test concurrent requests handling"""
- mock_redis.get.return_value = "0"
-
+ mock_redis.get.return_value = '0'
+
# Simulate 10 concurrent requests
tasks = []
for i in range(10):
task = asyncio.create_task(
- asyncio.to_thread(
- rate_limiter.check_hybrid,
- sample_rule,
- f"user_{i}"
- )
+ asyncio.to_thread(rate_limiter.check_hybrid, sample_rule, f'user_{i}')
)
tasks.append(task)
-
+
results = await asyncio.gather(*tasks)
-
+
# All should be allowed (different users)
assert all(r.allowed for r in results)
-
+
def test_burst_handling_sequence(self, rate_limiter, sample_rule, mock_redis):
"""Test burst handling in sequence"""
# Simulate reaching normal limit, then using burst
mock_redis.get.side_effect = [
- "99", "100", "0", # Normal flow
- "100", "5" # Burst flow
+ '99',
+ '100',
+ '0', # Normal flow
+ '100',
+ '5', # Burst flow
]
-
+
# First request: within normal limit
- result1 = rate_limiter.check_hybrid(sample_rule, "test_user")
+ result1 = rate_limiter.check_hybrid(sample_rule, 'test_user')
assert result1.allowed is True
-
+
# Second request: at normal limit, use burst
- result2 = rate_limiter.check_hybrid(sample_rule, "test_user")
+ rate_limiter.check_hybrid(sample_rule, 'test_user')
# Should attempt burst tokens
-
+
def test_quota_and_rate_limit_interaction(self, quota_tracker, rate_limiter, mock_redis):
"""Test interaction between quota and rate limits"""
- mock_redis.get.return_value = "50"
-
+ mock_redis.get.return_value = '50'
+
# Check rate limit
rate_result = rate_limiter._check_sliding_window(
RateLimitRule(
- rule_id="test",
+ rule_id='test',
rule_type=RuleType.PER_USER,
time_window=TimeWindow.MINUTE,
- limit=100
+ limit=100,
),
- "test_user"
+ 'test_user',
)
-
+
# Check quota
- quota_result = quota_tracker.check_quota(
- "test_user",
- QuotaType.REQUESTS,
- 10000,
- "month"
- )
-
+ quota_result = quota_tracker.check_quota('test_user', QuotaType.REQUESTS, 10000, 'month')
+
# Both should allow
assert rate_result.allowed is True
assert quota_result.is_exhausted is False
@@ -347,45 +319,43 @@ class TestRateLimitIntegration:
# LOAD TESTS
# ============================================================================
+
class TestRateLimitLoad:
"""Load tests for rate limiting system"""
-
+
def test_high_volume_requests(self, rate_limiter, sample_rule, mock_redis):
"""Test handling high volume of requests"""
- mock_redis.get.return_value = "0"
-
+ mock_redis.get.return_value = '0'
+
start_time = time.time()
-
+
# Simulate 1000 requests
for i in range(1000):
- result = rate_limiter.check_hybrid(sample_rule, f"user_{i % 100}")
-
+ rate_limiter.check_hybrid(sample_rule, f'user_{i % 100}')
+
elapsed = time.time() - start_time
-
+
# Should complete in reasonable time (< 1 second for 1000 requests)
assert elapsed < 1.0
-
+
def test_distributed_scenario(self, rate_limiter, mock_redis):
"""Test distributed rate limiting scenario"""
# Simulate multiple servers checking same user
rules = [
RateLimitRule(
- rule_id=f"rule_{i}",
+ rule_id=f'rule_{i}',
rule_type=RuleType.PER_USER,
time_window=TimeWindow.MINUTE,
- limit=100
+ limit=100,
)
for i in range(5)
]
-
- mock_redis.get.return_value = "50"
-
+
+ mock_redis.get.return_value = '50'
+
# All servers should get consistent results
- results = [
- rate_limiter.check_hybrid(rule, "test_user")
- for rule in rules
- ]
-
+ results = [rate_limiter.check_hybrid(rule, 'test_user') for rule in rules]
+
# All should have same limit
assert all(r.limit == 100 for r in results)
@@ -394,38 +364,39 @@ class TestRateLimitLoad:
# BURST HANDLING TESTS
# ============================================================================
+
class TestBurstHandling:
"""Specific tests for burst handling"""
-
+
def test_burst_allows_spike(self, rate_limiter, sample_rule, mock_redis):
"""Test burst allows temporary spike"""
# Normal limit reached
- mock_redis.get.side_effect = ["100", "0"]
-
- result = rate_limiter._use_burst_tokens(sample_rule, "test_user", Mock())
-
+ mock_redis.get.side_effect = ['100', '0']
+
+ result = rate_limiter._use_burst_tokens(sample_rule, 'test_user', Mock())
+
assert result.allowed is True
-
+
def test_burst_refills_over_time(self, rate_limiter, sample_rule, mock_redis):
"""Test burst tokens refill over time"""
# Burst used, then check again after time window
- mock_redis.get.side_effect = ["20", "0"] # Burst exhausted, then refilled
-
+ mock_redis.get.side_effect = ['20', '0'] # Burst exhausted, then refilled
+
# First check: exhausted
- result1 = rate_limiter._use_burst_tokens(sample_rule, "test_user", Mock())
-
+ rate_limiter._use_burst_tokens(sample_rule, 'test_user', Mock())
+
# Simulate time passing (would reset in Redis)
- mock_redis.get.side_effect = ["0"]
-
+ mock_redis.get.side_effect = ['0']
+
# Second check: refilled
- result2 = rate_limiter._use_burst_tokens(sample_rule, "test_user", Mock())
-
+ rate_limiter._use_burst_tokens(sample_rule, 'test_user', Mock())
+
def test_burst_tracking(self, rate_limiter, sample_rule, mock_redis):
"""Test burst usage is tracked separately"""
- mock_redis.get.return_value = "5"
-
- result = rate_limiter._use_burst_tokens(sample_rule, "test_user", Mock())
-
+ mock_redis.get.return_value = '5'
+
+ rate_limiter._use_burst_tokens(sample_rule, 'test_user', Mock())
+
# Should track burst usage
mock_redis.incr.assert_called()
@@ -434,39 +405,33 @@ class TestBurstHandling:
# PERFORMANCE TESTS
# ============================================================================
+
class TestPerformance:
"""Performance optimization tests"""
-
+
def test_redis_pipeline_usage(self, rate_limiter, mock_redis):
"""Test Redis pipeline is used for batch operations"""
rule = RateLimitRule(
- rule_id="test",
- rule_type=RuleType.PER_USER,
- time_window=TimeWindow.MINUTE,
- limit=100
+ rule_id='test', rule_type=RuleType.PER_USER, time_window=TimeWindow.MINUTE, limit=100
)
-
- mock_redis.get.return_value = "50"
-
- rate_limiter.check_hybrid(rule, "test_user")
-
+
+ mock_redis.get.return_value = '50'
+
+ rate_limiter.check_hybrid(rule, 'test_user')
+
# Pipeline should be used for atomic operations
# (Implementation dependent)
-
+
def test_counter_increment_efficiency(self, quota_tracker, mock_redis):
"""Test counter increment is efficient"""
start_time = time.time()
-
+
# Increment 100 times
for i in range(100):
- quota_tracker.increment_quota(
- f"user_{i}",
- QuotaType.REQUESTS,
- "month"
- )
-
+ quota_tracker.increment_quota(f'user_{i}', QuotaType.REQUESTS, 'month')
+
elapsed = time.time() - start_time
-
+
# Should be fast (< 0.1 seconds)
assert elapsed < 0.1
@@ -475,33 +440,34 @@ class TestPerformance:
# ERROR HANDLING TESTS
# ============================================================================
+
class TestErrorHandling:
"""Test error handling and graceful degradation"""
-
+
def test_redis_connection_failure(self, rate_limiter, sample_rule):
"""Test graceful handling of Redis connection failure"""
redis = Mock()
- redis.get.side_effect = Exception("Connection failed")
-
+ redis.get.side_effect = Exception('Connection failed')
+
limiter = RateLimiter(redis_client=redis)
-
+
# Should not crash, should allow by default (fail open)
- result = limiter.check_hybrid(sample_rule, "test_user")
-
+ result = limiter.check_hybrid(sample_rule, 'test_user')
+
# Graceful degradation: allow request
assert result is not None
-
+
def test_invalid_rule_parameters(self, rate_limiter):
"""Test handling of invalid rule parameters"""
invalid_rule = RateLimitRule(
- rule_id="invalid",
+ rule_id='invalid',
rule_type=RuleType.PER_USER,
time_window=TimeWindow.MINUTE,
- limit=0 # Invalid limit
+ limit=0, # Invalid limit
)
-
+
# Should handle gracefully
- result = rate_limiter.check_hybrid(invalid_rule, "test_user")
+ result = rate_limiter.check_hybrid(invalid_rule, 'test_user')
assert result is not None
@@ -509,5 +475,5 @@ class TestErrorHandling:
# RUN TESTS
# ============================================================================
-if __name__ == "__main__":
- pytest.main([__file__, "-v", "--tb=short"])
+if __name__ == '__main__':
+ pytest.main([__file__, '-v', '--tb=short'])
diff --git a/backend-services/tests/test_real_compression_benchmarks.py b/backend-services/tests/test_real_compression_benchmarks.py
index 3b2d151..41cc0eb 100644
--- a/backend-services/tests/test_real_compression_benchmarks.py
+++ b/backend-services/tests/test_real_compression_benchmarks.py
@@ -8,13 +8,14 @@ This test creates realistic API payloads and measures:
4. Realistic throughput estimates
"""
-import pytest
+import asyncio
import gzip
-import json
import io
+import json
import os
import time
-import asyncio
+
+import pytest
@pytest.mark.asyncio
@@ -23,7 +24,7 @@ async def test_realistic_rest_api_response_compression(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r = await client.post('/platform/authorization', json=login_payload)
@@ -40,7 +41,7 @@ async def test_realistic_rest_api_response_compression(client):
'api_servers': [
'https://api-primary.example.com',
'https://api-secondary.example.com',
- 'https://api-backup.example.com'
+ 'https://api-backup.example.com',
],
'api_type': 'REST',
'api_allowed_retry_count': 3,
@@ -49,7 +50,7 @@ async def test_realistic_rest_api_response_compression(client):
'api_credits_enabled': False,
'api_rate_limit_enabled': True,
'api_rate_limit_requests': 1000,
- 'api_rate_limit_window': 60
+ 'api_rate_limit_window': 60,
}
r = await client.post('/platform/api', json=api_payload)
@@ -74,24 +75,26 @@ async def test_realistic_rest_api_response_compression(client):
results[level] = {
'compressed_size': compressed_size,
'ratio': ratio,
- 'time_ms': compression_time
+ 'time_ms': compression_time,
}
- print(f"\n{'='*70}")
- print(f"REALISTIC REST API RESPONSE COMPRESSION BENCHMARK")
- print(f"{'='*70}")
- print(f"Uncompressed size: {uncompressed_size:,} bytes")
- print(f"\nCompression Results:")
+ print(f'\n{"=" * 70}')
+ print('REALISTIC REST API RESPONSE COMPRESSION BENCHMARK')
+ print(f'{"=" * 70}')
+ print(f'Uncompressed size: {uncompressed_size:,} bytes')
+ print('\nCompression Results:')
for level, result in results.items():
- print(f" Level {level}: {result['compressed_size']:,} bytes "
- f"({result['ratio']:.1f}% reduction) "
- f"in {result['time_ms']:.3f}ms")
+ print(
+ f' Level {level}: {result["compressed_size"]:,} bytes '
+ f'({result["ratio"]:.1f}% reduction) '
+ f'in {result["time_ms"]:.3f}ms'
+ )
# Cleanup
await client.delete(f'/platform/api/{api_name}/v1')
# Assertions
- assert results[6]['ratio'] > 0, "Should achieve some compression"
+ assert results[6]['ratio'] > 0, 'Should achieve some compression'
@pytest.mark.asyncio
@@ -100,7 +103,7 @@ async def test_typical_api_gateway_request_flow(client):
# Login
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
# Measure login response
@@ -127,20 +130,22 @@ async def test_typical_api_gateway_request_flow(client):
gz.write(health_json.encode('utf-8'))
health_compressed = len(compressed_buffer.getvalue())
- print(f"\n{'='*70}")
- print(f"TYPICAL API GATEWAY REQUEST FLOW")
- print(f"{'='*70}")
- print(f"\nLogin (POST /platform/authorization):")
- print(f" Uncompressed: {login_uncompressed:,} bytes")
- print(f" Compressed: {login_compressed:,} bytes")
- print(f" Ratio: {(1 - login_compressed/login_uncompressed)*100:.1f}% reduction")
- print(f"\nHealth Check (GET /api/health):")
- print(f" Uncompressed: {health_uncompressed:,} bytes")
- print(f" Compressed: {health_compressed:,} bytes")
+ print(f'\n{"=" * 70}')
+ print('TYPICAL API GATEWAY REQUEST FLOW')
+ print(f'{"=" * 70}')
+ print('\nLogin (POST /platform/authorization):')
+ print(f' Uncompressed: {login_uncompressed:,} bytes')
+ print(f' Compressed: {login_compressed:,} bytes')
+ print(f' Ratio: {(1 - login_compressed / login_uncompressed) * 100:.1f}% reduction')
+ print('\nHealth Check (GET /api/health):')
+ print(f' Uncompressed: {health_uncompressed:,} bytes')
+ print(f' Compressed: {health_compressed:,} bytes')
if health_uncompressed > 500:
- print(f" Ratio: {(1 - health_compressed/health_uncompressed)*100:.1f}% reduction")
+ print(
+ f' Ratio: {(1 - health_compressed / health_uncompressed) * 100:.1f}% reduction'
+ )
else:
- print(f" Note: Below 500 byte minimum - not compressed in production")
+ print(' Note: Below 500 byte minimum - not compressed in production')
@pytest.mark.asyncio
@@ -148,7 +153,7 @@ async def test_large_list_response_compression(client):
"""Test compression on large list responses (multiple APIs)"""
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r = await client.post('/platform/authorization', json=login_payload)
@@ -166,11 +171,11 @@ async def test_large_list_response_compression(client):
'api_allowed_groups': ['ALL', 'TEAM_A', 'TEAM_B'],
'api_servers': [
f'https://api-{i}-primary.example.com',
- f'https://api-{i}-secondary.example.com'
+ f'https://api-{i}-secondary.example.com',
],
'api_type': 'REST',
'api_allowed_retry_count': 3,
- 'active': True
+ 'active': True,
}
r = await client.post('/platform/api', json=api_payload)
if r.status_code in (200, 201):
@@ -193,10 +198,10 @@ async def test_large_list_response_compression(client):
'https://api-prod-1.example.com',
'https://api-prod-2.example.com',
'https://api-prod-3.example.com',
- 'https://api-dr.example.com'
+ 'https://api-dr.example.com',
],
'api_type': 'REST',
- 'active': True
+ 'active': True,
}
r = await client.post('/platform/api', json=api_payload)
@@ -209,14 +214,14 @@ async def test_large_list_response_compression(client):
gz.write(json_str.encode('utf-8'))
compressed = len(compressed_buffer.getvalue())
- ratio = (1 - compressed/uncompressed) * 100
+ ratio = (1 - compressed / uncompressed) * 100
- print(f"\n{'='*70}")
- print(f"LARGE API CONFIGURATION RESPONSE")
- print(f"{'='*70}")
- print(f"Uncompressed: {uncompressed:,} bytes")
- print(f"Compressed: {compressed:,} bytes")
- print(f"Ratio: {ratio:.1f}% reduction")
+ print(f'\n{"=" * 70}')
+ print('LARGE API CONFIGURATION RESPONSE')
+ print(f'{"=" * 70}')
+ print(f'Uncompressed: {uncompressed:,} bytes')
+ print(f'Compressed: {compressed:,} bytes')
+ print(f'Ratio: {ratio:.1f}% reduction')
api_names.append(api_payload['api_name'])
@@ -236,14 +241,11 @@ async def test_worst_case_already_compressed_data(client):
# Simulate a JWT token (random-looking base64)
import base64
+
random_bytes = os.urandom(256)
token_like = base64.b64encode(random_bytes).decode('utf-8')
- response_data = {
- 'status': 'success',
- 'token': token_like,
- 'expires_in': 3600
- }
+ response_data = {'status': 'success', 'token': token_like, 'expires_in': 3600}
json_str = json.dumps(response_data, separators=(',', ':'))
uncompressed = len(json_str.encode('utf-8'))
@@ -253,15 +255,15 @@ async def test_worst_case_already_compressed_data(client):
gz.write(json_str.encode('utf-8'))
compressed = len(compressed_buffer.getvalue())
- ratio = (1 - compressed/uncompressed) * 100
+ ratio = (1 - compressed / uncompressed) * 100
- print(f"\n{'='*70}")
- print(f"WORST CASE: RANDOM DATA (JWT-like tokens)")
- print(f"{'='*70}")
- print(f"Uncompressed: {uncompressed:,} bytes")
- print(f"Compressed: {compressed:,} bytes")
- print(f"Ratio: {ratio:.1f}% reduction")
- print(f"\nNote: Even random data achieves some compression due to JSON structure")
+ print(f'\n{"=" * 70}')
+ print('WORST CASE: RANDOM DATA (JWT-like tokens)')
+ print(f'{"=" * 70}')
+ print(f'Uncompressed: {uncompressed:,} bytes')
+ print(f'Compressed: {compressed:,} bytes')
+ print(f'Ratio: {ratio:.1f}% reduction')
+ print('\nNote: Even random data achieves some compression due to JSON structure')
@pytest.mark.asyncio
@@ -282,27 +284,22 @@ async def test_compression_cpu_overhead_estimate(client):
'created_at': '2025-01-15T12:00:00Z',
'updated_at': '2025-01-18T15:30:00Z',
'views': 1234,
- 'likes': 567
- }
+ 'likes': 567,
+ },
}
for i in range(50) # 50 products
],
- 'pagination': {
- 'page': 1,
- 'per_page': 50,
- 'total': 500,
- 'total_pages': 10
- }
+ 'pagination': {'page': 1, 'per_page': 50, 'total': 500, 'total_pages': 10},
}
json_str = json.dumps(large_response, separators=(',', ':'))
data = json_str.encode('utf-8')
- print(f"\n{'='*70}")
- print(f"COMPRESSION CPU OVERHEAD BENCHMARK")
- print(f"{'='*70}")
- print(f"Test payload: {len(data):,} bytes")
- print(f"\nCompression performance:")
+ print(f'\n{"=" * 70}')
+ print('COMPRESSION CPU OVERHEAD BENCHMARK')
+ print(f'{"=" * 70}')
+ print(f'Test payload: {len(data):,} bytes')
+ print('\nCompression performance:')
for level in [1, 4, 6, 9]:
# Warm-up
@@ -324,7 +321,7 @@ async def test_compression_cpu_overhead_estimate(client):
elapsed = time.perf_counter() - start
avg_time_ms = (elapsed / iterations) * 1000
avg_size = total_compressed // iterations
- ratio = (1 - avg_size/len(data)) * 100
+ ratio = (1 - avg_size / len(data)) * 100
# Estimate RPS capacity (assuming 50ms total request time)
# Compression adds overhead, reducing available CPU time
@@ -332,11 +329,11 @@ async def test_compression_cpu_overhead_estimate(client):
with_compression_time = base_request_time + avg_time_ms
rps_impact = (avg_time_ms / with_compression_time) * 100
- print(f"\n Level {level}:")
- print(f" Time: {avg_time_ms:.3f} ms/request")
- print(f" Compressed size: {avg_size:,} bytes ({ratio:.1f}% reduction)")
- print(f" CPU overhead: {rps_impact:.1f}% of total request time")
- print(f" Throughput: ~{1000/avg_time_ms:.0f} compressions/sec (single core)")
+ print(f'\n Level {level}:')
+ print(f' Time: {avg_time_ms:.3f} ms/request')
+ print(f' Compressed size: {avg_size:,} bytes ({ratio:.1f}% reduction)')
+ print(f' CPU overhead: {rps_impact:.1f}% of total request time')
+ print(f' Throughput: ~{1000 / avg_time_ms:.0f} compressions/sec (single core)')
pytestmark = [pytest.mark.benchmark]
diff --git a/backend-services/tests/test_redis_token_revocation_ha.py b/backend-services/tests/test_redis_token_revocation_ha.py
index 7be1bdd..4ad1465 100644
--- a/backend-services/tests/test_redis_token_revocation_ha.py
+++ b/backend-services/tests/test_redis_token_revocation_ha.py
@@ -7,9 +7,11 @@ Simulates multi-node scenario:
- Token validation on "Node B" (different process) should fail
"""
-import pytest
import os
+import pytest
+
+
@pytest.mark.asyncio
async def test_redis_token_revocation_shared_across_processes(monkeypatch, authed_client):
"""Test that token revocation via Redis is visible across simulated nodes.
@@ -22,6 +24,7 @@ async def test_redis_token_revocation_shared_across_processes(monkeypatch, authe
monkeypatch.setenv('MEM_OR_EXTERNAL', 'REDIS')
from utils import auth_blacklist
+
auth_blacklist._redis_client = None
auth_blacklist._redis_enabled = False
auth_blacklist._init_redis_if_possible()
@@ -31,7 +34,10 @@ async def test_redis_token_revocation_shared_across_processes(monkeypatch, authe
login_response = await authed_client.post(
'/platform/authorization',
- json={'email': os.environ.get('DOORMAN_ADMIN_EMAIL'), 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD')}
+ json={
+ 'email': os.environ.get('DOORMAN_ADMIN_EMAIL'),
+ 'password': os.environ.get('DOORMAN_ADMIN_PASSWORD'),
+ },
)
assert login_response.status_code == 200
token_data = login_response.json()
@@ -39,11 +45,8 @@ async def test_redis_token_revocation_shared_across_processes(monkeypatch, authe
assert access_token is not None
from jose import jwt
- payload = jwt.decode(
- access_token,
- os.environ.get('JWT_SECRET_KEY'),
- algorithms=['HS256']
- )
+
+ payload = jwt.decode(access_token, os.environ.get('JWT_SECRET_KEY'), algorithms=['HS256'])
jti = payload.get('jti')
username = payload.get('sub')
exp = payload.get('exp')
@@ -52,6 +55,7 @@ async def test_redis_token_revocation_shared_across_processes(monkeypatch, authe
assert username is not None
import time
+
ttl = max(1, int(exp - time.time())) if exp else 3600
auth_blacklist.add_revoked_jti(username, jti, ttl)
@@ -65,12 +69,14 @@ async def test_redis_token_revocation_shared_across_processes(monkeypatch, authe
if auth_blacklist._redis_client:
auth_blacklist._redis_client.delete(auth_blacklist._revoked_jti_key(username, jti))
+
@pytest.mark.asyncio
async def test_redis_revoke_all_for_user_shared_across_processes(monkeypatch):
"""Test that user-level revocation via Redis is visible across nodes."""
monkeypatch.setenv('MEM_OR_EXTERNAL', 'REDIS')
from utils import auth_blacklist
+
auth_blacklist._redis_client = None
auth_blacklist._redis_enabled = False
auth_blacklist._init_redis_if_possible()
@@ -93,14 +99,16 @@ async def test_redis_revoke_all_for_user_shared_across_processes(monkeypatch):
is_revoked_after_cleanup = auth_blacklist.is_user_revoked(test_username)
assert is_revoked_after_cleanup is False
+
@pytest.mark.asyncio
async def test_redis_token_revocation_ttl_expiry(monkeypatch):
"""Test that revoked tokens auto-expire in Redis based on TTL."""
monkeypatch.setenv('MEM_OR_EXTERNAL', 'REDIS')
- from utils import auth_blacklist
import time
+ from utils import auth_blacklist
+
auth_blacklist._redis_client = None
auth_blacklist._redis_enabled = False
auth_blacklist._init_redis_if_possible()
@@ -119,6 +127,7 @@ async def test_redis_token_revocation_ttl_expiry(monkeypatch):
assert auth_blacklist.is_jti_revoked(test_username, test_jti) is False
+
@pytest.mark.asyncio
async def test_memory_fallback_when_redis_unavailable(monkeypatch):
"""Test that system falls back to in-memory revocation when Redis is unavailable."""
@@ -138,4 +147,3 @@ async def test_memory_fallback_when_redis_unavailable(monkeypatch):
auth_blacklist.add_revoked_jti(test_username, test_jti, ttl_seconds=60)
assert auth_blacklist.is_jti_revoked(test_username, test_jti) is True
-
diff --git a/backend-services/tests/test_request_id_and_logging_redaction.py b/backend-services/tests/test_request_id_and_logging_redaction.py
index 77daa2f..5b3554f 100644
--- a/backend-services/tests/test_request_id_and_logging_redaction.py
+++ b/backend-services/tests/test_request_id_and_logging_redaction.py
@@ -1,19 +1,23 @@
-import pytest
import logging
from io import StringIO
+import pytest
+
+
@pytest.mark.asyncio
async def test_request_id_middleware_injects_header_when_missing(authed_client):
r = await authed_client.get('/api/status')
assert r.status_code == 200
assert r.headers.get('X-Request-ID')
+
@pytest.mark.asyncio
async def test_request_id_middleware_preserves_existing_header(authed_client):
r = await authed_client.get('/api/status', headers={'X-Request-ID': 'req-123'})
assert r.status_code == 200
assert r.headers.get('X-Request-ID') == 'req-123'
+
def _capture_logs(logger_name: str, message: str) -> str:
logger = logging.getLogger(logger_name)
stream = StringIO()
@@ -26,18 +30,20 @@ def _capture_logs(logger_name: str, message: str) -> str:
logger.removeHandler(handler)
return stream.getvalue()
+
def test_logging_redacts_authorization_headers():
msg = 'Authorization: Bearer secret-token'
out = _capture_logs('doorman.gateway', msg)
assert 'Authorization: [REDACTED]' in out
+
def test_logging_redacts_access_refresh_tokens():
msg = 'access_token="abc123" refresh_token="def456"'
out = _capture_logs('doorman.gateway', msg)
assert 'access_token' in out and '[REDACTED]' in out
+
def test_logging_redacts_cookie_values():
msg = 'cookie: sessionid=abcdef; csrftoken=xyz'
out = _capture_logs('doorman.gateway', msg)
assert 'cookie: [REDACTED]' in out
-
diff --git a/backend-services/tests/test_request_id_propagation.py b/backend-services/tests/test_request_id_propagation.py
index a575231..d3160ab 100644
--- a/backend-services/tests/test_request_id_propagation.py
+++ b/backend-services/tests/test_request_id_propagation.py
@@ -1,13 +1,16 @@
import httpx
import pytest
+
@pytest.mark.asyncio
async def test_request_id_propagates_to_upstream_and_response(monkeypatch, authed_client):
captured = {'xrid': None}
def handler(req: httpx.Request) -> httpx.Response:
captured['xrid'] = req.headers.get('X-Request-ID')
- return httpx.Response(200, json={'ok': True}, headers={'X-Upstream-Request-ID': captured['xrid'] or ''})
+ return httpx.Response(
+ 200, json={'ok': True}, headers={'X-Upstream-Request-ID': captured['xrid'] or ''}
+ )
transport = httpx.MockTransport(handler)
mock_client = httpx.AsyncClient(transport=transport)
@@ -17,7 +20,9 @@ async def test_request_id_propagates_to_upstream_and_response(monkeypatch, authe
async def _get_client():
return mock_client
- monkeypatch.setattr(gateway_service.GatewayService, 'get_http_client', classmethod(lambda cls: mock_client))
+ monkeypatch.setattr(
+ gateway_service.GatewayService, 'get_http_client', classmethod(lambda cls: mock_client)
+ )
api_name, api_version = 'ridtest', 'v1'
payload = {
@@ -34,16 +39,22 @@ async def test_request_id_propagates_to_upstream_and_response(monkeypatch, authe
r = await authed_client.post('/platform/api', json=payload)
assert r.status_code in (200, 201), r.text
- r2 = await authed_client.post('/platform/endpoint', json={
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/echo',
- 'endpoint_description': 'echo'
- })
+ r2 = await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/echo',
+ 'endpoint_description': 'echo',
+ },
+ )
assert r2.status_code in (200, 201), r2.text
- sub = await authed_client.post('/platform/subscription/subscribe', json={'username': 'admin', 'api_name': api_name, 'api_version': api_version})
+ sub = await authed_client.post(
+ '/platform/subscription/subscribe',
+ json={'username': 'admin', 'api_name': api_name, 'api_version': api_version},
+ )
assert sub.status_code in (200, 201), sub.text
resp = await authed_client.get(f'/api/rest/{api_name}/{api_version}/echo')
diff --git a/backend-services/tests/test_response_compression.py b/backend-services/tests/test_response_compression.py
index d39662c..d6c3892 100644
--- a/backend-services/tests/test_response_compression.py
+++ b/backend-services/tests/test_response_compression.py
@@ -10,21 +10,17 @@ Verifies:
- Compression level affects size and performance
"""
-import pytest
-import gzip
import json
import os
-from httpx import AsyncClient
+
+import pytest
@pytest.mark.asyncio
async def test_compression_enabled_for_json_response(client):
"""Verify JSON responses are compressed when client accepts gzip"""
# Request with Accept-Encoding: gzip header
- r = await client.get(
- '/api/health',
- headers={'Accept-Encoding': 'gzip'}
- )
+ r = await client.get('/api/health', headers={'Accept-Encoding': 'gzip'})
assert r.status_code == 200
# httpx automatically decompresses responses, but the middleware
@@ -38,15 +34,12 @@ async def test_compression_reduces_response_size(client):
# Get a response without compression
r_uncompressed = await client.get(
'/api/health',
- headers={'Accept-Encoding': 'identity'} # No compression
+ headers={'Accept-Encoding': 'identity'}, # No compression
)
assert r_uncompressed.status_code == 200
# Get same response with compression
- r_compressed = await client.get(
- '/api/health',
- headers={'Accept-Encoding': 'gzip'}
- )
+ r_compressed = await client.get('/api/health', headers={'Accept-Encoding': 'gzip'})
assert r_compressed.status_code == 200
# Both should have same decompressed content
@@ -62,17 +55,14 @@ async def test_compression_with_large_json_list(client):
# First authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
assert r_auth.status_code == 200
# Get list of APIs (potentially large response)
- r = await client.get(
- '/platform/api',
- headers={'Accept-Encoding': 'gzip'}
- )
+ r = await client.get('/platform/api', headers={'Accept-Encoding': 'gzip'})
assert r.status_code == 200
# Should have content-encoding header if response is large enough
@@ -89,10 +79,7 @@ async def test_compression_with_large_json_list(client):
async def test_small_response_not_compressed(client):
"""Verify small responses below minimum_size are not compressed"""
# Health endpoint returns small response
- r = await client.get(
- '/api/health',
- headers={'Accept-Encoding': 'gzip'}
- )
+ r = await client.get('/api/health', headers={'Accept-Encoding': 'gzip'})
assert r.status_code == 200
response_size = len(r.content)
@@ -100,7 +87,7 @@ async def test_small_response_not_compressed(client):
# If response is smaller than minimum_size (500 bytes default),
# it should not be compressed
if response_size < 500:
- headers_lower = {k.lower(): v for k, v in r.headers.items()}
+ {k.lower(): v for k, v in r.headers.items()}
# May or may not have content-encoding based on actual size
# This is expected behavior - small responses aren't worth compressing
@@ -111,7 +98,7 @@ async def test_compression_with_different_content_types(client):
# Authenticate first
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
@@ -123,10 +110,7 @@ async def test_compression_with_different_content_types(client):
]
for endpoint, expected_content_type in endpoints:
- r = await client.get(
- endpoint,
- headers={'Accept-Encoding': 'gzip'}
- )
+ r = await client.get(endpoint, headers={'Accept-Encoding': 'gzip'})
# Should succeed
assert r.status_code == 200
@@ -139,10 +123,7 @@ async def test_compression_with_different_content_types(client):
@pytest.mark.asyncio
async def test_no_compression_when_not_requested(client):
"""Verify compression is not applied when client doesn't accept it"""
- r = await client.get(
- '/api/health',
- headers={'Accept-Encoding': 'identity'}
- )
+ r = await client.get('/api/health', headers={'Accept-Encoding': 'identity'})
assert r.status_code == 200
# Should not have gzip encoding
@@ -157,23 +138,17 @@ async def test_compression_preserves_response_body(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
assert r_auth.status_code == 200
# Get response without compression
- r1 = await client.get(
- '/platform/user',
- headers={'Accept-Encoding': 'identity'}
- )
+ r1 = await client.get('/platform/user', headers={'Accept-Encoding': 'identity'})
# Get response with compression
- r2 = await client.get(
- '/platform/user',
- headers={'Accept-Encoding': 'gzip'}
- )
+ r2 = await client.get('/platform/user', headers={'Accept-Encoding': 'gzip'})
# Both should succeed
assert r1.status_code == 200
@@ -189,13 +164,11 @@ async def test_compression_with_post_request(client):
# Login with compression
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r = await client.post(
- '/platform/authorization',
- json=login_payload,
- headers={'Accept-Encoding': 'gzip'}
+ '/platform/authorization', json=login_payload, headers={'Accept-Encoding': 'gzip'}
)
assert r.status_code == 200
@@ -216,10 +189,7 @@ async def test_compression_works_with_errors(client):
except Exception:
pass
- r = await client.get(
- '/platform/user',
- headers={'Accept-Encoding': 'gzip'}
- )
+ r = await client.get('/platform/user', headers={'Accept-Encoding': 'gzip'})
# Should be unauthorized
assert r.status_code in (401, 403)
@@ -236,10 +206,7 @@ async def test_compression_works_with_errors(client):
@pytest.mark.asyncio
async def test_compression_with_cache_headers(client):
"""Verify compression doesn't interfere with cache headers"""
- r = await client.get(
- '/api/health',
- headers={'Accept-Encoding': 'gzip'}
- )
+ r = await client.get('/api/health', headers={'Accept-Encoding': 'gzip'})
assert r.status_code == 200
# Response should still have normal headers
@@ -265,7 +232,7 @@ async def test_compression_with_large_payload(client):
# Authenticate
login_payload = {
'email': os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev'),
- 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
+ 'password': os.getenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars'),
}
r_auth = await client.post('/platform/authorization', json=login_payload)
@@ -275,10 +242,7 @@ async def test_compression_with_large_payload(client):
# (This is mainly to test that compression handles large payloads)
# Get list of APIs (can be large)
- r = await client.get(
- '/platform/api',
- headers={'Accept-Encoding': 'gzip'}
- )
+ r = await client.get('/platform/api', headers={'Accept-Encoding': 'gzip'})
# Should succeed regardless of size
assert r.status_code == 200
diff --git a/backend-services/tests/test_response_compression_edges.py b/backend-services/tests/test_response_compression_edges.py
new file mode 100644
index 0000000..357e0a3
--- /dev/null
+++ b/backend-services/tests/test_response_compression_edges.py
@@ -0,0 +1,10 @@
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_large_response_is_compressed(authed_client):
+ # Export all config typically returns a payload above compression threshold.
+ r = await authed_client.get('/platform/config/export/all')
+ assert r.status_code == 200
+ ce = (r.headers.get('content-encoding') or '').lower()
+ assert ce == 'gzip'
diff --git a/backend-services/tests/test_response_envelope_and_headers.py b/backend-services/tests/test_response_envelope_and_headers.py
index 5557b0c..55e69e2 100644
--- a/backend-services/tests/test_response_envelope_and_headers.py
+++ b/backend-services/tests/test_response_envelope_and_headers.py
@@ -1,28 +1,37 @@
import pytest
+
@pytest.mark.asyncio
async def test_rest_loose_envelope_returns_raw_message(monkeypatch, authed_client):
name, ver = 'envrest', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://up'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'api_public': True,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/e',
- 'endpoint_description': 'e'
- })
- import services.gateway_service as gs
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'api_public': True,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/e',
+ 'endpoint_description': 'e',
+ },
+ )
from tests.test_gateway_routing_limits import _FakeAsyncClient
+
+ import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
monkeypatch.delenv('STRICT_RESPONSE_ENVELOPE', raising=False)
r = await authed_client.get(f'/api/rest/{name}/{ver}/e')
@@ -30,30 +39,39 @@ async def test_rest_loose_envelope_returns_raw_message(monkeypatch, authed_clien
body = r.json()
assert 'status_code' not in body and body.get('method') == 'GET'
+
@pytest.mark.asyncio
async def test_rest_strict_envelope_wraps_message(monkeypatch, authed_client):
monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true')
name, ver = 'envrest2', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://up'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'api_public': True,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/e',
- 'endpoint_description': 'e'
- })
- import services.gateway_service as gs
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'api_public': True,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/e',
+ 'endpoint_description': 'e',
+ },
+ )
from tests.test_gateway_routing_limits import _FakeAsyncClient
+
+ import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/e')
assert r.status_code == 200
@@ -62,6 +80,7 @@ async def test_rest_strict_envelope_wraps_message(monkeypatch, authed_client):
assert isinstance(body.get('response'), dict)
assert body['response'].get('method') == 'GET'
+
async def _setup_graphql(client, name='envgql', ver='v1'):
payload = {
'api_name': name,
@@ -74,160 +93,228 @@ async def _setup_graphql(client, name='envgql', ver='v1'):
'api_allowed_retry_count': 0,
}
await client.post('/platform/api', json=payload)
- await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/graphql',
- 'endpoint_description': 'gql',
- })
+ await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/graphql',
+ 'endpoint_description': 'gql',
+ },
+ )
from conftest import subscribe_self
+
await subscribe_self(client, name, ver)
return name, ver
+
@pytest.mark.asyncio
async def test_graphql_strict_and_loose_envelopes(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'envgql', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://gql.up'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'api_public': True,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/graphql',
- 'endpoint_description': 'gql'
- })
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://gql.up'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'api_public': True,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/graphql',
+ 'endpoint_description': 'gql',
+ },
+ )
class FakeHTTPResp:
def __init__(self, payload):
self._p = payload
+
def json(self):
return self._p
+
class FakeHTTPX:
def __init__(self, *args, **kwargs):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def post(self, url, json=None, headers=None):
return FakeHTTPResp({'data': {'pong': True}})
+
# Force HTTPX path
class Dummy:
pass
+
monkeypatch.setattr(gs, 'Client', Dummy)
monkeypatch.setattr(gs.httpx, 'AsyncClient', FakeHTTPX)
monkeypatch.delenv('STRICT_RESPONSE_ENVELOPE', raising=False)
- r1 = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ q }', 'variables': {}})
+ r1 = await authed_client.post(
+ f'/api/graphql/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'query': '{ q }', 'variables': {}},
+ )
assert r1.status_code == 200
assert 'status_code' not in r1.json()
monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true')
- r2 = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ q }', 'variables': {}})
+ r2 = await authed_client.post(
+ f'/api/graphql/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'query': '{ q }', 'variables': {}},
+ )
assert r2.status_code == 200
assert r2.json().get('status_code') == 200 and isinstance(r2.json().get('response'), dict)
+
@pytest.mark.asyncio
async def test_grpc_strict_and_loose_envelopes(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'envgrpc', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': 'g',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['grpc://127.0.0.1:9'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/grpc',
- 'endpoint_description': 'grpc'
- })
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'g',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['grpc://127.0.0.1:9'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/grpc',
+ 'endpoint_description': 'grpc',
+ },
+ )
from conftest import subscribe_self
+
await subscribe_self(authed_client, name, ver)
- async def fake_grpc_gateway(username, request, request_id, start_time, path, api_name=None, url=None, retry=0):
+ async def fake_grpc_gateway(
+ username, request, request_id, start_time, path, api_name=None, url=None, retry=0
+ ):
from models.response_model import ResponseModel
- return ResponseModel(status_code=200, response_headers={'request_id': request_id}, response={'ok': True}).dict()
+
+ return ResponseModel(
+ status_code=200, response_headers={'request_id': request_id}, response={'ok': True}
+ ).dict()
+
monkeypatch.setattr(gs.GatewayService, 'grpc_gateway', staticmethod(fake_grpc_gateway))
monkeypatch.delenv('STRICT_RESPONSE_ENVELOPE', raising=False)
- r1 = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}})
+ r1 = await authed_client.post(
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
+ )
assert r1.status_code == 200
assert r1.json() == {'ok': True}
monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true')
- r2 = await authed_client.post(f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}})
+ r2 = await authed_client.post(
+ f'/api/grpc/{name}',
+ headers={'X-API-Version': ver, 'Content-Type': 'application/json'},
+ json={'method': 'Svc.M', 'message': {}},
+ )
assert r2.status_code == 200
assert r2.json().get('status_code') == 200 and r2.json().get('response', {}).get('ok') is True
+
@pytest.mark.asyncio
async def test_header_normalization_sets_x_request_id_from_request_id(monkeypatch, authed_client):
name, ver = 'hdrnorm', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://up'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'api_public': True,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/h',
- 'endpoint_description': 'h'
- })
- import services.gateway_service as gs
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'api_public': True,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/h',
+ 'endpoint_description': 'h',
+ },
+ )
from tests.test_gateway_routing_limits import _FakeAsyncClient
+
+ import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/h')
assert r.status_code == 200
assert r.headers.get('X-Request-ID')
+
@pytest.mark.asyncio
async def test_header_normalization_preserves_existing_x_request_id(monkeypatch, authed_client):
name, ver = 'hdrnorm2', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://up'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- 'api_public': True,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/h',
- 'endpoint_description': 'h'
- })
- import services.gateway_service as gs
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ 'api_public': True,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/h',
+ 'endpoint_description': 'h',
+ },
+ )
from tests.test_gateway_routing_limits import _FakeAsyncClient
+
+ import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.get(f'/api/rest/{name}/{ver}/h', headers={'X-Request-ID': 'my-req-id'})
assert r.status_code == 200
diff --git a/backend-services/tests/test_rest_authorization_field_swap.py b/backend-services/tests/test_rest_authorization_field_swap.py
index 10637e6..d3e19a9 100644
--- a/backend-services/tests/test_rest_authorization_field_swap.py
+++ b/backend-services/tests/test_rest_authorization_field_swap.py
@@ -1,5 +1,6 @@
import pytest
+
class _Resp:
def __init__(self, status_code=200, json_body=None, headers=None):
self.status_code = status_code
@@ -9,17 +10,22 @@ class _Resp:
base_headers.update(headers)
self.headers = base_headers
self.text = ''
+
def json(self):
return self._json_body
+
def _mk_client_capture(seen):
class _Client:
def __init__(self, timeout=None, limits=None, http2=False):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -37,24 +43,48 @@ def _mk_client_capture(seen):
return await self.put(url, **kwargs)
else:
return _Resp(405)
+
async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- payload = {'method': 'POST', 'url': url, 'params': dict(params or {}), 'body': json, 'headers': headers or {}}
- seen.append({'url': url, 'params': dict(params or {}), 'headers': dict(headers or {}), 'json': json})
+ payload = {
+ 'method': 'POST',
+ 'url': url,
+ 'params': dict(params or {}),
+ 'body': json,
+ 'headers': headers or {},
+ }
+ seen.append(
+ {
+ 'url': url,
+ 'params': dict(params or {}),
+ 'headers': dict(headers or {}),
+ 'json': json,
+ }
+ )
return _Resp(200, json_body=payload, headers={'X-Upstream': 'yes'})
+
async def get(self, url, params=None, headers=None, **kwargs):
- payload = {'method': 'GET', 'url': url, 'params': dict(params or {}), 'headers': headers or {}}
+ payload = {
+ 'method': 'GET',
+ 'url': url,
+ 'params': dict(params or {}),
+ 'headers': headers or {},
+ }
seen.append({'url': url, 'params': dict(params or {}), 'headers': dict(headers or {})})
return _Resp(200, json_body=payload, headers={'X-Upstream': 'yes'})
+
async def put(self, url, **kwargs):
payload = {'method': 'PUT', 'url': url, 'params': {}, 'headers': {}}
seen.append({'url': url, 'params': {}, 'headers': {}})
return _Resp(200, json_body=payload, headers={'X-Upstream': 'yes'})
+
async def delete(self, url, **kwargs):
payload = {'method': 'DELETE', 'url': url, 'params': {}, 'headers': {}}
seen.append({'url': url, 'params': {}, 'headers': {}})
return _Resp(200, json_body=payload, headers={'X-Upstream': 'yes'})
+
return _Client
+
async def _setup_api(client, name, ver, swap_header, allowed_headers=None):
payload = {
'api_name': name,
@@ -71,28 +101,35 @@ async def _setup_api(client, name, ver, swap_header, allowed_headers=None):
payload['api_allowed_headers'] = allowed_headers
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/p',
- 'endpoint_description': 'p'
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/p',
+ 'endpoint_description': 'p',
+ },
+ )
assert r2.status_code in (200, 201)
from conftest import subscribe_self
+
await subscribe_self(client, name, ver)
+
@pytest.mark.asyncio
async def test_auth_swap_injects_authorization_from_custom_header(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'authswap1', 'v1'
swap_from = 'x-token'
- await _setup_api(authed_client, name, ver, swap_from, allowed_headers=['authorization', swap_from])
+ await _setup_api(
+ authed_client, name, ver, swap_from, allowed_headers=['authorization', swap_from]
+ )
seen = []
monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_client_capture(seen))
r = await authed_client.get(
- f'/api/rest/{name}/{ver}/p',
- headers={swap_from: 'Bearer backend-token'}
+ f'/api/rest/{name}/{ver}/p', headers={swap_from: 'Bearer backend-token'}
)
assert r.status_code == 200
assert len(seen) == 1
@@ -100,33 +137,41 @@ async def test_auth_swap_injects_authorization_from_custom_header(monkeypatch, a
auth_val = forwarded.get('Authorization') or forwarded.get('authorization')
assert auth_val == 'Bearer backend-token'
+
@pytest.mark.asyncio
async def test_auth_swap_missing_source_header_no_crash(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'authswap2', 'v1'
swap_from = 'X-Backend-Auth'
- await _setup_api(authed_client, name, ver, swap_from, allowed_headers=['Content-Type', swap_from])
+ await _setup_api(
+ authed_client, name, ver, swap_from, allowed_headers=['Content-Type', swap_from]
+ )
seen = []
monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_client_capture(seen))
- r = await authed_client.get(
- f'/api/rest/{name}/{ver}/p',
- headers={}
- )
+ r = await authed_client.get(f'/api/rest/{name}/{ver}/p', headers={})
assert r.status_code == 200
fwd = (r.json() or {}).get('headers') or {}
assert not (('Authorization' in fwd) or ('authorization' in fwd))
+
@pytest.mark.asyncio
async def test_auth_swap_with_empty_value_does_not_override(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'authswap3', 'v1'
swap_from = 'X-Backend-Auth'
- await _setup_api(authed_client, name, ver, swap_from, allowed_headers=['Content-Type', swap_from, 'Authorization'])
+ await _setup_api(
+ authed_client,
+ name,
+ ver,
+ swap_from,
+ allowed_headers=['Content-Type', swap_from, 'Authorization'],
+ )
seen = []
monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_client_capture(seen))
r = await authed_client.get(
- f'/api/rest/{name}/{ver}/p',
- headers={swap_from: '', 'Authorization': 'Bearer existing'}
+ f'/api/rest/{name}/{ver}/p', headers={swap_from: '', 'Authorization': 'Bearer existing'}
)
assert r.status_code == 200
fwd = (r.json() or {}).get('headers') or {}
diff --git a/backend-services/tests/test_rest_gateway_retries.py b/backend-services/tests/test_rest_gateway_retries.py
index a75fb12..07dbcf1 100644
--- a/backend-services/tests/test_rest_gateway_retries.py
+++ b/backend-services/tests/test_rest_gateway_retries.py
@@ -1,5 +1,6 @@
import pytest
+
class _Resp:
def __init__(self, status_code=200, body=b'{"ok":true}', headers=None):
self.status_code = status_code
@@ -12,8 +13,10 @@ class _Resp:
def json(self):
import json
+
return json.loads(self.text)
+
def _mk_retry_client(sequence, seen):
"""Factory for a fake AsyncClient that returns statuses from `sequence`.
Records each call's (url, headers, params) into `seen` list.
@@ -50,7 +53,14 @@ def _mk_retry_client(sequence, seen):
return _Resp(405)
async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- seen.append({'url': url, 'params': dict(params or {}), 'headers': dict(headers or {}), 'json': json})
+ seen.append(
+ {
+ 'url': url,
+ 'params': dict(params or {}),
+ 'headers': dict(headers or {}),
+ 'json': json,
+ }
+ )
idx = min(counter['i'], len(sequence) - 1)
code = sequence[idx]
counter['i'] = counter['i'] + 1
@@ -79,6 +89,7 @@ def _mk_retry_client(sequence, seen):
return _Client
+
async def _setup_api(client, name, ver, retry_count=0, allowed_headers=None):
payload = {
'api_name': name,
@@ -94,20 +105,26 @@ async def _setup_api(client, name, ver, retry_count=0, allowed_headers=None):
payload['api_allowed_headers'] = allowed_headers
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/p',
- 'endpoint_description': 'p'
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/p',
+ 'endpoint_description': 'p',
+ },
+ )
assert r2.status_code in (200, 201)
from conftest import subscribe_self
+
await subscribe_self(client, name, ver)
+
@pytest.mark.asyncio
async def test_rest_retry_on_500_then_success(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'retry500', 'v1'
await _setup_api(authed_client, name, ver, retry_count=2)
seen = []
@@ -116,9 +133,11 @@ async def test_rest_retry_on_500_then_success(monkeypatch, authed_client):
assert r.status_code == 200
assert len(seen) == 2
+
@pytest.mark.asyncio
async def test_rest_retry_on_502_then_success(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'retry502', 'v1'
await _setup_api(authed_client, name, ver, retry_count=2)
seen = []
@@ -127,9 +146,11 @@ async def test_rest_retry_on_502_then_success(monkeypatch, authed_client):
assert r.status_code == 200
assert len(seen) == 2
+
@pytest.mark.asyncio
async def test_rest_retry_on_503_then_success(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'retry503', 'v1'
await _setup_api(authed_client, name, ver, retry_count=2)
seen = []
@@ -138,9 +159,11 @@ async def test_rest_retry_on_503_then_success(monkeypatch, authed_client):
assert r.status_code == 200
assert len(seen) == 2
+
@pytest.mark.asyncio
async def test_rest_retry_on_504_then_success(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'retry504', 'v1'
await _setup_api(authed_client, name, ver, retry_count=2)
seen = []
@@ -149,9 +172,11 @@ async def test_rest_retry_on_504_then_success(monkeypatch, authed_client):
assert r.status_code == 200
assert len(seen) == 2
+
@pytest.mark.asyncio
async def test_rest_no_retry_when_retry_count_zero(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'retry0', 'v1'
await _setup_api(authed_client, name, ver, retry_count=0)
seen = []
@@ -160,9 +185,11 @@ async def test_rest_no_retry_when_retry_count_zero(monkeypatch, authed_client):
assert r.status_code == 500
assert len(seen) == 1
+
@pytest.mark.asyncio
async def test_rest_retry_stops_after_limit(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'retryLimit', 'v1'
await _setup_api(authed_client, name, ver, retry_count=1)
seen = []
@@ -171,9 +198,11 @@ async def test_rest_retry_stops_after_limit(monkeypatch, authed_client):
assert r.status_code == 500
assert len(seen) == 2
+
@pytest.mark.asyncio
async def test_rest_retry_preserves_headers_and_params(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'retryHdr', 'v1'
await _setup_api(authed_client, name, ver, retry_count=1, allowed_headers=['X-Custom'])
seen = []
@@ -181,11 +210,13 @@ async def test_rest_retry_preserves_headers_and_params(monkeypatch, authed_clien
r = await authed_client.post(
f'/api/rest/{name}/{ver}/p?foo=bar',
headers={'X-Custom': 'abc', 'Content-Type': 'application/json'},
- json={'a': 1}
+ json={'a': 1},
)
assert r.status_code == 200
assert len(seen) == 2
assert all(call['params'].get('foo') == 'bar' for call in seen)
+
def _hdr(call):
return call['headers'].get('X-Custom') or call['headers'].get('x-custom')
+
assert all(_hdr(call) == 'abc' for call in seen)
diff --git a/backend-services/tests/test_rest_header_and_response_parsing.py b/backend-services/tests/test_rest_header_and_response_parsing.py
index 18cf795..a1945a2 100644
--- a/backend-services/tests/test_rest_header_and_response_parsing.py
+++ b/backend-services/tests/test_rest_header_and_response_parsing.py
@@ -1,5 +1,6 @@
import pytest
+
class _Resp:
def __init__(self, status_code=200, body=b'{"ok":true}', headers=None):
self.status_code = status_code
@@ -22,16 +23,21 @@ class _Resp:
def json(self):
import json
+
return json.loads(self.text)
+
def _mk_client_capture(seen, resp_status=200, resp_headers=None, resp_body=b'{"ok":true}'):
class _Client:
def __init__(self, timeout=None, limits=None, http2=False):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -49,20 +55,33 @@ def _mk_client_capture(seen, resp_status=200, resp_headers=None, resp_body=b'{"o
return await self.put(url, **kwargs)
else:
return _Resp(405)
+
async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- seen.append({'url': url, 'params': dict(params or {}), 'headers': dict(headers or {}), 'json': json})
+ seen.append(
+ {
+ 'url': url,
+ 'params': dict(params or {}),
+ 'headers': dict(headers or {}),
+ 'json': json,
+ }
+ )
return _Resp(resp_status, body=resp_body, headers=resp_headers)
+
async def get(self, url, **kwargs):
seen.append({'url': url, 'params': {}, 'headers': {}})
return _Resp(resp_status, body=resp_body, headers=resp_headers)
+
async def put(self, url, **kwargs):
seen.append({'url': url, 'params': {}, 'headers': {}})
return _Resp(resp_status, body=resp_body, headers=resp_headers)
+
async def delete(self, url, **kwargs):
seen.append({'url': url, 'params': {}, 'headers': {}})
return _Resp(resp_status, body=resp_body, headers=resp_headers)
+
return _Client
+
async def _setup_api(client, name, ver, allowed_headers=None):
payload = {
'api_name': name,
@@ -78,20 +97,28 @@ async def _setup_api(client, name, ver, allowed_headers=None):
payload['api_allowed_headers'] = allowed_headers
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/p',
- 'endpoint_description': 'p'
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/p',
+ 'endpoint_description': 'p',
+ },
+ )
assert r2.status_code in (200, 201)
from conftest import subscribe_self
+
await subscribe_self(client, name, ver)
+
@pytest.mark.asyncio
-async def test_header_allowlist_forwards_only_allowed_headers_case_insensitive(monkeypatch, authed_client):
+async def test_header_allowlist_forwards_only_allowed_headers_case_insensitive(
+ monkeypatch, authed_client
+):
import services.gateway_service as gs
+
name, ver = 'hdrallow', 'v1'
await _setup_api(authed_client, name, ver, allowed_headers=['X-Custom', 'Content-Type'])
seen = []
@@ -99,7 +126,7 @@ async def test_header_allowlist_forwards_only_allowed_headers_case_insensitive(m
r = await authed_client.post(
f'/api/rest/{name}/{ver}/p?foo=bar',
headers={'x-custom': 'abc', 'X-Blocked': 'nope', 'Content-Type': 'application/json'},
- json={'a': 1}
+ json={'a': 1},
)
assert r.status_code == 200
assert len(seen) == 1
@@ -108,9 +135,11 @@ async def test_header_allowlist_forwards_only_allowed_headers_case_insensitive(m
assert keys_lower.get('x-custom') == 'abc'
assert 'x-blocked' not in keys_lower
+
@pytest.mark.asyncio
async def test_header_block_non_allowlisted_headers(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'hdrblock', 'v1'
await _setup_api(authed_client, name, ver, allowed_headers=['Content-Type'])
seen = []
@@ -118,59 +147,75 @@ async def test_header_block_non_allowlisted_headers(monkeypatch, authed_client):
r = await authed_client.post(
f'/api/rest/{name}/{ver}/p',
headers={'X-NotAllowed': '123', 'Content-Type': 'application/json'},
- json={'a': 1}
+ json={'a': 1},
)
assert r.status_code == 200
forwarded = seen[0]['headers']
assert 'X-NotAllowed' not in forwarded and 'x-notallowed' not in {k.lower() for k in forwarded}
+
def test_response_parse_application_json():
import services.gateway_service as gs
+
body = b'{"x": 1}'
resp = _Resp(headers={'Content-Type': 'application/json'}, body=body)
out = gs.GatewayService.parse_response(resp)
assert isinstance(out, dict) and out.get('x') == 1
+
def test_response_parse_text_plain_fallback():
import services.gateway_service as gs
+
body = b'hello world'
resp = _Resp(headers={'Content-Type': 'text/plain'}, body=body)
out = gs.GatewayService.parse_response(resp)
assert out == body
+
def test_response_parse_application_xml():
import services.gateway_service as gs
+
body = b'1'
resp = _Resp(headers={'Content-Type': 'application/xml'}, body=body)
out = gs.GatewayService.parse_response(resp)
from xml.etree.ElementTree import Element
+
assert isinstance(out, Element) and out.tag == 'root'
+
def test_response_parse_malformed_json_as_text():
import services.gateway_service as gs
+
body = b'{"x": 1'
resp = _Resp(headers={'Content-Type': 'text/plain'}, body=body)
out = gs.GatewayService.parse_response(resp)
assert out == body
+
def test_response_binary_passthrough_no_decode():
import services.gateway_service as gs
- binary = b'\x00\xFF\x10\x80'
+
+ binary = b'\x00\xff\x10\x80'
resp = _Resp(headers={'Content-Type': 'application/octet-stream'}, body=binary)
out = gs.GatewayService.parse_response(resp)
assert out == binary
+
def test_response_malformed_json_with_application_json_raises():
import services.gateway_service as gs
+
body = b'{"x": 1'
resp = _Resp(headers={'Content-Type': 'application/json'}, body=body)
import pytest
+
with pytest.raises(Exception):
gs.GatewayService.parse_response(resp)
+
@pytest.mark.asyncio
async def test_rest_gateway_returns_500_on_malformed_json_upstream(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'jsonfail', 'v1'
await _setup_api(authed_client, name, ver)
@@ -182,14 +227,22 @@ async def test_rest_gateway_returns_500_on_malformed_json_upstream(monkeypatch,
self.headers = {'Content-Type': 'application/json'}
self.content = bad_body
self.text = bad_body.decode('utf-8', errors='ignore')
+
def json(self):
import json
+
return json.loads(self.text)
class _Client2:
- def __init__(self, *a, **k): pass
- async def __aenter__(self): return self
- async def __aexit__(self, exc_type, exc, tb): return False
+ def __init__(self, *a, **k):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -207,16 +260,28 @@ async def test_rest_gateway_returns_500_on_malformed_json_upstream(monkeypatch,
return await self.put(url, **kwargs)
else:
return _Resp2()
- async def get(self, url, params=None, headers=None, **kwargs): return _Resp2()
- async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs): return _Resp2()
- async def head(self, url, params=None, headers=None, **kwargs): return _Resp2()
- async def put(self, url, **kwargs): return _Resp2()
- async def delete(self, url, **kwargs): return _Resp2()
+
+ async def get(self, url, params=None, headers=None, **kwargs):
+ return _Resp2()
+
+ async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs):
+ return _Resp2()
+
+ async def head(self, url, params=None, headers=None, **kwargs):
+ return _Resp2()
+
+ async def put(self, url, **kwargs):
+ return _Resp2()
+
+ async def delete(self, url, **kwargs):
+ return _Resp2()
monkeypatch.setattr(gs.httpx, 'AsyncClient', _Client2)
- r = await authed_client.post(f'/api/rest/{name}/{ver}/p', headers={'Content-Type': 'application/json'}, json={'k': 'v'})
+ r = await authed_client.post(
+ f'/api/rest/{name}/{ver}/p', headers={'Content-Type': 'application/json'}, json={'k': 'v'}
+ )
assert r.status_code == 500
body = r.json()
payload = body.get('response', body)
- assert (payload.get('error_code') or payload.get('error_message'))
+ assert payload.get('error_code') or payload.get('error_message')
diff --git a/backend-services/tests/test_rest_methods_and_405.py b/backend-services/tests/test_rest_methods_and_405.py
index 99634d4..2c3092c 100644
--- a/backend-services/tests/test_rest_methods_and_405.py
+++ b/backend-services/tests/test_rest_methods_and_405.py
@@ -1,5 +1,6 @@
import pytest
+
class _FakeHTTPResponse:
def __init__(self, status_code=200, json_body=None, text_body=None, headers=None):
self.status_code = status_code
@@ -12,10 +13,12 @@ class _FakeHTTPResponse:
def json(self):
import json as _json
+
if self._json_body is None:
return _json.loads(self.text or '{}')
return self._json_body
+
class _FakeAsyncClient:
def __init__(self, *args, **kwargs):
pass
@@ -45,63 +48,114 @@ class _FakeAsyncClient:
return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'})
async def get(self, url, params=None, headers=None, **kwargs):
- return _FakeHTTPResponse(200, json_body={'method': 'GET', 'url': url, 'params': dict(params or {}), 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ return _FakeHTTPResponse(
+ 200,
+ json_body={
+ 'method': 'GET',
+ 'url': url,
+ 'params': dict(params or {}),
+ 'headers': headers or {},
+ },
+ headers={'X-Upstream': 'yes'},
+ )
async def post(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
- return _FakeHTTPResponse(200, json_body={'method': 'POST', 'url': url, 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ body = (
+ json
+ if json is not None
+ else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ )
+ return _FakeHTTPResponse(
+ 200,
+ json_body={'method': 'POST', 'url': url, 'body': body, 'headers': headers or {}},
+ headers={'X-Upstream': 'yes'},
+ )
async def put(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
- return _FakeHTTPResponse(200, json_body={'method': 'PUT', 'url': url, 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ body = (
+ json
+ if json is not None
+ else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ )
+ return _FakeHTTPResponse(
+ 200,
+ json_body={'method': 'PUT', 'url': url, 'body': body, 'headers': headers or {}},
+ headers={'X-Upstream': 'yes'},
+ )
async def delete(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- return _FakeHTTPResponse(200, json_body={'method': 'DELETE', 'url': url, 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ return _FakeHTTPResponse(
+ 200,
+ json_body={'method': 'DELETE', 'url': url, 'headers': headers or {}},
+ headers={'X-Upstream': 'yes'},
+ )
async def patch(self, url, json=None, params=None, headers=None, content=None, **kwargs):
- body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
- return _FakeHTTPResponse(200, json_body={'method': 'PATCH', 'url': url, 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
+ body = (
+ json
+ if json is not None
+ else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
+ )
+ return _FakeHTTPResponse(
+ 200,
+ json_body={'method': 'PATCH', 'url': url, 'body': body, 'headers': headers or {}},
+ headers={'X-Upstream': 'yes'},
+ )
async def head(self, url, params=None, headers=None, **kwargs):
return _FakeHTTPResponse(200, json_body=None, headers={'X-Upstream': 'yes'})
+
async def _setup_api(client, name, ver, endpoint_method='GET', endpoint_uri='/p'):
- r = await client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://up.methods'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
+ r = await client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up.methods'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': endpoint_method,
- 'endpoint_uri': endpoint_uri,
- 'endpoint_description': endpoint_method.lower(),
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': endpoint_method,
+ 'endpoint_uri': endpoint_uri,
+ 'endpoint_description': endpoint_method.lower(),
+ },
+ )
assert r2.status_code in (200, 201)
rme = await client.get('/platform/user/me')
- username = (rme.json().get('username') if rme.status_code == 200 else 'admin')
- sr = await client.post('/platform/subscription/subscribe', json={'username': username, 'api_name': name, 'api_version': ver})
+ username = rme.json().get('username') if rme.status_code == 200 else 'admin'
+ sr = await client.post(
+ '/platform/subscription/subscribe',
+ json={'username': username, 'api_name': name, 'api_version': ver},
+ )
assert sr.status_code in (200, 201)
+
@pytest.mark.asyncio
async def test_rest_head_supported_when_upstream_allows(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'headok', 'v1'
await _setup_api(authed_client, name, ver, endpoint_method='GET', endpoint_uri='/p')
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.request('HEAD', f'/api/rest/{name}/{ver}/p')
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_rest_patch_supported_when_registered(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'patchok', 'v1'
await _setup_api(authed_client, name, ver, endpoint_method='PATCH', endpoint_uri='/edit')
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
@@ -110,24 +164,29 @@ async def test_rest_patch_supported_when_registered(monkeypatch, authed_client):
j = r.json().get('response', r.json())
assert j.get('method') == 'PATCH'
+
@pytest.mark.asyncio
async def test_rest_options_unregistered_endpoint_returns_405(monkeypatch, authed_client):
monkeypatch.setenv('STRICT_OPTIONS_405', 'true')
name, ver = 'optunreg', 'v1'
- r = await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://up.methods'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
+ r = await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up.methods'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
assert r.status_code in (200, 201)
resp = await authed_client.options(f'/api/rest/{name}/{ver}/not-made')
assert resp.status_code == 405
+
@pytest.mark.asyncio
async def test_rest_unsupported_method_returns_405(authed_client):
name, ver = 'unsup', 'v1'
diff --git a/backend-services/tests/test_rest_preflight_positive.py b/backend-services/tests/test_rest_preflight_positive.py
new file mode 100644
index 0000000..065ce67
--- /dev/null
+++ b/backend-services/tests/test_rest_preflight_positive.py
@@ -0,0 +1,54 @@
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_rest_preflight_positive_allows(authed_client):
+ name, ver = 'restpos', 'v1'
+ c = await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'REST preflight positive',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up.invalid'],
+ 'api_type': 'REST',
+ 'api_cors_allow_origins': ['http://ok.example'],
+ 'api_cors_allow_methods': ['GET'],
+ 'api_cors_allow_headers': ['Content-Type'],
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ assert c.status_code in (200, 201)
+ ce = await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/p',
+ 'endpoint_description': 'p',
+ },
+ )
+ assert ce.status_code in (200, 201)
+
+ r = await authed_client.options(
+ f'/api/rest/{name}/{ver}/p',
+ headers={
+ 'Origin': 'http://ok.example',
+ 'Access-Control-Request-Method': 'GET',
+ 'Access-Control-Request-Headers': 'Content-Type',
+ },
+ )
+ assert r.status_code == 204
+ acao = r.headers.get('Access-Control-Allow-Origin') or r.headers.get(
+ 'access-control-allow-origin'
+ )
+ assert acao == 'http://ok.example'
+ ach = (
+ r.headers.get('Access-Control-Allow-Headers')
+ or r.headers.get('access-control-allow-headers')
+ or ''
+ )
+ assert 'Content-Type' in ach
diff --git a/backend-services/tests/test_role_admin_edge_cases.py b/backend-services/tests/test_role_admin_edge_cases.py
new file mode 100644
index 0000000..c7855a8
--- /dev/null
+++ b/backend-services/tests/test_role_admin_edge_cases.py
@@ -0,0 +1,45 @@
+import time
+
+import pytest
+from httpx import AsyncClient
+
+
+async def _login(email: str, password: str) -> AsyncClient:
+ from doorman import doorman
+
+ c = AsyncClient(app=doorman, base_url='http://testserver')
+ r = await c.post('/platform/authorization', json={'email': email, 'password': password})
+ assert r.status_code == 200
+ body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
+ token = body.get('access_token')
+ if token:
+ c.cookies.set('access_token_cookie', token, domain='testserver', path='/')
+ return c
+
+
+@pytest.mark.asyncio
+async def test_non_admin_cannot_create_admin_role(authed_client):
+ # Create a user with manage_roles but not admin
+ rname = f'mgr_roles_{int(time.time())}'
+ cr = await authed_client.post('/platform/role', json={'role_name': rname, 'manage_roles': True})
+ assert cr.status_code in (200, 201)
+ uname = f'role_mgr_{int(time.time())}'
+ pwd = 'RoleMgrStrongPass1!!'
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': f'{uname}@example.com',
+ 'password': pwd,
+ 'role': rname,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu.status_code in (200, 201)
+ mgr = await _login(f'{uname}@example.com', pwd)
+ # Attempt to create 'admin' role (should be forbidden for non-admin)
+ admin_create = await mgr.post(
+ '/platform/role', json={'role_name': 'admin', 'manage_roles': True}
+ )
+ assert admin_create.status_code == 403
diff --git a/backend-services/tests/test_role_update.py b/backend-services/tests/test_role_update.py
index 982c23f..aa7adb1 100644
--- a/backend-services/tests/test_role_update.py
+++ b/backend-services/tests/test_role_update.py
@@ -2,35 +2,35 @@
"""Quick test to validate role model"""
from models.update_role_model import UpdateRoleModel
-import json
# Test data with new permissions
test_data = {
- "role_name": "admin",
- "role_description": "Administrator role",
- "manage_users": True,
- "manage_apis": True,
- "manage_endpoints": True,
- "manage_groups": True,
- "manage_roles": True,
- "manage_routings": True,
- "manage_gateway": True,
- "manage_subscriptions": True,
- "manage_security": True,
- "manage_tiers": True,
- "manage_rate_limits": True,
- "manage_credits": True,
- "manage_auth": True,
- "view_analytics": True,
- "view_logs": True,
- "export_logs": True
+ 'role_name': 'admin',
+ 'role_description': 'Administrator role',
+ 'manage_users': True,
+ 'manage_apis': True,
+ 'manage_endpoints': True,
+ 'manage_groups': True,
+ 'manage_roles': True,
+ 'manage_routings': True,
+ 'manage_gateway': True,
+ 'manage_subscriptions': True,
+ 'manage_security': True,
+ 'manage_tiers': True,
+ 'manage_rate_limits': True,
+ 'manage_credits': True,
+ 'manage_auth': True,
+ 'view_analytics': True,
+ 'view_logs': True,
+ 'export_logs': True,
}
try:
model = UpdateRoleModel(**test_data)
- print("✅ Model validation successful!")
- print(f"Model: {model.model_dump()}")
+ print('✅ Model validation successful!')
+ print(f'Model: {model.model_dump()}')
except Exception as e:
- print(f"❌ Model validation failed: {e}")
+ print(f'❌ Model validation failed: {e}')
import traceback
+
traceback.print_exc()
diff --git a/backend-services/tests/test_roles_groups_subscriptions.py b/backend-services/tests/test_roles_groups_subscriptions.py
index 9318e4d..afc6608 100644
--- a/backend-services/tests/test_roles_groups_subscriptions.py
+++ b/backend-services/tests/test_roles_groups_subscriptions.py
@@ -1,8 +1,8 @@
import pytest
+
@pytest.mark.asyncio
async def test_roles_crud(authed_client):
-
r = await authed_client.post(
'/platform/role',
json={
@@ -35,9 +35,9 @@ async def test_roles_crud(authed_client):
d = await authed_client.delete('/platform/role/qa')
assert d.status_code == 200
+
@pytest.mark.asyncio
async def test_groups_crud(authed_client):
-
cg = await authed_client.post(
'/platform/group',
json={'group_name': 'qa-group', 'group_description': 'QA', 'api_access': []},
@@ -58,9 +58,9 @@ async def test_groups_crud(authed_client):
dg = await authed_client.delete('/platform/group/qa-group')
assert dg.status_code == 200
+
@pytest.mark.asyncio
async def test_subscriptions_flow(authed_client):
-
api_payload = {
'api_name': 'orders',
'api_version': 'v1',
@@ -92,9 +92,9 @@ async def test_subscriptions_flow(authed_client):
)
assert us.status_code in (200, 400)
+
@pytest.mark.asyncio
async def test_token_defs_and_deduction_on_gateway(monkeypatch, authed_client):
-
credit_group = 'ai-group'
cd = await authed_client.post(
'/platform/credit',
@@ -103,7 +103,13 @@ async def test_token_defs_and_deduction_on_gateway(monkeypatch, authed_client):
'api_key': 'sk-test-123',
'api_key_header': 'x-api-key',
'credit_tiers': [
- {'tier_name': 'basic', 'credits': 100, 'input_limit': 150, 'output_limit': 150, 'reset_frequency': 'monthly'}
+ {
+ 'tier_name': 'basic',
+ 'credits': 100,
+ 'input_limit': 150,
+ 'output_limit': 150,
+ 'reset_frequency': 'monthly',
+ }
],
},
)
@@ -144,12 +150,10 @@ async def test_token_defs_and_deduction_on_gateway(monkeypatch, authed_client):
assert s.status_code in (200, 201)
uc = await authed_client.post(
- f'/platform/credit/admin',
+ '/platform/credit/admin',
json={
'username': 'admin',
- 'users_credits': {
- credit_group: {'tier_name': 'basic', 'available_credits': 2}
- },
+ 'users_credits': {credit_group: {'tier_name': 'basic', 'available_credits': 2}},
},
)
assert uc.status_code in (200, 201), uc.text
@@ -158,7 +162,9 @@ async def test_token_defs_and_deduction_on_gateway(monkeypatch, authed_client):
r = await authed_client.get('/platform/credit/admin')
assert r.status_code == 200, r.text
body = r.json()
- users_credits = body.get('users_credits') or body.get('response', {}).get('users_credits', {})
+ users_credits = body.get('users_credits') or body.get('response', {}).get(
+ 'users_credits', {}
+ )
return int(users_credits.get(credit_group, {}).get('available_credits', 0))
import services.gateway_service as gs
@@ -177,6 +183,7 @@ async def test_token_defs_and_deduction_on_gateway(monkeypatch, authed_client):
class _FakeAsyncClient:
def __init__(self, timeout=None, limits=None, http2=False):
self._timeout = timeout
+
async def __aenter__(self):
return self
diff --git a/backend-services/tests/test_routing_crud_permissions.py b/backend-services/tests/test_routing_crud_permissions.py
new file mode 100644
index 0000000..51eb7d2
--- /dev/null
+++ b/backend-services/tests/test_routing_crud_permissions.py
@@ -0,0 +1,81 @@
+import time
+
+import pytest
+from httpx import AsyncClient
+
+
+async def _login(email: str, password: str) -> AsyncClient:
+ from doorman import doorman
+
+ c = AsyncClient(app=doorman, base_url='http://testserver')
+ r = await c.post('/platform/authorization', json={'email': email, 'password': password})
+ assert r.status_code == 200
+ body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
+ token = body.get('access_token')
+ if token:
+ c.cookies.set('access_token_cookie', token, domain='testserver', path='/')
+ return c
+
+
+@pytest.mark.asyncio
+async def test_routing_crud_requires_manage_routings(authed_client):
+ client_key = f'client_{int(time.time())}'
+ # Limited user
+ uname = f'route_limited_{int(time.time())}'
+ pwd = 'RouteLimitStrong1!!'
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': f'{uname}@example.com',
+ 'password': pwd,
+ 'role': 'user',
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu.status_code in (200, 201)
+ limited = await _login(f'{uname}@example.com', pwd)
+ # Create denied
+ cr = await limited.post(
+ '/platform/routing',
+ json={'client_key': client_key, 'routing_name': 'r1', 'routing_servers': ['http://up']},
+ )
+ assert cr.status_code == 403
+
+ # Grant manage_routings
+ rname = f'route_mgr_{int(time.time())}'
+ rr = await authed_client.post(
+ '/platform/role', json={'role_name': rname, 'manage_routings': True}
+ )
+ assert rr.status_code in (200, 201)
+ uname2 = f'route_mgr_user_{int(time.time())}'
+ cu2 = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname2,
+ 'email': f'{uname2}@example.com',
+ 'password': pwd,
+ 'role': rname,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu2.status_code in (200, 201)
+ mgr = await _login(f'{uname2}@example.com', pwd)
+
+ # Create allowed
+ ok = await mgr.post(
+ '/platform/routing',
+ json={'client_key': client_key, 'routing_name': 'r1', 'routing_servers': ['http://up']},
+ )
+ assert ok.status_code in (200, 201)
+ # Get allowed
+ gl = await mgr.get('/platform/routing/all')
+ assert gl.status_code == 200
+ # Update allowed
+ up = await mgr.put(f'/platform/routing/{client_key}', json={'routing_description': 'updated'})
+ assert up.status_code == 200
+ # Delete allowed
+ de = await mgr.delete(f'/platform/routing/{client_key}')
+ assert de.status_code == 200
diff --git a/backend-services/tests/test_routing_precedence_and_round_robin.py b/backend-services/tests/test_routing_precedence_and_round_robin.py
index 6531940..98c47c7 100644
--- a/backend-services/tests/test_routing_precedence_and_round_robin.py
+++ b/backend-services/tests/test_routing_precedence_and_round_robin.py
@@ -1,31 +1,38 @@
import pytest
+
@pytest.mark.asyncio
async def test_routing_endpoint_servers_take_precedence_over_api_servers(authed_client):
- from utils.database import api_collection, endpoint_collection
- from utils.doorman_cache_util import doorman_cache
from utils import routing_util
+ from utils.database import api_collection
+ from utils.doorman_cache_util import doorman_cache
name, ver = 'route1', 'v1'
- r = await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://api1', 'http://api2'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
+ r = await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://api1', 'http://api2'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
assert r.status_code in (200, 201)
- r2 = await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/echo',
- 'endpoint_description': 'echo',
- 'endpoint_servers': ['http://ep1', 'http://ep2']
- })
+ r2 = await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/echo',
+ 'endpoint_description': 'echo',
+ 'endpoint_servers': ['http://ep1', 'http://ep2'],
+ },
+ )
assert r2.status_code in (200, 201)
api = api_collection.find_one({'api_name': name, 'api_version': ver})
@@ -36,39 +43,44 @@ async def test_routing_endpoint_servers_take_precedence_over_api_servers(authed_
picked = await routing_util.pick_upstream_server(api, 'POST', '/echo', client_key=None)
assert picked == 'http://ep1'
+
@pytest.mark.asyncio
async def test_routing_client_specific_routing_over_endpoint_and_api(authed_client):
+ from utils import routing_util
from utils.database import api_collection, routing_collection
from utils.doorman_cache_util import doorman_cache
- from utils import routing_util
name, ver = 'route2', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://api1', 'http://api2'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/echo',
- 'endpoint_description': 'echo',
- 'endpoint_servers': ['http://ep1', 'http://ep2']
- })
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://api1', 'http://api2'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/echo',
+ 'endpoint_description': 'echo',
+ 'endpoint_servers': ['http://ep1', 'http://ep2'],
+ },
+ )
api = api_collection.find_one({'api_name': name, 'api_version': ver})
api.pop('_id', None)
- routing_collection.insert_one({
- 'client_key': 'ck1',
- 'routing_servers': ['http://r1', 'http://r2'],
- 'server_index': 0,
- })
+ routing_collection.insert_one(
+ {'client_key': 'ck1', 'routing_servers': ['http://r1', 'http://r2'], 'server_index': 0}
+ )
doorman_cache.clear_cache('client_routing_cache')
doorman_cache.clear_cache('endpoint_server_cache')
@@ -77,30 +89,37 @@ async def test_routing_client_specific_routing_over_endpoint_and_api(authed_clie
assert s1 == 'http://r1'
assert s2 == 'http://r2'
+
@pytest.mark.asyncio
async def test_routing_round_robin_api_servers_rotates(authed_client):
+ from utils import routing_util
from utils.database import api_collection
from utils.doorman_cache_util import doorman_cache
- from utils import routing_util
name, ver = 'route3', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://a1', 'http://a2'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/status',
- 'endpoint_description': 'status'
- })
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://a1', 'http://a2'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/status',
+ 'endpoint_description': 'status',
+ },
+ )
api = api_collection.find_one({'api_name': name, 'api_version': ver})
api.pop('_id', None)
@@ -110,31 +129,38 @@ async def test_routing_round_robin_api_servers_rotates(authed_client):
s3 = await routing_util.pick_upstream_server(api, 'GET', '/status', client_key=None)
assert [s1, s2, s3] == ['http://a1', 'http://a2', 'http://a1']
+
@pytest.mark.asyncio
async def test_routing_round_robin_endpoint_servers_rotates(authed_client):
+ from utils import routing_util
from utils.database import api_collection
from utils.doorman_cache_util import doorman_cache
- from utils import routing_util
name, ver = 'route4', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://a1', 'http://a2'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/echo',
- 'endpoint_description': 'echo',
- 'endpoint_servers': ['http://e1', 'http://e2']
- })
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://a1', 'http://a2'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/echo',
+ 'endpoint_description': 'echo',
+ 'endpoint_servers': ['http://e1', 'http://e2'],
+ },
+ )
api = api_collection.find_one({'api_name': name, 'api_version': ver})
api.pop('_id', None)
@@ -144,30 +170,37 @@ async def test_routing_round_robin_endpoint_servers_rotates(authed_client):
s3 = await routing_util.pick_upstream_server(api, 'POST', '/echo', client_key=None)
assert [s1, s2, s3] == ['http://e1', 'http://e2', 'http://e1']
+
@pytest.mark.asyncio
async def test_routing_round_robin_index_persists_in_cache(authed_client):
+ from utils import routing_util
from utils.database import api_collection
from utils.doorman_cache_util import doorman_cache
- from utils import routing_util
name, ver = 'route5', 'v1'
- await authed_client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://a1', 'http://a2', 'http://a3'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
- await authed_client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'GET',
- 'endpoint_uri': '/status',
- 'endpoint_description': 'status'
- })
+ await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://a1', 'http://a2', 'http://a3'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
+ await authed_client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'GET',
+ 'endpoint_uri': '/status',
+ 'endpoint_description': 'status',
+ },
+ )
api = api_collection.find_one({'api_name': name, 'api_version': ver})
api.pop('_id', None)
@@ -181,4 +214,3 @@ async def test_routing_round_robin_index_persists_in_cache(authed_client):
assert s3 == 'http://a3'
idx_after = doorman_cache.get_cache('endpoint_server_cache', api['api_id'])
assert idx_after == 0
-
diff --git a/backend-services/tests/test_security.py b/backend-services/tests/test_security.py
index 70b2345..6771b03 100644
--- a/backend-services/tests/test_security.py
+++ b/backend-services/tests/test_security.py
@@ -1,25 +1,24 @@
-import os
import pytest
from jose import jwt
+
@pytest.mark.asyncio
async def test_jwt_tamper_rejected(client):
-
bad_token = jwt.encode({'sub': 'admin'}, 'wrong-secret', algorithm='HS256')
client.cookies.set('access_token_cookie', bad_token)
r = await client.get('/platform/user/me')
assert r.status_code in (401, 500)
+
@pytest.mark.asyncio
async def test_csrf_required_when_https(monkeypatch, authed_client):
-
monkeypatch.setenv('HTTPS_ONLY', 'true')
r = await authed_client.get('/platform/user/me')
assert r.status_code in (401, 500)
+
@pytest.mark.asyncio
async def test_header_injection_is_sanitized(monkeypatch, authed_client):
-
api_name, version = 'hdr', 'v1'
c = await authed_client.post(
'/platform/api',
@@ -54,6 +53,7 @@ async def test_header_injection_is_sanitized(monkeypatch, authed_client):
assert s.status_code in (200, 201)
import services.gateway_service as gs
+
captured = {}
class _FakeHTTPResponse:
@@ -63,14 +63,17 @@ async def test_header_injection_is_sanitized(monkeypatch, authed_client):
self.headers = {'Content-Type': 'application/json'}
self.content = b'{}'
self.text = '{}'
+
def json(self):
return self._json_body
class _FakeAsyncClient:
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -88,61 +91,82 @@ async def test_header_injection_is_sanitized(monkeypatch, authed_client):
return await self.put(url, **kwargs)
else:
return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'})
+
async def get(self, url, params=None, headers=None, **kwargs):
captured['headers'] = headers or {}
return _FakeHTTPResponse(200)
+
async def post(self, url, **kwargs):
return _FakeHTTPResponse(200)
+
async def put(self, url, **kwargs):
return _FakeHTTPResponse(200)
+
async def delete(self, url, **kwargs):
return _FakeHTTPResponse(200)
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
import routes.gateway_routes as gr
- async def _no_limit(req): return None
- async def _pass_sub(req): return {'sub': 'admin'}
- async def _pass_group(req, full_path: str = None, user_to_subscribe=None): return {'sub': 'admin'}
+
+ async def _no_limit(req):
+ return None
+
+ async def _pass_sub(req):
+ return {'sub': 'admin'}
+
+ async def _pass_group(req, full_path: str = None, user_to_subscribe=None):
+ return {'sub': 'admin'}
+
monkeypatch.setattr(gr, 'limit_and_throttle', _no_limit)
monkeypatch.setattr(gr, 'subscription_required', _pass_sub)
monkeypatch.setattr(gr, 'group_required', _pass_group)
inj_value = 'abc\r\nInjected: 1'
- r = await authed_client.get(
- '/api/rest/hdr/v1/x',
- headers={'X-Allowed': inj_value},
- )
+ r = await authed_client.get('/api/rest/hdr/v1/x', headers={'X-Allowed': inj_value})
assert r.status_code in (200, 500)
forwarded = captured.get('headers', {}).get('X-Allowed', '')
assert '\r' not in forwarded and '\n' not in forwarded
assert '<' not in forwarded and '>' not in forwarded
+
@pytest.mark.asyncio
async def test_rate_limit_enforced(monkeypatch, authed_client):
-
from utils.database import user_collection
+
user_collection.update_one(
{'username': 'admin'},
- {'$set': {'rate_limit_duration': 1, 'rate_limit_duration_type': 'second',
- 'throttle_duration': 1, 'throttle_duration_type': 'second',
- 'throttle_queue_limit': 1, 'throttle_wait_duration': 0.0, 'throttle_wait_duration_type': 'second'}}
+ {
+ '$set': {
+ 'rate_limit_duration': 1,
+ 'rate_limit_duration_type': 'second',
+ 'throttle_duration': 1,
+ 'throttle_duration_type': 'second',
+ 'throttle_queue_limit': 1,
+ 'throttle_wait_duration': 0.0,
+ 'throttle_wait_duration_type': 'second',
+ }
+ },
)
class _FakeRedis:
def __init__(self):
self.store = {}
+
async def incr(self, key):
self.store[key] = self.store.get(key, 0) + 1
return self.store[key]
+
async def expire(self, key, ttl):
return True
from doorman import doorman as app
+
app.state.redis = _FakeRedis()
from utils.doorman_cache_util import doorman_cache
+
try:
doorman_cache.delete_cache('user_cache', 'admin')
except Exception:
@@ -180,6 +204,7 @@ async def test_rate_limit_enforced(monkeypatch, authed_client):
assert s.status_code in (200, 201)
import services.gateway_service as gs
+
class _FakeHTTPResponse:
def __init__(self, status_code=200, json_body=None):
self.status_code = status_code
@@ -187,11 +212,17 @@ async def test_rate_limit_enforced(monkeypatch, authed_client):
self.headers = {'Content-Type': 'application/json'}
self.content = b'{}'
self.text = '{}'
+
def json(self):
return self._json_body
+
class _FakeAsyncClient:
- async def __aenter__(self): return self
- async def __aexit__(self, exc_type, exc, tb): return False
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -209,15 +240,29 @@ async def test_rate_limit_enforced(monkeypatch, authed_client):
return await self.put(url, **kwargs)
else:
return _FakeHTTPResponse(405, json_body={'error': 'Method not allowed'})
- async def get(self, url, params=None, headers=None, **kwargs): return _FakeHTTPResponse(200)
- async def post(self, url, **kwargs): return _FakeHTTPResponse(200)
- async def put(self, url, **kwargs): return _FakeHTTPResponse(200)
- async def delete(self, url, **kwargs): return _FakeHTTPResponse(200)
+
+ async def get(self, url, params=None, headers=None, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def post(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def put(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
+
+ async def delete(self, url, **kwargs):
+ return _FakeHTTPResponse(200)
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
import routes.gateway_routes as gr
- async def _pass_sub(req): return {'sub': 'admin'}
- async def _pass_group(req, full_path: str = None, user_to_subscribe=None): return {'sub': 'admin'}
+
+ async def _pass_sub(req):
+ return {'sub': 'admin'}
+
+ async def _pass_group(req, full_path: str = None, user_to_subscribe=None):
+ return {'sub': 'admin'}
+
monkeypatch.setattr(gr, 'subscription_required', _pass_sub)
monkeypatch.setattr(gr, 'group_required', _pass_group)
diff --git a/backend-services/tests/test_security_and_metrics.py b/backend-services/tests/test_security_and_metrics.py
index 5f3bfaf..2dc9e90 100644
--- a/backend-services/tests/test_security_and_metrics.py
+++ b/backend-services/tests/test_security_and_metrics.py
@@ -1,11 +1,10 @@
import os
-import json
-import time
+
import pytest
+
@pytest.mark.asyncio
async def test_security_headers_and_hsts(monkeypatch, client):
-
r = await client.get('/platform/monitor/liveness')
assert r.status_code == 200
assert r.headers.get('X-Content-Type-Options') == 'nosniff'
@@ -18,13 +17,17 @@ async def test_security_headers_and_hsts(monkeypatch, client):
assert r.status_code == 200
assert 'Strict-Transport-Security' in r.headers
+
@pytest.mark.asyncio
async def test_body_size_limit_returns_413(monkeypatch, client):
monkeypatch.setenv('MAX_BODY_SIZE_BYTES', '10')
payload = 'x' * 100
- r = await client.post('/platform/authorization', content=payload, headers={'Content-Type': 'text/plain'})
+ r = await client.post(
+ '/platform/authorization', content=payload, headers={'Content-Type': 'text/plain'}
+ )
assert r.status_code == 413
+
@pytest.mark.asyncio
async def test_strict_response_envelope(monkeypatch, authed_client):
monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true')
@@ -36,10 +39,11 @@ async def test_strict_response_envelope(monkeypatch, authed_client):
assert isinstance(data, dict)
assert 'status_code' in data and data['status_code'] == 200
+
@pytest.mark.asyncio
async def test_metrics_recording_snapshot(authed_client):
-
from conftest import create_api, create_endpoint, subscribe_self
+
name, ver = 'metapi', 'v1'
await create_api(authed_client, name, ver)
await create_endpoint(authed_client, name, ver, 'GET', '/status')
@@ -56,19 +60,22 @@ async def test_metrics_recording_snapshot(authed_client):
assert isinstance(series, list)
assert body.get('response', {}).get('total_requests') or body.get('total_requests') >= 1
+
@pytest.mark.asyncio
async def test_cors_strict_allows_localhost(monkeypatch, client):
-
monkeypatch.setenv('CORS_STRICT', 'true')
monkeypatch.setenv('ALLOWED_ORIGINS', '*')
monkeypatch.setenv('ALLOW_CREDENTIALS', 'true')
r = await client.get('/platform/monitor/liveness', headers={'Origin': 'http://localhost:3000'})
- assert r.headers.get('access-control-allow-origin') in ('http://localhost:3000', 'http://localhost')
+ assert r.headers.get('access-control-allow-origin') in (
+ 'http://localhost:3000',
+ 'http://localhost',
+ )
+
@pytest.mark.asyncio
async def test_csp_header_default_and_override(monkeypatch, client):
-
monkeypatch.delenv('CONTENT_SECURITY_POLICY', raising=False)
r = await client.get('/platform/monitor/liveness')
assert r.status_code == 200
@@ -79,9 +86,9 @@ async def test_csp_header_default_and_override(monkeypatch, client):
r2 = await client.get('/platform/monitor/liveness')
assert r2.headers.get('Content-Security-Policy') == "default-src 'self'"
+
@pytest.mark.asyncio
async def test_request_id_header_generation_and_echo(client):
-
r = await client.get('/platform/monitor/liveness')
assert r.status_code == 200
assert r.headers.get('X-Request-ID')
@@ -92,9 +99,9 @@ async def test_request_id_header_generation_and_echo(client):
assert r2.headers.get('X-Request-ID') == incoming
assert r2.headers.get('request_id') == incoming
+
@pytest.mark.asyncio
async def test_memory_dump_and_restore(tmp_path, monkeypatch):
-
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'unit-test-secret')
@@ -102,8 +109,12 @@ async def test_memory_dump_and_restore(tmp_path, monkeypatch):
dump_dir.mkdir(parents=True, exist_ok=True)
monkeypatch.setenv('MEM_DUMP_PATH', str(dump_dir / 'memory_dump.bin'))
- from utils.memory_dump_util import dump_memory_to_file, find_latest_dump_path, restore_memory_from_file
from utils.database import database
+ from utils.memory_dump_util import (
+ dump_memory_to_file,
+ find_latest_dump_path,
+ restore_memory_from_file,
+ )
database.db.users.insert_one({'username': 'tmp', 'email': 't@t.t', 'password': 'x'})
path = dump_memory_to_file(None)
@@ -113,5 +124,5 @@ async def test_memory_dump_and_restore(tmp_path, monkeypatch):
database.db.users._docs.clear()
assert database.db.users.count_documents({}) == 0
- info = restore_memory_from_file(latest)
+ restore_memory_from_file(latest)
assert database.db.users.count_documents({}) >= 1
diff --git a/backend-services/tests/test_security_permissions.py b/backend-services/tests/test_security_permissions.py
index f78d124..85797af 100644
--- a/backend-services/tests/test_security_permissions.py
+++ b/backend-services/tests/test_security_permissions.py
@@ -1,8 +1,8 @@
import pytest
+
@pytest.mark.asyncio
async def test_security_settings_requires_permission(authed_client):
-
r = await authed_client.put('/platform/role/admin', json={'manage_security': False})
assert r.status_code in (200, 201)
@@ -17,4 +17,3 @@ async def test_security_settings_requires_permission(authed_client):
gs2 = await authed_client.get('/platform/security/settings')
assert gs2.status_code == 200
-
diff --git a/backend-services/tests/test_security_settings_permissions.py b/backend-services/tests/test_security_settings_permissions.py
new file mode 100644
index 0000000..4c7cdc4
--- /dev/null
+++ b/backend-services/tests/test_security_settings_permissions.py
@@ -0,0 +1,65 @@
+import time
+
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_security_settings_get_put_permissions(authed_client):
+ # Limited user: 403 on get/put
+ uname = f'sec_limited_{int(time.time())}'
+ pwd = 'SecLimitStrongPass1!!'
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': f'{uname}@example.com',
+ 'password': pwd,
+ 'role': 'user',
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu.status_code in (200, 201)
+
+ from httpx import AsyncClient
+
+ from doorman import doorman
+
+ limited = AsyncClient(app=doorman, base_url='http://testserver')
+ r = await limited.post(
+ '/platform/authorization', json={'email': f'{uname}@example.com', 'password': pwd}
+ )
+ assert r.status_code == 200
+ get403 = await limited.get('/platform/security/settings')
+ assert get403.status_code == 403
+ put403 = await limited.put('/platform/security/settings', json={'trust_x_forwarded_for': True})
+ assert put403.status_code == 403
+
+ # Role manage_security: 200 on get/put
+ rname = f'sec_mgr_{int(time.time())}'
+ cr = await authed_client.post(
+ '/platform/role', json={'role_name': rname, 'manage_security': True}
+ )
+ assert cr.status_code in (200, 201)
+ uname2 = f'sec_mgr_user_{int(time.time())}'
+ cu2 = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname2,
+ 'email': f'{uname2}@example.com',
+ 'password': pwd,
+ 'role': rname,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu2.status_code in (200, 201)
+ mgr = AsyncClient(app=doorman, base_url='http://testserver')
+ r2 = await mgr.post(
+ '/platform/authorization', json={'email': f'{uname2}@example.com', 'password': pwd}
+ )
+ assert r2.status_code == 200
+ g = await mgr.get('/platform/security/settings')
+ assert g.status_code == 200
+ u = await mgr.put('/platform/security/settings', json={'trust_x_forwarded_for': True})
+ assert u.status_code == 200
diff --git a/backend-services/tests/test_security_settings_persistence.py b/backend-services/tests/test_security_settings_persistence.py
index 4c2bcb1..884a838 100644
--- a/backend-services/tests/test_security_settings_persistence.py
+++ b/backend-services/tests/test_security_settings_persistence.py
@@ -1,15 +1,17 @@
-import json
import asyncio
+import json
+
import pytest
+
@pytest.mark.asyncio
async def test_load_settings_from_file_in_memory_mode(tmp_path, monkeypatch):
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
from utils import security_settings_util as ssu
+
settings_path = tmp_path / 'sec_settings.json'
monkeypatch.setattr(ssu, 'SETTINGS_FILE', str(settings_path), raising=False)
- from utils.security_settings_util import load_settings, get_cached_settings
- from utils.security_settings_util import _get_collection
+ from utils.security_settings_util import _get_collection, get_cached_settings, load_settings
coll = _get_collection()
try:
@@ -41,6 +43,7 @@ async def test_load_settings_from_file_in_memory_mode(tmp_path, monkeypatch):
assert cached.get('ip_whitelist') == ['203.0.113.1']
assert cached.get('allow_localhost_bypass') is True
+
@pytest.mark.asyncio
async def test_save_settings_writes_file_and_autosave_triggers_dump(tmp_path, monkeypatch):
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
@@ -50,20 +53,24 @@ async def test_save_settings_writes_file_and_autosave_triggers_dump(tmp_path, mo
monkeypatch.setattr(ssu, 'SETTINGS_FILE', str(sec_file), raising=False)
calls = {'count': 0, 'last_path': None}
+
def _fake_dump(path_hint):
calls['count'] += 1
calls['last_path'] = path_hint
return str(tmp_path / 'dump.bin')
+
monkeypatch.setattr(ssu, 'dump_memory_to_file', _fake_dump)
await ssu.start_auto_save_task()
prev_task = getattr(ssu, '_AUTO_TASK', None)
- result = await ssu.save_settings({
- 'enable_auto_save': True,
- 'auto_save_frequency_seconds': 90,
- 'dump_path': str(tmp_path / 'wanted_dump.bin'),
- })
+ result = await ssu.save_settings(
+ {
+ 'enable_auto_save': True,
+ 'auto_save_frequency_seconds': 90,
+ 'dump_path': str(tmp_path / 'wanted_dump.bin'),
+ }
+ )
await asyncio.sleep(0)
diff --git a/backend-services/tests/test_soap_gateway_content_types.py b/backend-services/tests/test_soap_gateway_content_types.py
index 731626f..f3e1092 100644
--- a/backend-services/tests/test_soap_gateway_content_types.py
+++ b/backend-services/tests/test_soap_gateway_content_types.py
@@ -1,5 +1,6 @@
import pytest
+
class _FakeXMLResponse:
def __init__(self, status_code=200, text='', headers=None):
self.status_code = status_code
@@ -10,14 +11,18 @@ class _FakeXMLResponse:
self.headers = base
self.content = self.text.encode('utf-8')
+
def _mk_xml_client(captured):
class _FakeXMLClient:
def __init__(self, *args, **kwargs):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -35,45 +40,62 @@ def _mk_xml_client(captured):
return await self.put(url, **kwargs)
else:
return _FakeXMLResponse(405, 'Method not allowed')
+
async def get(self, url, **kwargs):
return _FakeXMLResponse(200, '', {'X-Upstream': 'yes', 'Content-Type': 'text/xml'})
+
async def post(self, url, content=None, params=None, headers=None, **kwargs):
captured.append({'url': url, 'headers': dict(headers or {}), 'content': content})
return _FakeXMLResponse(200, '', {'X-Upstream': 'yes', 'Content-Type': 'text/xml'})
+
async def put(self, url, **kwargs):
return _FakeXMLResponse(200, '', {'X-Upstream': 'yes', 'Content-Type': 'text/xml'})
+
async def delete(self, url, **kwargs):
return _FakeXMLResponse(200, '', {'X-Upstream': 'yes', 'Content-Type': 'text/xml'})
+
return _FakeXMLClient
+
async def _setup_api(client, name, ver):
- r = await client.post('/platform/api', json={
- 'api_name': name,
- 'api_version': ver,
- 'api_description': f'{name} {ver}',
- 'api_allowed_roles': ['admin'],
- 'api_allowed_groups': ['ALL'],
- 'api_servers': ['http://soap.up'],
- 'api_type': 'REST',
- 'api_allowed_retry_count': 0,
- })
+ r = await client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': f'{name} {ver}',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://soap.up'],
+ 'api_type': 'REST',
+ 'api_allowed_retry_count': 0,
+ },
+ )
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/call',
- 'endpoint_description': 'soap call',
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/call',
+ 'endpoint_description': 'soap call',
+ },
+ )
assert r2.status_code in (200, 201)
rme = await client.get('/platform/user/me')
- username = (rme.json().get('username') if rme.status_code == 200 else 'admin')
- rs = await client.post('/platform/subscription/subscribe', json={'username': username, 'api_name': name, 'api_version': ver})
+ username = rme.json().get('username') if rme.status_code == 200 else 'admin'
+ rs = await client.post(
+ '/platform/subscription/subscribe',
+ json={'username': username, 'api_name': name, 'api_version': ver},
+ )
assert rs.status_code in (200, 201)
+
@pytest.mark.asyncio
async def test_soap_incoming_application_xml_sets_text_xml_outgoing(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'soapct1', 'v1'
await _setup_api(authed_client, name, ver)
captured = []
@@ -89,26 +111,28 @@ async def test_soap_incoming_application_xml_sets_text_xml_outgoing(monkeypatch,
h = captured[0]['headers']
assert h.get('Content-Type') == 'text/xml; charset=utf-8'
+
@pytest.mark.asyncio
async def test_soap_incoming_text_xml_passes_through(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'soapct2', 'v1'
await _setup_api(authed_client, name, ver)
captured = []
monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_xml_client(captured))
envelope = ''
r = await authed_client.post(
- f'/api/soap/{name}/{ver}/call',
- headers={'Content-Type': 'text/xml'},
- content=envelope,
+ f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'text/xml'}, content=envelope
)
assert r.status_code == 200
h = captured[0]['headers']
assert h.get('Content-Type') == 'text/xml'
+
@pytest.mark.asyncio
async def test_soap_incoming_application_soap_xml_passes_through(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'soapct3', 'v1'
await _setup_api(authed_client, name, ver)
captured = []
@@ -123,9 +147,11 @@ async def test_soap_incoming_application_soap_xml_passes_through(monkeypatch, au
h = captured[0]['headers']
assert h.get('Content-Type') == 'application/soap+xml'
+
@pytest.mark.asyncio
async def test_soap_adds_default_soapaction_when_missing(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'soapct4', 'v1'
await _setup_api(authed_client, name, ver)
captured = []
@@ -140,9 +166,11 @@ async def test_soap_adds_default_soapaction_when_missing(monkeypatch, authed_cli
h = captured[0]['headers']
assert 'SOAPAction' in h and h['SOAPAction'] == '""'
+
@pytest.mark.asyncio
async def test_soap_parses_xml_response_success(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'soapct5', 'v1'
await _setup_api(authed_client, name, ver)
captured = []
diff --git a/backend-services/tests/test_soap_gateway_retries.py b/backend-services/tests/test_soap_gateway_retries.py
index 0cd1d6e..b7d33f8 100644
--- a/backend-services/tests/test_soap_gateway_retries.py
+++ b/backend-services/tests/test_soap_gateway_retries.py
@@ -1,5 +1,6 @@
import pytest
+
class _Resp:
def __init__(self, status_code=200, body='', headers=None):
self.status_code = status_code
@@ -10,16 +11,20 @@ class _Resp:
self.headers = base
self.content = (self.text or '').encode('utf-8')
+
def _mk_retry_xml_client(sequence, seen):
counter = {'i': 0}
class _Client:
def __init__(self, timeout=None, limits=None, http2=False):
pass
+
async def __aenter__(self):
return self
+
async def __aexit__(self, exc_type, exc, tb):
return False
+
async def request(self, method, url, **kwargs):
"""Generic request method used by http_client.request_with_resilience"""
method = method.upper()
@@ -37,32 +42,45 @@ def _mk_retry_xml_client(sequence, seen):
return await self.put(url, **kwargs)
else:
return _Resp(405)
+
async def post(self, url, content=None, params=None, headers=None, **kwargs):
- seen.append({'url': url, 'params': dict(params or {}), 'headers': dict(headers or {}), 'content': content})
+ seen.append(
+ {
+ 'url': url,
+ 'params': dict(params or {}),
+ 'headers': dict(headers or {}),
+ 'content': content,
+ }
+ )
idx = min(counter['i'], len(sequence) - 1)
code = sequence[idx]
counter['i'] = counter['i'] + 1
return _Resp(code)
+
async def get(self, url, **kwargs):
seen.append({'url': url, 'params': {}, 'headers': {}})
idx = min(counter['i'], len(sequence) - 1)
code = sequence[idx]
counter['i'] = counter['i'] + 1
return _Resp(code)
+
async def put(self, url, **kwargs):
seen.append({'url': url, 'params': {}, 'headers': {}})
idx = min(counter['i'], len(sequence) - 1)
code = sequence[idx]
counter['i'] = counter['i'] + 1
return _Resp(code)
+
async def delete(self, url, **kwargs):
seen.append({'url': url, 'params': {}, 'headers': {}})
idx = min(counter['i'], len(sequence) - 1)
code = sequence[idx]
counter['i'] = counter['i'] + 1
return _Resp(code)
+
return _Client
+
async def _setup_soap(client, name, ver, retry_count=0):
payload = {
'api_name': name,
@@ -76,79 +94,102 @@ async def _setup_soap(client, name, ver, retry_count=0):
}
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
- r2 = await client.post('/platform/endpoint', json={
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': 'POST',
- 'endpoint_uri': '/call',
- 'endpoint_description': 'soap call',
- })
+ r2 = await client.post(
+ '/platform/endpoint',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': 'POST',
+ 'endpoint_uri': '/call',
+ 'endpoint_description': 'soap call',
+ },
+ )
assert r2.status_code in (200, 201)
from conftest import subscribe_self
+
await subscribe_self(client, name, ver)
+
@pytest.mark.asyncio
async def test_soap_retry_on_500_then_success(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'soapretry500', 'v1'
await _setup_soap(authed_client, name, ver, retry_count=2)
seen = []
monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_retry_xml_client([500, 200], seen))
r = await authed_client.post(
- f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'application/xml'}, content=''
+ f'/api/soap/{name}/{ver}/call',
+ headers={'Content-Type': 'application/xml'},
+ content='',
)
assert r.status_code == 200
assert len(seen) == 2
+
@pytest.mark.asyncio
async def test_soap_retry_on_502_then_success(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'soapretry502', 'v1'
await _setup_soap(authed_client, name, ver, retry_count=2)
seen = []
monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_retry_xml_client([502, 200], seen))
r = await authed_client.post(
- f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'application/xml'}, content=''
+ f'/api/soap/{name}/{ver}/call',
+ headers={'Content-Type': 'application/xml'},
+ content='',
)
assert r.status_code == 200
assert len(seen) == 2
+
@pytest.mark.asyncio
async def test_soap_retry_on_503_then_success(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'soapretry503', 'v1'
await _setup_soap(authed_client, name, ver, retry_count=2)
seen = []
monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_retry_xml_client([503, 200], seen))
r = await authed_client.post(
- f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'application/xml'}, content=''
+ f'/api/soap/{name}/{ver}/call',
+ headers={'Content-Type': 'application/xml'},
+ content='',
)
assert r.status_code == 200
assert len(seen) == 2
+
@pytest.mark.asyncio
async def test_soap_retry_on_504_then_success(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'soapretry504', 'v1'
await _setup_soap(authed_client, name, ver, retry_count=2)
seen = []
monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_retry_xml_client([504, 200], seen))
r = await authed_client.post(
- f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'application/xml'}, content=''
+ f'/api/soap/{name}/{ver}/call',
+ headers={'Content-Type': 'application/xml'},
+ content='',
)
assert r.status_code == 200
assert len(seen) == 2
+
@pytest.mark.asyncio
async def test_soap_no_retry_when_retry_count_zero(monkeypatch, authed_client):
import services.gateway_service as gs
+
name, ver = 'soapretry0', 'v1'
await _setup_soap(authed_client, name, ver, retry_count=0)
seen = []
monkeypatch.setattr(gs.httpx, 'AsyncClient', _mk_retry_xml_client([500, 200], seen))
r = await authed_client.post(
- f'/api/soap/{name}/{ver}/call', headers={'Content-Type': 'application/xml'}, content=''
+ f'/api/soap/{name}/{ver}/call',
+ headers={'Content-Type': 'application/xml'},
+ content='',
)
assert r.status_code == 500
assert len(seen) == 1
-
diff --git a/backend-services/tests/test_soap_validation_no_wsdl.py b/backend-services/tests/test_soap_validation_no_wsdl.py
index 81a7035..65f9669 100644
--- a/backend-services/tests/test_soap_validation_no_wsdl.py
+++ b/backend-services/tests/test_soap_validation_no_wsdl.py
@@ -1,6 +1,7 @@
import pytest
from fastapi import HTTPException
+
@pytest.mark.asyncio
async def test_soap_structural_validation_passes_without_wsdl():
from utils.database import endpoint_validation_collection
@@ -8,38 +9,32 @@ async def test_soap_structural_validation_passes_without_wsdl():
endpoint_id = 'soap-ep-struct-1'
endpoint_validation_collection.delete_one({'endpoint_id': endpoint_id})
- endpoint_validation_collection.insert_one({
- 'endpoint_id': endpoint_id,
- 'validation_enabled': True,
- 'validation_schema': {
- 'username': {
- 'required': True,
- 'type': 'string',
- 'min': 3,
- 'max': 50,
- },
- 'email': {
- 'required': True,
- 'type': 'string',
- 'format': 'email',
+ endpoint_validation_collection.insert_one(
+ {
+ 'endpoint_id': endpoint_id,
+ 'validation_enabled': True,
+ 'validation_schema': {
+ 'username': {'required': True, 'type': 'string', 'min': 3, 'max': 50},
+ 'email': {'required': True, 'type': 'string', 'format': 'email'},
},
}
- })
+ )
envelope = (
""
""
- " "
- " "
- " alice"
- " alice@example.com"
- " "
- " "
- ""
+ ' '
+ ' '
+ ' alice'
+ ' alice@example.com'
+ ' '
+ ' '
+ ''
)
await validation_util.validate_soap_request(endpoint_id, envelope)
+
@pytest.mark.asyncio
async def test_soap_structural_validation_fails_without_wsdl():
from utils.database import endpoint_validation_collection
@@ -47,30 +42,25 @@ async def test_soap_structural_validation_fails_without_wsdl():
endpoint_id = 'soap-ep-struct-2'
endpoint_validation_collection.delete_one({'endpoint_id': endpoint_id})
- endpoint_validation_collection.insert_one({
- 'endpoint_id': endpoint_id,
- 'validation_enabled': True,
- 'validation_schema': {
- 'username': {
- 'required': True,
- 'type': 'string',
- 'min': 3,
- }
+ endpoint_validation_collection.insert_one(
+ {
+ 'endpoint_id': endpoint_id,
+ 'validation_enabled': True,
+ 'validation_schema': {'username': {'required': True, 'type': 'string', 'min': 3}},
}
- })
+ )
bad_envelope = (
""
""
- " "
- " "
- " no-user@example.com"
- " "
- " "
- ""
+ ' '
+ ' '
+ ' no-user@example.com'
+ ' '
+ ' '
+ ''
)
with pytest.raises(HTTPException) as ex:
await validation_util.validate_soap_request(endpoint_id, bad_envelope)
assert ex.value.status_code == 400
-
diff --git a/backend-services/tests/test_subscription_flows.py b/backend-services/tests/test_subscription_flows.py
new file mode 100644
index 0000000..667526c
--- /dev/null
+++ b/backend-services/tests/test_subscription_flows.py
@@ -0,0 +1,38 @@
+import time
+
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_subscription_lifecycle(authed_client):
+ name, ver = f'subapi_{int(time.time())}', 'v1'
+ # Create API and endpoint
+ c = await authed_client.post(
+ '/platform/api',
+ json={
+ 'api_name': name,
+ 'api_version': ver,
+ 'api_description': 'sub api',
+ 'api_allowed_roles': ['admin'],
+ 'api_allowed_groups': ['ALL'],
+ 'api_servers': ['http://up.invalid'],
+ 'api_type': 'REST',
+ 'active': True,
+ },
+ )
+ assert c.status_code in (200, 201)
+ # Subscribe current user (admin)
+ s = await authed_client.post(
+ '/platform/subscription/subscribe',
+ json={'api_name': name, 'api_version': ver, 'username': 'admin'},
+ )
+ assert s.status_code in (200, 201)
+ # List current user subscriptions
+ ls = await authed_client.get('/platform/subscription/subscriptions')
+ assert ls.status_code == 200
+ # Unsubscribe
+ u = await authed_client.post(
+ '/platform/subscription/unsubscribe',
+ json={'api_name': name, 'api_version': ver, 'username': 'admin'},
+ )
+ assert u.status_code == 200
diff --git a/backend-services/tests/test_subscription_required_parsing.py b/backend-services/tests/test_subscription_required_parsing.py
index abaab0d..99f14db 100644
--- a/backend-services/tests/test_subscription_required_parsing.py
+++ b/backend-services/tests/test_subscription_required_parsing.py
@@ -1,7 +1,9 @@
import pytest
+
def _make_request(path: str, headers: dict | None = None):
from starlette.requests import Request
+
hdrs = []
for k, v in (headers or {}).items():
hdrs.append((k.lower().encode('latin-1'), str(v).encode('latin-1')))
@@ -17,65 +19,108 @@ def _make_request(path: str, headers: dict | None = None):
}
return Request(scope)
+
@pytest.mark.asyncio
async def test_subscription_required_rest_path_parsing(monkeypatch):
import utils.subscription_util as su
async def fake_auth(req):
return {'sub': 'alice'}
+
monkeypatch.setattr(su, 'auth_required', fake_auth)
- monkeypatch.setattr(su.doorman_cache, 'get_cache', lambda name, key: {'apis': ['svc1/v1']} if (name, key) == ('user_subscription_cache', 'alice') else None)
+ monkeypatch.setattr(
+ su.doorman_cache,
+ 'get_cache',
+ lambda name, key: {'apis': ['svc1/v1']}
+ if (name, key) == ('user_subscription_cache', 'alice')
+ else None,
+ )
req = _make_request('/api/rest/svc1/v1/resource')
payload = await su.subscription_required(req)
assert payload.get('sub') == 'alice'
+
@pytest.mark.asyncio
async def test_subscription_required_soap_path_parsing(monkeypatch):
import utils.subscription_util as su
+
async def fake_auth(req):
return {'sub': 'alice'}
+
monkeypatch.setattr(su, 'auth_required', fake_auth)
- monkeypatch.setattr(su.doorman_cache, 'get_cache', lambda name, key: {'apis': ['svc2/v2']} if (name, key) == ('user_subscription_cache', 'alice') else None)
+ monkeypatch.setattr(
+ su.doorman_cache,
+ 'get_cache',
+ lambda name, key: {'apis': ['svc2/v2']}
+ if (name, key) == ('user_subscription_cache', 'alice')
+ else None,
+ )
req = _make_request('/api/soap/svc2/v2/do')
payload = await su.subscription_required(req)
assert payload.get('sub') == 'alice'
+
@pytest.mark.asyncio
async def test_subscription_required_graphql_header_parsing(monkeypatch):
import utils.subscription_util as su
+
async def fake_auth(req):
return {'sub': 'alice'}
+
monkeypatch.setattr(su, 'auth_required', fake_auth)
- monkeypatch.setattr(su.doorman_cache, 'get_cache', lambda name, key: {'apis': ['svc3/v3']} if (name, key) == ('user_subscription_cache', 'alice') else None)
+ monkeypatch.setattr(
+ su.doorman_cache,
+ 'get_cache',
+ lambda name, key: {'apis': ['svc3/v3']}
+ if (name, key) == ('user_subscription_cache', 'alice')
+ else None,
+ )
req = _make_request('/api/graphql/svc3', headers={'X-API-Version': 'v3'})
payload = await su.subscription_required(req)
assert payload.get('sub') == 'alice'
+
@pytest.mark.asyncio
async def test_subscription_required_grpc_path_parsing(monkeypatch):
import utils.subscription_util as su
+
async def fake_auth(req):
return {'sub': 'alice'}
+
monkeypatch.setattr(su, 'auth_required', fake_auth)
- monkeypatch.setattr(su.doorman_cache, 'get_cache', lambda name, key: {'apis': ['svc4/v4']} if (name, key) == ('user_subscription_cache', 'alice') else None)
+ monkeypatch.setattr(
+ su.doorman_cache,
+ 'get_cache',
+ lambda name, key: {'apis': ['svc4/v4']}
+ if (name, key) == ('user_subscription_cache', 'alice')
+ else None,
+ )
req = _make_request('/api/grpc/svc4', headers={'X-API-Version': 'v4'})
payload = await su.subscription_required(req)
assert payload.get('sub') == 'alice'
+
@pytest.mark.asyncio
async def test_subscription_required_unknown_prefix_fallback(monkeypatch):
import utils.subscription_util as su
+
async def fake_auth(req):
return {'sub': 'alice'}
+
monkeypatch.setenv('PYTHONASYNCIODEBUG', '0')
monkeypatch.setattr(su, 'auth_required', fake_auth)
- monkeypatch.setattr(su.doorman_cache, 'get_cache', lambda name, key: {'apis': ['svc5/v5']} if (name, key) == ('user_subscription_cache', 'alice') else None)
+ monkeypatch.setattr(
+ su.doorman_cache,
+ 'get_cache',
+ lambda name, key: {'apis': ['svc5/v5']}
+ if (name, key) == ('user_subscription_cache', 'alice')
+ else None,
+ )
req = _make_request('/api/other/svc5/v5/op')
payload = await su.subscription_required(req)
assert payload.get('sub') == 'alice'
-
diff --git a/backend-services/tests/test_subscription_routes_extended.py b/backend-services/tests/test_subscription_routes_extended.py
index 1a1b1f3..7fbc709 100644
--- a/backend-services/tests/test_subscription_routes_extended.py
+++ b/backend-services/tests/test_subscription_routes_extended.py
@@ -1,8 +1,8 @@
import pytest
+
@pytest.mark.asyncio
async def test_subscriptions_happy_and_invalid_payload(authed_client):
-
c = await authed_client.post(
'/platform/api',
json={
@@ -35,4 +35,3 @@ async def test_subscriptions_happy_and_invalid_payload(authed_client):
bad = await authed_client.post('/platform/subscription/subscribe', json={'username': 'admin'})
assert bad.status_code in (400, 422)
-
diff --git a/backend-services/tests/test_tools_chaos_toggles.py b/backend-services/tests/test_tools_chaos_toggles.py
new file mode 100644
index 0000000..cd9a4a0
--- /dev/null
+++ b/backend-services/tests/test_tools_chaos_toggles.py
@@ -0,0 +1,34 @@
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_tools_chaos_toggle_and_stats(authed_client):
+ # Baseline stats
+ st0 = await authed_client.get('/platform/tools/chaos/stats')
+ assert st0.status_code == 200
+
+ # Enable redis outage (response reflects enabled state)
+ r1 = await authed_client.post(
+ '/platform/tools/chaos/toggle', json={'backend': 'redis', 'enabled': True}
+ )
+ assert r1.status_code == 200
+ en = r1.json().get('response', r1.json())
+ assert en.get('enabled') is True
+
+ # Immediately disable to avoid auth failures on subsequent calls
+ # Disable using internal util to avoid auth during outage
+ from utils import chaos_util as _cu
+
+ _cu.enable('redis', False)
+
+ # Stats should reflect disabled
+ st2 = await authed_client.get('/platform/tools/chaos/stats')
+ assert st2.status_code == 200
+ body2 = st2.json().get('response', st2.json())
+ assert body2.get('redis_outage') is False
+
+ # Invalid backend -> 400
+ bad = await authed_client.post(
+ '/platform/tools/chaos/toggle', json={'backend': 'notabackend', 'enabled': True}
+ )
+ assert bad.status_code == 400
diff --git a/backend-services/tests/test_tools_cors_checker.py b/backend-services/tests/test_tools_cors_checker.py
index b9dee02..ac87d25 100644
--- a/backend-services/tests/test_tools_cors_checker.py
+++ b/backend-services/tests/test_tools_cors_checker.py
@@ -1,8 +1,8 @@
import pytest
+
@pytest.mark.asyncio
async def test_cors_checker_allows_matching_origin(monkeypatch, authed_client):
-
monkeypatch.setenv('ALLOWED_ORIGINS', 'http://localhost:3000')
monkeypatch.setenv('ALLOW_METHODS', 'GET,POST')
monkeypatch.setenv('ALLOW_HEADERS', 'Content-Type,X-CSRF-Token')
@@ -13,14 +13,18 @@ async def test_cors_checker_allows_matching_origin(monkeypatch, authed_client):
'origin': 'http://localhost:3000',
'method': 'GET',
'request_headers': ['Content-Type'],
- 'with_credentials': True
+ 'with_credentials': True,
}
r = await authed_client.post('/platform/tools/cors/check', json=body)
assert r.status_code == 200, r.text
data = r.json()
assert data.get('preflight', {}).get('allowed') is True
assert data.get('actual', {}).get('allowed') is True
- assert data.get('preflight', {}).get('response_headers', {}).get('Access-Control-Allow-Origin') == body['origin']
+ assert (
+ data.get('preflight', {}).get('response_headers', {}).get('Access-Control-Allow-Origin')
+ == body['origin']
+ )
+
@pytest.mark.asyncio
async def test_cors_checker_denies_disallowed_header(monkeypatch, authed_client):
@@ -34,7 +38,7 @@ async def test_cors_checker_denies_disallowed_header(monkeypatch, authed_client)
'origin': 'http://localhost:3000',
'method': 'GET',
'request_headers': ['X-Custom-Header'],
- 'with_credentials': True
+ 'with_credentials': True,
}
r = await authed_client.post('/platform/tools/cors/check', json=body)
assert r.status_code == 200, r.text
diff --git a/backend-services/tests/test_tools_cors_checker_edge_cases.py b/backend-services/tests/test_tools_cors_checker_edge_cases.py
index 5bfb5c8..e111260 100644
--- a/backend-services/tests/test_tools_cors_checker_edge_cases.py
+++ b/backend-services/tests/test_tools_cors_checker_edge_cases.py
@@ -1,8 +1,10 @@
import pytest
+
async def _allow_tools(client):
await client.put('/platform/user/admin', json={'manage_security': True})
+
@pytest.mark.asyncio
async def test_tools_cors_checker_allows_when_method_and_headers_match(monkeypatch, authed_client):
await _allow_tools(authed_client)
@@ -10,12 +12,18 @@ async def test_tools_cors_checker_allows_when_method_and_headers_match(monkeypat
monkeypatch.setenv('ALLOW_METHODS', 'GET,POST')
monkeypatch.setenv('ALLOW_HEADERS', 'Content-Type,X-CSRF-Token')
monkeypatch.setenv('ALLOW_CREDENTIALS', 'true')
- body = {'origin': 'http://ok.example', 'method': 'GET', 'request_headers': ['X-CSRF-Token'], 'with_credentials': True}
+ body = {
+ 'origin': 'http://ok.example',
+ 'method': 'GET',
+ 'request_headers': ['X-CSRF-Token'],
+ 'with_credentials': True,
+ }
r = await authed_client.post('/platform/tools/cors/check', json=body)
assert r.status_code == 200
data = r.json()
assert data.get('preflight', {}).get('allowed') is True
+
@pytest.mark.asyncio
async def test_tools_cors_checker_denies_when_method_not_allowed(monkeypatch, authed_client):
await _allow_tools(authed_client)
@@ -28,6 +36,7 @@ async def test_tools_cors_checker_denies_when_method_not_allowed(monkeypatch, au
assert data.get('preflight', {}).get('allowed') is False
assert data.get('preflight', {}).get('method_allowed') is False
+
@pytest.mark.asyncio
async def test_tools_cors_checker_denies_when_headers_not_allowed(monkeypatch, authed_client):
await _allow_tools(authed_client)
@@ -41,6 +50,7 @@ async def test_tools_cors_checker_denies_when_headers_not_allowed(monkeypatch, a
assert data.get('preflight', {}).get('allowed') is False
assert 'X-CSRF-Token' in (data.get('preflight', {}).get('not_allowed_headers') or [])
+
@pytest.mark.asyncio
async def test_tools_cors_checker_credentials_and_wildcard_interaction(monkeypatch, authed_client):
await _allow_tools(authed_client)
@@ -53,4 +63,3 @@ async def test_tools_cors_checker_credentials_and_wildcard_interaction(monkeypat
data = r.json()
assert data.get('preflight', {}).get('allow_origin') is True
assert any('Wildcard origins' in n for n in data.get('notes') or [])
-
diff --git a/backend-services/tests/test_tools_cors_checker_env_edges.py b/backend-services/tests/test_tools_cors_checker_env_edges.py
new file mode 100644
index 0000000..a8ac49e
--- /dev/null
+++ b/backend-services/tests/test_tools_cors_checker_env_edges.py
@@ -0,0 +1,51 @@
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_tools_cors_checker_strict_with_credentials_blocks_wildcard(
+ monkeypatch, authed_client
+):
+ monkeypatch.setenv('ALLOWED_ORIGINS', '*')
+ monkeypatch.setenv('ALLOW_CREDENTIALS', 'true')
+ monkeypatch.setenv('CORS_STRICT', 'true')
+
+ r = await authed_client.post(
+ '/platform/tools/cors/check',
+ json={
+ 'origin': 'http://evil.example',
+ 'with_credentials': True,
+ 'method': 'GET',
+ 'request_headers': ['Content-Type'],
+ },
+ )
+ assert r.status_code == 200
+ body = r.json().get('response', r.json())
+ pre = body.get('preflight', {})
+ hdrs = pre.get('response_headers', {})
+ assert hdrs.get('Access-Control-Allow-Origin') in (None, '')
+ assert hdrs.get('Access-Control-Allow-Credentials') == 'true'
+
+
+@pytest.mark.asyncio
+async def test_tools_cors_checker_non_strict_wildcard_with_credentials_allows(
+ monkeypatch, authed_client
+):
+ monkeypatch.setenv('ALLOWED_ORIGINS', '*')
+ monkeypatch.setenv('ALLOW_CREDENTIALS', 'true')
+ monkeypatch.setenv('CORS_STRICT', 'false')
+
+ r = await authed_client.post(
+ '/platform/tools/cors/check',
+ json={
+ 'origin': 'http://ok.example',
+ 'with_credentials': True,
+ 'method': 'GET',
+ 'request_headers': ['Content-Type'],
+ },
+ )
+ assert r.status_code == 200
+ body = r.json().get('response', r.json())
+ pre = body.get('preflight', {})
+ hdrs = pre.get('response_headers', {})
+ assert hdrs.get('Access-Control-Allow-Origin') == 'http://ok.example'
+ assert hdrs.get('Access-Control-Allow-Credentials') == 'true'
diff --git a/backend-services/tests/test_tools_cors_checker_permissions.py b/backend-services/tests/test_tools_cors_checker_permissions.py
new file mode 100644
index 0000000..c4ef991
--- /dev/null
+++ b/backend-services/tests/test_tools_cors_checker_permissions.py
@@ -0,0 +1,66 @@
+import time
+
+import pytest
+from httpx import AsyncClient
+
+
+async def _login(email: str, password: str) -> AsyncClient:
+ from doorman import doorman
+
+ c = AsyncClient(app=doorman, base_url='http://testserver')
+ r = await c.post('/platform/authorization', json={'email': email, 'password': password})
+ assert r.status_code == 200, r.text
+ body = r.json() if r.headers.get('content-type', '').startswith('application/json') else {}
+ token = body.get('access_token')
+ if token:
+ c.cookies.set('access_token_cookie', token, domain='testserver', path='/')
+ return c
+
+
+@pytest.mark.asyncio
+async def test_tools_cors_checker_requires_manage_security(authed_client):
+ uname = f'sec_check_{int(time.time())}'
+ pwd = 'SecCheckStrongPass!!'
+ # No manage_security
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': f'{uname}@example.com',
+ 'password': pwd,
+ 'role': 'user',
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu.status_code in (200, 201)
+ client = await _login(f'{uname}@example.com', pwd)
+ r = await client.post(
+ '/platform/tools/cors/check', json={'origin': 'http://x', 'method': 'GET'}
+ )
+ assert r.status_code == 403
+
+ # Grant manage_security via role and new user
+ role = f'sec_mgr_{int(time.time())}'
+ cr = await authed_client.post(
+ '/platform/role', json={'role_name': role, 'manage_security': True}
+ )
+ assert cr.status_code in (200, 201)
+ uname2 = f'sec_check2_{int(time.time())}'
+ cu2 = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname2,
+ 'email': f'{uname2}@example.com',
+ 'password': pwd,
+ 'role': role,
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu2.status_code in (200, 201)
+ client2 = await _login(f'{uname2}@example.com', pwd)
+ r2 = await client2.post(
+ '/platform/tools/cors/check', json={'origin': 'http://x', 'method': 'GET'}
+ )
+ assert r2.status_code == 200
diff --git a/backend-services/tests/test_user_endpoints.py b/backend-services/tests/test_user_endpoints.py
index 99bea55..3781e5a 100644
--- a/backend-services/tests/test_user_endpoints.py
+++ b/backend-services/tests/test_user_endpoints.py
@@ -1,8 +1,8 @@
import pytest
+
@pytest.mark.asyncio
async def test_user_me_and_crud(authed_client):
-
me = await authed_client.get('/platform/user/me')
assert me.status_code == 200
assert me.json().get('username') == 'admin'
@@ -25,8 +25,7 @@ async def test_user_me_and_crud(authed_client):
assert uu.status_code == 200
up = await authed_client.put(
- '/platform/user/testuser1/update-password',
- json={'new_password': 'ThisIsANewPwd!456'},
+ '/platform/user/testuser1/update-password', json={'new_password': 'ThisIsANewPwd!456'}
)
assert up.status_code == 200
diff --git a/backend-services/tests/test_user_permissions_negative.py b/backend-services/tests/test_user_permissions_negative.py
index 28fcb46..0118df1 100644
--- a/backend-services/tests/test_user_permissions_negative.py
+++ b/backend-services/tests/test_user_permissions_negative.py
@@ -1,32 +1,30 @@
import pytest
+
@pytest.mark.asyncio
async def test_update_other_user_denied_without_permission(authed_client):
-
await authed_client.post(
- '/platform/role',
- json={'role_name': 'user', 'role_description': 'Standard user'},
+ '/platform/role', json={'role_name': 'user', 'role_description': 'Standard user'}
)
cu = await authed_client.post(
'/platform/user',
- json={'username': 'qa_user', 'email': 'qa@doorman.dev', 'password': 'QaPass123_ValidLen!!', 'role': 'user'},
+ json={
+ 'username': 'qa_user',
+ 'email': 'qa@doorman.dev',
+ 'password': 'QaPass123_ValidLen!!',
+ 'role': 'user',
+ },
)
assert cu.status_code in (200, 201), cu.text
r = await authed_client.put('/platform/role/admin', json={'manage_users': False})
assert r.status_code in (200, 201)
- up = await authed_client.put(
- '/platform/user/qa_user',
- json={'email': 'qa2@doorman.dev'},
- )
+ up = await authed_client.put('/platform/user/qa_user', json={'email': 'qa2@doorman.dev'})
assert up.status_code == 403
r2 = await authed_client.put('/platform/role/admin', json={'manage_users': True})
assert r2.status_code in (200, 201)
- up2 = await authed_client.put(
- '/platform/user/qa_user',
- json={'email': 'qa3@doorman.dev'},
- )
+ up2 = await authed_client.put('/platform/user/qa_user', json={'email': 'qa3@doorman.dev'})
assert up2.status_code in (200, 201)
diff --git a/backend-services/tests/test_validation_audit.py b/backend-services/tests/test_validation_audit.py
index 68fb945..4d37c01 100644
--- a/backend-services/tests/test_validation_audit.py
+++ b/backend-services/tests/test_validation_audit.py
@@ -4,6 +4,7 @@ import pytest
from utils.database import endpoint_collection, endpoint_validation_collection
+
def _mk_endpoint(api_name: str, api_version: str, method: str, uri: str) -> dict:
eid = str(uuid.uuid4())
doc = {
@@ -19,6 +20,7 @@ def _mk_endpoint(api_name: str, api_version: str, method: str, uri: str) -> dict
endpoint_collection.insert_one(doc)
return doc
+
def _run_audit() -> list[str]:
failures: list[str] = []
for vdoc in endpoint_validation_collection.find({'validation_enabled': True}):
@@ -29,59 +31,50 @@ def _run_audit() -> list[str]:
continue
schema = vdoc.get('validation_schema')
if not isinstance(schema, dict) or not schema:
- failures.append(f'Enabled validation missing schema for endpoint {ep.get("endpoint_method")} {ep.get("api_name")}/{ep.get("api_version")} {ep.get("endpoint_uri")} (id={eid})')
+ failures.append(
+ f'Enabled validation missing schema for endpoint {ep.get("endpoint_method")} {ep.get("api_name")}/{ep.get("api_version")} {ep.get("endpoint_uri")} (id={eid})'
+ )
return failures
+
@pytest.mark.asyncio
async def test_validator_activation_audit_passes():
e_rest = _mk_endpoint('customers', 'v1', 'POST', '/create')
e_graphql = _mk_endpoint('graphqlsvc', 'v1', 'POST', '/graphql')
e_grpc = _mk_endpoint('grpcsvc', 'v1', 'POST', '/grpc')
- e_soap = _mk_endpoint('soapsvc', 'v1', 'POST', '/soap')
+ _mk_endpoint('soapsvc', 'v1', 'POST', '/soap')
- endpoint_validation_collection.insert_one({
- 'endpoint_id': e_rest['endpoint_id'],
- 'validation_enabled': True,
- 'validation_schema': {
- 'payload.name': {
- 'required': True,
- 'type': 'string',
- 'min': 1
- }
+ endpoint_validation_collection.insert_one(
+ {
+ 'endpoint_id': e_rest['endpoint_id'],
+ 'validation_enabled': True,
+ 'validation_schema': {'payload.name': {'required': True, 'type': 'string', 'min': 1}},
}
- })
- endpoint_validation_collection.insert_one({
- 'endpoint_id': e_graphql['endpoint_id'],
- 'validation_enabled': True,
- 'validation_schema': {
- 'input.query': {
- 'required': True,
- 'type': 'string',
- 'min': 1
- }
+ )
+ endpoint_validation_collection.insert_one(
+ {
+ 'endpoint_id': e_graphql['endpoint_id'],
+ 'validation_enabled': True,
+ 'validation_schema': {'input.query': {'required': True, 'type': 'string', 'min': 1}},
}
- })
- endpoint_validation_collection.insert_one({
- 'endpoint_id': e_grpc['endpoint_id'],
- 'validation_enabled': True,
- 'validation_schema': {
- 'message.name': {
- 'required': True,
- 'type': 'string',
- 'min': 1
- }
+ )
+ endpoint_validation_collection.insert_one(
+ {
+ 'endpoint_id': e_grpc['endpoint_id'],
+ 'validation_enabled': True,
+ 'validation_schema': {'message.name': {'required': True, 'type': 'string', 'min': 1}},
}
- })
+ )
failures = _run_audit()
assert not failures, '\n'.join(failures)
+
@pytest.mark.asyncio
async def test_validator_activation_audit_detects_missing_schema():
e = _mk_endpoint('soapsvc2', 'v1', 'POST', '/soap')
- endpoint_validation_collection.insert_one({
- 'endpoint_id': e['endpoint_id'],
- 'validation_enabled': True,
- })
+ endpoint_validation_collection.insert_one(
+ {'endpoint_id': e['endpoint_id'], 'validation_enabled': True}
+ )
failures = _run_audit()
assert failures and any('missing schema' in f for f in failures)
diff --git a/backend-services/tests/test_validation_nested_and_schema_errors.py b/backend-services/tests/test_validation_nested_and_schema_errors.py
index 7b432bd..638a314 100644
--- a/backend-services/tests/test_validation_nested_and_schema_errors.py
+++ b/backend-services/tests/test_validation_nested_and_schema_errors.py
@@ -1,9 +1,10 @@
import pytest
-
from tests.test_gateway_routing_limits import _FakeAsyncClient
+
async def _setup(client, api='vtest', ver='v1', method='POST', uri='/data'):
from conftest import create_api, create_endpoint, subscribe_self
+
await create_api(client, api, ver)
await create_endpoint(client, api, ver, method, uri)
await subscribe_self(client, api, ver)
@@ -12,11 +13,13 @@ async def _setup(client, api='vtest', ver='v1', method='POST', uri='/data'):
eid = g.json().get('endpoint_id') or g.json().get('response', {}).get('endpoint_id')
return api, ver, uri, eid
+
async def _apply_schema(client, eid, schema_dict):
payload = {'endpoint_id': eid, 'validation_enabled': True, 'validation_schema': schema_dict}
r = await client.post('/platform/endpoint/endpoint/validation', json=payload)
assert r.status_code in (200, 201, 400)
+
@pytest.mark.asyncio
async def test_validation_nested_object_paths_valid(monkeypatch, authed_client):
api, ver, uri, eid = await _setup(authed_client, api='vnest', uri='/ok1')
@@ -25,105 +28,142 @@ async def test_validation_nested_object_paths_valid(monkeypatch, authed_client):
'user': {
'required': True,
'type': 'object',
- 'nested_schema': {
- 'name': {'required': True, 'type': 'string', 'min': 2}
- }
+ 'nested_schema': {'name': {'required': True, 'type': 'string', 'min': 2}},
}
}
}
await _apply_schema(authed_client, eid, schema)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'user': {'name': 'John'}})
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_validation_array_items_valid(monkeypatch, authed_client):
api, ver, uri, eid = await _setup(authed_client, api='varr1', uri='/ok2')
schema = {
'validation_schema': {
- 'tags': {'required': True, 'type': 'array', 'min': 1, 'array_items': {'type': 'string', 'min': 2, 'required': True}}
+ 'tags': {
+ 'required': True,
+ 'type': 'array',
+ 'min': 1,
+ 'array_items': {'type': 'string', 'min': 2, 'required': True},
+ }
}
}
await _apply_schema(authed_client, eid, schema)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'tags': ['ab', 'cd']})
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_validation_array_items_invalid_type(monkeypatch, authed_client):
api, ver, uri, eid = await _setup(authed_client, api='varr2', uri='/bad1')
schema = {
'validation_schema': {
- 'tags': {'required': True, 'type': 'array', 'array_items': {'type': 'string', 'required': True}}
+ 'tags': {
+ 'required': True,
+ 'type': 'array',
+ 'array_items': {'type': 'string', 'required': True},
+ }
}
}
await _apply_schema(authed_client, eid, schema)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'tags': [1, 2]})
assert r.status_code == 400
+
@pytest.mark.asyncio
async def test_validation_required_field_missing(monkeypatch, authed_client):
api, ver, uri, eid = await _setup(authed_client, api='vreq', uri='/bad2')
schema = {'validation_schema': {'profile.age': {'required': True, 'type': 'number'}}}
await _apply_schema(authed_client, eid, schema)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'profile': {}})
assert r.status_code == 400
+
@pytest.mark.asyncio
async def test_validation_enum_restrictions(monkeypatch, authed_client):
api, ver, uri, eid = await _setup(authed_client, api='venum', uri='/bad3')
- schema = {'validation_schema': {'status': {'required': True, 'type': 'string', 'enum': ['NEW', 'OPEN']}}}
+ schema = {
+ 'validation_schema': {
+ 'status': {'required': True, 'type': 'string', 'enum': ['NEW', 'OPEN']}
+ }
+ }
await _apply_schema(authed_client, eid, schema)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
bad = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'status': 'CLOSED'})
assert bad.status_code == 400
ok = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'status': 'OPEN'})
assert ok.status_code == 200
+
@pytest.mark.asyncio
async def test_validation_custom_validator_success(monkeypatch, authed_client):
api, ver, uri, eid = await _setup(authed_client, api='vcust1', uri='/ok3')
- from utils.validation_util import validation_util, ValidationError
+ from utils.validation_util import ValidationError, validation_util
+
def is_upper(value, vdef):
if not isinstance(value, str) or not value.isupper():
raise ValidationError('Not upper', 'code')
+
validation_util.register_custom_validator('isUpper', is_upper)
- schema = {'validation_schema': {'code': {'required': True, 'type': 'string', 'custom_validator': 'isUpper'}}}
+ schema = {
+ 'validation_schema': {
+ 'code': {'required': True, 'type': 'string', 'custom_validator': 'isUpper'}
+ }
+ }
await _apply_schema(authed_client, eid, schema)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'code': 'ABC'})
assert r.status_code == 200
+
@pytest.mark.asyncio
async def test_validation_custom_validator_failure(monkeypatch, authed_client):
api, ver, uri, eid = await _setup(authed_client, api='vcust2', uri='/bad4')
- from utils.validation_util import validation_util, ValidationError
+ from utils.validation_util import ValidationError, validation_util
+
def is_upper(value, vdef):
if not isinstance(value, str) or not value.isupper():
raise ValidationError('Not upper', 'code')
+
validation_util.register_custom_validator('isUpper2', is_upper)
- schema = {'validation_schema': {'code': {'required': True, 'type': 'string', 'custom_validator': 'isUpper2'}}}
+ schema = {
+ 'validation_schema': {
+ 'code': {'required': True, 'type': 'string', 'custom_validator': 'isUpper2'}
+ }
+ }
await _apply_schema(authed_client, eid, schema)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'code': 'Abc'})
assert r.status_code == 400
+
@pytest.mark.asyncio
async def test_validation_invalid_field_path_raises_schema_error(monkeypatch, authed_client):
api, ver, uri, eid = await _setup(authed_client, api='vbadpath', uri='/bad5')
schema = {'validation_schema': {'user..name': {'required': True, 'type': 'string'}}}
await _apply_schema(authed_client, eid, schema)
import services.gateway_service as gs
+
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
r = await authed_client.post(f'/api/rest/{api}/{ver}{uri}', json={'user': {'name': 'ok'}})
assert r.status_code == 400
-
diff --git a/backend-services/tests/test_vault_routes_permissions.py b/backend-services/tests/test_vault_routes_permissions.py
new file mode 100644
index 0000000..008cd99
--- /dev/null
+++ b/backend-services/tests/test_vault_routes_permissions.py
@@ -0,0 +1,36 @@
+import time
+
+import pytest
+
+
+@pytest.mark.asyncio
+async def test_vault_routes_permissions(authed_client):
+ # Limited user
+ uname = f'vault_limited_{int(time.time())}'
+ pwd = 'VaultLimitStrongPass1!!'
+ cu = await authed_client.post(
+ '/platform/user',
+ json={
+ 'username': uname,
+ 'email': f'{uname}@example.com',
+ 'password': pwd,
+ 'role': 'user',
+ 'groups': ['ALL'],
+ 'ui_access': True,
+ },
+ )
+ assert cu.status_code in (200, 201)
+
+ from httpx import AsyncClient
+
+ from doorman import doorman
+
+ limited = AsyncClient(app=doorman, base_url='http://testserver')
+ r = await limited.post(
+ '/platform/authorization', json={'email': f'{uname}@example.com', 'password': pwd}
+ )
+ assert r.status_code == 200
+ # Vault operations require manage_security
+ sk = await limited.post('/platform/vault', json={'key_name': 'k', 'value': 'v'})
+ # Without manage_security and possibly without VAULT_KEY, expect 403 or 400/404/500 depending on env
+ assert sk.status_code in (403, 400, 404, 500)
diff --git a/backend-services/utils/analytics_aggregator.py b/backend-services/utils/analytics_aggregator.py
index 5e9b036..7d31883 100644
--- a/backend-services/utils/analytics_aggregator.py
+++ b/backend-services/utils/analytics_aggregator.py
@@ -6,144 +6,141 @@ for efficient historical queries without scanning all minute-level data.
"""
from __future__ import annotations
+
import time
from collections import defaultdict, deque
-from typing import Dict, List, Optional, Deque
-from dataclasses import dataclass
from models.analytics_models import (
- AggregationLevel,
AggregatedMetrics,
+ AggregationLevel,
+ EnhancedMinuteBucket,
PercentileMetrics,
- EnhancedMinuteBucket
)
class AnalyticsAggregator:
"""
Multi-level time-series aggregator.
-
+
Aggregates minute-level buckets into:
- 5-minute buckets (for 24-hour views)
- Hourly buckets (for 7-day views)
- Daily buckets (for 30-day+ views)
-
+
Retention policy:
- Minute-level: 24 hours
- 5-minute level: 7 days
- Hourly level: 30 days
- Daily level: 90 days
"""
-
+
def __init__(self):
- self.five_minute_buckets: Deque[AggregatedMetrics] = deque(maxlen=2016) # 7 days * 288 buckets/day
- self.hourly_buckets: Deque[AggregatedMetrics] = deque(maxlen=720) # 30 days * 24 hours
- self.daily_buckets: Deque[AggregatedMetrics] = deque(maxlen=90) # 90 days
-
+ self.five_minute_buckets: deque[AggregatedMetrics] = deque(
+ maxlen=2016
+ ) # 7 days * 288 buckets/day
+ self.hourly_buckets: deque[AggregatedMetrics] = deque(maxlen=720) # 30 days * 24 hours
+ self.daily_buckets: deque[AggregatedMetrics] = deque(maxlen=90) # 90 days
+
self._last_5min_aggregation = 0
self._last_hourly_aggregation = 0
self._last_daily_aggregation = 0
-
+
@staticmethod
def _floor_timestamp(ts: int, seconds: int) -> int:
"""Floor timestamp to nearest interval."""
return (ts // seconds) * seconds
-
- def aggregate_to_5minute(self, minute_buckets: List[EnhancedMinuteBucket]) -> None:
+
+ def aggregate_to_5minute(self, minute_buckets: list[EnhancedMinuteBucket]) -> None:
"""
Aggregate minute-level buckets into 5-minute buckets.
-
+
Should be called every 5 minutes with the last 5 minutes of data.
"""
if not minute_buckets:
return
-
+
# Group by 5-minute intervals
- five_min_groups: Dict[int, List[EnhancedMinuteBucket]] = defaultdict(list)
+ five_min_groups: dict[int, list[EnhancedMinuteBucket]] = defaultdict(list)
for bucket in minute_buckets:
five_min_start = self._floor_timestamp(bucket.start_ts, 300) # 300 seconds = 5 minutes
five_min_groups[five_min_start].append(bucket)
-
+
# Create aggregated buckets
for five_min_start, buckets in five_min_groups.items():
agg = self._aggregate_buckets(
buckets,
start_ts=five_min_start,
end_ts=five_min_start + 300,
- level=AggregationLevel.FIVE_MINUTE
+ level=AggregationLevel.FIVE_MINUTE,
)
self.five_minute_buckets.append(agg)
-
+
self._last_5min_aggregation = int(time.time())
-
- def aggregate_to_hourly(self, five_minute_buckets: Optional[List[AggregatedMetrics]] = None) -> None:
+
+ def aggregate_to_hourly(
+ self, five_minute_buckets: list[AggregatedMetrics] | None = None
+ ) -> None:
"""
Aggregate 5-minute buckets into hourly buckets.
-
+
Should be called every hour.
"""
# Use provided buckets or last 12 from deque (1 hour = 12 * 5-minute buckets)
if five_minute_buckets is None:
five_minute_buckets = list(self.five_minute_buckets)[-12:]
-
+
if not five_minute_buckets:
return
-
+
# Group by hour
- hourly_groups: Dict[int, List[AggregatedMetrics]] = defaultdict(list)
+ hourly_groups: dict[int, list[AggregatedMetrics]] = defaultdict(list)
for bucket in five_minute_buckets:
hour_start = self._floor_timestamp(bucket.start_ts, 3600) # 3600 seconds = 1 hour
hourly_groups[hour_start].append(bucket)
-
+
# Create aggregated buckets
for hour_start, buckets in hourly_groups.items():
agg = self._aggregate_aggregated_buckets(
- buckets,
- start_ts=hour_start,
- end_ts=hour_start + 3600,
- level=AggregationLevel.HOUR
+ buckets, start_ts=hour_start, end_ts=hour_start + 3600, level=AggregationLevel.HOUR
)
self.hourly_buckets.append(agg)
-
+
self._last_hourly_aggregation = int(time.time())
-
- def aggregate_to_daily(self, hourly_buckets: Optional[List[AggregatedMetrics]] = None) -> None:
+
+ def aggregate_to_daily(self, hourly_buckets: list[AggregatedMetrics] | None = None) -> None:
"""
Aggregate hourly buckets into daily buckets.
-
+
Should be called once per day.
"""
# Use provided buckets or last 24 from deque (1 day = 24 hourly buckets)
if hourly_buckets is None:
hourly_buckets = list(self.hourly_buckets)[-24:]
-
+
if not hourly_buckets:
return
-
+
# Group by day
- daily_groups: Dict[int, List[AggregatedMetrics]] = defaultdict(list)
+ daily_groups: dict[int, list[AggregatedMetrics]] = defaultdict(list)
for bucket in hourly_buckets:
day_start = self._floor_timestamp(bucket.start_ts, 86400) # 86400 seconds = 1 day
daily_groups[day_start].append(bucket)
-
+
# Create aggregated buckets
for day_start, buckets in daily_groups.items():
agg = self._aggregate_aggregated_buckets(
- buckets,
- start_ts=day_start,
- end_ts=day_start + 86400,
- level=AggregationLevel.DAY
+ buckets, start_ts=day_start, end_ts=day_start + 86400, level=AggregationLevel.DAY
)
self.daily_buckets.append(agg)
-
+
self._last_daily_aggregation = int(time.time())
-
+
def _aggregate_buckets(
self,
- buckets: List[EnhancedMinuteBucket],
+ buckets: list[EnhancedMinuteBucket],
start_ts: int,
end_ts: int,
- level: AggregationLevel
+ level: AggregationLevel,
) -> AggregatedMetrics:
"""Aggregate minute-level buckets into a single aggregated bucket."""
total_count = sum(b.count for b in buckets)
@@ -151,31 +148,31 @@ class AnalyticsAggregator:
total_ms = sum(b.total_ms for b in buckets)
total_bytes_in = sum(b.bytes_in for b in buckets)
total_bytes_out = sum(b.bytes_out for b in buckets)
-
+
# Merge status counts
- status_counts: Dict[int, int] = defaultdict(int)
+ status_counts: dict[int, int] = defaultdict(int)
for bucket in buckets:
for status, count in bucket.status_counts.items():
status_counts[status] += count
-
+
# Merge API counts
- api_counts: Dict[str, int] = defaultdict(int)
+ api_counts: dict[str, int] = defaultdict(int)
for bucket in buckets:
for api, count in bucket.api_counts.items():
api_counts[api] += count
-
+
# Collect all latencies for percentile calculation
- all_latencies: List[float] = []
+ all_latencies: list[float] = []
for bucket in buckets:
all_latencies.extend(list(bucket.latencies))
-
+
percentiles = PercentileMetrics.calculate(all_latencies) if all_latencies else None
-
+
# Count unique users across all buckets
unique_users = set()
for bucket in buckets:
unique_users.update(bucket.unique_users)
-
+
return AggregatedMetrics(
start_ts=start_ts,
end_ts=end_ts,
@@ -188,15 +185,11 @@ class AnalyticsAggregator:
unique_users=len(unique_users),
status_counts=dict(status_counts),
api_counts=dict(api_counts),
- percentiles=percentiles
+ percentiles=percentiles,
)
-
+
def _aggregate_aggregated_buckets(
- self,
- buckets: List[AggregatedMetrics],
- start_ts: int,
- end_ts: int,
- level: AggregationLevel
+ self, buckets: list[AggregatedMetrics], start_ts: int, end_ts: int, level: AggregationLevel
) -> AggregatedMetrics:
"""Aggregate already-aggregated buckets (5-min → hourly, hourly → daily)."""
total_count = sum(b.count for b in buckets)
@@ -204,26 +197,26 @@ class AnalyticsAggregator:
total_ms = sum(b.total_ms for b in buckets)
total_bytes_in = sum(b.bytes_in for b in buckets)
total_bytes_out = sum(b.bytes_out for b in buckets)
-
+
# Merge status counts
- status_counts: Dict[int, int] = defaultdict(int)
+ status_counts: dict[int, int] = defaultdict(int)
for bucket in buckets:
for status, count in bucket.status_counts.items():
status_counts[status] += count
-
+
# Merge API counts
- api_counts: Dict[str, int] = defaultdict(int)
+ api_counts: dict[str, int] = defaultdict(int)
for bucket in buckets:
for api, count in bucket.api_counts.items():
api_counts[api] += count
-
+
# For percentiles, we'll use weighted average (not perfect, but acceptable)
# Ideally, we'd re-calculate from raw latencies, but those aren't stored in aggregated buckets
weighted_percentiles = self._weighted_average_percentiles(buckets)
-
+
# Unique users: sum (may overcount, but acceptable for aggregated data)
unique_users = sum(b.unique_users for b in buckets)
-
+
return AggregatedMetrics(
start_ts=start_ts,
end_ts=end_ts,
@@ -236,28 +229,30 @@ class AnalyticsAggregator:
unique_users=unique_users,
status_counts=dict(status_counts),
api_counts=dict(api_counts),
- percentiles=weighted_percentiles
+ percentiles=weighted_percentiles,
)
-
- def _weighted_average_percentiles(self, buckets: List[AggregatedMetrics]) -> Optional[PercentileMetrics]:
+
+ def _weighted_average_percentiles(
+ self, buckets: list[AggregatedMetrics]
+ ) -> PercentileMetrics | None:
"""Calculate weighted average of percentiles from aggregated buckets."""
if not buckets:
return None
-
+
total_count = sum(b.count for b in buckets if b.count > 0)
if total_count == 0:
return None
-
+
# Weighted average of each percentile
p50_sum = sum(b.percentiles.p50 * b.count for b in buckets if b.percentiles and b.count > 0)
p75_sum = sum(b.percentiles.p75 * b.count for b in buckets if b.percentiles and b.count > 0)
p90_sum = sum(b.percentiles.p90 * b.count for b in buckets if b.percentiles and b.count > 0)
p95_sum = sum(b.percentiles.p95 * b.count for b in buckets if b.percentiles and b.count > 0)
p99_sum = sum(b.percentiles.p99 * b.count for b in buckets if b.percentiles and b.count > 0)
-
+
min_val = min(b.percentiles.min for b in buckets if b.percentiles)
max_val = max(b.percentiles.max for b in buckets if b.percentiles)
-
+
return PercentileMetrics(
p50=p50_sum / total_count,
p75=p75_sum / total_count,
@@ -265,25 +260,22 @@ class AnalyticsAggregator:
p95=p95_sum / total_count,
p99=p99_sum / total_count,
min=min_val,
- max=max_val
+ max=max_val,
)
-
+
def get_buckets_for_range(
- self,
- start_ts: int,
- end_ts: int,
- preferred_level: Optional[AggregationLevel] = None
- ) -> List[AggregatedMetrics]:
+ self, start_ts: int, end_ts: int, preferred_level: AggregationLevel | None = None
+ ) -> list[AggregatedMetrics]:
"""
Get the most appropriate aggregation level for a time range.
-
+
Automatically selects:
- 5-minute buckets for ranges < 24 hours
- Hourly buckets for ranges < 7 days
- Daily buckets for ranges >= 7 days
"""
range_seconds = end_ts - start_ts
-
+
# Determine best aggregation level
if preferred_level:
level = preferred_level
@@ -293,7 +285,7 @@ class AnalyticsAggregator:
level = AggregationLevel.HOUR
else:
level = AggregationLevel.DAY
-
+
# Select appropriate bucket collection
if level == AggregationLevel.FIVE_MINUTE:
buckets = list(self.five_minute_buckets)
@@ -301,21 +293,21 @@ class AnalyticsAggregator:
buckets = list(self.hourly_buckets)
else:
buckets = list(self.daily_buckets)
-
+
# Filter to time range
return [b for b in buckets if b.start_ts >= start_ts and b.end_ts <= end_ts]
-
- def should_aggregate(self) -> Dict[str, bool]:
+
+ def should_aggregate(self) -> dict[str, bool]:
"""Check if any aggregation jobs should run."""
now = int(time.time())
-
+
return {
'5minute': (now - self._last_5min_aggregation) >= 300, # Every 5 minutes
'hourly': (now - self._last_hourly_aggregation) >= 3600, # Every hour
- 'daily': (now - self._last_daily_aggregation) >= 86400 # Every day
+ 'daily': (now - self._last_daily_aggregation) >= 86400, # Every day
}
-
- def to_dict(self) -> Dict:
+
+ def to_dict(self) -> dict:
"""Serialize aggregator state for persistence."""
return {
'five_minute_buckets': [b.to_dict() for b in self.five_minute_buckets],
@@ -323,10 +315,10 @@ class AnalyticsAggregator:
'daily_buckets': [b.to_dict() for b in self.daily_buckets],
'last_5min_aggregation': self._last_5min_aggregation,
'last_hourly_aggregation': self._last_hourly_aggregation,
- 'last_daily_aggregation': self._last_daily_aggregation
+ 'last_daily_aggregation': self._last_daily_aggregation,
}
-
- def load_dict(self, data: Dict) -> None:
+
+ def load_dict(self, data: dict) -> None:
"""Load aggregator state from persistence."""
# Note: This is a simplified version. Full implementation would
# reconstruct AggregatedMetrics objects from dictionaries
diff --git a/backend-services/utils/analytics_scheduler.py b/backend-services/utils/analytics_scheduler.py
index d5d89f9..2a4d7d1 100644
--- a/backend-services/utils/analytics_scheduler.py
+++ b/backend-services/utils/analytics_scheduler.py
@@ -12,10 +12,9 @@ Runs periodic tasks to:
import asyncio
import logging
import time
-from typing import Optional
-from utils.enhanced_metrics_util import enhanced_metrics_store
from utils.analytics_aggregator import analytics_aggregator
+from utils.enhanced_metrics_util import enhanced_metrics_store
logger = logging.getLogger('doorman.analytics')
@@ -23,7 +22,7 @@ logger = logging.getLogger('doorman.analytics')
class AnalyticsScheduler:
"""
Background task scheduler for analytics aggregation.
-
+
Runs aggregation jobs at appropriate intervals:
- 5-minute aggregation: Every 5 minutes
- Hourly aggregation: Every hour
@@ -31,31 +30,31 @@ class AnalyticsScheduler:
- Persistence: Every 5 minutes
- Cleanup: Once per day
"""
-
+
def __init__(self):
self.running = False
- self._task: Optional[asyncio.Task] = None
+ self._task: asyncio.Task | None = None
self._last_5min = 0
self._last_hourly = 0
self._last_daily = 0
self._last_persist = 0
self._last_cleanup = 0
-
+
async def start(self):
"""Start the scheduler."""
if self.running:
- logger.warning("Analytics scheduler already running")
+ logger.warning('Analytics scheduler already running')
return
-
+
self.running = True
self._task = asyncio.create_task(self._run_loop())
- logger.info("Analytics scheduler started")
-
+ logger.info('Analytics scheduler started')
+
async def stop(self):
"""Stop the scheduler."""
if not self.running:
return
-
+
self.running = False
if self._task:
self._task.cancel()
@@ -63,9 +62,9 @@ class AnalyticsScheduler:
await self._task
except asyncio.CancelledError:
pass
-
- logger.info("Analytics scheduler stopped")
-
+
+ logger.info('Analytics scheduler stopped')
+
async def _run_loop(self):
"""Main scheduler loop."""
while self.running:
@@ -76,136 +75,139 @@ class AnalyticsScheduler:
except asyncio.CancelledError:
break
except Exception as e:
- logger.error(f"Error in analytics scheduler: {str(e)}", exc_info=True)
+ logger.error(f'Error in analytics scheduler: {str(e)}', exc_info=True)
await asyncio.sleep(60)
-
+
async def _check_and_run_jobs(self):
"""Check if any jobs should run and execute them."""
now = int(time.time())
-
+
# 5-minute aggregation (every 5 minutes)
if now - self._last_5min >= 300:
await self._run_5minute_aggregation()
self._last_5min = now
-
+
# Hourly aggregation (every hour)
if now - self._last_hourly >= 3600:
await self._run_hourly_aggregation()
self._last_hourly = now
-
+
# Daily aggregation (once per day)
if now - self._last_daily >= 86400:
await self._run_daily_aggregation()
self._last_daily = now
-
+
# Persist metrics (every 5 minutes)
if now - self._last_persist >= 300:
await self._persist_metrics()
self._last_persist = now
-
+
# Cleanup old data (once per day)
if now - self._last_cleanup >= 86400:
await self._cleanup_old_data()
self._last_cleanup = now
-
+
async def _run_5minute_aggregation(self):
"""Aggregate last 5 minutes of data into 5-minute buckets."""
try:
- logger.info("Running 5-minute aggregation")
+ logger.info('Running 5-minute aggregation')
start_time = time.time()
-
+
# Get last 5 minutes of buckets
minute_buckets = list(enhanced_metrics_store._buckets)[-5:]
-
+
if minute_buckets:
analytics_aggregator.aggregate_to_5minute(minute_buckets)
-
+
duration_ms = (time.time() - start_time) * 1000
- logger.info(f"5-minute aggregation completed in {duration_ms:.2f}ms")
+ logger.info(f'5-minute aggregation completed in {duration_ms:.2f}ms')
else:
- logger.debug("No buckets to aggregate (5-minute)")
-
+ logger.debug('No buckets to aggregate (5-minute)')
+
except Exception as e:
- logger.error(f"Failed to run 5-minute aggregation: {str(e)}", exc_info=True)
-
+ logger.error(f'Failed to run 5-minute aggregation: {str(e)}', exc_info=True)
+
async def _run_hourly_aggregation(self):
"""Aggregate last hour of 5-minute data into hourly buckets."""
try:
- logger.info("Running hourly aggregation")
+ logger.info('Running hourly aggregation')
start_time = time.time()
-
+
analytics_aggregator.aggregate_to_hourly()
-
+
duration_ms = (time.time() - start_time) * 1000
- logger.info(f"Hourly aggregation completed in {duration_ms:.2f}ms")
-
+ logger.info(f'Hourly aggregation completed in {duration_ms:.2f}ms')
+
except Exception as e:
- logger.error(f"Failed to run hourly aggregation: {str(e)}", exc_info=True)
-
+ logger.error(f'Failed to run hourly aggregation: {str(e)}', exc_info=True)
+
async def _run_daily_aggregation(self):
"""Aggregate last day of hourly data into daily buckets."""
try:
- logger.info("Running daily aggregation")
+ logger.info('Running daily aggregation')
start_time = time.time()
-
+
analytics_aggregator.aggregate_to_daily()
-
+
duration_ms = (time.time() - start_time) * 1000
- logger.info(f"Daily aggregation completed in {duration_ms:.2f}ms")
-
+ logger.info(f'Daily aggregation completed in {duration_ms:.2f}ms')
+
except Exception as e:
- logger.error(f"Failed to run daily aggregation: {str(e)}", exc_info=True)
-
+ logger.error(f'Failed to run daily aggregation: {str(e)}', exc_info=True)
+
async def _persist_metrics(self):
"""Save metrics to disk for persistence."""
try:
- logger.debug("Persisting metrics to disk")
+ logger.debug('Persisting metrics to disk')
start_time = time.time()
-
+
# Save minute-level metrics
enhanced_metrics_store.save_to_file('platform-logs/enhanced_metrics.json')
-
+
# Save aggregated metrics
import json
import os
-
+
aggregated_data = analytics_aggregator.to_dict()
path = 'platform-logs/aggregated_metrics.json'
-
+
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp = path + '.tmp'
with open(tmp, 'w', encoding='utf-8') as f:
json.dump(aggregated_data, f)
os.replace(tmp, path)
-
+
duration_ms = (time.time() - start_time) * 1000
- logger.debug(f"Metrics persisted in {duration_ms:.2f}ms")
-
+ logger.debug(f'Metrics persisted in {duration_ms:.2f}ms')
+
except Exception as e:
- logger.error(f"Failed to persist metrics: {str(e)}", exc_info=True)
-
+ logger.error(f'Failed to persist metrics: {str(e)}', exc_info=True)
+
async def _cleanup_old_data(self):
"""Remove data beyond retention policy."""
try:
- logger.info("Running data cleanup")
+ logger.info('Running data cleanup')
start_time = time.time()
-
+
now = int(time.time())
-
+
# Minute-level: Keep only last 24 hours
cutoff_minute = now - (24 * 3600)
- while enhanced_metrics_store._buckets and enhanced_metrics_store._buckets[0].start_ts < cutoff_minute:
+ while (
+ enhanced_metrics_store._buckets
+ and enhanced_metrics_store._buckets[0].start_ts < cutoff_minute
+ ):
enhanced_metrics_store._buckets.popleft()
-
+
# 5-minute: Keep only last 7 days (handled by deque maxlen)
# Hourly: Keep only last 30 days (handled by deque maxlen)
# Daily: Keep only last 90 days (handled by deque maxlen)
-
+
duration_ms = (time.time() - start_time) * 1000
- logger.info(f"Data cleanup completed in {duration_ms:.2f}ms")
-
+ logger.info(f'Data cleanup completed in {duration_ms:.2f}ms')
+
except Exception as e:
- logger.error(f"Failed to cleanup old data: {str(e)}", exc_info=True)
+ logger.error(f'Failed to cleanup old data: {str(e)}', exc_info=True)
# Global scheduler instance
diff --git a/backend-services/utils/api_resolution_util.py b/backend-services/utils/api_resolution_util.py
index 562852d..abe31be 100644
--- a/backend-services/utils/api_resolution_util.py
+++ b/backend-services/utils/api_resolution_util.py
@@ -5,12 +5,14 @@ Reduces duplicate code for GraphQL/gRPC API name/version parsing.
"""
import re
-from typing import Tuple, Optional
-from fastapi import Request, HTTPException
-from utils.doorman_cache_util import doorman_cache
-from utils import api_util
-def parse_graphql_grpc_path(path: str, request: Request) -> Tuple[str, str, str]:
+from fastapi import HTTPException, Request
+
+from utils import api_util
+from utils.doorman_cache_util import doorman_cache
+
+
+def parse_graphql_grpc_path(path: str, request: Request) -> tuple[str, str, str]:
"""Parse GraphQL/gRPC path to extract API name and version.
Args:
@@ -38,7 +40,8 @@ def parse_graphql_grpc_path(path: str, request: Request) -> Tuple[str, str, str]
return api_name, api_version, api_path
-async def resolve_api(api_name: str, api_version: str) -> Optional[dict]:
+
+async def resolve_api(api_name: str, api_version: str) -> dict | None:
"""Resolve API from cache or database.
Args:
@@ -52,7 +55,10 @@ async def resolve_api(api_name: str, api_version: str) -> Optional[dict]:
api_key = doorman_cache.get_cache('api_id_cache', api_path)
return await api_util.get_api(api_key, api_path)
-async def resolve_api_from_request(path: str, request: Request) -> Tuple[Optional[dict], str, str, str]:
+
+async def resolve_api_from_request(
+ path: str, request: Request
+) -> tuple[dict | None, str, str, str]:
"""Parse path, extract API name/version, and resolve API in one call.
Args:
diff --git a/backend-services/utils/api_util.py b/backend-services/utils/api_util.py
index aab2c7a..5f55f25 100644
--- a/backend-services/utils/api_util.py
+++ b/backend-services/utils/api_util.py
@@ -1,10 +1,9 @@
-from typing import Optional, Dict
-
-from utils.doorman_cache_util import doorman_cache
+from utils.async_db import db_find_list, db_find_one
from utils.database_async import api_collection, endpoint_collection
-from utils.async_db import db_find_one, db_find_list
+from utils.doorman_cache_util import doorman_cache
-async def get_api(api_key: Optional[str], api_name_version: str) -> Optional[Dict]:
+
+async def get_api(api_key: str | None, api_name_version: str) -> dict | None:
"""Get API document by key or name/version.
Args:
@@ -25,7 +24,8 @@ async def get_api(api_key: Optional[str], api_name_version: str) -> Optional[Dic
doorman_cache.set_cache('api_id_cache', api_name_version, api_key)
return api
-async def get_api_endpoints(api_id: str) -> Optional[list]:
+
+async def get_api_endpoints(api_id: str) -> list | None:
"""Get list of endpoints for an API.
Args:
@@ -40,13 +40,14 @@ async def get_api_endpoints(api_id: str) -> Optional[list]:
if not endpoints_list:
return None
endpoints = [
- f"{endpoint.get('endpoint_method')}{endpoint.get('endpoint_uri')}"
+ f'{endpoint.get("endpoint_method")}{endpoint.get("endpoint_uri")}'
for endpoint in endpoints_list
]
doorman_cache.set_cache('api_endpoint_cache', api_id, endpoints)
return endpoints
-async def get_endpoint(api: Dict, method: str, endpoint_uri: str) -> Optional[Dict]:
+
+async def get_endpoint(api: dict, method: str, endpoint_uri: str) -> dict | None:
"""Return the endpoint document for a given API, method, and uri.
Uses the same cache key pattern as EndpointService to avoid duplicate queries.
@@ -57,12 +58,15 @@ async def get_endpoint(api: Dict, method: str, endpoint_uri: str) -> Optional[Di
endpoint = doorman_cache.get_cache('endpoint_cache', cache_key)
if endpoint:
return endpoint
- doc = await db_find_one(endpoint_collection, {
- 'api_name': api_name,
- 'api_version': api_version,
- 'endpoint_uri': endpoint_uri,
- 'endpoint_method': method
- })
+ doc = await db_find_one(
+ endpoint_collection,
+ {
+ 'api_name': api_name,
+ 'api_version': api_version,
+ 'endpoint_uri': endpoint_uri,
+ 'endpoint_method': method,
+ },
+ )
if not doc:
return None
doc.pop('_id', None)
diff --git a/backend-services/utils/async_db.py b/backend-services/utils/async_db.py
index b4f40f0..835cb3a 100644
--- a/backend-services/utils/async_db.py
+++ b/backend-services/utils/async_db.py
@@ -9,34 +9,39 @@ from __future__ import annotations
import asyncio
import inspect
-from typing import Any, Dict, List, Optional
+from typing import Any
-async def db_find_one(collection: Any, query: Dict[str, Any]) -> Optional[Dict[str, Any]]:
- fn = getattr(collection, 'find_one')
+
+async def db_find_one(collection: Any, query: dict[str, Any]) -> dict[str, Any] | None:
+ fn = collection.find_one
if inspect.iscoroutinefunction(fn):
return await fn(query)
return await asyncio.to_thread(fn, query)
-async def db_insert_one(collection: Any, doc: Dict[str, Any]) -> Any:
- fn = getattr(collection, 'insert_one')
+
+async def db_insert_one(collection: Any, doc: dict[str, Any]) -> Any:
+ fn = collection.insert_one
if inspect.iscoroutinefunction(fn):
return await fn(doc)
return await asyncio.to_thread(fn, doc)
-async def db_update_one(collection: Any, query: Dict[str, Any], update: Dict[str, Any]) -> Any:
- fn = getattr(collection, 'update_one')
+
+async def db_update_one(collection: Any, query: dict[str, Any], update: dict[str, Any]) -> Any:
+ fn = collection.update_one
if inspect.iscoroutinefunction(fn):
return await fn(query, update)
return await asyncio.to_thread(fn, query, update)
-async def db_delete_one(collection: Any, query: Dict[str, Any]) -> Any:
- fn = getattr(collection, 'delete_one')
+
+async def db_delete_one(collection: Any, query: dict[str, Any]) -> Any:
+ fn = collection.delete_one
if inspect.iscoroutinefunction(fn):
return await fn(query)
return await asyncio.to_thread(fn, query)
-async def db_find_list(collection: Any, query: Dict[str, Any]) -> List[Dict[str, Any]]:
- find = getattr(collection, 'find')
+
+async def db_find_list(collection: Any, query: dict[str, Any]) -> list[dict[str, Any]]:
+ find = collection.find
cursor = find(query)
to_list = getattr(cursor, 'to_list', None)
if callable(to_list):
@@ -44,4 +49,3 @@ async def db_find_list(collection: Any, query: Dict[str, Any]) -> List[Dict[str,
return await to_list(length=None)
return await asyncio.to_thread(to_list, None)
return await asyncio.to_thread(lambda: list(cursor))
-
diff --git a/backend-services/utils/audit_util.py b/backend-services/utils/audit_util.py
index be1155b..b721bf8 100644
--- a/backend-services/utils/audit_util.py
+++ b/backend-services/utils/audit_util.py
@@ -1,38 +1,72 @@
-import logging
import json
+import logging
import re
_logger = logging.getLogger('doorman.audit')
SENSITIVE_KEYS = {
- 'password', 'passwd', 'pwd',
- 'token', 'access_token', 'refresh_token', 'bearer_token', 'auth_token',
- 'authorization', 'auth', 'bearer',
-
- 'api_key', 'apikey', 'api-key',
- 'user_api_key', 'user-api-key',
- 'secret', 'client_secret', 'client-secret', 'api_secret', 'api-secret',
- 'private_key', 'private-key', 'privatekey',
-
- 'session', 'session_id', 'session-id', 'sessionid',
- 'csrf_token', 'csrf-token', 'csrftoken',
- 'x-csrf-token', 'xsrf_token', 'xsrf-token',
-
- 'cookie', 'set-cookie', 'set_cookie',
- 'access_token_cookie', 'refresh_token_cookie',
-
- 'connection_string', 'connection-string', 'connectionstring',
- 'database_password', 'db_password', 'db_passwd',
- 'mongo_password', 'redis_password',
-
- 'id_token', 'id-token',
- 'jwt', 'jwt_token',
- 'oauth_token', 'oauth-token',
- 'code_verifier', 'code-verifier',
-
- 'encryption_key', 'encryption-key',
- 'signing_key', 'signing-key',
- 'key', 'private', 'secret_key',
+ 'password',
+ 'passwd',
+ 'pwd',
+ 'token',
+ 'access_token',
+ 'refresh_token',
+ 'bearer_token',
+ 'auth_token',
+ 'authorization',
+ 'auth',
+ 'bearer',
+ 'api_key',
+ 'apikey',
+ 'api-key',
+ 'user_api_key',
+ 'user-api-key',
+ 'secret',
+ 'client_secret',
+ 'client-secret',
+ 'api_secret',
+ 'api-secret',
+ 'private_key',
+ 'private-key',
+ 'privatekey',
+ 'session',
+ 'session_id',
+ 'session-id',
+ 'sessionid',
+ 'csrf_token',
+ 'csrf-token',
+ 'csrftoken',
+ 'x-csrf-token',
+ 'xsrf_token',
+ 'xsrf-token',
+ 'cookie',
+ 'set-cookie',
+ 'set_cookie',
+ 'access_token_cookie',
+ 'refresh_token_cookie',
+ 'connection_string',
+ 'connection-string',
+ 'connectionstring',
+ 'database_password',
+ 'db_password',
+ 'db_passwd',
+ 'mongo_password',
+ 'redis_password',
+ 'id_token',
+ 'id-token',
+ 'jwt',
+ 'jwt_token',
+ 'oauth_token',
+ 'oauth-token',
+ 'code_verifier',
+ 'code-verifier',
+ 'encryption_key',
+ 'encryption-key',
+ 'signing_key',
+ 'signing-key',
+ 'key',
+ 'private',
+ 'secret_key',
}
SENSITIVE_VALUE_PATTERNS = [
@@ -44,14 +78,18 @@ SENSITIVE_VALUE_PATTERNS = [
re.compile(r'^-----BEGIN[A-Z\s]+PRIVATE KEY-----', re.DOTALL),
]
+
def _is_sensitive_key(key: str) -> bool:
"""Check if a key name indicates sensitive data."""
try:
lk = str(key).lower().replace('-', '_')
- return lk in SENSITIVE_KEYS or any(s in lk for s in ['password', 'secret', 'token', 'key', 'auth'])
+ return lk in SENSITIVE_KEYS or any(
+ s in lk for s in ['password', 'secret', 'token', 'key', 'auth']
+ )
except Exception:
return False
+
def _is_sensitive_value(value) -> bool:
"""Check if a value looks like sensitive data (even if key isn't obviously sensitive)."""
try:
@@ -61,6 +99,7 @@ def _is_sensitive_value(value) -> bool:
except Exception:
return False
+
def _sanitize(obj):
"""Recursively sanitize objects to redact sensitive data.
@@ -88,7 +127,10 @@ def _sanitize(obj):
except Exception:
return None
-def audit(request=None, actor=None, action=None, target=None, status=None, details=None, request_id=None):
+
+def audit(
+ request=None, actor=None, action=None, target=None, status=None, details=None, request_id=None
+):
event = {
'actor': actor,
'action': action,
@@ -104,6 +146,4 @@ def audit(request=None, actor=None, action=None, target=None, status=None, detai
event['request_id'] = request_id
_logger.info(json.dumps(event, separators=(',', ':')))
except Exception:
-
pass
-
diff --git a/backend-services/utils/auth_blacklist.py b/backend-services/utils/auth_blacklist.py
index 2e309ab..82c1969 100644
--- a/backend-services/utils/auth_blacklist.py
+++ b/backend-services/utils/auth_blacklist.py
@@ -44,11 +44,10 @@ see limit_throttle_util.py which uses the async Redis client (app.state.redis).
- doorman.py app_lifespan() for production Redis requirement enforcement
"""
-from datetime import datetime, timedelta
import heapq
import os
-from typing import Optional
import time
+from datetime import datetime, timedelta
try:
from utils.database import database, revocations_collection
@@ -67,6 +66,7 @@ revoked_all_users = set()
_redis_client = None
_redis_enabled = False
+
def _init_redis_if_possible():
global _redis_client, _redis_enabled
if _redis_client is not None:
@@ -84,7 +84,9 @@ def _init_redis_if_possible():
host = os.getenv('REDIS_HOST', 'localhost')
port = int(os.getenv('REDIS_PORT', 6379))
db = int(os.getenv('REDIS_DB', 0))
- pool = redis.ConnectionPool(host=host, port=port, db=db, decode_responses=True, max_connections=100)
+ pool = redis.ConnectionPool(
+ host=host, port=port, db=db, decode_responses=True, max_connections=100
+ )
_redis_client = redis.StrictRedis(connection_pool=pool)
try:
_redis_client.ping()
@@ -96,23 +98,36 @@ def _init_redis_if_possible():
_redis_client = None
_redis_enabled = False
+
def _revoked_jti_key(username: str, jti: str) -> str:
return f'jwt:revoked:{username}:{jti}'
+
def _revoke_all_key(username: str) -> str:
return f'jwt:revoke_all:{username}'
+
def revoke_all_for_user(username: str):
"""Mark all tokens for a user as revoked (durable if Redis is enabled)."""
_init_redis_if_possible()
try:
- if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None:
+ if (
+ database is not None
+ and getattr(database, 'memory_only', False)
+ and revocations_collection is not None
+ ):
try:
- existing = revocations_collection.find_one({'type': 'revoke_all', 'username': username})
+ existing = revocations_collection.find_one(
+ {'type': 'revoke_all', 'username': username}
+ )
if existing:
- revocations_collection.update_one({'_id': existing.get('_id')}, {'$set': {'revoke_all': True}})
+ revocations_collection.update_one(
+ {'_id': existing.get('_id')}, {'$set': {'revoke_all': True}}
+ )
else:
- revocations_collection.insert_one({'type': 'revoke_all', 'username': username, 'revoke_all': True})
+ revocations_collection.insert_one(
+ {'type': 'revoke_all', 'username': username, 'revoke_all': True}
+ )
except Exception:
revoked_all_users.add(username)
return
@@ -123,11 +138,16 @@ def revoke_all_for_user(username: str):
except Exception:
revoked_all_users.add(username)
+
def unrevoke_all_for_user(username: str):
"""Clear 'revoke all' for a user (durable if Redis is enabled)."""
_init_redis_if_possible()
try:
- if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None:
+ if (
+ database is not None
+ and getattr(database, 'memory_only', False)
+ and revocations_collection is not None
+ ):
try:
revocations_collection.delete_one({'type': 'revoke_all', 'username': username})
except Exception:
@@ -140,11 +160,16 @@ def unrevoke_all_for_user(username: str):
except Exception:
revoked_all_users.discard(username)
+
def is_user_revoked(username: str) -> bool:
"""Return True if user is under 'revoke all' (durable check if Redis enabled)."""
_init_redis_if_possible()
try:
- if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None:
+ if (
+ database is not None
+ and getattr(database, 'memory_only', False)
+ and revocations_collection is not None
+ ):
try:
doc = revocations_collection.find_one({'type': 'revoke_all', 'username': username})
return bool(doc and doc.get('revoke_all'))
@@ -156,6 +181,7 @@ def is_user_revoked(username: str) -> bool:
except Exception:
return username in revoked_all_users
+
class TimedHeap:
def __init__(self, purge_after=timedelta(hours=1)):
self.heap = []
@@ -182,7 +208,8 @@ class TimedHeap:
return self.heap[0][1]
return None
-def add_revoked_jti(username: str, jti: str, ttl_seconds: Optional[int] = None):
+
+def add_revoked_jti(username: str, jti: str, ttl_seconds: int | None = None):
"""Add a specific JTI to the revocation list.
- If Redis is enabled, store key with TTL so it auto-expires.
@@ -192,14 +219,26 @@ def add_revoked_jti(username: str, jti: str, ttl_seconds: Optional[int] = None):
return
_init_redis_if_possible()
try:
- if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None:
+ if (
+ database is not None
+ and getattr(database, 'memory_only', False)
+ and revocations_collection is not None
+ ):
try:
- exp = int(time.time()) + (max(1, int(ttl_seconds)) if ttl_seconds is not None else 3600)
- existing = revocations_collection.find_one({'type': 'jti', 'username': username, 'jti': jti})
+ exp = int(time.time()) + (
+ max(1, int(ttl_seconds)) if ttl_seconds is not None else 3600
+ )
+ existing = revocations_collection.find_one(
+ {'type': 'jti', 'username': username, 'jti': jti}
+ )
if existing:
- revocations_collection.update_one({'_id': existing.get('_id')}, {'$set': {'expires_at': exp}})
+ revocations_collection.update_one(
+ {'_id': existing.get('_id')}, {'$set': {'expires_at': exp}}
+ )
else:
- revocations_collection.insert_one({'type': 'jti', 'username': username, 'jti': jti, 'expires_at': exp})
+ revocations_collection.insert_one(
+ {'type': 'jti', 'username': username, 'jti': jti, 'expires_at': exp}
+ )
return
except Exception:
pass
@@ -215,15 +254,22 @@ def add_revoked_jti(username: str, jti: str, ttl_seconds: Optional[int] = None):
jwt_blacklist[username] = th
th.push(jti)
+
def is_jti_revoked(username: str, jti: str) -> bool:
"""Check whether a specific JTI is revoked (durable if Redis enabled)."""
if not username or not jti:
return False
_init_redis_if_possible()
try:
- if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None:
+ if (
+ database is not None
+ and getattr(database, 'memory_only', False)
+ and revocations_collection is not None
+ ):
try:
- doc = revocations_collection.find_one({'type': 'jti', 'username': username, 'jti': jti})
+ doc = revocations_collection.find_one(
+ {'type': 'jti', 'username': username, 'jti': jti}
+ )
if not doc:
pass
else:
@@ -248,13 +294,18 @@ def is_jti_revoked(username: str, jti: str) -> bool:
return True
return False
+
async def purge_expired_tokens():
"""No-op when Redis-backed; purge DB/in-memory when memory-only."""
_init_redis_if_possible()
if _redis_enabled:
return
try:
- if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None:
+ if (
+ database is not None
+ and getattr(database, 'memory_only', False)
+ and revocations_collection is not None
+ ):
now = int(time.time())
to_delete = []
for d in revocations_collection.find({'type': 'jti'}):
diff --git a/backend-services/utils/auth_util.py b/backend-services/utils/auth_util.py
index 0bfc06a..8fd8ac6 100644
--- a/backend-services/utils/auth_util.py
+++ b/backend-services/utils/auth_util.py
@@ -7,32 +7,32 @@ See https://github.com/pypeople-dev/doorman for more information
from datetime import datetime, timedelta
try:
-
from datetime import UTC
except Exception:
- from datetime import timezone as _timezone
- UTC = _timezone.utc
+ UTC = UTC
+import asyncio
+import logging
import os
import uuid
+
from fastapi import HTTPException, Request
-from jose import jwt, JWTError
-import asyncio
+from jose import JWTError, jwt
-from utils.auth_blacklist import is_user_revoked, is_jti_revoked
-from utils.database import user_collection, role_collection
+from utils.auth_blacklist import is_jti_revoked, is_user_revoked
+from utils.database import role_collection, user_collection
from utils.doorman_cache_util import doorman_cache
-import logging
-
logger = logging.getLogger('doorman.gateway')
SECRET_KEY = os.getenv('JWT_SECRET_KEY')
ALGORITHM = 'HS256'
+
def is_jwt_configured() -> bool:
"""Return True if a JWT secret key is configured."""
return bool(os.getenv('JWT_SECRET_KEY'))
+
def _read_int_env(name: str, default: int) -> int:
try:
raw = os.getenv(name)
@@ -47,22 +47,39 @@ def _read_int_env(name: str, default: int) -> int:
logger.warning(f'Invalid value for {name}; using default {default}')
return default
+
def _normalize_unit(unit: str) -> str:
u = (unit or '').strip().lower()
mapping = {
- 's': 'seconds', 'sec': 'seconds', 'second': 'seconds', 'seconds': 'seconds',
- 'm': 'minutes', 'min': 'minutes', 'minute': 'minutes', 'minutes': 'minutes',
- 'h': 'hours', 'hr': 'hours', 'hour': 'hours', 'hours': 'hours',
- 'd': 'days', 'day': 'days', 'days': 'days',
- 'w': 'weeks', 'wk': 'weeks', 'week': 'weeks', 'weeks': 'weeks',
+ 's': 'seconds',
+ 'sec': 'seconds',
+ 'second': 'seconds',
+ 'seconds': 'seconds',
+ 'm': 'minutes',
+ 'min': 'minutes',
+ 'minute': 'minutes',
+ 'minutes': 'minutes',
+ 'h': 'hours',
+ 'hr': 'hours',
+ 'hour': 'hours',
+ 'hours': 'hours',
+ 'd': 'days',
+ 'day': 'days',
+ 'days': 'days',
+ 'w': 'weeks',
+ 'wk': 'weeks',
+ 'week': 'weeks',
+ 'weeks': 'weeks',
}
return mapping.get(u, 'minutes')
-def _expiry_from_env(value_key: str, unit_key: str, default_value: int, default_unit: str) -> timedelta:
+
+def _expiry_from_env(
+ value_key: str, unit_key: str, default_value: int, default_unit: str
+) -> timedelta:
value = _read_int_env(value_key, default_value)
unit = _normalize_unit(os.getenv(unit_key, default_unit))
try:
-
return timedelta(**{unit: value})
except Exception:
logger.warning(
@@ -70,6 +87,7 @@ def _expiry_from_env(value_key: str, unit_key: str, default_value: int, default_
)
return timedelta(**{_normalize_unit(default_unit): default_value})
+
async def validate_csrf_double_submit(header_token: str, cookie_token: str) -> bool:
try:
if not header_token or not cookie_token:
@@ -78,6 +96,7 @@ async def validate_csrf_double_submit(header_token: str, cookie_token: str) -> b
except Exception:
return False
+
async def auth_required(request: Request) -> dict:
"""Validate JWT token and CSRF for HTTPS
@@ -96,10 +115,7 @@ async def auth_required(request: Request) -> dict:
raise HTTPException(status_code=401, detail='Invalid CSRF token')
try:
payload = jwt.decode(
- token,
- SECRET_KEY,
- algorithms=[ALGORITHM],
- options={'verify_signature': True}
+ token, SECRET_KEY, algorithms=[ALGORITHM], options={'verify_signature': True}
)
username = payload.get('sub')
jti = payload.get('jti')
@@ -112,8 +128,10 @@ async def auth_required(request: Request) -> dict:
user = await asyncio.to_thread(user_collection.find_one, {'username': username})
if not user:
raise HTTPException(status_code=404, detail='User not found')
- if user.get('_id'): del user['_id']
- if user.get('password'): del user['password']
+ if user.get('_id'):
+ del user['_id']
+ if user.get('password'):
+ del user['password']
doorman_cache.set_cache('user_cache', username, user)
if not user:
raise HTTPException(status_code=404, detail='User not found')
@@ -127,6 +145,7 @@ async def auth_required(request: Request) -> dict:
logger.error(f'Unexpected error in auth_required: {str(e)}')
raise HTTPException(status_code=401, detail='Unauthorized')
+
def create_access_token(data: dict, refresh: bool = False) -> str:
"""Create a JWT access token with user permissions.
@@ -153,8 +172,10 @@ def create_access_token(data: dict, refresh: bool = False) -> str:
if not user:
user = user_collection.find_one({'username': username})
if user:
- if user.get('_id'): del user['_id']
- if user.get('password'): del user['password']
+ if user.get('_id'):
+ del user['_id']
+ if user.get('password'):
+ del user['password']
doorman_cache.set_cache('user_cache', username, user)
if not user:
@@ -168,7 +189,8 @@ def create_access_token(data: dict, refresh: bool = False) -> str:
if not role:
role = role_collection.find_one({'role_name': role_name})
if role:
- if role.get('_id'): del role['_id']
+ if role.get('_id'):
+ del role['_id']
doorman_cache.set_cache('role_cache', role_name, role)
accesses = {
@@ -186,11 +208,9 @@ def create_access_token(data: dict, refresh: bool = False) -> str:
'view_logs': role.get('view_logs', False) if role else False,
}
- to_encode.update({
- 'exp': datetime.now(UTC) + expire,
- 'jti': str(uuid.uuid4()),
- 'accesses': accesses
- })
+ to_encode.update(
+ {'exp': datetime.now(UTC) + expire, 'jti': str(uuid.uuid4()), 'accesses': accesses}
+ )
logger.info(f'Creating token for user {username} with accesses: {accesses}')
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
diff --git a/backend-services/utils/bandwidth_util.py b/backend-services/utils/bandwidth_util.py
index 4ce563c..accd6c0 100644
--- a/backend-services/utils/bandwidth_util.py
+++ b/backend-services/utils/bandwidth_util.py
@@ -1,13 +1,14 @@
from __future__ import annotations
-from fastapi import Request, HTTPException
import time
-from typing import Optional
-from utils.doorman_cache_util import doorman_cache
+from fastapi import HTTPException, Request
+
from utils.database import user_collection
+from utils.doorman_cache_util import doorman_cache
-def _window_to_seconds(win: Optional[str]) -> int:
+
+def _window_to_seconds(win: str | None) -> int:
mapping = {
'second': 1,
'minute': 60,
@@ -21,14 +22,16 @@ def _window_to_seconds(win: Optional[str]) -> int:
w = win.lower().rstrip('s')
return mapping.get(w, 86400)
-def _bucket_key(username: str, window: str, now: Optional[int] = None) -> tuple[str, int]:
+
+def _bucket_key(username: str, window: str, now: int | None = None) -> tuple[str, int]:
sec = _window_to_seconds(window)
now = now or int(time.time())
bucket = (now // sec) * sec
key = f'bandwidth_usage:{username}:{sec}:{bucket}'
return key, sec
-def _get_user(username: str) -> Optional[dict]:
+
+def _get_user(username: str) -> dict | None:
user = doorman_cache.get_cache('user_cache', username)
if not user:
user = user_collection.find_one({'username': username})
@@ -36,10 +39,12 @@ def _get_user(username: str) -> Optional[dict]:
del user['_id']
return user
+
def _get_client():
return doorman_cache.cache if getattr(doorman_cache, 'is_redis', False) else None
-def get_current_usage(username: str, window: Optional[str]) -> int:
+
+def get_current_usage(username: str, window: str | None) -> int:
win = window or 'day'
key, ttl = _bucket_key(username, win)
client = _get_client()
@@ -55,7 +60,8 @@ def get_current_usage(username: str, window: Optional[str]) -> int:
except Exception:
return 0
-def add_usage(username: str, delta_bytes: int, window: Optional[str]) -> None:
+
+def add_usage(username: str, delta_bytes: int, window: str | None) -> None:
if not delta_bytes:
return
win = window or 'day'
@@ -75,7 +81,8 @@ def add_usage(username: str, delta_bytes: int, window: Optional[str]) -> None:
except Exception:
pass
-async def enforce_pre_request_limit(request: Request, username: Optional[str]) -> None:
+
+async def enforce_pre_request_limit(request: Request, username: str | None) -> None:
if not username:
return
user = _get_user(username)
diff --git a/backend-services/utils/cache_manager_util.py b/backend-services/utils/cache_manager_util.py
index 1e256c6..155216a 100644
--- a/backend-services/utils/cache_manager_util.py
+++ b/backend-services/utils/cache_manager_util.py
@@ -4,17 +4,18 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/doorman for more information
"""
-from fastapi import FastAPI
+import os
+
from aiocache import Cache, caches
from aiocache.decorators import cached
from dotenv import load_dotenv
-import os
+from fastapi import FastAPI
load_dotenv()
+
class CacheManager:
def __init__(self):
-
redis_host = os.getenv('REDIS_HOST')
redis_port = os.getenv('REDIS_PORT')
redis_db = os.getenv('REDIS_DB')
@@ -29,26 +30,17 @@ class CacheManager:
'endpoint': redis_host,
'port': port,
'db': db,
- 'timeout': 300
+ 'timeout': 300,
}
}
except Exception:
-
self.cache_backend = Cache.MEMORY
self.cache_config = {
- 'default': {
- 'cache': 'aiocache.SimpleMemoryCache',
- 'timeout': 300
- }
+ 'default': {'cache': 'aiocache.SimpleMemoryCache', 'timeout': 300}
}
else:
self.cache_backend = Cache.MEMORY
- self.cache_config = {
- 'default': {
- 'cache': 'aiocache.SimpleMemoryCache',
- 'timeout': 300
- }
- }
+ self.cache_config = {'default': {'cache': 'aiocache.SimpleMemoryCache', 'timeout': 300}}
caches.set_config(self.cache_config)
def init_app(self, app: FastAPI):
@@ -58,4 +50,5 @@ class CacheManager:
def cached(self, ttl=300, key=None):
return cached(ttl=ttl, key=key, cache=self.cache_backend)
+
cache_manager = CacheManager()
diff --git a/backend-services/utils/chaos_util.py b/backend-services/utils/chaos_util.py
index 0467f2c..cbcefa9 100644
--- a/backend-services/utils/chaos_util.py
+++ b/backend-services/utils/chaos_util.py
@@ -1,16 +1,12 @@
-import threading
-import time
import logging
+import threading
-_state = {
- 'redis_outage': False,
- 'mongo_outage': False,
- 'error_budget_burn': 0,
-}
+_state = {'redis_outage': False, 'mongo_outage': False, 'error_budget_burn': 0}
_lock = threading.RLock()
_logger = logging.getLogger('doorman.chaos')
+
def enable(backend: str, on: bool):
with _lock:
key = _key_for(backend)
@@ -18,12 +14,14 @@ def enable(backend: str, on: bool):
_state[key] = bool(on)
_logger.warning(f'chaos: {backend} outage set to {on}')
+
def enable_for(backend: str, duration_ms: int):
enable(backend, True)
t = threading.Timer(duration_ms / 1000.0, lambda: enable(backend, False))
t.daemon = True
t.start()
+
def _key_for(backend: str):
b = (backend or '').strip().lower()
if b == 'redis':
@@ -32,6 +30,7 @@ def _key_for(backend: str):
return 'mongo_outage'
return None
+
def should_fail(backend: str) -> bool:
key = _key_for(backend)
if not key:
@@ -39,12 +38,15 @@ def should_fail(backend: str) -> bool:
with _lock:
return bool(_state.get(key))
+
def burn_error_budget(backend: str):
with _lock:
_state['error_budget_burn'] += 1
- _logger.warning(f'chaos: error_budget_burn+1 backend={backend} total={_state["error_budget_burn"]}')
+ _logger.warning(
+ f'chaos: error_budget_burn+1 backend={backend} total={_state["error_budget_burn"]}'
+ )
+
def stats() -> dict:
with _lock:
return dict(_state)
-
diff --git a/backend-services/utils/constants.py b/backend-services/utils/constants.py
index 0f167a9..72fdb8a 100644
--- a/backend-services/utils/constants.py
+++ b/backend-services/utils/constants.py
@@ -1,6 +1,7 @@
class Headers:
REQUEST_ID = 'request_id'
+
class Defaults:
PAGE = 1
PAGE_SIZE = 10
@@ -9,6 +10,7 @@ class Defaults:
MAX_MULTIPART_SIZE_BYTES_ENV = 'MAX_MULTIPART_SIZE_BYTES'
MAX_MULTIPART_SIZE_BYTES_DEFAULT = 5_242_880
+
class Roles:
MANAGE_USERS = 'manage_users'
MANAGE_APIS = 'manage_apis'
@@ -18,6 +20,7 @@ class Roles:
EXPORT_LOGS = 'export_logs'
MANAGE_ROLES = 'manage_roles'
+
class ErrorCodes:
UNEXPECTED = 'GTW999'
HTTP_EXCEPTION = 'GTW998'
@@ -29,6 +32,7 @@ class ErrorCodes:
REQUEST_FILE_TYPE = 'REQ003'
PAGE_SIZE = 'PAG001'
+
class Messages:
UNEXPECTED = 'An unexpected error occurred'
FILE_TOO_LARGE = 'Uploaded file too large'
diff --git a/backend-services/utils/correlation_util.py b/backend-services/utils/correlation_util.py
index 901e7e9..3c56c27 100644
--- a/backend-services/utils/correlation_util.py
+++ b/backend-services/utils/correlation_util.py
@@ -4,27 +4,29 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from contextvars import ContextVar
-from typing import Optional
-import uuid
import logging
+import uuid
+from contextvars import ContextVar
-correlation_id: ContextVar[Optional[str]] = ContextVar('correlation_id', default=None)
+correlation_id: ContextVar[str | None] = ContextVar('correlation_id', default=None)
logger = logging.getLogger('doorman.gateway')
-def get_correlation_id() -> Optional[str]:
+
+def get_correlation_id() -> str | None:
"""
Get the current correlation ID from context.
"""
return correlation_id.get()
+
def set_correlation_id(value: str) -> None:
"""
Set the correlation ID in the current context.
"""
correlation_id.set(value)
+
def ensure_correlation_id() -> str:
"""
Get existing correlation ID or generate a new one.
@@ -35,16 +37,18 @@ def ensure_correlation_id() -> str:
correlation_id.set(cid)
return cid
+
def log_with_correlation(level: str, message: str, **kwargs) -> None:
"""
Log a message with the correlation ID automatically prepended.
"""
cid = get_correlation_id() or 'no-correlation-id'
- log_message = f"{cid} | {message}"
+ log_message = f'{cid} | {message}'
log_func = getattr(logger, level.lower(), logger.info)
log_func(log_message, **kwargs)
-async def run_with_correlation(coro, correlation_id_value: Optional[str] = None):
+
+async def run_with_correlation(coro, correlation_id_value: str | None = None):
"""
Run an async coroutine with a correlation ID.
"""
@@ -56,10 +60,12 @@ async def run_with_correlation(coro, correlation_id_value: Optional[str] = None)
finally:
pass
+
class CorrelationContext:
"""
Context manager for setting correlation ID in a scope.
"""
+
def __init__(self, correlation_id_value: str):
self.correlation_id_value = correlation_id_value
self.token = None
@@ -72,7 +78,10 @@ class CorrelationContext:
if self.token:
correlation_id.reset(self.token)
-async def run_async_with_correlation(func, *args, correlation_id_value: Optional[str] = None, **kwargs):
+
+async def run_async_with_correlation(
+ func, *args, correlation_id_value: str | None = None, **kwargs
+):
"""
Run an async function with correlation ID.
"""
diff --git a/backend-services/utils/credit_util.py b/backend-services/utils/credit_util.py
index 1c259e6..d194b57 100644
--- a/backend-services/utils/credit_util.py
+++ b/backend-services/utils/credit_util.py
@@ -1,7 +1,9 @@
-from utils.database_async import user_credit_collection, credit_def_collection
+from datetime import UTC, datetime
+
from utils.async_db import db_find_one, db_update_one
+from utils.database_async import credit_def_collection, user_credit_collection
from utils.encryption_util import decrypt_value
-from datetime import datetime, timezone
+
async def deduct_credit(api_credit_group, username):
if not api_credit_group:
@@ -14,9 +16,14 @@ async def deduct_credit(api_credit_group, username):
if not info or info.get('available_credits', 0) <= 0:
return False
available_credits = info.get('available_credits', 0) - 1
- await db_update_one(user_credit_collection, {'username': username}, {'$set': {f'users_credits.{api_credit_group}.available_credits': available_credits}})
+ await db_update_one(
+ user_credit_collection,
+ {'username': username},
+ {'$set': {f'users_credits.{api_credit_group}.available_credits': available_credits}},
+ )
return True
+
async def get_user_api_key(api_credit_group, username):
if not api_credit_group:
return None
@@ -29,6 +36,7 @@ async def get_user_api_key(api_credit_group, username):
dec = decrypt_value(enc)
return dec if dec is not None else enc
+
async def get_credit_api_header(api_credit_group):
"""
Get credit API header and key, supporting rotation.
@@ -61,7 +69,9 @@ async def get_credit_api_header(api_credit_group):
if api_key_new_encrypted and rotation_expires:
if isinstance(rotation_expires, str):
try:
- rotation_expires_dt = datetime.fromisoformat(rotation_expires.replace('Z', '+00:00'))
+ rotation_expires_dt = datetime.fromisoformat(
+ rotation_expires.replace('Z', '+00:00')
+ )
except Exception:
rotation_expires_dt = None
elif isinstance(rotation_expires, datetime):
@@ -69,7 +79,7 @@ async def get_credit_api_header(api_credit_group):
else:
rotation_expires_dt = None
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
if rotation_expires_dt and now < rotation_expires_dt:
api_key_new = decrypt_value(api_key_new_encrypted)
api_key_new = api_key_new if api_key_new is not None else api_key_new_encrypted
diff --git a/backend-services/utils/database.py b/backend-services/utils/database.py
index 10a1eb7..f495223 100644
--- a/backend-services/utils/database.py
+++ b/backend-services/utils/database.py
@@ -4,24 +4,22 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/doorman for more information
"""
-from pymongo import MongoClient, IndexModel, ASCENDING
-from dotenv import load_dotenv
-import os
-import uuid
import copy
-import json
-import threading
-import secrets
-import string as _string
import logging
+import os
+import threading
+import uuid
-from utils import password_util
-from utils import chaos_util
+from dotenv import load_dotenv
+from pymongo import ASCENDING, IndexModel, MongoClient
+
+from utils import chaos_util, password_util
load_dotenv()
logger = logging.getLogger('doorman.gateway')
+
def _build_admin_seed_doc(email: str, pwd_hash: str) -> dict:
"""Canonical admin bootstrap document used for both memory and Mongo modes.
@@ -47,9 +45,9 @@ def _build_admin_seed_doc(email: str, pwd_hash: str) -> dict:
'active': True,
}
+
class Database:
def __init__(self):
-
mem_flag = os.getenv('MEM_OR_EXTERNAL')
if mem_flag is None:
mem_flag = os.getenv('MEM_OR_REDIS', 'MEM')
@@ -75,17 +73,15 @@ class Database:
self.db_existed = True
if len(host_list) > 1 and replica_set_name:
- connection_uri = f"mongodb://{mongo_user}:{mongo_pass}@{','.join(host_list)}/doorman?replicaSet={replica_set_name}"
+ connection_uri = f'mongodb://{mongo_user}:{mongo_pass}@{",".join(host_list)}/doorman?replicaSet={replica_set_name}'
else:
- connection_uri = f"mongodb://{mongo_user}:{mongo_pass}@{','.join(host_list)}/doorman"
+ connection_uri = f'mongodb://{mongo_user}:{mongo_pass}@{",".join(host_list)}/doorman'
self.client = MongoClient(
- connection_uri,
- serverSelectionTimeoutMS=5000,
- maxPoolSize=100,
- minPoolSize=5
+ connection_uri, serverSelectionTimeoutMS=5000, maxPoolSize=100, minPoolSize=5
)
self.db = self.client.get_database()
+
def initialize_collections(self):
if self.memory_only:
# Resolve admin seed credentials consistently across modes (no auto-generation)
@@ -93,7 +89,9 @@ class Database:
email = os.getenv('DOORMAN_ADMIN_EMAIL') or 'admin@doorman.dev'
pwd = os.getenv('DOORMAN_ADMIN_PASSWORD')
if not pwd:
- raise RuntimeError('DOORMAN_ADMIN_PASSWORD is required for admin initialization')
+ raise RuntimeError(
+ 'DOORMAN_ADMIN_PASSWORD is required for admin initialization'
+ )
return email, password_util.hash_password(pwd)
users = self.db.users
@@ -101,37 +99,43 @@ class Database:
groups = self.db.groups
if not roles.find_one({'role_name': 'admin'}):
- roles.insert_one({
- 'role_name': 'admin',
- 'role_description': 'Administrator role',
- 'manage_users': True,
- 'manage_apis': True,
- 'manage_endpoints': True,
- 'manage_groups': True,
- 'manage_roles': True,
- 'manage_routings': True,
- 'manage_gateway': True,
- 'manage_subscriptions': True,
- 'manage_credits': True,
- 'manage_auth': True,
- 'manage_security': True,
- 'view_logs': True,
- 'export_logs': True,
- 'ui_access': True
- })
+ roles.insert_one(
+ {
+ 'role_name': 'admin',
+ 'role_description': 'Administrator role',
+ 'manage_users': True,
+ 'manage_apis': True,
+ 'manage_endpoints': True,
+ 'manage_groups': True,
+ 'manage_roles': True,
+ 'manage_routings': True,
+ 'manage_gateway': True,
+ 'manage_subscriptions': True,
+ 'manage_credits': True,
+ 'manage_auth': True,
+ 'manage_security': True,
+ 'view_logs': True,
+ 'export_logs': True,
+ 'ui_access': True,
+ }
+ )
if not groups.find_one({'group_name': 'admin'}):
- groups.insert_one({
- 'group_name': 'admin',
- 'group_description': 'Administrator group with full access',
- 'api_access': []
- })
+ groups.insert_one(
+ {
+ 'group_name': 'admin',
+ 'group_description': 'Administrator group with full access',
+ 'api_access': [],
+ }
+ )
if not groups.find_one({'group_name': 'ALL'}):
- groups.insert_one({
- 'group_name': 'ALL',
- 'group_description': 'Default group with access to all APIs',
- 'api_access': []
- })
+ groups.insert_one(
+ {
+ 'group_name': 'ALL',
+ 'group_description': 'Default group with access to all APIs',
+ 'api_access': [],
+ }
+ )
if not users.find_one({'username': 'admin'}):
_email, _pwd_hash = _admin_seed_creds()
@@ -150,16 +154,24 @@ class Database:
try:
env_pwd = os.getenv('DOORMAN_ADMIN_PASSWORD')
if env_pwd:
- users.update_one({'username': 'admin'}, {'$set': {'password': password_util.hash_password(env_pwd)}})
+ users.update_one(
+ {'username': 'admin'},
+ {'$set': {'password': password_util.hash_password(env_pwd)}},
+ )
except Exception:
pass
try:
from datetime import datetime
+
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
env_logs = os.getenv('LOGS_DIR')
- logs_dir = os.path.abspath(env_logs) if env_logs else os.path.join(base_dir, 'platform-logs')
+ logs_dir = (
+ os.path.abspath(env_logs)
+ if env_logs
+ else os.path.join(base_dir, 'platform-logs')
+ )
os.makedirs(logs_dir, exist_ok=True)
log_path = os.path.join(logs_dir, 'doorman.log')
now = datetime.now()
@@ -169,7 +181,7 @@ class Database:
('orders', '/orders/v1/status'),
('weather', '/weather/v1/status'),
]
- for api_name, ep in samples:
+ for _api_name, ep in samples:
ts = now.strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
rid = str(uuid.uuid4())
msg = f'{rid} | Username: admin | From: 127.0.0.1:54321 | Endpoint: GET {ep} | Total time: 42ms'
@@ -180,7 +192,23 @@ class Database:
pass
logger.info('Memory-only mode: Core data initialized (admin user/role/groups)')
return
- collections = ['users', 'apis', 'endpoints', 'groups', 'roles', 'subscriptions', 'routings', 'credit_defs', 'user_credits', 'endpoint_validations', 'settings', 'revocations', 'vault_entries', 'tiers', 'user_tier_assignments']
+ collections = [
+ 'users',
+ 'apis',
+ 'endpoints',
+ 'groups',
+ 'roles',
+ 'subscriptions',
+ 'routings',
+ 'credit_defs',
+ 'user_credits',
+ 'endpoint_validations',
+ 'settings',
+ 'revocations',
+ 'vault_entries',
+ 'tiers',
+ 'user_tier_assignments',
+ ]
for collection in collections:
if collection not in self.db.list_collection_names():
self.db_existed = False
@@ -193,8 +221,11 @@ class Database:
email = os.getenv('DOORMAN_ADMIN_EMAIL') or 'admin@doorman.dev'
pwd = os.getenv('DOORMAN_ADMIN_PASSWORD')
if not pwd:
- raise RuntimeError('DOORMAN_ADMIN_PASSWORD is required for admin initialization')
+ raise RuntimeError(
+ 'DOORMAN_ADMIN_PASSWORD is required for admin initialization'
+ )
return email, password_util.hash_password(pwd)
+
_email, _pwd_hash = _admin_seed_creds_mongo()
self.db.users.insert_one(_build_admin_seed_doc(_email, _pwd_hash))
try:
@@ -210,90 +241,109 @@ class Database:
if env_pwd:
self.db.users.update_one(
{'username': 'admin'},
- {'$set': {'password': password_util.hash_password(env_pwd)}}
+ {'$set': {'password': password_util.hash_password(env_pwd)}},
)
logger.warning('Admin user lacked password; set from DOORMAN_ADMIN_PASSWORD')
else:
- raise RuntimeError('Admin user missing password and DOORMAN_ADMIN_PASSWORD not set')
+ raise RuntimeError(
+ 'Admin user missing password and DOORMAN_ADMIN_PASSWORD not set'
+ )
except Exception:
pass
if not self.db.roles.find_one({'role_name': 'admin'}):
- self.db.roles.insert_one({
- 'role_name': 'admin',
- 'role_description': 'Administrator role',
- 'manage_users': True,
- 'manage_apis': True,
- 'manage_endpoints': True,
- 'manage_groups': True,
- 'manage_roles': True,
- 'manage_routings': True,
- 'manage_gateway': True,
- 'manage_subscriptions': True,
- 'manage_credits': True,
- 'manage_auth': True,
- 'view_logs': True,
- 'export_logs': True,
- 'manage_security': True
- })
+ self.db.roles.insert_one(
+ {
+ 'role_name': 'admin',
+ 'role_description': 'Administrator role',
+ 'manage_users': True,
+ 'manage_apis': True,
+ 'manage_endpoints': True,
+ 'manage_groups': True,
+ 'manage_roles': True,
+ 'manage_routings': True,
+ 'manage_gateway': True,
+ 'manage_subscriptions': True,
+ 'manage_credits': True,
+ 'manage_auth': True,
+ 'view_logs': True,
+ 'export_logs': True,
+ 'manage_security': True,
+ }
+ )
if not self.db.groups.find_one({'group_name': 'admin'}):
- self.db.groups.insert_one({
- 'group_name': 'admin',
- 'group_description': 'Administrator group with full access',
- 'api_access': []
- })
+ self.db.groups.insert_one(
+ {
+ 'group_name': 'admin',
+ 'group_description': 'Administrator group with full access',
+ 'api_access': [],
+ }
+ )
if not self.db.groups.find_one({'group_name': 'ALL'}):
- self.db.groups.insert_one({
- 'group_name': 'ALL',
- 'group_description': 'Default group with access to all APIs',
- 'api_access': []
- })
+ self.db.groups.insert_one(
+ {
+ 'group_name': 'ALL',
+ 'group_description': 'Default group with access to all APIs',
+ 'api_access': [],
+ }
+ )
def create_indexes(self):
if self.memory_only:
logger.debug('Memory-only mode: Skipping MongoDB index creation')
return
- self.db.apis.create_indexes([
- IndexModel([('api_id', ASCENDING)], unique=True),
- IndexModel([('name', ASCENDING), ('version', ASCENDING)])
- ])
- self.db.endpoints.create_indexes([
- IndexModel([('endpoint_method', ASCENDING), ('api_name', ASCENDING), ('api_version', ASCENDING), ('endpoint_uri', ASCENDING)], unique=True),
- ])
- self.db.users.create_indexes([
- IndexModel([('username', ASCENDING)], unique=True),
- IndexModel([('email', ASCENDING)], unique=True)
- ])
- self.db.groups.create_indexes([
- IndexModel([('group_name', ASCENDING)], unique=True)
- ])
- self.db.roles.create_indexes([
- IndexModel([('role_name', ASCENDING)], unique=True)
- ])
- self.db.subscriptions.create_indexes([
- IndexModel([('username', ASCENDING)], unique=True)
- ])
- self.db.routings.create_indexes([
- IndexModel([('client_key', ASCENDING)], unique=True)
- ])
- self.db.credit_defs.create_indexes([
- IndexModel([('api_credit_group', ASCENDING)], unique=True),
- IndexModel([('username', ASCENDING)], unique=True)
- ])
- self.db.endpoint_validations.create_indexes([
- IndexModel([('endpoint_id', ASCENDING)], unique=True)
- ])
- self.db.vault_entries.create_indexes([
- IndexModel([('username', ASCENDING), ('key_name', ASCENDING)], unique=True),
- IndexModel([('username', ASCENDING)])
- ])
- self.db.tiers.create_indexes([
- IndexModel([('tier_id', ASCENDING)], unique=True),
- IndexModel([('name', ASCENDING)])
- ])
- self.db.user_tier_assignments.create_indexes([
- IndexModel([('user_id', ASCENDING)], unique=True),
- IndexModel([('tier_id', ASCENDING)])
- ])
+ self.db.apis.create_indexes(
+ [
+ IndexModel([('api_id', ASCENDING)], unique=True),
+ IndexModel([('name', ASCENDING), ('version', ASCENDING)]),
+ ]
+ )
+ self.db.endpoints.create_indexes(
+ [
+ IndexModel(
+ [
+ ('endpoint_method', ASCENDING),
+ ('api_name', ASCENDING),
+ ('api_version', ASCENDING),
+ ('endpoint_uri', ASCENDING),
+ ],
+ unique=True,
+ )
+ ]
+ )
+ self.db.users.create_indexes(
+ [
+ IndexModel([('username', ASCENDING)], unique=True),
+ IndexModel([('email', ASCENDING)], unique=True),
+ ]
+ )
+ self.db.groups.create_indexes([IndexModel([('group_name', ASCENDING)], unique=True)])
+ self.db.roles.create_indexes([IndexModel([('role_name', ASCENDING)], unique=True)])
+ self.db.subscriptions.create_indexes([IndexModel([('username', ASCENDING)], unique=True)])
+ self.db.routings.create_indexes([IndexModel([('client_key', ASCENDING)], unique=True)])
+ self.db.credit_defs.create_indexes(
+ [
+ IndexModel([('api_credit_group', ASCENDING)], unique=True),
+ IndexModel([('username', ASCENDING)], unique=True),
+ ]
+ )
+ self.db.endpoint_validations.create_indexes(
+ [IndexModel([('endpoint_id', ASCENDING)], unique=True)]
+ )
+ self.db.vault_entries.create_indexes(
+ [
+ IndexModel([('username', ASCENDING), ('key_name', ASCENDING)], unique=True),
+ IndexModel([('username', ASCENDING)]),
+ ]
+ )
+ self.db.tiers.create_indexes(
+ [IndexModel([('tier_id', ASCENDING)], unique=True), IndexModel([('name', ASCENDING)])]
+ )
+ self.db.user_tier_assignments.create_indexes(
+ [
+ IndexModel([('user_id', ASCENDING)], unique=True),
+ IndexModel([('tier_id', ASCENDING)]),
+ ]
+ )
def is_memory_only(self) -> bool:
return self.memory_only
@@ -303,27 +353,30 @@ class Database:
'mode': 'memory_only' if self.memory_only else 'mongodb',
'mongodb_connected': not self.memory_only and self.client is not None,
'collections_available': not self.memory_only,
- 'cache_backend': os.getenv('MEM_OR_EXTERNAL', os.getenv('MEM_OR_REDIS', 'REDIS'))
+ 'cache_backend': os.getenv('MEM_OR_EXTERNAL', os.getenv('MEM_OR_REDIS', 'REDIS')),
}
+
class InMemoryInsertResult:
def __init__(self, inserted_id):
self.acknowledged = True
self.inserted_id = inserted_id
+
class InMemoryUpdateResult:
def __init__(self, modified_count):
self.acknowledged = True
self.modified_count = modified_count
+
class InMemoryDeleteResult:
def __init__(self, deleted_count):
self.acknowledged = True
self.deleted_count = deleted_count
+
class InMemoryCursor:
def __init__(self, docs):
-
self._docs = [copy.deepcopy(d) for d in docs]
self._index = 0
@@ -362,11 +415,11 @@ class InMemoryCursor:
if length is None:
return data
try:
-
return data[: int(length)]
except Exception:
return data
+
class InMemoryCollection:
def __init__(self, name):
self.name = name
@@ -453,7 +506,6 @@ class InMemoryCollection:
if set_data:
for k, v in set_data.items():
-
if isinstance(k, str) and '.' in k:
parts = k.split('.')
cur = updated
@@ -545,49 +597,48 @@ class InMemoryCollection:
return None
def create_indexes(self, *args, **kwargs):
-
return []
class AsyncInMemoryCollection:
"""Async wrapper around InMemoryCollection for async/await compatibility"""
-
+
def __init__(self, sync_collection):
self._sync = sync_collection
self.name = sync_collection.name
-
+
async def find_one(self, query=None):
"""Async find_one"""
return self._sync.find_one(query)
-
+
def find(self, query=None):
"""Returns cursor (sync method, but cursor supports async iteration)"""
return self._sync.find(query)
-
+
async def insert_one(self, doc):
"""Async insert_one"""
return self._sync.insert_one(doc)
-
+
async def update_one(self, query, update):
"""Async update_one"""
return self._sync.update_one(query, update)
-
+
async def delete_one(self, query):
"""Async delete_one"""
return self._sync.delete_one(query)
-
+
async def count_documents(self, query=None):
"""Async count_documents"""
return self._sync.count_documents(query)
-
+
async def replace_one(self, query, replacement):
"""Async replace_one"""
return self._sync.replace_one(query, replacement)
-
+
async def find_one_and_update(self, query, update, return_document=False):
"""Async find_one_and_update"""
return self._sync.find_one_and_update(query, update, return_document)
-
+
def create_indexes(self, *args, **kwargs):
return self._sync.create_indexes(*args, **kwargs)
@@ -595,8 +646,7 @@ class AsyncInMemoryCollection:
class InMemoryDB:
def __init__(self, async_mode=False):
self._async_mode = async_mode
- CollectionClass = AsyncInMemoryCollection if async_mode else InMemoryCollection
-
+
# Create base sync collections
self._sync_users = InMemoryCollection('users')
self._sync_apis = InMemoryCollection('apis')
@@ -614,7 +664,7 @@ class InMemoryDB:
self._sync_tiers = InMemoryCollection('tiers')
self._sync_user_tier_assignments = InMemoryCollection('user_tier_assignments')
self._sync_rate_limit_rules = InMemoryCollection('rate_limit_rules')
-
+
# Expose as async or sync based on mode
if async_mode:
self.users = AsyncInMemoryCollection(self._sync_users)
@@ -653,10 +703,22 @@ class InMemoryDB:
def list_collection_names(self):
return [
- 'users', 'apis', 'endpoints', 'groups', 'roles',
- 'subscriptions', 'routings', 'credit_defs', 'user_credits',
- 'endpoint_validations', 'settings', 'revocations', 'vault_entries',
- 'tiers', 'user_tier_assignments', 'rate_limit_rules'
+ 'users',
+ 'apis',
+ 'endpoints',
+ 'groups',
+ 'roles',
+ 'subscriptions',
+ 'routings',
+ 'credit_defs',
+ 'user_credits',
+ 'endpoint_validations',
+ 'settings',
+ 'revocations',
+ 'vault_entries',
+ 'tiers',
+ 'user_tier_assignments',
+ 'rate_limit_rules',
]
def create_collection(self, name):
@@ -709,11 +771,11 @@ class InMemoryDB:
load_coll(self._sync_tiers, data.get('tiers', []))
load_coll(self._sync_user_tier_assignments, data.get('user_tier_assignments', []))
+
database = Database()
database.initialize_collections()
database.create_indexes()
if database.memory_only:
-
db = database.db
mongodb_client = None
api_collection = db.apis
@@ -750,6 +812,7 @@ else:
except Exception:
vault_entries_collection = None
+
def close_database_connections():
"""
Close all database connections for graceful shutdown.
@@ -758,6 +821,6 @@ def close_database_connections():
try:
if mongodb_client:
mongodb_client.close()
- logger.info("MongoDB connections closed")
+ logger.info('MongoDB connections closed')
except Exception as e:
- logger.warning(f"Error closing MongoDB connections: {e}")
+ logger.warning(f'Error closing MongoDB connections: {e}')
diff --git a/backend-services/utils/database_async.py b/backend-services/utils/database_async.py
index b023e19..fefaa78 100644
--- a/backend-services/utils/database_async.py
+++ b/backend-services/utils/database_async.py
@@ -10,19 +10,19 @@ try:
from motor.motor_asyncio import AsyncIOMotorClient
except Exception:
AsyncIOMotorClient = None
-from dotenv import load_dotenv
-import os
-import asyncio
-from typing import Optional
import logging
+import os
+
+from dotenv import load_dotenv
-from utils.database import InMemoryDB, InMemoryCollection
from utils import password_util
+from utils.database import InMemoryDB
load_dotenv()
logger = logging.getLogger('doorman.gateway')
+
class AsyncDatabase:
"""Async database wrapper that supports both Motor (MongoDB) and in-memory modes."""
@@ -37,6 +37,7 @@ class AsyncDatabase:
# async-compatible collection interfaces (wrapping the same sync collections).
try:
from utils.database import database as _sync_db
+
self.client = None
self.db_existed = getattr(_sync_db, 'db_existed', False)
@@ -46,6 +47,7 @@ class AsyncDatabase:
self._sync = sync_db
# Wrap collections with async facade using the same underlying storage
from utils.database import AsyncInMemoryCollection as _AIC
+
self.users = _AIC(sync_db._sync_users)
self.apis = _AIC(sync_db._sync_apis)
self.endpoints = _AIC(sync_db._sync_endpoints)
@@ -95,17 +97,16 @@ class AsyncDatabase:
self.db_existed = True
if len(host_list) > 1 and replica_set_name:
- connection_uri = f"mongodb://{mongo_user}:{mongo_pass}@{','.join(host_list)}/doorman?replicaSet={replica_set_name}"
+ connection_uri = f'mongodb://{mongo_user}:{mongo_pass}@{",".join(host_list)}/doorman?replicaSet={replica_set_name}'
else:
- connection_uri = f"mongodb://{mongo_user}:{mongo_pass}@{','.join(host_list)}/doorman"
+ connection_uri = f'mongodb://{mongo_user}:{mongo_pass}@{",".join(host_list)}/doorman'
if AsyncIOMotorClient is None:
- raise RuntimeError('motor is required for async MongoDB mode; install motor or set MEM_OR_EXTERNAL=MEM')
+ raise RuntimeError(
+ 'motor is required for async MongoDB mode; install motor or set MEM_OR_EXTERNAL=MEM'
+ )
self.client = AsyncIOMotorClient(
- connection_uri,
- serverSelectionTimeoutMS=5000,
- maxPoolSize=100,
- minPoolSize=5
+ connection_uri, serverSelectionTimeoutMS=5000, maxPoolSize=100, minPoolSize=5
)
self.db = self.client.get_database()
@@ -113,13 +114,24 @@ class AsyncDatabase:
"""Initialize collections and default data."""
if self.memory_only:
from utils.database import database
+
database.initialize_collections()
return
collections = [
- 'users', 'apis', 'endpoints', 'groups', 'roles', 'subscriptions',
- 'routings', 'credit_defs', 'user_credits', 'endpoint_validations',
- 'settings', 'revocations', 'vault_entries'
+ 'users',
+ 'apis',
+ 'endpoints',
+ 'groups',
+ 'roles',
+ 'subscriptions',
+ 'routings',
+ 'credit_defs',
+ 'user_credits',
+ 'endpoint_validations',
+ 'settings',
+ 'revocations',
+ 'vault_entries',
]
existing_collections = await self.db.list_collection_names()
@@ -136,82 +148,91 @@ class AsyncDatabase:
email = os.getenv('DOORMAN_ADMIN_EMAIL') or 'admin@doorman.dev'
pwd = os.getenv('DOORMAN_ADMIN_PASSWORD')
if not pwd:
- raise RuntimeError('DOORMAN_ADMIN_PASSWORD is required for admin initialization')
+ raise RuntimeError(
+ 'DOORMAN_ADMIN_PASSWORD is required for admin initialization'
+ )
pwd_hash = password_util.hash_password(pwd)
- await self.db.users.insert_one({
- 'username': 'admin',
- 'email': email,
- 'password': pwd_hash,
- 'role': 'admin',
- 'groups': ['ALL', 'admin'],
- 'rate_limit_duration': 1,
- 'rate_limit_duration_type': 'second',
- 'throttle_duration': 1,
- 'throttle_duration_type': 'second',
- 'throttle_wait_duration': 0,
- 'throttle_wait_duration_type': 'second',
- 'custom_attributes': {'custom_key': 'custom_value'},
- 'active': True,
- 'throttle_queue_limit': 1,
- 'throttle_enabled': None,
- 'ui_access': True
- })
+ await self.db.users.insert_one(
+ {
+ 'username': 'admin',
+ 'email': email,
+ 'password': pwd_hash,
+ 'role': 'admin',
+ 'groups': ['ALL', 'admin'],
+ 'rate_limit_duration': 1,
+ 'rate_limit_duration_type': 'second',
+ 'throttle_duration': 1,
+ 'throttle_duration_type': 'second',
+ 'throttle_wait_duration': 0,
+ 'throttle_wait_duration_type': 'second',
+ 'custom_attributes': {'custom_key': 'custom_value'},
+ 'active': True,
+ 'throttle_queue_limit': 1,
+ 'throttle_enabled': None,
+ 'ui_access': True,
+ }
+ )
try:
adm = await self.db.users.find_one({'username': 'admin'})
if adm and adm.get('ui_access') is not True:
- await self.db.users.update_one(
- {'username': 'admin'},
- {'$set': {'ui_access': True}}
- )
+ await self.db.users.update_one({'username': 'admin'}, {'$set': {'ui_access': True}})
if adm and not adm.get('password'):
env_pwd = os.getenv('DOORMAN_ADMIN_PASSWORD')
if env_pwd:
await self.db.users.update_one(
{'username': 'admin'},
- {'$set': {'password': password_util.hash_password(env_pwd)}}
+ {'$set': {'password': password_util.hash_password(env_pwd)}},
)
logger.warning('Admin user lacked password; set from DOORMAN_ADMIN_PASSWORD')
else:
- raise RuntimeError('Admin user missing password and DOORMAN_ADMIN_PASSWORD not set')
+ raise RuntimeError(
+ 'Admin user missing password and DOORMAN_ADMIN_PASSWORD not set'
+ )
except Exception:
pass
admin_role = await self.db.roles.find_one({'role_name': 'admin'})
if not admin_role:
- await self.db.roles.insert_one({
- 'role_name': 'admin',
- 'role_description': 'Administrator role',
- 'manage_users': True,
- 'manage_apis': True,
- 'manage_endpoints': True,
- 'manage_groups': True,
- 'manage_roles': True,
- 'manage_routings': True,
- 'manage_gateway': True,
- 'manage_subscriptions': True,
- 'manage_credits': True,
- 'manage_auth': True,
- 'view_logs': True,
- 'export_logs': True,
- 'manage_security': True
- })
+ await self.db.roles.insert_one(
+ {
+ 'role_name': 'admin',
+ 'role_description': 'Administrator role',
+ 'manage_users': True,
+ 'manage_apis': True,
+ 'manage_endpoints': True,
+ 'manage_groups': True,
+ 'manage_roles': True,
+ 'manage_routings': True,
+ 'manage_gateway': True,
+ 'manage_subscriptions': True,
+ 'manage_credits': True,
+ 'manage_auth': True,
+ 'view_logs': True,
+ 'export_logs': True,
+ 'manage_security': True,
+ }
+ )
admin_group = await self.db.groups.find_one({'group_name': 'admin'})
if not admin_group:
- await self.db.groups.insert_one({
- 'group_name': 'admin',
- 'group_description': 'Administrator group with full access',
- 'api_access': []
- })
+ await self.db.groups.insert_one(
+ {
+ 'group_name': 'admin',
+ 'group_description': 'Administrator group with full access',
+ 'api_access': [],
+ }
+ )
all_group = await self.db.groups.find_one({'group_name': 'ALL'})
if not all_group:
- await self.db.groups.insert_one({
- 'group_name': 'ALL',
- 'group_description': 'Default group with access to all APIs',
- 'api_access': []
- })
+ await self.db.groups.insert_one(
+ {
+ 'group_name': 'ALL',
+ 'group_description': 'Default group with access to all APIs',
+ 'api_access': [],
+ }
+ )
async def create_indexes(self):
"""Create database indexes for performance."""
@@ -219,56 +240,65 @@ class AsyncDatabase:
logger.debug('Async Memory-only mode: Skipping MongoDB index creation')
return
- from pymongo import IndexModel, ASCENDING
+ from pymongo import ASCENDING, IndexModel
- await self.db.apis.create_indexes([
- IndexModel([('api_id', ASCENDING)], unique=True),
- IndexModel([('name', ASCENDING), ('version', ASCENDING)])
- ])
+ await self.db.apis.create_indexes(
+ [
+ IndexModel([('api_id', ASCENDING)], unique=True),
+ IndexModel([('name', ASCENDING), ('version', ASCENDING)]),
+ ]
+ )
- await self.db.endpoints.create_indexes([
- IndexModel([
- ('endpoint_method', ASCENDING),
- ('api_name', ASCENDING),
- ('api_version', ASCENDING),
- ('endpoint_uri', ASCENDING)
- ], unique=True),
- ])
+ await self.db.endpoints.create_indexes(
+ [
+ IndexModel(
+ [
+ ('endpoint_method', ASCENDING),
+ ('api_name', ASCENDING),
+ ('api_version', ASCENDING),
+ ('endpoint_uri', ASCENDING),
+ ],
+ unique=True,
+ )
+ ]
+ )
- await self.db.users.create_indexes([
- IndexModel([('username', ASCENDING)], unique=True),
- IndexModel([('email', ASCENDING)], unique=True)
- ])
+ await self.db.users.create_indexes(
+ [
+ IndexModel([('username', ASCENDING)], unique=True),
+ IndexModel([('email', ASCENDING)], unique=True),
+ ]
+ )
- await self.db.groups.create_indexes([
- IndexModel([('group_name', ASCENDING)], unique=True)
- ])
+ await self.db.groups.create_indexes([IndexModel([('group_name', ASCENDING)], unique=True)])
- await self.db.roles.create_indexes([
- IndexModel([('role_name', ASCENDING)], unique=True)
- ])
+ await self.db.roles.create_indexes([IndexModel([('role_name', ASCENDING)], unique=True)])
- await self.db.subscriptions.create_indexes([
- IndexModel([('username', ASCENDING)], unique=True)
- ])
+ await self.db.subscriptions.create_indexes(
+ [IndexModel([('username', ASCENDING)], unique=True)]
+ )
- await self.db.routings.create_indexes([
- IndexModel([('client_key', ASCENDING)], unique=True)
- ])
+ await self.db.routings.create_indexes(
+ [IndexModel([('client_key', ASCENDING)], unique=True)]
+ )
- await self.db.credit_defs.create_indexes([
- IndexModel([('api_credit_group', ASCENDING)], unique=True),
- IndexModel([('username', ASCENDING)], unique=True)
- ])
+ await self.db.credit_defs.create_indexes(
+ [
+ IndexModel([('api_credit_group', ASCENDING)], unique=True),
+ IndexModel([('username', ASCENDING)], unique=True),
+ ]
+ )
- await self.db.endpoint_validations.create_indexes([
- IndexModel([('endpoint_id', ASCENDING)], unique=True)
- ])
+ await self.db.endpoint_validations.create_indexes(
+ [IndexModel([('endpoint_id', ASCENDING)], unique=True)]
+ )
- await self.db.vault_entries.create_indexes([
- IndexModel([('username', ASCENDING), ('key_name', ASCENDING)], unique=True),
- IndexModel([('username', ASCENDING)])
- ])
+ await self.db.vault_entries.create_indexes(
+ [
+ IndexModel([('username', ASCENDING), ('key_name', ASCENDING)], unique=True),
+ IndexModel([('username', ASCENDING)]),
+ ]
+ )
def is_memory_only(self) -> bool:
"""Check if running in memory-only mode."""
@@ -280,14 +310,15 @@ class AsyncDatabase:
'mode': 'memory_only' if self.memory_only else 'mongodb',
'mongodb_connected': not self.memory_only and self.client is not None,
'collections_available': not self.memory_only,
- 'cache_backend': os.getenv('MEM_OR_EXTERNAL', os.getenv('MEM_OR_REDIS', 'REDIS'))
+ 'cache_backend': os.getenv('MEM_OR_EXTERNAL', os.getenv('MEM_OR_REDIS', 'REDIS')),
}
async def close(self):
"""Close database connections gracefully."""
if self.client:
self.client.close()
- logger.info("Async MongoDB connections closed")
+ logger.info('Async MongoDB connections closed')
+
async_database = AsyncDatabase()
@@ -332,6 +363,7 @@ else:
except Exception:
vault_entries_collection = None
+
async def close_async_database_connections():
"""Close all async database connections for graceful shutdown."""
await async_database.close()
diff --git a/backend-services/utils/demo_seed_util.py b/backend-services/utils/demo_seed_util.py
index c2ce2c0..dfc7b7e 100644
--- a/backend-services/utils/demo_seed_util.py
+++ b/backend-services/utils/demo_seed_util.py
@@ -1,39 +1,73 @@
from __future__ import annotations
+
import os
-import uuid
import random
import string
+import uuid
from datetime import datetime, timedelta
from utils import password_util
from utils.database import (
-
api_collection,
+ credit_def_collection,
endpoint_collection,
group_collection,
role_collection,
subscriptions_collection,
- credit_def_collection,
- user_credit_collection,
user_collection,
+ user_credit_collection,
)
-from utils.metrics_util import metrics_store, MinuteBucket
from utils.encryption_util import encrypt_value
+from utils.metrics_util import MinuteBucket, metrics_store
+
def _rand_choice(seq):
return random.choice(seq)
+
def _rand_word(min_len=4, max_len=10) -> str:
length = random.randint(min_len, max_len)
return ''.join(random.choices(string.ascii_lowercase, k=length))
+
def _rand_name() -> str:
- firsts = ['alex','casey','morgan','sam','taylor','riley','jamie','jordan','drew','quinn','kyle','parker','blake','devon']
- lasts = ['lee','kim','patel','garcia','nguyen','williams','brown','davis','miller','wilson','moore','taylor','thomas']
+ firsts = [
+ 'alex',
+ 'casey',
+ 'morgan',
+ 'sam',
+ 'taylor',
+ 'riley',
+ 'jamie',
+ 'jordan',
+ 'drew',
+ 'quinn',
+ 'kyle',
+ 'parker',
+ 'blake',
+ 'devon',
+ ]
+ lasts = [
+ 'lee',
+ 'kim',
+ 'patel',
+ 'garcia',
+ 'nguyen',
+ 'williams',
+ 'brown',
+ 'davis',
+ 'miller',
+ 'wilson',
+ 'moore',
+ 'taylor',
+ 'thomas',
+ ]
return f'{_rand_choice(firsts)}.{_rand_choice(lasts)}'
+
def _rand_domain() -> str:
- return _rand_choice(['example.com','acme.io','contoso.net','demo.dev'])
+ return _rand_choice(['example.com', 'acme.io', 'contoso.net', 'demo.dev'])
+
def _rand_password() -> str:
upp = _rand_choice(string.ascii_uppercase)
@@ -44,11 +78,23 @@ def _rand_password() -> str:
raw = upp + low + dig + spc + tail
return ''.join(random.sample(raw, len(raw)))
+
def ensure_roles() -> list[str]:
- roles = [('developer', dict(manage_apis=True, manage_endpoints=True, manage_subscriptions=True, manage_credits=True, view_logs=True)),
- ('analyst', dict(view_logs=True, export_logs=True)),
- ('viewer', dict(view_logs=True)),
- ('ops', dict(manage_gateway=True, view_logs=True, export_logs=True, manage_security=True))]
+ roles = [
+ (
+ 'developer',
+ dict(
+ manage_apis=True,
+ manage_endpoints=True,
+ manage_subscriptions=True,
+ manage_credits=True,
+ view_logs=True,
+ ),
+ ),
+ ('analyst', dict(view_logs=True, export_logs=True)),
+ ('viewer', dict(view_logs=True)),
+ ('ops', dict(manage_gateway=True, view_logs=True, export_logs=True, manage_security=True)),
+ ]
created = []
for role_name, extra in roles:
if not role_collection.find_one({'role_name': role_name}):
@@ -58,58 +104,103 @@ def ensure_roles() -> list[str]:
created.append(role_name)
return ['admin', *created]
+
def seed_groups(n: int, api_keys: list[str]) -> list[str]:
names = []
for i in range(n):
- gname = f'team-{_rand_word(3,6)}-{i}'
+ gname = f'team-{_rand_word(3, 6)}-{i}'
if group_collection.find_one({'group_name': gname}):
names.append(gname)
continue
- access = sorted(set(random.sample(api_keys, k=min(len(api_keys), random.randint(1, max(1, len(api_keys)//3))))) ) if api_keys else []
- group_collection.insert_one({'group_name': gname, 'group_description': f'Auto group {gname}', 'api_access': access})
+ access = (
+ sorted(
+ set(
+ random.sample(
+ api_keys,
+ k=min(len(api_keys), random.randint(1, max(1, len(api_keys) // 3))),
+ )
+ )
+ )
+ if api_keys
+ else []
+ )
+ group_collection.insert_one(
+ {'group_name': gname, 'group_description': f'Auto group {gname}', 'api_access': access}
+ )
names.append(gname)
for base in ('ALL', 'admin'):
if not group_collection.find_one({'group_name': base}):
- group_collection.insert_one({'group_name': base, 'group_description': f'{base} group', 'api_access': []})
+ group_collection.insert_one(
+ {'group_name': base, 'group_description': f'{base} group', 'api_access': []}
+ )
if base not in names:
names.append(base)
return names
+
def seed_users(n: int, roles: list[str], groups: list[str]) -> list[str]:
usernames = []
for i in range(n):
uname = f'{_rand_name()}_{i}'
- email = f"{uname.replace('.', '_')}@{_rand_domain()}"
+ email = f'{uname.replace(".", "_")}@{_rand_domain()}'
if user_collection.find_one({'username': uname}):
usernames.append(uname)
continue
hashed = password_util.hash_password(_rand_password())
- ugrps = sorted(set(random.sample(groups, k=min(len(groups), random.randint(1, min(3, max(1, len(groups))))))))
+ ugrps = sorted(
+ set(
+ random.sample(
+ groups, k=min(len(groups), random.randint(1, min(3, max(1, len(groups)))))
+ )
+ )
+ )
role = _rand_choice(roles)
- user_collection.insert_one({
- 'username': uname,
- 'email': email,
- 'password': hashed,
- 'role': role,
- 'groups': ugrps,
- 'rate_limit_duration': random.randint(100, 10000),
- 'rate_limit_duration_type': _rand_choice(['minute','hour','day']),
- 'throttle_duration': random.randint(1000, 100000),
- 'throttle_duration_type': _rand_choice(['second','minute']),
- 'throttle_wait_duration': random.randint(100, 10000),
- 'throttle_wait_duration_type': _rand_choice(['seconds','minutes']),
- 'custom_attributes': {'dept': _rand_choice(['sales','eng','support','ops'])},
- 'active': True,
- 'ui_access': _rand_choice([True, False])
- })
+ user_collection.insert_one(
+ {
+ 'username': uname,
+ 'email': email,
+ 'password': hashed,
+ 'role': role,
+ 'groups': ugrps,
+ 'rate_limit_duration': random.randint(100, 10000),
+ 'rate_limit_duration_type': _rand_choice(['minute', 'hour', 'day']),
+ 'throttle_duration': random.randint(1000, 100000),
+ 'throttle_duration_type': _rand_choice(['second', 'minute']),
+ 'throttle_wait_duration': random.randint(100, 10000),
+ 'throttle_wait_duration_type': _rand_choice(['seconds', 'minutes']),
+ 'custom_attributes': {'dept': _rand_choice(['sales', 'eng', 'support', 'ops'])},
+ 'active': True,
+ 'ui_access': _rand_choice([True, False]),
+ }
+ )
usernames.append(uname)
return usernames
-def seed_apis(n: int, roles: list[str], groups: list[str]) -> list[tuple[str,str]]:
+
+def seed_apis(n: int, roles: list[str], groups: list[str]) -> list[tuple[str, str]]:
pairs = []
- for i in range(n):
- name = _rand_choice(['customers','orders','billing','weather','news','crypto','search','inventory','shipping','payments','alerts','metrics','recommendations']) + f'-{_rand_word(3,6)}'
- ver = _rand_choice(['v1','v2','v3'])
+ for _i in range(n):
+ name = (
+ _rand_choice(
+ [
+ 'customers',
+ 'orders',
+ 'billing',
+ 'weather',
+ 'news',
+ 'crypto',
+ 'search',
+ 'inventory',
+ 'shipping',
+ 'payments',
+ 'alerts',
+ 'metrics',
+ 'recommendations',
+ ]
+ )
+ + f'-{_rand_word(3, 6)}'
+ )
+ ver = _rand_choice(['v1', 'v2', 'v3'])
if api_collection.find_one({'api_name': name, 'api_version': ver}):
pairs.append((name, ver))
continue
@@ -118,9 +209,17 @@ def seed_apis(n: int, roles: list[str], groups: list[str]) -> list[tuple[str,str
'api_name': name,
'api_version': ver,
'api_description': f'Auto API {name}/{ver}',
- 'api_allowed_roles': sorted(set(random.sample(roles, k=min(len(roles), random.randint(1, min(3, len(roles))))))),
- 'api_allowed_groups': sorted(set(random.sample(groups, k=min(len(groups), random.randint(1, min(5, len(groups))))))),
- 'api_servers': [f'http://localhost:{8000+random.randint(0,999)}'],
+ 'api_allowed_roles': sorted(
+ set(random.sample(roles, k=min(len(roles), random.randint(1, min(3, len(roles))))))
+ ),
+ 'api_allowed_groups': sorted(
+ set(
+ random.sample(
+ groups, k=min(len(groups), random.randint(1, min(5, len(groups))))
+ )
+ )
+ ),
+ 'api_servers': [f'http://localhost:{8000 + random.randint(0, 999)}'],
'api_type': 'REST',
'api_allowed_retry_count': random.randint(0, 3),
'api_id': api_id,
@@ -130,10 +229,22 @@ def seed_apis(n: int, roles: list[str], groups: list[str]) -> list[tuple[str,str
pairs.append((name, ver))
return pairs
-def seed_endpoints(apis: list[tuple[str,str]], per_api: int) -> None:
- methods = ['GET','POST','PUT','DELETE','PATCH']
- bases = ['/status','/health','/items','/items/{id}','/search','/reports','/export','/metrics','/list','/detail/{id}']
- for (name, ver) in apis:
+
+def seed_endpoints(apis: list[tuple[str, str]], per_api: int) -> None:
+ methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']
+ bases = [
+ '/status',
+ '/health',
+ '/items',
+ '/items/{id}',
+ '/search',
+ '/reports',
+ '/export',
+ '/metrics',
+ '/list',
+ '/detail/{id}',
+ ]
+ for name, ver in apis:
created = set()
for _ in range(per_api):
m = _rand_choice(methods)
@@ -142,24 +253,49 @@ def seed_endpoints(apis: list[tuple[str,str]], per_api: int) -> None:
if key in created:
continue
created.add(key)
- if endpoint_collection.find_one({'api_name': name, 'api_version': ver, 'endpoint_method': m, 'endpoint_uri': u}):
+ if endpoint_collection.find_one(
+ {'api_name': name, 'api_version': ver, 'endpoint_method': m, 'endpoint_uri': u}
+ ):
continue
- endpoint_collection.insert_one({
- 'api_name': name,
- 'api_version': ver,
- 'endpoint_method': m,
- 'endpoint_uri': u,
- 'endpoint_description': f'{m} {u} for {name}',
- 'api_id': api_collection.find_one({'api_name': name, 'api_version': ver}).get('api_id'),
- 'endpoint_id': str(uuid.uuid4()),
- })
+ endpoint_collection.insert_one(
+ {
+ 'api_name': name,
+ 'api_version': ver,
+ 'endpoint_method': m,
+ 'endpoint_uri': u,
+ 'endpoint_description': f'{m} {u} for {name}',
+ 'api_id': api_collection.find_one({'api_name': name, 'api_version': ver}).get(
+ 'api_id'
+ ),
+ 'endpoint_id': str(uuid.uuid4()),
+ }
+ )
+
def seed_credits() -> list[str]:
- groups = ['ai-basic','ai-pro','maps-basic','maps-pro','news-tier','weather-tier']
+ groups = ['ai-basic', 'ai-pro', 'maps-basic', 'maps-pro', 'news-tier', 'weather-tier']
tiers_catalog = [
- {'tier_name': 'basic', 'credits': 100, 'input_limit': 100, 'output_limit': 100, 'reset_frequency': 'monthly'},
- {'tier_name': 'pro', 'credits': 1000, 'input_limit': 500, 'output_limit': 500, 'reset_frequency': 'monthly'},
- {'tier_name': 'enterprise', 'credits': 10000, 'input_limit': 2000, 'output_limit': 2000, 'reset_frequency': 'monthly'},
+ {
+ 'tier_name': 'basic',
+ 'credits': 100,
+ 'input_limit': 100,
+ 'output_limit': 100,
+ 'reset_frequency': 'monthly',
+ },
+ {
+ 'tier_name': 'pro',
+ 'credits': 1000,
+ 'input_limit': 500,
+ 'output_limit': 500,
+ 'reset_frequency': 'monthly',
+ },
+ {
+ 'tier_name': 'enterprise',
+ 'credits': 10000,
+ 'input_limit': 2000,
+ 'output_limit': 2000,
+ 'reset_frequency': 'monthly',
+ },
]
created = []
for g in groups:
@@ -167,64 +303,84 @@ def seed_credits() -> list[str]:
created.append(g)
continue
tiers = random.sample(tiers_catalog, k=random.randint(1, 3))
- credit_def_collection.insert_one({
- 'api_credit_group': g,
- 'api_key': encrypt_value(uuid.uuid4().hex),
- 'api_key_header': _rand_choice(['x-api-key','authorization','x-token']),
- 'credit_tiers': tiers,
- })
+ credit_def_collection.insert_one(
+ {
+ 'api_credit_group': g,
+ 'api_key': encrypt_value(uuid.uuid4().hex),
+ 'api_key_header': _rand_choice(['x-api-key', 'authorization', 'x-token']),
+ 'credit_tiers': tiers,
+ }
+ )
created.append(g)
return created
+
def seed_user_credits(usernames: list[str], credit_groups: list[str]) -> None:
- pick_users = random.sample(usernames, k=min(len(usernames), max(1, len(usernames)//2))) if usernames else []
+ pick_users = (
+ random.sample(usernames, k=min(len(usernames), max(1, len(usernames) // 2)))
+ if usernames
+ else []
+ )
for u in pick_users:
users_credits = {}
for g in random.sample(credit_groups, k=random.randint(1, min(3, len(credit_groups)))):
users_credits[g] = {
- 'tier_name': _rand_choice(['basic','pro','enterprise']),
+ 'tier_name': _rand_choice(['basic', 'pro', 'enterprise']),
'available_credits': random.randint(10, 10000),
- 'reset_date': (datetime.utcnow() + timedelta(days=random.randint(1, 30))).strftime('%Y-%m-%d'),
+ 'reset_date': (datetime.utcnow() + timedelta(days=random.randint(1, 30))).strftime(
+ '%Y-%m-%d'
+ ),
'user_api_key': encrypt_value(uuid.uuid4().hex),
}
existing = user_credit_collection.find_one({'username': u})
if existing:
- user_credit_collection.update_one({'username': u}, {'$set': {'users_credits': users_credits}})
+ user_credit_collection.update_one(
+ {'username': u}, {'$set': {'users_credits': users_credits}}
+ )
else:
user_credit_collection.insert_one({'username': u, 'users_credits': users_credits})
-def seed_subscriptions(usernames: list[str], apis: list[tuple[str,str]]) -> None:
+
+def seed_subscriptions(usernames: list[str], apis: list[tuple[str, str]]) -> None:
api_keys = [f'{a}/{v}' for a, v in apis]
for u in usernames:
- subs = sorted(set(random.sample(api_keys, k=random.randint(1, min(5, len(api_keys))))) ) if api_keys else []
+ subs = (
+ sorted(set(random.sample(api_keys, k=random.randint(1, min(5, len(api_keys))))))
+ if api_keys
+ else []
+ )
existing = subscriptions_collection.find_one({'username': u})
if existing:
subscriptions_collection.update_one({'username': u}, {'$set': {'apis': subs}})
else:
subscriptions_collection.insert_one({'username': u, 'apis': subs})
-def seed_logs(n: int, usernames: list[str], apis: list[tuple[str,str]]) -> None:
+
+def seed_logs(n: int, usernames: list[str], apis: list[tuple[str, str]]) -> None:
base_dir = os.path.dirname(os.path.abspath(__file__))
base_dir = os.path.abspath(os.path.join(base_dir, '..'))
logs_dir = os.path.join(base_dir, 'logs')
os.makedirs(logs_dir, exist_ok=True)
log_path = os.path.join(logs_dir, 'doorman.log')
- methods = ['GET','POST','PUT','DELETE','PATCH']
- uris = ['/status','/list','/items','/items/123','/search?q=test','/export','/metrics']
+ methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']
+ uris = ['/status', '/list', '/items', '/items/123', '/search?q=test', '/export', '/metrics']
now = datetime.now()
with open(log_path, 'a', encoding='utf-8') as lf:
for _ in range(n):
- api = _rand_choice(apis) if apis else ('demo','v1')
+ api = _rand_choice(apis) if apis else ('demo', 'v1')
method = _rand_choice(methods)
uri = _rand_choice(uris)
- ts = (now - timedelta(seconds=random.randint(0, 3600))).strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
+ ts = (now - timedelta(seconds=random.randint(0, 3600))).strftime(
+ '%Y-%m-%d %H:%M:%S,%f'
+ )[:-3]
rid = str(uuid.uuid4())
user = _rand_choice(usernames) if usernames else 'admin'
port = random.randint(10000, 65000)
- msg = f'{rid} | Username: {user} | From: 127.0.0.1:{port} | Endpoint: {method} /{api[0]}/{api[1]}{uri} | Total time: {random.randint(5,500)}ms'
+ msg = f'{rid} | Username: {user} | From: 127.0.0.1:{port} | Endpoint: {method} /{api[0]}/{api[1]}{uri} | Total time: {random.randint(5, 500)}ms'
lf.write(f'{ts} - doorman.gateway - INFO - {msg}\n')
-def seed_protos(n: int, apis: list[tuple[str,str]]) -> None:
+
+def seed_protos(n: int, apis: list[tuple[str, str]]) -> None:
base_dir = os.path.dirname(os.path.abspath(__file__))
base_dir = os.path.abspath(os.path.join(base_dir, '..'))
proto_dir = os.path.join(base_dir, 'proto')
@@ -235,7 +391,7 @@ def seed_protos(n: int, apis: list[tuple[str,str]]) -> None:
for name, ver in picked:
key = f'{name}_{ver}'.replace('-', '_')
svc = ''.join([p.capitalize() for p in name.split('-')])
- content = f'''syntax = "proto3";
+ content = f"""syntax = "proto3";
package {key};
@@ -251,12 +407,13 @@ message StatusReply {{
string status = 1;
string message = 2;
}}
-'''
+"""
path = os.path.join(proto_dir, f'{key}.proto')
with open(path, 'w', encoding='utf-8') as f:
f.write(content)
-def seed_metrics(usernames: list[str], apis: list[tuple[str,str]], minutes: int = 400) -> None:
+
+def seed_metrics(usernames: list[str], apis: list[tuple[str, str]], minutes: int = 400) -> None:
now = datetime.utcnow()
for i in range(minutes, 0, -1):
minute_start = int(((now - timedelta(minutes=i)).timestamp()) // 60) * 60
@@ -264,7 +421,7 @@ def seed_metrics(usernames: list[str], apis: list[tuple[str,str]], minutes: int
count = random.randint(0, 50)
for _ in range(count):
dur = random.uniform(10, 400)
- status = _rand_choice([200,200,200,201,204,400,401,403,404,500])
+ status = _rand_choice([200, 200, 200, 201, 204, 400, 401, 403, 404, 500])
b.add(dur, status)
metrics_store.total_requests += 1
metrics_store.total_ms += dur
@@ -276,11 +433,12 @@ def seed_metrics(usernames: list[str], apis: list[tuple[str,str]], minutes: int
metrics_store.api_counts[f'rest:{_rand_choice(apis)[0]}'] += 1
metrics_store._buckets.append(b)
+
def run_seed(users=30, apis=12, endpoints=5, groups=6, protos=5, logs=1000, seed=None):
if seed is not None:
random.seed(seed)
roles = ensure_roles()
- api_pairs = seed_apis(apis, roles, ['ALL','admin'])
+ api_pairs = seed_apis(apis, roles, ['ALL', 'admin'])
group_names = seed_groups(groups, [f'{a}/{v}' for a, v in api_pairs])
usernames = seed_users(users, roles, group_names)
seed_endpoints(api_pairs, endpoints)
diff --git a/backend-services/utils/doorman_cache_async.py b/backend-services/utils/doorman_cache_async.py
index e128edd..b23c320 100644
--- a/backend-services/utils/doorman_cache_async.py
+++ b/backend-services/utils/doorman_cache_async.py
@@ -6,16 +6,18 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/doorman for more information
"""
-import redis.asyncio as aioredis
import json
-import os
-from typing import Dict, Any, Optional
import logging
+import os
+from typing import Any
+
+import redis.asyncio as aioredis
from utils.doorman_cache_util import MemoryCache
logger = logging.getLogger('doorman.gateway')
+
class AsyncDoormanCacheManager:
"""Async cache manager supporting both Redis (async) and in-memory modes."""
@@ -52,7 +54,7 @@ class AsyncDoormanCacheManager:
'endpoint_server_cache': 'endpoint_server_cache:',
'client_routing_cache': 'client_routing_cache:',
'token_def_cache': 'token_def_cache:',
- 'credit_def_cache': 'credit_def_cache:'
+ 'credit_def_cache': 'credit_def_cache:',
}
self.default_ttls = {
@@ -70,7 +72,7 @@ class AsyncDoormanCacheManager:
'endpoint_server_cache': 86400,
'client_routing_cache': 86400,
'token_def_cache': 86400,
- 'credit_def_cache': 86400
+ 'credit_def_cache': 86400,
}
def _to_json_serializable(self, value):
@@ -99,6 +101,7 @@ class AsyncDoormanCacheManager:
if self._init_lock:
import asyncio
+
while self._init_lock:
await asyncio.sleep(0.01)
return
@@ -114,7 +117,7 @@ class AsyncDoormanCacheManager:
port=redis_port,
db=redis_db,
decode_responses=True,
- max_connections=100
+ max_connections=100,
)
self.cache = aioredis.Redis(connection_pool=self._redis_pool)
@@ -148,7 +151,7 @@ class AsyncDoormanCacheManager:
else:
self.cache.setex(cache_key, ttl, payload)
- async def get_cache(self, cache_name: str, key: str) -> Optional[Any]:
+ async def get_cache(self, cache_name: str, key: str) -> Any | None:
"""Get cache value (async)."""
if self.is_redis:
await self._ensure_redis_connection()
@@ -200,13 +203,13 @@ class AsyncDoormanCacheManager:
for cache_name in self.prefixes.keys():
await self.clear_cache(cache_name)
- async def get_cache_info(self) -> Dict[str, Any]:
+ async def get_cache_info(self) -> dict[str, Any]:
"""Get cache information (async)."""
info = {
'type': self.cache_type,
'is_redis': self.is_redis,
'prefixes': list(self.prefixes.keys()),
- 'default_ttl': self.default_ttls
+ 'default_ttl': self.default_ttls,
}
if not self.is_redis and hasattr(self.cache, 'get_cache_stats'):
@@ -257,6 +260,7 @@ class AsyncDoormanCacheManager:
"""
try:
import inspect
+
if inspect.iscoroutine(operation):
result = await operation
else:
@@ -268,7 +272,7 @@ class AsyncDoormanCacheManager:
await self.delete_cache(cache_name, key)
return result
- except Exception as e:
+ except Exception:
await self.delete_cache(cache_name, key)
raise
@@ -278,10 +282,12 @@ class AsyncDoormanCacheManager:
await self.cache.close()
if self._redis_pool:
await self._redis_pool.disconnect()
- logger.info("Async Redis connections closed")
+ logger.info('Async Redis connections closed')
+
async_doorman_cache = AsyncDoormanCacheManager()
+
async def close_async_cache_connections():
"""Close all async cache connections for graceful shutdown."""
await async_doorman_cache.close()
diff --git a/backend-services/utils/doorman_cache_util.py b/backend-services/utils/doorman_cache_util.py
index 7499e0b..689aecc 100644
--- a/backend-services/utils/doorman_cache_util.py
+++ b/backend-services/utils/doorman_cache_util.py
@@ -4,18 +4,21 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/doorman for more information
"""
-import redis
+import asyncio
import json
+import logging
import os
import threading
-from typing import Dict, Any, Optional
-import asyncio
-import logging
+from typing import Any
+
+import redis
+
from utils import chaos_util
+
class MemoryCache:
def __init__(self, maxsize: int = 10000):
- self._cache: Dict[str, Dict[str, Any]] = {}
+ self._cache: dict[str, dict[str, Any]] = {}
self._lock = threading.RLock()
self._maxsize = maxsize
self._access_order = []
@@ -29,16 +32,13 @@ class MemoryCache:
lru_key = self._access_order.pop(0)
self._cache.pop(lru_key, None)
- self._cache[key] = {
- 'value': value,
- 'expires_at': self._get_current_time() + ttl
- }
+ self._cache[key] = {'value': value, 'expires_at': self._get_current_time() + ttl}
if key in self._access_order:
self._access_order.remove(key)
self._access_order.append(key)
- def get(self, key: str) -> Optional[str]:
+ def get(self, key: str) -> str | None:
with self._lock:
if key in self._cache:
cache_entry = self._cache[key]
@@ -70,42 +70,45 @@ class MemoryCache:
def _get_current_time(self) -> int:
import time
+
return int(time.time())
- def get_cache_stats(self) -> Dict[str, Any]:
+ def get_cache_stats(self) -> dict[str, Any]:
with self._lock:
current_time = self._get_current_time()
total_entries = len(self._cache)
- expired_entries = sum(1 for entry in self._cache.values()
- if current_time >= entry['expires_at'])
+ expired_entries = sum(
+ 1 for entry in self._cache.values() if current_time >= entry['expires_at']
+ )
active_entries = total_entries - expired_entries
return {
'total_entries': total_entries,
'active_entries': active_entries,
'expired_entries': expired_entries,
'maxsize': self._maxsize,
- 'usage_percent': (total_entries / self._maxsize * 100) if self._maxsize > 0 else 0
+ 'usage_percent': (total_entries / self._maxsize * 100) if self._maxsize > 0 else 0,
}
def _cleanup_expired(self):
current_time = self._get_current_time()
expired_keys = [
- key for key, entry in self._cache.items()
- if current_time >= entry['expires_at']
+ key for key, entry in self._cache.items() if current_time >= entry['expires_at']
]
for key in expired_keys:
del self._cache[key]
if key in self._access_order:
self._access_order.remove(key)
if expired_keys:
- logging.getLogger('doorman.cache').info(f'Cleaned up {len(expired_keys)} expired cache entries')
+ logging.getLogger('doorman.cache').info(
+ f'Cleaned up {len(expired_keys)} expired cache entries'
+ )
def stop_auto_save(self):
return
+
class DoormanCacheManager:
def __init__(self):
-
cache_flag = os.getenv('MEM_OR_EXTERNAL')
if cache_flag is None:
cache_flag = os.getenv('MEM_OR_REDIS', 'MEM')
@@ -124,12 +127,14 @@ class DoormanCacheManager:
port=redis_port,
db=redis_db,
decode_responses=True,
- max_connections=100
+ max_connections=100,
)
self.cache = redis.StrictRedis(connection_pool=pool)
self.is_redis = True
except Exception as e:
- logging.getLogger('doorman.cache').warning(f'Redis connection failed, falling back to memory cache: {e}')
+ logging.getLogger('doorman.cache').warning(
+ f'Redis connection failed, falling back to memory cache: {e}'
+ )
maxsize = int(os.getenv('CACHE_MAX_SIZE', 10000))
self.cache = MemoryCache(maxsize=maxsize)
self.is_redis = False
@@ -150,7 +155,7 @@ class DoormanCacheManager:
'endpoint_server_cache': 'endpoint_server_cache:',
'client_routing_cache': 'client_routing_cache:',
'token_def_cache': 'token_def_cache:',
- 'credit_def_cache': 'credit_def_cache:'
+ 'credit_def_cache': 'credit_def_cache:',
}
self.default_ttls = {
'api_cache': 86400,
@@ -167,7 +172,7 @@ class DoormanCacheManager:
'endpoint_server_cache': 86400,
'client_routing_cache': 86400,
'token_def_cache': 86400,
- 'credit_def_cache': 86400
+ 'credit_def_cache': 86400,
}
def _get_key(self, cache_name, key):
@@ -255,7 +260,7 @@ class DoormanCacheManager:
'type': self.cache_type,
'is_redis': self.is_redis,
'prefixes': list(self.prefixes.keys()),
- 'default_ttl': self.default_ttls
+ 'default_ttl': self.default_ttls,
}
if not self.is_redis and hasattr(self.cache, 'get_cache_stats'):
info['memory_stats'] = self.cache.get_cache_stats()
@@ -267,7 +272,6 @@ class DoormanCacheManager:
self.cache._cleanup_expired()
def force_save_cache(self):
-
return
def stop_cache_persistence(self):
@@ -317,8 +321,9 @@ class DoormanCacheManager:
elif hasattr(result, 'deleted_count') and result.deleted_count > 0:
self.delete_cache(cache_name, key)
return result
- except Exception as e:
+ except Exception:
self.delete_cache(cache_name, key)
raise
+
doorman_cache = DoormanCacheManager()
diff --git a/backend-services/utils/encryption_util.py b/backend-services/utils/encryption_util.py
index a11150e..6d1a781 100644
--- a/backend-services/utils/encryption_util.py
+++ b/backend-services/utils/encryption_util.py
@@ -1,10 +1,11 @@
-from typing import Optional
-import os
import base64
import hashlib
+import os
+
from cryptography.fernet import Fernet
-def _get_cipher() -> Optional[Fernet]:
+
+def _get_cipher() -> Fernet | None:
"""Return a Fernet cipher derived from TOKEN_ENCRYPTION_KEY or MEM_ENCRYPTION_KEY.
If neither is set, returns None (plaintext compatibility mode).
"""
@@ -12,16 +13,15 @@ def _get_cipher() -> Optional[Fernet]:
if not key:
return None
try:
-
Fernet(key)
fkey = key
except Exception:
-
digest = hashlib.sha256(key.encode('utf-8')).digest()
fkey = base64.urlsafe_b64encode(digest)
return Fernet(fkey)
-def encrypt_value(value: Optional[str]) -> Optional[str]:
+
+def encrypt_value(value: str | None) -> str | None:
if value is None:
return None
cipher = _get_cipher()
@@ -30,7 +30,8 @@ def encrypt_value(value: Optional[str]) -> Optional[str]:
token = cipher.encrypt(value.encode('utf-8')).decode('utf-8')
return f'enc:{token}'
-def decrypt_value(value: Optional[str]) -> Optional[str]:
+
+def decrypt_value(value: str | None) -> str | None:
if value is None:
return None
if not isinstance(value, str):
@@ -39,11 +40,9 @@ def decrypt_value(value: Optional[str]) -> Optional[str]:
return value
cipher = _get_cipher()
if not cipher:
-
return None
try:
raw = value[4:]
return cipher.decrypt(raw.encode('utf-8')).decode('utf-8')
except Exception:
return None
-
diff --git a/backend-services/utils/enhanced_metrics_util.py b/backend-services/utils/enhanced_metrics_util.py
index f123ff5..b645de9 100644
--- a/backend-services/utils/enhanced_metrics_util.py
+++ b/backend-services/utils/enhanced_metrics_util.py
@@ -10,31 +10,26 @@ Extends the existing metrics_util.py with:
"""
from __future__ import annotations
-import time
-import os
-from collections import defaultdict, deque
-from typing import Dict, List, Optional, Deque
-from models.analytics_models import (
- EnhancedMinuteBucket,
- PercentileMetrics,
- EndpointMetrics,
- AnalyticsSnapshot
-)
+import os
+import time
+from collections import defaultdict, deque
+
+from models.analytics_models import AnalyticsSnapshot, EnhancedMinuteBucket, PercentileMetrics
from utils.analytics_aggregator import analytics_aggregator
class EnhancedMetricsStore:
"""
Enhanced version of MetricsStore with analytics capabilities.
-
+
Backward compatible with existing metrics_util.py while adding:
- Per-endpoint performance tracking
- Full percentile calculations
- Unique user counting
- Automatic aggregation to 5-min/hourly/daily buckets
"""
-
+
def __init__(self, max_minutes: int = 60 * 24): # 24 hours of minute-level data
# Global counters (backward compatible)
self.total_requests: int = 0
@@ -43,60 +38,60 @@ class EnhancedMetricsStore:
self.total_bytes_out: int = 0
self.total_upstream_timeouts: int = 0
self.total_retries: int = 0
- self.status_counts: Dict[int, int] = defaultdict(int)
- self.username_counts: Dict[str, int] = defaultdict(int)
- self.api_counts: Dict[str, int] = defaultdict(int)
-
+ self.status_counts: dict[int, int] = defaultdict(int)
+ self.username_counts: dict[str, int] = defaultdict(int)
+ self.api_counts: dict[str, int] = defaultdict(int)
+
# Enhanced: Use EnhancedMinuteBucket instead of MinuteBucket
- self._buckets: Deque[EnhancedMinuteBucket] = deque()
+ self._buckets: deque[EnhancedMinuteBucket] = deque()
self._max_minutes = max_minutes
-
+
# Track last aggregation times
self._last_aggregation_check = 0
-
+
@staticmethod
def _minute_floor(ts: float) -> int:
"""Floor timestamp to nearest minute."""
return int(ts // 60) * 60
-
+
def _ensure_bucket(self, minute_start: int) -> EnhancedMinuteBucket:
"""Get or create bucket for the given minute."""
if self._buckets and self._buckets[-1].start_ts == minute_start:
return self._buckets[-1]
-
+
# Create new enhanced bucket
mb = EnhancedMinuteBucket(start_ts=minute_start)
self._buckets.append(mb)
-
+
# Maintain max size
while len(self._buckets) > self._max_minutes:
- old_bucket = self._buckets.popleft()
+ self._buckets.popleft()
# Trigger aggregation for old buckets
self._maybe_aggregate()
-
+
return mb
-
+
def record(
self,
status: int,
duration_ms: float,
- username: Optional[str] = None,
- api_key: Optional[str] = None,
- endpoint_uri: Optional[str] = None,
- method: Optional[str] = None,
+ username: str | None = None,
+ api_key: str | None = None,
+ endpoint_uri: str | None = None,
+ method: str | None = None,
bytes_in: int = 0,
- bytes_out: int = 0
+ bytes_out: int = 0,
) -> None:
"""
Record a request with enhanced tracking.
-
+
Backward compatible with existing record() calls while supporting
new parameters for per-endpoint tracking.
"""
now = time.time()
minute_start = self._minute_floor(now)
bucket = self._ensure_bucket(minute_start)
-
+
# Record with enhanced tracking
bucket.add_request(
ms=duration_ms,
@@ -106,9 +101,9 @@ class EnhancedMetricsStore:
endpoint_uri=endpoint_uri,
method=method,
bytes_in=bytes_in,
- bytes_out=bytes_out
+ bytes_out=bytes_out,
)
-
+
# Update global counters (backward compatible)
self.total_requests += 1
self.total_ms += duration_ms
@@ -117,17 +112,17 @@ class EnhancedMetricsStore:
self.total_bytes_out += int(bytes_out or 0)
except Exception:
pass
-
+
self.status_counts[status] += 1
if username:
self.username_counts[username] += 1
if api_key:
self.api_counts[api_key] += 1
-
+
# Check if aggregation should run
self._maybe_aggregate()
-
- def record_retry(self, api_key: Optional[str] = None) -> None:
+
+ def record_retry(self, api_key: str | None = None) -> None:
"""Record a retry event."""
now = time.time()
minute_start = self._minute_floor(now)
@@ -137,8 +132,8 @@ class EnhancedMetricsStore:
self.total_retries += 1
except Exception:
pass
-
- def record_upstream_timeout(self, api_key: Optional[str] = None) -> None:
+
+ def record_upstream_timeout(self, api_key: str | None = None) -> None:
"""Record an upstream timeout event."""
now = time.time()
minute_start = self._minute_floor(now)
@@ -148,48 +143,45 @@ class EnhancedMetricsStore:
self.total_upstream_timeouts += 1
except Exception:
pass
-
+
def _maybe_aggregate(self) -> None:
"""
Check if aggregation should run and trigger if needed.
-
+
Runs aggregation jobs based on time elapsed:
- 5-minute aggregation: Every 5 minutes
- Hourly aggregation: Every hour
- Daily aggregation: Once per day
"""
now = int(time.time())
-
+
# Only check every minute to avoid overhead
if now - self._last_aggregation_check < 60:
return
-
+
self._last_aggregation_check = now
-
+
# Check what aggregations should run
should_run = analytics_aggregator.should_aggregate()
-
+
if should_run.get('5minute'):
# Get last 5 minutes of buckets
minute_buckets = list(self._buckets)[-5:]
if minute_buckets:
analytics_aggregator.aggregate_to_5minute(minute_buckets)
-
+
if should_run.get('hourly'):
analytics_aggregator.aggregate_to_hourly()
-
+
if should_run.get('daily'):
analytics_aggregator.aggregate_to_daily()
-
+
def get_snapshot(
- self,
- start_ts: int,
- end_ts: int,
- granularity: str = 'auto'
+ self, start_ts: int, end_ts: int, granularity: str = 'auto'
) -> AnalyticsSnapshot:
"""
Get analytics snapshot for a time range.
-
+
Automatically selects best aggregation level based on range.
"""
# Determine which buckets to use
@@ -207,7 +199,7 @@ class EnhancedMetricsStore:
buckets = [b for b in self._buckets if start_ts <= b.start_ts <= end_ts]
else:
buckets = analytics_aggregator.get_buckets_for_range(start_ts, end_ts)
-
+
if not buckets:
# Return empty snapshot
return AnalyticsSnapshot(
@@ -225,98 +217,107 @@ class EnhancedMetricsStore:
top_apis=[],
top_users=[],
top_endpoints=[],
- status_distribution={}
+ status_distribution={},
)
-
+
# Aggregate data from buckets
total_requests = sum(b.count for b in buckets)
total_errors = sum(b.error_count for b in buckets)
total_ms = sum(b.total_ms for b in buckets)
total_bytes_in = sum(b.bytes_in for b in buckets)
total_bytes_out = sum(b.bytes_out for b in buckets)
-
+
# Collect all latencies for percentile calculation
- all_latencies: List[float] = []
+ all_latencies: list[float] = []
for bucket in buckets:
all_latencies.extend(list(bucket.latencies))
-
- percentiles = PercentileMetrics.calculate(all_latencies) if all_latencies else PercentileMetrics()
-
+
+ percentiles = (
+ PercentileMetrics.calculate(all_latencies) if all_latencies else PercentileMetrics()
+ )
+
# Count unique users
unique_users = set()
for bucket in buckets:
unique_users.update(bucket.unique_users)
-
+
# Aggregate status counts
- status_distribution: Dict[str, int] = defaultdict(int)
+ status_distribution: dict[str, int] = defaultdict(int)
for bucket in buckets:
for status, count in bucket.status_counts.items():
status_distribution[str(status)] += count
-
+
# Aggregate API counts
- api_counts: Dict[str, int] = defaultdict(int)
+ api_counts: dict[str, int] = defaultdict(int)
for bucket in buckets:
for api, count in bucket.api_counts.items():
api_counts[api] += count
-
+
# Aggregate user counts
- user_counts: Dict[str, int] = defaultdict(int)
+ user_counts: dict[str, int] = defaultdict(int)
for bucket in buckets:
for user, count in bucket.user_counts.items():
user_counts[user] += count
-
+
# Aggregate endpoint metrics
- endpoint_metrics: Dict[str, Dict] = defaultdict(lambda: {
- 'count': 0,
- 'error_count': 0,
- 'total_ms': 0.0,
- 'latencies': []
- })
+ endpoint_metrics: dict[str, dict] = defaultdict(
+ lambda: {'count': 0, 'error_count': 0, 'total_ms': 0.0, 'latencies': []}
+ )
for bucket in buckets:
for endpoint_key, ep_metrics in bucket.endpoint_metrics.items():
endpoint_metrics[endpoint_key]['count'] += ep_metrics.count
endpoint_metrics[endpoint_key]['error_count'] += ep_metrics.error_count
endpoint_metrics[endpoint_key]['total_ms'] += ep_metrics.total_ms
endpoint_metrics[endpoint_key]['latencies'].extend(list(ep_metrics.latencies))
-
+
# Build top endpoints list
top_endpoints = []
for endpoint_key, metrics in endpoint_metrics.items():
method, uri = endpoint_key.split(':', 1)
avg_ms = metrics['total_ms'] / metrics['count'] if metrics['count'] > 0 else 0.0
- ep_percentiles = PercentileMetrics.calculate(metrics['latencies']) if metrics['latencies'] else PercentileMetrics()
-
- top_endpoints.append({
- 'endpoint_uri': uri,
- 'method': method,
- 'count': metrics['count'],
- 'error_count': metrics['error_count'],
- 'error_rate': metrics['error_count'] / metrics['count'] if metrics['count'] > 0 else 0.0,
- 'avg_ms': avg_ms,
- 'percentiles': ep_percentiles.to_dict()
- })
-
+ ep_percentiles = (
+ PercentileMetrics.calculate(metrics['latencies'])
+ if metrics['latencies']
+ else PercentileMetrics()
+ )
+
+ top_endpoints.append(
+ {
+ 'endpoint_uri': uri,
+ 'method': method,
+ 'count': metrics['count'],
+ 'error_count': metrics['error_count'],
+ 'error_rate': metrics['error_count'] / metrics['count']
+ if metrics['count'] > 0
+ else 0.0,
+ 'avg_ms': avg_ms,
+ 'percentiles': ep_percentiles.to_dict(),
+ }
+ )
+
# Sort by count (most used)
top_endpoints.sort(key=lambda x: x['count'], reverse=True)
-
+
# Build time-series data
series = []
for bucket in buckets:
avg_ms = bucket.total_ms / bucket.count if bucket.count > 0 else 0.0
bucket_percentiles = bucket.get_percentiles()
-
- series.append({
- 'timestamp': bucket.start_ts,
- 'count': bucket.count,
- 'error_count': bucket.error_count,
- 'error_rate': bucket.error_count / bucket.count if bucket.count > 0 else 0.0,
- 'avg_ms': avg_ms,
- 'percentiles': bucket_percentiles.to_dict(),
- 'bytes_in': bucket.bytes_in,
- 'bytes_out': bucket.bytes_out,
- 'unique_users': bucket.get_unique_user_count()
- })
-
+
+ series.append(
+ {
+ 'timestamp': bucket.start_ts,
+ 'count': bucket.count,
+ 'error_count': bucket.error_count,
+ 'error_rate': bucket.error_count / bucket.count if bucket.count > 0 else 0.0,
+ 'avg_ms': avg_ms,
+ 'percentiles': bucket_percentiles.to_dict(),
+ 'bytes_in': bucket.bytes_in,
+ 'bytes_out': bucket.bytes_out,
+ 'unique_users': bucket.get_unique_user_count(),
+ }
+ )
+
# Create snapshot
return AnalyticsSnapshot(
start_ts=start_ts,
@@ -333,32 +334,30 @@ class EnhancedMetricsStore:
top_apis=sorted(api_counts.items(), key=lambda x: x[1], reverse=True)[:10],
top_users=sorted(user_counts.items(), key=lambda x: x[1], reverse=True)[:10],
top_endpoints=top_endpoints[:10],
- status_distribution=dict(status_distribution)
+ status_distribution=dict(status_distribution),
)
-
- def snapshot(self, range_key: str, group: str = 'minute', sort: str = 'asc') -> Dict:
+
+ def snapshot(self, range_key: str, group: str = 'minute', sort: str = 'asc') -> dict:
"""
Backward compatible snapshot method for existing monitor endpoints.
"""
- range_to_minutes = {
- '1h': 60,
- '24h': 60 * 24,
- '7d': 60 * 24 * 7,
- '30d': 60 * 24 * 30,
- }
+ range_to_minutes = {'1h': 60, '24h': 60 * 24, '7d': 60 * 24 * 7, '30d': 60 * 24 * 30}
minutes = range_to_minutes.get(range_key, 60 * 24)
- buckets: List[EnhancedMinuteBucket] = list(self._buckets)[-minutes:]
+ buckets: list[EnhancedMinuteBucket] = list(self._buckets)[-minutes:]
series = []
-
+
if group == 'day':
from collections import defaultdict
- day_map: Dict[int, Dict[str, float]] = defaultdict(lambda: {
- 'count': 0,
- 'error_count': 0,
- 'total_ms': 0.0,
- 'bytes_in': 0,
- 'bytes_out': 0,
- })
+
+ day_map: dict[int, dict[str, float]] = defaultdict(
+ lambda: {
+ 'count': 0,
+ 'error_count': 0,
+ 'total_ms': 0.0,
+ 'bytes_in': 0,
+ 'bytes_out': 0,
+ }
+ )
for b in buckets:
day_ts = int((b.start_ts // 86400) * 86400)
d = day_map[day_ts]
@@ -369,38 +368,44 @@ class EnhancedMetricsStore:
d['bytes_out'] += b.bytes_out
for day_ts, d in day_map.items():
avg_ms = (d['total_ms'] / d['count']) if d['count'] else 0.0
- series.append({
- 'timestamp': day_ts,
- 'count': int(d['count']),
- 'error_count': int(d['error_count']),
- 'avg_ms': avg_ms,
- 'bytes_in': int(d['bytes_in']),
- 'bytes_out': int(d['bytes_out']),
- 'error_rate': (int(d['error_count']) / int(d['count'])) if d['count'] else 0.0,
- })
+ series.append(
+ {
+ 'timestamp': day_ts,
+ 'count': int(d['count']),
+ 'error_count': int(d['error_count']),
+ 'avg_ms': avg_ms,
+ 'bytes_in': int(d['bytes_in']),
+ 'bytes_out': int(d['bytes_out']),
+ 'error_rate': (int(d['error_count']) / int(d['count']))
+ if d['count']
+ else 0.0,
+ }
+ )
else:
for b in buckets:
avg_ms = (b.total_ms / b.count) if b.count else 0.0
percentiles = b.get_percentiles()
- series.append({
- 'timestamp': b.start_ts,
- 'count': b.count,
- 'error_count': b.error_count,
- 'avg_ms': avg_ms,
- 'p95_ms': percentiles.p95,
- 'bytes_in': b.bytes_in,
- 'bytes_out': b.bytes_out,
- 'error_rate': (b.error_count / b.count) if b.count else 0.0,
- 'upstream_timeouts': b.upstream_timeouts,
- 'retries': b.retries,
- })
-
- reverse = (str(sort).lower() == 'desc')
+ series.append(
+ {
+ 'timestamp': b.start_ts,
+ 'count': b.count,
+ 'error_count': b.error_count,
+ 'avg_ms': avg_ms,
+ 'p95_ms': percentiles.p95,
+ 'bytes_in': b.bytes_in,
+ 'bytes_out': b.bytes_out,
+ 'error_rate': (b.error_count / b.count) if b.count else 0.0,
+ 'upstream_timeouts': b.upstream_timeouts,
+ 'retries': b.retries,
+ }
+ )
+
+ reverse = str(sort).lower() == 'desc'
try:
series.sort(key=lambda x: x.get('timestamp', 0), reverse=reverse)
except Exception:
pass
-
+
total = self.total_requests
avg_total_ms = (self.total_ms / total) if total else 0.0
status = {str(k): v for k, v in self.status_counts.items()}
@@ -413,33 +418,36 @@ class EnhancedMetricsStore:
'total_retries': self.total_retries,
'status_counts': status,
'series': series,
- 'top_users': sorted(self.username_counts.items(), key=lambda kv: kv[1], reverse=True)[:10],
+ 'top_users': sorted(self.username_counts.items(), key=lambda kv: kv[1], reverse=True)[
+ :10
+ ],
'top_apis': sorted(self.api_counts.items(), key=lambda kv: kv[1], reverse=True)[:10],
}
-
+
def save_to_file(self, path: str) -> None:
"""Save metrics to file for persistence."""
try:
os.makedirs(os.path.dirname(path), exist_ok=True)
except Exception:
pass
-
+
try:
import json
+
tmp = path + '.tmp'
data = {
'total_requests': self.total_requests,
'total_ms': self.total_ms,
'total_bytes_in': self.total_bytes_in,
'total_bytes_out': self.total_bytes_out,
- 'buckets': [b.to_dict() for b in list(self._buckets)]
+ 'buckets': [b.to_dict() for b in list(self._buckets)],
}
with open(tmp, 'w', encoding='utf-8') as f:
json.dump(data, f)
os.replace(tmp, path)
except Exception:
pass
-
+
def load_from_file(self, path: str) -> None:
"""Load metrics from file."""
# Note: Simplified version - full implementation would reconstruct
diff --git a/backend-services/utils/error_codes.py b/backend-services/utils/error_codes.py
index 5a40ecf..8c4e0c0 100644
--- a/backend-services/utils/error_codes.py
+++ b/backend-services/utils/error_codes.py
@@ -22,6 +22,7 @@ Usage:
).dict()
"""
+
class ErrorCode:
"""
Centralized error code constants.
@@ -217,7 +218,9 @@ class ErrorCode:
RATE_LIMIT_EXCEEDED = 'RATE_LIMIT_EXCEEDED'
+
# Alias for backward compatibility
class ErrorCodes(ErrorCode):
"""Deprecated: Use ErrorCode instead."""
+
pass
diff --git a/backend-services/utils/error_util.py b/backend-services/utils/error_util.py
index d84efb0..c7bfb96 100644
--- a/backend-services/utils/error_util.py
+++ b/backend-services/utils/error_util.py
@@ -2,18 +2,20 @@
Standardized error response utilities
"""
+from typing import Any
+
from models.response_model import ResponseModel
from utils.response_util import process_response
-from typing import Optional, Dict, Any
+
def create_error_response(
status_code: int,
error_code: str,
error_message: str,
- request_id: Optional[str] = None,
- data: Optional[Dict[str, Any]] = None,
- api_type: str = 'rest'
-) -> Dict[str, Any]:
+ request_id: str | None = None,
+ data: dict[str, Any] | None = None,
+ api_type: str = 'rest',
+) -> dict[str, Any]:
"""
Create a standardized error response using ResponseModel.
"""
@@ -25,17 +27,18 @@ def create_error_response(
response_headers=response_headers,
error_code=error_code,
error_message=error_message,
- response=data
+ response=data,
)
return process_response(response_model.dict(), api_type)
+
def success_response(
status_code: int = 200,
- message: Optional[str] = None,
- data: Optional[Dict[str, Any]] = None,
- request_id: Optional[str] = None,
- api_type: str = 'rest'
-) -> Dict[str, Any]:
+ message: str | None = None,
+ data: dict[str, Any] | None = None,
+ request_id: str | None = None,
+ api_type: str = 'rest',
+) -> dict[str, Any]:
"""
Create a standardized success response using ResponseModel.
"""
@@ -43,9 +46,6 @@ def success_response(
if request_id:
response_headers['request_id'] = request_id
response_model = ResponseModel(
- status_code=status_code,
- response_headers=response_headers,
- message=message,
- response=data
+ status_code=status_code, response_headers=response_headers, message=message, response=data
)
return process_response(response_model.dict(), api_type)
diff --git a/backend-services/utils/gateway_utils.py b/backend-services/utils/gateway_utils.py
index 7be23ce..92c28ac 100644
--- a/backend-services/utils/gateway_utils.py
+++ b/backend-services/utils/gateway_utils.py
@@ -1,7 +1,7 @@
-import re
-from typing import Dict, List
-from fastapi import Request
import logging
+import re
+
+from fastapi import Request
_logger = logging.getLogger('doorman.gateway')
@@ -17,6 +17,7 @@ SENSITIVE_HEADERS = {
'csrf-token',
}
+
def sanitize_headers(value: str):
"""Sanitize header values to prevent injection attacks.
@@ -37,6 +38,7 @@ def sanitize_headers(value: str):
except Exception:
return ''
+
def redact_sensitive_header(header_name: str, header_value: str) -> str:
"""Redact sensitive header values for logging purposes.
@@ -69,7 +71,8 @@ def redact_sensitive_header(header_name: str, header_value: str) -> str:
except Exception:
return '[REDACTION_ERROR]'
-def log_headers_safely(request: Request, allowed_headers: List[str] = None, redact: bool = True):
+
+def log_headers_safely(request: Request, allowed_headers: list[str] = None, redact: bool = True):
"""Log request headers safely with redaction.
Args:
@@ -100,12 +103,13 @@ def log_headers_safely(request: Request, allowed_headers: List[str] = None, reda
headers_to_log[key] = sanitized
if headers_to_log:
- _logger.debug(f"Request headers: {headers_to_log}")
+ _logger.debug(f'Request headers: {headers_to_log}')
except Exception as e:
- _logger.debug(f"Failed to log headers safely: {e}")
+ _logger.debug(f'Failed to log headers safely: {e}')
-async def get_headers(request: Request, allowed_headers: List[str]):
+
+async def get_headers(request: Request, allowed_headers: list[str]):
"""Extract and sanitize allowed headers from request.
This function is used for forwarding headers to upstream services.
diff --git a/backend-services/utils/geo_lookup.py b/backend-services/utils/geo_lookup.py
index 96ed6f9..6967c0c 100644
--- a/backend-services/utils/geo_lookup.py
+++ b/backend-services/utils/geo_lookup.py
@@ -6,9 +6,7 @@ Uses MaxMind GeoLite2 database (free) or can be extended to use commercial servi
"""
import logging
-from typing import Optional, Dict, List, Set
from dataclasses import dataclass
-from datetime import datetime
from utils.redis_client import RedisClient, get_redis_client
@@ -18,45 +16,46 @@ logger = logging.getLogger(__name__)
@dataclass
class GeoLocation:
"""Geographic location information"""
+
ip: str
- country_code: Optional[str] = None
- country_name: Optional[str] = None
- region: Optional[str] = None
- city: Optional[str] = None
- latitude: Optional[float] = None
- longitude: Optional[float] = None
- timezone: Optional[str] = None
+ country_code: str | None = None
+ country_name: str | None = None
+ region: str | None = None
+ city: str | None = None
+ latitude: float | None = None
+ longitude: float | None = None
+ timezone: str | None = None
class GeoLookup:
"""
Geographic IP lookup and country-based rate limiting
-
+
Note: This is a simplified implementation. In production, integrate with:
- MaxMind GeoLite2 (free): https://dev.maxmind.com/geoip/geolite2-free-geolocation-data
- MaxMind GeoIP2 (paid): More accurate
- IP2Location: Alternative service
"""
-
- def __init__(self, redis_client: Optional[RedisClient] = None):
+
+ def __init__(self, redis_client: RedisClient | None = None):
"""Initialize geo lookup"""
self.redis = redis_client or get_redis_client()
-
+
# Cache TTL for geo lookups (24 hours)
self.cache_ttl = 86400
-
+
def lookup_ip(self, ip: str) -> GeoLocation:
"""
Lookup geographic information for IP
-
+
In production, this would use MaxMind GeoIP2 or similar service.
For now, returns cached data or placeholder.
"""
try:
# Check cache first
- cache_key = f"geo:cache:{ip}"
+ cache_key = f'geo:cache:{ip}'
cached = self.redis.hgetall(cache_key)
-
+
if cached:
return GeoLocation(
ip=ip,
@@ -66,9 +65,9 @@ class GeoLookup:
city=cached.get('city'),
latitude=float(cached['latitude']) if cached.get('latitude') else None,
longitude=float(cached['longitude']) if cached.get('longitude') else None,
- timezone=cached.get('timezone')
+ timezone=cached.get('timezone'),
)
-
+
# TODO: Integrate with MaxMind GeoIP2
# Example integration:
# import geoip2.database
@@ -79,23 +78,23 @@ class GeoLookup:
# city = response.city.name
# latitude = response.location.latitude
# longitude = response.location.longitude
-
+
# For now, return placeholder with unknown location
geo = GeoLocation(ip=ip, country_code='UNKNOWN')
-
+
# Cache the result
self._cache_geo_data(ip, geo)
-
+
return geo
-
+
except Exception as e:
- logger.error(f"Error looking up IP geolocation: {e}")
+ logger.error(f'Error looking up IP geolocation: {e}')
return GeoLocation(ip=ip)
-
+
def _cache_geo_data(self, ip: str, geo: GeoLocation) -> None:
"""Cache geo data in Redis"""
try:
- cache_key = f"geo:cache:{ip}"
+ cache_key = f'geo:cache:{ip}'
data = {
'country_code': geo.country_code or '',
'country_name': geo.country_name or '',
@@ -103,197 +102,197 @@ class GeoLookup:
'city': geo.city or '',
'latitude': str(geo.latitude) if geo.latitude else '',
'longitude': str(geo.longitude) if geo.longitude else '',
- 'timezone': geo.timezone or ''
+ 'timezone': geo.timezone or '',
}
self.redis.hmset(cache_key, data)
self.redis.expire(cache_key, self.cache_ttl)
except Exception as e:
- logger.error(f"Error caching geo data: {e}")
-
+ logger.error(f'Error caching geo data: {e}')
+
def is_country_blocked(self, country_code: str) -> bool:
"""Check if country is blocked"""
try:
return bool(self.redis.sismember('geo:blocked_countries', country_code))
except Exception as e:
- logger.error(f"Error checking blocked country: {e}")
+ logger.error(f'Error checking blocked country: {e}')
return False
-
+
def is_country_allowed(self, country_code: str) -> bool:
"""
Check if country is in allowlist
-
+
If allowlist is empty, all countries are allowed.
If allowlist has entries, only those countries are allowed.
"""
try:
# Get allowlist
allowed = self.redis.smembers('geo:allowed_countries')
-
+
# If no allowlist, all countries allowed
if not allowed:
return True
-
+
# Check if country in allowlist
return country_code in allowed
except Exception as e:
- logger.error(f"Error checking allowed country: {e}")
+ logger.error(f'Error checking allowed country: {e}')
return True
-
+
def block_country(self, country_code: str) -> bool:
"""Add country to blocklist"""
try:
self.redis.sadd('geo:blocked_countries', country_code)
- logger.info(f"Blocked country: {country_code}")
+ logger.info(f'Blocked country: {country_code}')
return True
except Exception as e:
- logger.error(f"Error blocking country: {e}")
+ logger.error(f'Error blocking country: {e}')
return False
-
+
def unblock_country(self, country_code: str) -> bool:
"""Remove country from blocklist"""
try:
self.redis.srem('geo:blocked_countries', country_code)
- logger.info(f"Unblocked country: {country_code}")
+ logger.info(f'Unblocked country: {country_code}')
return True
except Exception as e:
- logger.error(f"Error unblocking country: {e}")
+ logger.error(f'Error unblocking country: {e}')
return False
-
+
def add_to_allowlist(self, country_code: str) -> bool:
"""Add country to allowlist"""
try:
self.redis.sadd('geo:allowed_countries', country_code)
- logger.info(f"Added country to allowlist: {country_code}")
+ logger.info(f'Added country to allowlist: {country_code}')
return True
except Exception as e:
- logger.error(f"Error adding to allowlist: {e}")
+ logger.error(f'Error adding to allowlist: {e}')
return False
-
+
def remove_from_allowlist(self, country_code: str) -> bool:
"""Remove country from allowlist"""
try:
self.redis.srem('geo:allowed_countries', country_code)
- logger.info(f"Removed country from allowlist: {country_code}")
+ logger.info(f'Removed country from allowlist: {country_code}')
return True
except Exception as e:
- logger.error(f"Error removing from allowlist: {e}")
+ logger.error(f'Error removing from allowlist: {e}')
return False
-
- def get_country_rate_limit(self, country_code: str) -> Optional[int]:
+
+ def get_country_rate_limit(self, country_code: str) -> int | None:
"""
Get custom rate limit for country
-
+
Returns None if no custom limit set.
"""
try:
- limit_key = f"geo:rate_limit:{country_code}"
+ limit_key = f'geo:rate_limit:{country_code}'
limit = self.redis.get(limit_key)
return int(limit) if limit else None
except Exception as e:
- logger.error(f"Error getting country rate limit: {e}")
+ logger.error(f'Error getting country rate limit: {e}')
return None
-
+
def set_country_rate_limit(self, country_code: str, limit: int) -> bool:
"""Set custom rate limit for country"""
try:
- limit_key = f"geo:rate_limit:{country_code}"
+ limit_key = f'geo:rate_limit:{country_code}'
self.redis.set(limit_key, str(limit))
- logger.info(f"Set rate limit for {country_code}: {limit}")
+ logger.info(f'Set rate limit for {country_code}: {limit}')
return True
except Exception as e:
- logger.error(f"Error setting country rate limit: {e}")
+ logger.error(f'Error setting country rate limit: {e}')
return False
-
- def check_geographic_access(self, ip: str) -> tuple[bool, Optional[str]]:
+
+ def check_geographic_access(self, ip: str) -> tuple[bool, str | None]:
"""
Check if IP's geographic location is allowed
-
+
Returns:
(allowed, reason) tuple
"""
try:
geo = self.lookup_ip(ip)
-
+
if not geo.country_code or geo.country_code == 'UNKNOWN':
# Allow unknown locations by default
return True, None
-
+
# Check blocklist
if self.is_country_blocked(geo.country_code):
- return False, f"Country {geo.country_code} is blocked"
-
+ return False, f'Country {geo.country_code} is blocked'
+
# Check allowlist
if not self.is_country_allowed(geo.country_code):
- return False, f"Country {geo.country_code} is not in allowlist"
-
+ return False, f'Country {geo.country_code} is not in allowlist'
+
return True, None
-
+
except Exception as e:
- logger.error(f"Error checking geographic access: {e}")
+ logger.error(f'Error checking geographic access: {e}')
# Allow by default on error
return True, None
-
+
def track_country_request(self, country_code: str) -> None:
"""Track request from country for analytics"""
try:
- counter_key = f"geo:requests:{country_code}"
+ counter_key = f'geo:requests:{country_code}'
self.redis.incr(counter_key)
self.redis.expire(counter_key, 86400) # 24 hour window
except Exception as e:
- logger.error(f"Error tracking country request: {e}")
-
- def get_geographic_distribution(self) -> List[tuple]:
+ logger.error(f'Error tracking country request: {e}')
+
+ def get_geographic_distribution(self) -> list[tuple]:
"""
Get request distribution by country
-
+
Returns:
List of (country_code, request_count) tuples
"""
try:
- pattern = "geo:requests:*"
+ pattern = 'geo:requests:*'
keys = []
-
+
cursor = 0
while True:
cursor, batch = self.redis.scan(cursor, match=pattern, count=100)
keys.extend(batch)
if cursor == 0:
break
-
+
# Get counts for each country
country_counts = []
for key in keys:
country_code = key.replace('geo:requests:', '')
count = int(self.redis.get(key) or 0)
country_counts.append((country_code, count))
-
+
# Sort by count
country_counts.sort(key=lambda x: x[1], reverse=True)
return country_counts
-
+
except Exception as e:
- logger.error(f"Error getting geographic distribution: {e}")
+ logger.error(f'Error getting geographic distribution: {e}')
return []
-
- def get_blocked_countries(self) -> Set[str]:
+
+ def get_blocked_countries(self) -> set[str]:
"""Get list of blocked countries"""
try:
return self.redis.smembers('geo:blocked_countries')
except Exception as e:
- logger.error(f"Error getting blocked countries: {e}")
+ logger.error(f'Error getting blocked countries: {e}')
return set()
-
- def get_allowed_countries(self) -> Set[str]:
+
+ def get_allowed_countries(self) -> set[str]:
"""Get list of allowed countries"""
try:
return self.redis.smembers('geo:allowed_countries')
except Exception as e:
- logger.error(f"Error getting allowed countries: {e}")
+ logger.error(f'Error getting allowed countries: {e}')
return set()
# Global instance
-_geo_lookup: Optional[GeoLookup] = None
+_geo_lookup: GeoLookup | None = None
def get_geo_lookup() -> GeoLookup:
diff --git a/backend-services/utils/group_util.py b/backend-services/utils/group_util.py
index 66e70bd..602da96 100644
--- a/backend-services/utils/group_util.py
+++ b/backend-services/utils/group_util.py
@@ -5,17 +5,19 @@ See https://github.com/pypeople-dev/doorman for more information
"""
import logging
+
from fastapi import HTTPException, Request
-from utils.doorman_cache_util import doorman_cache
from services.user_service import UserService
-from utils.database_async import api_collection
from utils.async_db import db_find_one
from utils.auth_util import auth_required
+from utils.database_async import api_collection
+from utils.doorman_cache_util import doorman_cache
logger = logging.getLogger('doorman.gateway')
-async def group_required(request: Request = None, full_path: str = None, user_to_subscribe = None):
+
+async def group_required(request: Request = None, full_path: str = None, user_to_subscribe=None):
try:
payload = await auth_required(request)
username = payload.get('sub')
@@ -37,7 +39,7 @@ async def group_required(request: Request = None, full_path: str = None, user_to
prefix = '/api/grpc/'
if request:
postfix = '/' + request.headers.get('X-API-Version', 'v0')
- path = full_path[len(prefix):] if full_path.startswith(prefix) else full_path
+ path = full_path[len(prefix) :] if full_path.startswith(prefix) else full_path
api_and_version = '/'.join(path.split('/')[:2]) + postfix
if not api_and_version or '/' not in api_and_version:
raise HTTPException(status_code=404, detail='Invalid API path format')
@@ -46,11 +48,15 @@ async def group_required(request: Request = None, full_path: str = None, user_to
user = await UserService.get_user_by_username_helper(user_to_subscribe)
else:
user = await UserService.get_user_by_username_helper(username)
- api = doorman_cache.get_cache('api_cache', api_and_version) or await db_find_one(api_collection, {'api_name': api_name, 'api_version': api_version})
+ api = doorman_cache.get_cache('api_cache', api_and_version) or await db_find_one(
+ api_collection, {'api_name': api_name, 'api_version': api_version}
+ )
if not api:
raise HTTPException(status_code=404, detail='API not found')
if not set(user.get('groups') or []).intersection(api.get('api_allowed_groups') or []):
- raise HTTPException(status_code=401, detail='You do not have the correct group for this')
+ raise HTTPException(
+ status_code=401, detail='You do not have the correct group for this'
+ )
except HTTPException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
return payload
diff --git a/backend-services/utils/health_check_util.py b/backend-services/utils/health_check_util.py
index f314a16..4a104e1 100644
--- a/backend-services/utils/health_check_util.py
+++ b/backend-services/utils/health_check_util.py
@@ -4,23 +4,24 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/doorman for more information
"""
-import psutil
-import time
import logging
-from datetime import timedelta
-from redis.asyncio import Redis
import os
+import time
+from datetime import timedelta
-from utils.database import mongodb_client, database
+import psutil
+from redis.asyncio import Redis
+
+from utils.database import database, mongodb_client
from utils.doorman_cache_util import doorman_cache
logger = logging.getLogger('doorman.gateway')
START_TIME = time.time()
+
async def check_mongodb():
try:
-
if database.memory_only:
return True
if mongodb_client is None:
@@ -31,14 +32,14 @@ async def check_mongodb():
logger.error(f'MongoDB health check failed: {str(e)}')
return False
+
async def check_redis():
try:
-
if not getattr(doorman_cache, 'is_redis', False):
return True
redis = Redis.from_url(
f'redis://{os.getenv("REDIS_HOST")}:{os.getenv("REDIS_PORT")}/{os.getenv("REDIS_DB")}',
- decode_responses=True
+ decode_responses=True,
)
await redis.ping()
@@ -47,6 +48,7 @@ async def check_redis():
logger.error(f'Redis health check failed: {str(e)}')
return False
+
def get_memory_usage():
try:
process = psutil.Process(os.getpid())
@@ -58,6 +60,7 @@ def get_memory_usage():
logger.error(f'Memory usage check failed: {str(e)}')
return 'unknown'
+
def get_active_connections():
try:
process = psutil.Process(os.getpid())
@@ -67,6 +70,7 @@ def get_active_connections():
logger.error(f'Active connections check failed: {str(e)}')
return 0
+
def get_uptime():
try:
uptime_seconds = time.time() - START_TIME
diff --git a/backend-services/utils/hot_reload_config.py b/backend-services/utils/hot_reload_config.py
index d8c7ba1..c43bb82 100644
--- a/backend-services/utils/hot_reload_config.py
+++ b/backend-services/utils/hot_reload_config.py
@@ -18,16 +18,19 @@ Usage:
hot_config.reload()
"""
-import os
import json
-import yaml
import logging
+import os
import threading
-from typing import Any, Dict, Callable, Optional
+from collections.abc import Callable
from pathlib import Path
+from typing import Any
+
+import yaml
logger = logging.getLogger('doorman.gateway')
+
class HotReloadConfig:
"""
Thread-safe configuration manager with hot reload support.
@@ -39,10 +42,10 @@ class HotReloadConfig:
- Callbacks for configuration changes
"""
- def __init__(self, config_file: Optional[str] = None):
+ def __init__(self, config_file: str | None = None):
self._lock = threading.RLock()
- self._config: Dict[str, Any] = {}
- self._callbacks: Dict[str, list] = {}
+ self._config: dict[str, Any] = {}
+ self._callbacks: dict[str, list] = {}
self._config_file = config_file or os.getenv('DOORMAN_CONFIG_FILE')
self._load_initial_config()
@@ -64,7 +67,7 @@ class HotReloadConfig:
"""Load configuration from YAML or JSON file"""
path = Path(filepath)
- with open(filepath, 'r') as f:
+ with open(filepath) as f:
if path.suffix in ['.yaml', '.yml']:
file_config = yaml.safe_load(f) or {}
elif path.suffix == '.json':
@@ -80,29 +83,22 @@ class HotReloadConfig:
'LOG_LEVEL',
'LOG_FORMAT',
'LOG_FILE',
-
'GATEWAY_TIMEOUT',
'UPSTREAM_TIMEOUT',
'CONNECTION_TIMEOUT',
-
'RATE_LIMIT_ENABLED',
'RATE_LIMIT_REQUESTS',
'RATE_LIMIT_WINDOW',
-
'CACHE_TTL',
'CACHE_MAX_SIZE',
-
'CIRCUIT_BREAKER_ENABLED',
'CIRCUIT_BREAKER_THRESHOLD',
'CIRCUIT_BREAKER_TIMEOUT',
-
'RETRY_ENABLED',
'RETRY_MAX_ATTEMPTS',
'RETRY_BACKOFF',
-
'METRICS_ENABLED',
'METRICS_INTERVAL',
-
'FEATURE_REQUEST_REPLAY',
'FEATURE_AB_TESTING',
'FEATURE_COST_ANALYTICS',
@@ -249,7 +245,7 @@ class HotReloadConfig:
logger.info('Configuration reload complete')
- def dump(self) -> Dict[str, Any]:
+ def dump(self) -> dict[str, Any]:
"""Dump current configuration (for debugging)"""
with self._lock:
config = self._config.copy()
@@ -259,10 +255,12 @@ class HotReloadConfig:
config[key] = self._parse_value(env_value)
return config
+
hot_config = HotReloadConfig()
+
# Convenience functions for common config patterns
-def get_timeout_config() -> Dict[str, int]:
+def get_timeout_config() -> dict[str, int]:
"""Get all timeout configurations"""
return {
'gateway_timeout': hot_config.get_int('GATEWAY_TIMEOUT', 30),
@@ -270,7 +268,8 @@ def get_timeout_config() -> Dict[str, int]:
'connection_timeout': hot_config.get_int('CONNECTION_TIMEOUT', 10),
}
-def get_rate_limit_config() -> Dict[str, Any]:
+
+def get_rate_limit_config() -> dict[str, Any]:
"""Get rate limiting configuration"""
return {
'enabled': hot_config.get_bool('RATE_LIMIT_ENABLED', True),
@@ -278,14 +277,16 @@ def get_rate_limit_config() -> Dict[str, Any]:
'window': hot_config.get_int('RATE_LIMIT_WINDOW', 60),
}
-def get_cache_config() -> Dict[str, Any]:
+
+def get_cache_config() -> dict[str, Any]:
"""Get cache configuration"""
return {
'ttl': hot_config.get_int('CACHE_TTL', 300),
'max_size': hot_config.get_int('CACHE_MAX_SIZE', 1000),
}
-def get_circuit_breaker_config() -> Dict[str, Any]:
+
+def get_circuit_breaker_config() -> dict[str, Any]:
"""Get circuit breaker configuration"""
return {
'enabled': hot_config.get_bool('CIRCUIT_BREAKER_ENABLED', True),
@@ -293,7 +294,8 @@ def get_circuit_breaker_config() -> Dict[str, Any]:
'timeout': hot_config.get_int('CIRCUIT_BREAKER_TIMEOUT', 60),
}
-def get_retry_config() -> Dict[str, Any]:
+
+def get_retry_config() -> dict[str, Any]:
"""Get retry configuration"""
return {
'enabled': hot_config.get_bool('RETRY_ENABLED', True),
diff --git a/backend-services/utils/http_client.py b/backend-services/utils/http_client.py
index f3985d3..1d1919f 100644
--- a/backend-services/utils/http_client.py
+++ b/backend-services/utils/http_client.py
@@ -15,30 +15,34 @@ Usage:
from __future__ import annotations
import asyncio
+import logging
import os
import random
import time
from dataclasses import dataclass
-from typing import Any, Dict, Optional
+from typing import Any
import httpx
-import logging
+
from utils.metrics_util import metrics_store
logger = logging.getLogger('doorman.gateway')
+
class CircuitOpenError(Exception):
pass
+
@dataclass
class _BreakerState:
failures: int = 0
opened_at: float = 0.0
state: str = 'closed'
+
class _CircuitManager:
def __init__(self) -> None:
- self._states: Dict[str, _BreakerState] = {}
+ self._states: dict[str, _BreakerState] = {}
def get(self, key: str) -> _BreakerState:
st = self._states.get(key)
@@ -75,9 +79,11 @@ class _CircuitManager:
st.state = 'open'
st.opened_at = self.now()
+
circuit_manager = _CircuitManager()
-def _build_timeout(api_config: Optional[dict]) -> httpx.Timeout:
+
+def _build_timeout(api_config: dict | None) -> httpx.Timeout:
# Per-API overrides if present on document; otherwise env defaults
def _f(key: str, env_key: str, default: float) -> float:
try:
@@ -93,28 +99,31 @@ def _build_timeout(api_config: Optional[dict]) -> httpx.Timeout:
pool = _f('api_pool_timeout', 'HTTP_TIMEOUT', 30.0)
return httpx.Timeout(connect=connect, read=read, write=write, pool=pool)
+
def _should_retry_status(status: int) -> bool:
return status in (500, 502, 503, 504)
+
def _backoff_delay(attempt: int) -> float:
base = float(os.getenv('HTTP_RETRY_BASE_DELAY', 0.25))
cap = float(os.getenv('HTTP_RETRY_MAX_DELAY', 2.0))
delay = min(cap, base * (2 ** max(0, attempt - 1)))
return random.uniform(0, delay)
+
async def request_with_resilience(
client: httpx.AsyncClient,
method: str,
url: str,
*,
api_key: str,
- headers: Optional[Dict[str, str]] = None,
- params: Optional[Dict[str, Any]] = None,
+ headers: dict[str, str] | None = None,
+ params: dict[str, Any] | None = None,
data: Any = None,
json: Any = None,
content: Any = None,
retries: int = 0,
- api_config: Optional[dict] = None,
+ api_config: dict | None = None,
) -> httpx.Response:
"""Perform an HTTP request with retries, backoff, and circuit breaker.
@@ -132,8 +141,8 @@ async def request_with_resilience(
if enabled:
circuit_manager.check(api_key, open_seconds)
- last_exc: Optional[BaseException] = None
- response: Optional[httpx.Response] = None
+ last_exc: BaseException | None = None
+ response: httpx.Response | None = None
for attempt in range(1, attempts + 1):
if attempt > 1:
try:
@@ -143,13 +152,18 @@ async def request_with_resilience(
await asyncio.sleep(_backoff_delay(attempt))
try:
try:
- requester = getattr(client, 'request')
+ requester = client.request
except Exception:
requester = None
if requester is not None:
response = await requester(
- method.upper(), url,
- headers=headers, params=params, data=data, json=json, content=content,
+ method.upper(),
+ url,
+ headers=headers,
+ params=params,
+ data=data,
+ json=json,
+ content=content,
timeout=timeout,
)
else:
@@ -165,10 +179,7 @@ async def request_with_resilience(
kwargs['json'] = json
elif data is not None:
kwargs['json'] = data
- response = await meth(
- url,
- **kwargs,
- )
+ response = await meth(url, **kwargs)
if _should_retry_status(response.status_code) and attempt < attempts:
if enabled:
diff --git a/backend-services/utils/ip_policy_util.py b/backend-services/utils/ip_policy_util.py
index 404b222..e64bf50 100644
--- a/backend-services/utils/ip_policy_util.py
+++ b/backend-services/utils/ip_policy_util.py
@@ -1,13 +1,14 @@
from __future__ import annotations
-from fastapi import Request, HTTPException
-from typing import Optional, List
-
-from utils.security_settings_util import get_cached_settings
-from utils.audit_util import audit
import os
-def _get_client_ip(request: Request, trust_xff: bool) -> Optional[str]:
+from fastapi import HTTPException, Request
+
+from utils.audit_util import audit
+from utils.security_settings_util import get_cached_settings
+
+
+def _get_client_ip(request: Request, trust_xff: bool) -> str | None:
"""Determine client IP with optional proxy trust.
When `trust_xff` is True, this prefers headers supplied by trusted proxies.
@@ -28,7 +29,14 @@ def _get_client_ip(request: Request, trust_xff: bool) -> Optional[str]:
return _ip_in_list(src_ip, trusted) if src_ip else False
if trust_xff and _from_trusted_proxy():
- for header in ('x-forwarded-for', 'X-Forwarded-For', 'x-real-ip', 'X-Real-IP', 'cf-connecting-ip', 'CF-Connecting-IP'):
+ for header in (
+ 'x-forwarded-for',
+ 'X-Forwarded-For',
+ 'x-real-ip',
+ 'X-Real-IP',
+ 'cf-connecting-ip',
+ 'CF-Connecting-IP',
+ ):
val = request.headers.get(header)
if val:
ip = val.split(',')[0].strip()
@@ -38,11 +46,13 @@ def _get_client_ip(request: Request, trust_xff: bool) -> Optional[str]:
except Exception:
return request.client.host if request.client else None
-def _ip_in_list(ip: str, patterns: List[str]) -> bool:
+
+def _ip_in_list(ip: str, patterns: list[str]) -> bool:
try:
import ipaddress
+
ip_obj = ipaddress.ip_address(ip)
- for pat in (patterns or []):
+ for pat in patterns or []:
p = (pat or '').strip()
if not p:
continue
@@ -60,17 +70,20 @@ def _ip_in_list(ip: str, patterns: List[str]) -> bool:
except Exception:
return False
-def _is_loopback(ip: Optional[str]) -> bool:
+
+def _is_loopback(ip: str | None) -> bool:
try:
if not ip:
return False
if ip in ('testserver', 'localhost'):
return True
import ipaddress
+
return ipaddress.ip_address(ip).is_loopback
except Exception:
return False
+
def enforce_api_ip_policy(request: Request, api: dict):
"""
Enforce per-API IP policy.
@@ -81,16 +94,36 @@ def enforce_api_ip_policy(request: Request, api: dict):
"""
try:
settings = get_cached_settings()
- trust_xff = bool(api.get('api_trust_x_forwarded_for')) if api.get('api_trust_x_forwarded_for') is not None else bool(settings.get('trust_x_forwarded_for'))
+ trust_xff = (
+ bool(api.get('api_trust_x_forwarded_for'))
+ if api.get('api_trust_x_forwarded_for') is not None
+ else bool(settings.get('trust_x_forwarded_for'))
+ )
client_ip = _get_client_ip(request, trust_xff)
if not client_ip:
return
try:
settings = get_cached_settings()
env_flag = os.getenv('LOCAL_HOST_IP_BYPASS')
- allow_local = (env_flag.lower() == 'true') if isinstance(env_flag, str) and env_flag.strip() != '' else bool(settings.get('allow_localhost_bypass'))
+ allow_local = (
+ (env_flag.lower() == 'true')
+ if isinstance(env_flag, str) and env_flag.strip() != ''
+ else bool(settings.get('allow_localhost_bypass'))
+ )
direct_ip = getattr(getattr(request, 'client', None), 'host', None)
- has_forward = any(request.headers.get(h) for h in ('x-forwarded-for','X-Forwarded-For','x-real-ip','X-Real-IP','cf-connecting-ip','CF-Connecting-IP','forwarded','Forwarded'))
+ has_forward = any(
+ request.headers.get(h)
+ for h in (
+ 'x-forwarded-for',
+ 'X-Forwarded-For',
+ 'x-real-ip',
+ 'X-Real-IP',
+ 'cf-connecting-ip',
+ 'CF-Connecting-IP',
+ 'forwarded',
+ 'Forwarded',
+ )
+ )
if allow_local and direct_ip and _is_loopback(direct_ip) and not has_forward:
return
except Exception:
@@ -100,14 +133,28 @@ def enforce_api_ip_policy(request: Request, api: dict):
bl = api.get('api_ip_blacklist') or []
if bl and _ip_in_list(client_ip, bl):
try:
- audit(request, actor=None, action='ip.api_deny', target=str(api.get('api_id') or api.get('api_name') or 'unknown_api'), status='blocked', details={'reason': 'blacklisted', 'effective_ip': client_ip})
+ audit(
+ request,
+ actor=None,
+ action='ip.api_deny',
+ target=str(api.get('api_id') or api.get('api_name') or 'unknown_api'),
+ status='blocked',
+ details={'reason': 'blacklisted', 'effective_ip': client_ip},
+ )
except Exception:
pass
raise HTTPException(status_code=403, detail='API011')
if mode == 'whitelist':
if not wl or not _ip_in_list(client_ip, wl):
try:
- audit(request, actor=None, action='ip.api_deny', target=str(api.get('api_id') or api.get('api_name') or 'unknown_api'), status='blocked', details={'reason': 'not_in_whitelist', 'effective_ip': client_ip})
+ audit(
+ request,
+ actor=None,
+ action='ip.api_deny',
+ target=str(api.get('api_id') or api.get('api_name') or 'unknown_api'),
+ status='blocked',
+ details={'reason': 'not_in_whitelist', 'effective_ip': client_ip},
+ )
except Exception:
pass
raise HTTPException(status_code=403, detail='API010')
diff --git a/backend-services/utils/ip_rate_limiter.py b/backend-services/utils/ip_rate_limiter.py
index 06ebe7c..c21afae 100644
--- a/backend-services/utils/ip_rate_limiter.py
+++ b/backend-services/utils/ip_rate_limiter.py
@@ -6,14 +6,14 @@ and IP reputation scoring.
"""
import logging
-from typing import Optional, List, Set
-from datetime import datetime, timedelta
+import time
from dataclasses import dataclass
+from datetime import datetime
+
from fastapi import Request
-import time
-from utils.redis_client import RedisClient, get_redis_client
from utils.rate_limiter import RateLimiter, RateLimitResult
+from utils.redis_client import RedisClient, get_redis_client
logger = logging.getLogger(__name__)
@@ -21,14 +21,15 @@ logger = logging.getLogger(__name__)
@dataclass
class IPInfo:
"""Information about an IP address"""
+
ip: str
is_whitelisted: bool = False
is_blacklisted: bool = False
reputation_score: int = 100 # 0-100, lower is worse
request_count: int = 0
- last_seen: Optional[datetime] = None
- countries: Set[str] = None
-
+ last_seen: datetime | None = None
+ countries: set[str] = None
+
def __post_init__(self):
if self.countries is None:
self.countries = set()
@@ -38,24 +39,24 @@ class IPRateLimiter:
"""
IP-based rate limiting with whitelist/blacklist and reputation scoring
"""
-
- def __init__(self, redis_client: Optional[RedisClient] = None):
+
+ def __init__(self, redis_client: RedisClient | None = None):
"""Initialize IP rate limiter"""
self.redis = redis_client or get_redis_client()
self.rate_limiter = RateLimiter(redis_client)
-
+
# Default IP rate limits (can be overridden per IP)
self.default_ip_limit_per_minute = 60
self.default_ip_limit_per_hour = 1000
-
+
# Reputation thresholds
self.suspicious_threshold = 50 # Below this is suspicious
self.ban_threshold = 20 # Below this gets banned
-
+
def extract_client_ip(self, request: Request) -> str:
"""
Extract client IP from request, handling proxy headers
-
+
Priority order:
1. X-Forwarded-For (first IP in chain)
2. X-Real-IP
@@ -66,239 +67,229 @@ class IPRateLimiter:
if forwarded_for:
# Take the first IP in the chain (original client)
return forwarded_for.split(',')[0].strip()
-
+
# Check X-Real-IP
real_ip = request.headers.get('X-Real-IP')
if real_ip:
return real_ip.strip()
-
+
# Fallback to direct connection IP
if request.client and request.client.host:
return request.client.host
-
+
return 'unknown'
-
+
def is_whitelisted(self, ip: str) -> bool:
"""Check if IP is whitelisted"""
try:
return bool(self.redis.sismember('ip:whitelist', ip))
except Exception as e:
- logger.error(f"Error checking whitelist: {e}")
+ logger.error(f'Error checking whitelist: {e}')
return False
-
+
def is_blacklisted(self, ip: str) -> bool:
"""Check if IP is blacklisted"""
try:
return bool(self.redis.sismember('ip:blacklist', ip))
except Exception as e:
- logger.error(f"Error checking blacklist: {e}")
+ logger.error(f'Error checking blacklist: {e}')
return False
-
+
def add_to_whitelist(self, ip: str) -> bool:
"""Add IP to whitelist"""
try:
self.redis.sadd('ip:whitelist', ip)
- logger.info(f"Added IP to whitelist: {ip}")
+ logger.info(f'Added IP to whitelist: {ip}')
return True
except Exception as e:
- logger.error(f"Error adding to whitelist: {e}")
+ logger.error(f'Error adding to whitelist: {e}')
return False
-
- def add_to_blacklist(self, ip: str, duration_seconds: Optional[int] = None) -> bool:
+
+ def add_to_blacklist(self, ip: str, duration_seconds: int | None = None) -> bool:
"""
Add IP to blacklist
-
+
Args:
ip: IP address to blacklist
duration_seconds: Optional duration for temporary ban
"""
try:
self.redis.sadd('ip:blacklist', ip)
-
+
if duration_seconds:
# Set expiration for temporary ban
- ban_key = f"ip:ban:{ip}"
+ ban_key = f'ip:ban:{ip}'
self.redis.setex(ban_key, duration_seconds, '1')
- logger.info(f"Temporarily banned IP for {duration_seconds}s: {ip}")
+ logger.info(f'Temporarily banned IP for {duration_seconds}s: {ip}')
else:
- logger.info(f"Permanently blacklisted IP: {ip}")
-
+ logger.info(f'Permanently blacklisted IP: {ip}')
+
return True
except Exception as e:
- logger.error(f"Error adding to blacklist: {e}")
+ logger.error(f'Error adding to blacklist: {e}')
return False
-
+
def remove_from_whitelist(self, ip: str) -> bool:
"""Remove IP from whitelist"""
try:
self.redis.srem('ip:whitelist', ip)
- logger.info(f"Removed IP from whitelist: {ip}")
+ logger.info(f'Removed IP from whitelist: {ip}')
return True
except Exception as e:
- logger.error(f"Error removing from whitelist: {e}")
+ logger.error(f'Error removing from whitelist: {e}')
return False
-
+
def remove_from_blacklist(self, ip: str) -> bool:
"""Remove IP from blacklist"""
try:
self.redis.srem('ip:blacklist', ip)
- ban_key = f"ip:ban:{ip}"
+ ban_key = f'ip:ban:{ip}'
self.redis.delete(ban_key)
- logger.info(f"Removed IP from blacklist: {ip}")
+ logger.info(f'Removed IP from blacklist: {ip}')
return True
except Exception as e:
- logger.error(f"Error removing from blacklist: {e}")
+ logger.error(f'Error removing from blacklist: {e}')
return False
-
+
def get_reputation_score(self, ip: str) -> int:
"""
Get reputation score for IP (0-100)
-
+
Lower scores indicate worse reputation.
"""
try:
- score_key = f"ip:reputation:{ip}"
+ score_key = f'ip:reputation:{ip}'
score = self.redis.get(score_key)
return int(score) if score else 100
except Exception as e:
- logger.error(f"Error getting reputation score: {e}")
+ logger.error(f'Error getting reputation score: {e}')
return 100
-
+
def update_reputation_score(self, ip: str, delta: int) -> int:
"""
Update reputation score for IP
-
+
Args:
ip: IP address
delta: Change in score (positive or negative)
-
+
Returns:
New reputation score
"""
try:
- score_key = f"ip:reputation:{ip}"
+ score_key = f'ip:reputation:{ip}'
current_score = self.get_reputation_score(ip)
new_score = max(0, min(100, current_score + delta))
-
+
self.redis.setex(score_key, 86400 * 7, str(new_score)) # 7 day TTL
-
+
# Auto-ban if score too low
if new_score <= self.ban_threshold:
self.add_to_blacklist(ip, duration_seconds=3600) # 1 hour ban
- logger.warning(f"Auto-banned IP due to low reputation: {ip} (score: {new_score})")
-
+ logger.warning(f'Auto-banned IP due to low reputation: {ip} (score: {new_score})')
+
return new_score
except Exception as e:
- logger.error(f"Error updating reputation score: {e}")
+ logger.error(f'Error updating reputation score: {e}')
return 100
-
+
def track_request(self, ip: str) -> None:
"""Track request from IP for analytics"""
try:
# Increment request counter
- counter_key = f"ip:requests:{ip}"
+ counter_key = f'ip:requests:{ip}'
self.redis.incr(counter_key)
self.redis.expire(counter_key, 86400) # 24 hour window
-
+
# Update last seen
- last_seen_key = f"ip:last_seen:{ip}"
+ last_seen_key = f'ip:last_seen:{ip}'
self.redis.set(last_seen_key, datetime.now().isoformat())
self.redis.expire(last_seen_key, 86400 * 7) # 7 days
-
+
except Exception as e:
- logger.error(f"Error tracking request: {e}")
-
+ logger.error(f'Error tracking request: {e}')
+
def get_ip_info(self, ip: str) -> IPInfo:
"""Get comprehensive information about an IP"""
try:
- request_count_key = f"ip:requests:{ip}"
+ request_count_key = f'ip:requests:{ip}'
request_count = int(self.redis.get(request_count_key) or 0)
-
- last_seen_key = f"ip:last_seen:{ip}"
+
+ last_seen_key = f'ip:last_seen:{ip}'
last_seen_str = self.redis.get(last_seen_key)
last_seen = datetime.fromisoformat(last_seen_str) if last_seen_str else None
-
+
return IPInfo(
ip=ip,
is_whitelisted=self.is_whitelisted(ip),
is_blacklisted=self.is_blacklisted(ip),
reputation_score=self.get_reputation_score(ip),
request_count=request_count,
- last_seen=last_seen
+ last_seen=last_seen,
)
except Exception as e:
- logger.error(f"Error getting IP info: {e}")
+ logger.error(f'Error getting IP info: {e}')
return IPInfo(ip=ip)
-
+
def check_ip_rate_limit(
- self,
- ip: str,
- limit_per_minute: Optional[int] = None,
- limit_per_hour: Optional[int] = None
+ self, ip: str, limit_per_minute: int | None = None, limit_per_hour: int | None = None
) -> RateLimitResult:
"""
Check rate limit for specific IP
-
+
Args:
ip: IP address
limit_per_minute: Override default per-minute limit
limit_per_hour: Override default per-hour limit
-
+
Returns:
RateLimitResult indicating if request is allowed
"""
# Check whitelist (always allow)
if self.is_whitelisted(ip):
- return RateLimitResult(allowed=True, limit=999999, remaining=999999, reset_at=int(time.time()) + 60)
-
+ return RateLimitResult(
+ allowed=True, limit=999999, remaining=999999, reset_at=int(time.time()) + 60
+ )
+
# Check blacklist (always deny)
if self.is_blacklisted(ip):
return RateLimitResult(
- allowed=False,
- limit=0,
- remaining=0,
- reset_at=int(time.time()) + 60
+ allowed=False, limit=0, remaining=0, reset_at=int(time.time()) + 60
)
-
+
# Check reputation score
reputation = self.get_reputation_score(ip)
if reputation < self.suspicious_threshold:
# Reduce limits for suspicious IPs
limit_per_minute = int((limit_per_minute or self.default_ip_limit_per_minute) * 0.5)
limit_per_hour = int((limit_per_hour or self.default_ip_limit_per_hour) * 0.5)
- logger.warning(f"Reduced limits for suspicious IP: {ip} (reputation: {reputation})")
-
+ logger.warning(f'Reduced limits for suspicious IP: {ip} (reputation: {reputation})')
+
# Use defaults if not specified
limit_per_minute = limit_per_minute or self.default_ip_limit_per_minute
limit_per_hour = limit_per_hour or self.default_ip_limit_per_hour
-
+
# Check per-minute limit
- minute_key = f"ip:limit:minute:{ip}"
+ minute_key = f'ip:limit:minute:{ip}'
minute_count = int(self.redis.get(minute_key) or 0)
-
+
if minute_count >= limit_per_minute:
# Decrease reputation for rate limit violations
self.update_reputation_score(ip, -5)
return RateLimitResult(
- allowed=False,
- limit=limit_per_minute,
- remaining=0,
- reset_at=int(time.time()) + 60
+ allowed=False, limit=limit_per_minute, remaining=0, reset_at=int(time.time()) + 60
)
-
+
# Check per-hour limit
- hour_key = f"ip:limit:hour:{ip}"
+ hour_key = f'ip:limit:hour:{ip}'
hour_count = int(self.redis.get(hour_key) or 0)
-
+
if hour_count >= limit_per_hour:
self.update_reputation_score(ip, -5)
return RateLimitResult(
- allowed=False,
- limit=limit_per_hour,
- remaining=0,
- reset_at=int(time.time()) + 3600
+ allowed=False, limit=limit_per_hour, remaining=0, reset_at=int(time.time()) + 3600
)
-
+
# Increment counters
pipe = self.redis.pipeline()
pipe.incr(minute_key)
@@ -306,83 +297,83 @@ class IPRateLimiter:
pipe.incr(hour_key)
pipe.expire(hour_key, 3600)
pipe.execute()
-
+
# Track request
self.track_request(ip)
-
+
return RateLimitResult(
allowed=True,
limit=limit_per_minute,
remaining=limit_per_minute - minute_count - 1,
- reset_at=int(time.time()) + 60
+ reset_at=int(time.time()) + 60,
)
-
- def get_top_ips(self, limit: int = 10) -> List[tuple]:
+
+ def get_top_ips(self, limit: int = 10) -> list[tuple]:
"""
Get top IPs by request volume
-
+
Returns:
List of (ip, request_count) tuples
"""
try:
# Scan for all IP request counters
- pattern = "ip:requests:*"
+ pattern = 'ip:requests:*'
keys = []
-
+
cursor = 0
while True:
cursor, batch = self.redis.scan(cursor, match=pattern, count=100)
keys.extend(batch)
if cursor == 0:
break
-
+
# Get counts for each IP
ip_counts = []
for key in keys[:1000]: # Limit to prevent overwhelming
ip = key.replace('ip:requests:', '')
count = int(self.redis.get(key) or 0)
ip_counts.append((ip, count))
-
+
# Sort by count and return top N
ip_counts.sort(key=lambda x: x[1], reverse=True)
return ip_counts[:limit]
-
+
except Exception as e:
- logger.error(f"Error getting top IPs: {e}")
+ logger.error(f'Error getting top IPs: {e}')
return []
-
+
def detect_suspicious_activity(self, ip: str) -> bool:
"""
Detect suspicious activity patterns
-
+
Returns:
True if suspicious activity detected
"""
try:
# Check request rate
- minute_key = f"ip:limit:minute:{ip}"
+ minute_key = f'ip:limit:minute:{ip}'
minute_count = int(self.redis.get(minute_key) or 0)
-
+
# Suspicious if > 80% of limit
if minute_count > self.default_ip_limit_per_minute * 0.8:
- logger.warning(f"Suspicious activity detected for IP: {ip} (high request rate)")
+ logger.warning(f'Suspicious activity detected for IP: {ip} (high request rate)')
self.update_reputation_score(ip, -10)
return True
-
+
# Check reputation
reputation = self.get_reputation_score(ip)
if reputation < self.suspicious_threshold:
return True
-
+
return False
-
+
except Exception as e:
- logger.error(f"Error detecting suspicious activity: {e}")
+ logger.error(f'Error detecting suspicious activity: {e}')
return False
# Global instance
-_ip_rate_limiter: Optional[IPRateLimiter] = None
+_ip_rate_limiter: IPRateLimiter | None = None
def get_ip_rate_limiter() -> IPRateLimiter:
diff --git a/backend-services/utils/limit_throttle_util.py b/backend-services/utils/limit_throttle_util.py
index 6a77a87..b0202fa 100644
--- a/backend-services/utils/limit_throttle_util.py
+++ b/backend-services/utils/limit_throttle_util.py
@@ -1,18 +1,19 @@
-from fastapi import Request, HTTPException
import asyncio
-import time
import logging
import os
+import time
+from fastapi import HTTPException, Request
+
+from utils.async_db import db_find_one
from utils.auth_util import auth_required
from utils.database_async import user_collection
-from utils.async_db import db_find_one
-import asyncio
from utils.doorman_cache_util import doorman_cache
from utils.ip_policy_util import _get_client_ip
logger = logging.getLogger('doorman.gateway')
+
class InMemoryWindowCounter:
"""Simple in-memory counter with TTL semantics to mimic required Redis ops.
@@ -36,6 +37,7 @@ class InMemoryWindowCounter:
See: doorman.py app_lifespan() for multi-worker validation
"""
+
def __init__(self):
self._store = {}
@@ -45,7 +47,6 @@ class InMemoryWindowCounter:
if entry and entry['expires_at'] > now:
entry['count'] += 1
else:
-
entry = {'count': 1, 'expires_at': now + 1}
self._store[key] = entry
return entry['count']
@@ -57,8 +58,10 @@ class InMemoryWindowCounter:
entry['expires_at'] = now + int(ttl_seconds)
self._store[key] = entry
+
_fallback_counter = InMemoryWindowCounter()
+
def duration_to_seconds(duration: str) -> int:
mapping = {
'second': 1,
@@ -67,7 +70,7 @@ def duration_to_seconds(duration: str) -> int:
'day': 86400,
'week': 604800,
'month': 2592000,
- 'year': 31536000
+ 'year': 31536000,
}
if not duration:
return 60
@@ -75,13 +78,14 @@ def duration_to_seconds(duration: str) -> int:
duration = duration[:-1]
return mapping.get(duration.lower(), 60)
+
async def limit_and_throttle(request: Request):
"""Enforce user-level rate limiting and throttling.
-
+
**Rate Limiting Hierarchy:**
1. Tier-based limits (checked by TierRateLimitMiddleware first)
2. User-specific overrides (checked here)
-
+
This function provides user-specific rate/throttle settings that override
or supplement tier-based limits. The TierRateLimitMiddleware runs first
and enforces tier limits, then this function applies user-specific rules.
@@ -163,7 +167,9 @@ async def limit_and_throttle(request: Request):
await _fallback_counter.expire(throttle_key, throttle_window)
try:
if os.getenv('DOORMAN_TEST_MODE', 'false').lower() == 'true':
- logger.info(f'[throttle] key={throttle_key} count={throttle_count} qlimit={int(user.get("throttle_queue_limit") or 10)} window={throttle_window}s')
+ logger.info(
+ f'[throttle] key={throttle_key} count={throttle_count} qlimit={int(user.get("throttle_queue_limit") or 10)} window={throttle_window}s'
+ )
except Exception:
pass
throttle_queue_limit = int(user.get('throttle_queue_limit') or 10)
@@ -181,13 +187,18 @@ async def limit_and_throttle(request: Request):
throttle_wait *= duration_to_seconds(throttle_wait_duration)
dynamic_wait = throttle_wait * (throttle_count - throttle_limit)
try:
- import sys as _sys, os as _os
- if _os.getenv('DOORMAN_TEST_MODE', 'false').lower() == 'true' and _sys.version_info >= (3, 13):
+ import os as _os
+ import sys as _sys
+
+ if _os.getenv(
+ 'DOORMAN_TEST_MODE', 'false'
+ ).lower() == 'true' and _sys.version_info >= (3, 13):
dynamic_wait = max(dynamic_wait, 0.2)
except Exception:
pass
await asyncio.sleep(dynamic_wait)
+
def reset_counters():
"""Reset in-memory rate/throttle counters (used by tests and cache clears).
Has no effect when a real Redis client is configured.
@@ -197,6 +208,7 @@ def reset_counters():
except Exception:
pass
+
async def limit_by_ip(request: Request, limit: int = 10, window: int = 60):
"""IP-based rate limiting for endpoints that don't require authentication.
@@ -232,12 +244,7 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60):
try:
if os.getenv('LOGIN_IP_RATE_DISABLED', 'false').lower() == 'true':
now = int(time.time())
- return {
- 'limit': limit,
- 'remaining': limit,
- 'reset': now + window,
- 'window': window
- }
+ return {'limit': limit, 'remaining': limit, 'reset': now + window, 'window': window}
client_ip = _get_client_ip(request, trust_xff=True)
if not client_ip:
logger.warning('Unable to determine client IP for rate limiting, allowing request')
@@ -245,7 +252,7 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60):
'limit': limit,
'remaining': limit,
'reset': int(time.time()) + window,
- 'window': window
+ 'window': window,
}
now = int(time.time())
@@ -273,7 +280,7 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60):
'limit': limit,
'remaining': remaining,
'reset': reset_time,
- 'window': window
+ 'window': window,
}
if count > limit:
@@ -283,14 +290,14 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60):
detail={
'error_code': 'IP_RATE_LIMIT',
'message': f'Too many requests from your IP address. Please wait {retry_after} seconds before trying again. Limit: {limit} requests per {window} seconds.',
- 'retry_after': retry_after
+ 'retry_after': retry_after,
},
headers={
'Retry-After': str(retry_after),
'X-RateLimit-Limit': str(limit),
'X-RateLimit-Remaining': '0',
- 'X-RateLimit-Reset': str(reset_time)
- }
+ 'X-RateLimit-Reset': str(reset_time),
+ },
)
if count > (limit * 0.8):
@@ -306,5 +313,5 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60):
'limit': limit,
'remaining': limit,
'reset': int(time.time()) + window,
- 'window': window
+ 'window': window,
}
diff --git a/backend-services/utils/memory_dump_util.py b/backend-services/utils/memory_dump_util.py
index c870285..0f33731 100644
--- a/backend-services/utils/memory_dump_util.py
+++ b/backend-services/utils/memory_dump_util.py
@@ -2,30 +2,28 @@
Utilities to dump and restore in-memory database state with encryption.
"""
-import os
-from pathlib import Path
-import json
import base64
-from typing import Optional, Any
-from datetime import datetime, timezone
-from cryptography.hazmat.primitives.kdf.hkdf import HKDF
+import json
+import os
+from datetime import UTC, datetime
+from pathlib import Path
+from typing import Any
+
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
+from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from .database import database
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DUMP_PATH = os.getenv('MEM_DUMP_PATH', str(_PROJECT_ROOT / 'generated' / 'memory_dump.bin'))
+
def _derive_key(key_material: str, salt: bytes) -> bytes:
- hkdf = HKDF(
- algorithm=hashes.SHA256(),
- length=32,
- salt=salt,
- info=b'doorman-mem-dump-v1',
- )
+ hkdf = HKDF(algorithm=hashes.SHA256(), length=32, salt=salt, info=b'doorman-mem-dump-v1')
return hkdf.derive(key_material.encode('utf-8'))
+
def _encrypt_blob(plaintext: bytes, key_str: str) -> bytes:
if not key_str or len(key_str) < 8:
raise ValueError('MEM_ENCRYPTION_KEY must be set and at least 8 characters')
@@ -37,6 +35,7 @@ def _encrypt_blob(plaintext: bytes, key_str: str) -> bytes:
return b'DMP1' + salt + nonce + ct
+
def _decrypt_blob(blob: bytes, key_str: str) -> bytes:
if not key_str or len(key_str) < 8:
raise ValueError('MEM_ENCRYPTION_KEY must be set and at least 8 characters')
@@ -51,7 +50,8 @@ def _decrypt_blob(blob: bytes, key_str: str) -> bytes:
aesgcm = AESGCM(key)
return aesgcm.decrypt(nonce, ct, None)
-def _split_dir_and_stem(path_hint: Optional[str]) -> tuple[str, str]:
+
+def _split_dir_and_stem(path_hint: str | None) -> tuple[str, str]:
"""Return (directory, stem) for naming timestamped dump files.
- If hint is a directory (or endswith '/'), use it and default stem 'memory_dump'.
@@ -73,8 +73,10 @@ def _split_dir_and_stem(path_hint: Optional[str]) -> tuple[str, str]:
stem = stem or 'memory_dump'
return dump_dir, stem
+
BYTES_KEY_PREFIX = '__byteskey__:'
+
def _to_jsonable(obj: Any) -> Any:
"""Recursively convert arbitrary objects to JSON-serializable structures.
@@ -90,7 +92,6 @@ def _to_jsonable(obj: Any) -> Any:
if isinstance(obj, dict):
out = {}
for k, v in obj.items():
-
if isinstance(k, bytes):
sk = BYTES_KEY_PREFIX + base64.b64encode(k).decode('ascii')
elif isinstance(k, (str, int, float, bool)) or k is None:
@@ -112,15 +113,16 @@ def _to_jsonable(obj: Any) -> Any:
except Exception:
return None
+
def _json_default(o: Any) -> Any:
if isinstance(o, bytes):
return {'__type__': 'bytes', 'data': base64.b64encode(o).decode('ascii')}
try:
-
return str(o)
except Exception:
return None
+
def _from_jsonable(obj: Any) -> Any:
"""Inverse of _to_jsonable for the specific encodings we apply."""
if isinstance(obj, dict):
@@ -133,7 +135,7 @@ def _from_jsonable(obj: Any) -> Any:
for k, v in obj.items():
rk: Any = k
if isinstance(k, str) and k.startswith(BYTES_KEY_PREFIX):
- b64 = k[len(BYTES_KEY_PREFIX):]
+ b64 = k[len(BYTES_KEY_PREFIX) :]
try:
rk = base64.b64decode(b64)
except Exception:
@@ -144,23 +146,41 @@ def _from_jsonable(obj: Any) -> Any:
return [_from_jsonable(v) for v in obj]
return obj
+
def _sanitize_for_dump(data: Any) -> Any:
"""
Remove sensitive data before dumping to prevent secret exposure.
"""
SENSITIVE_KEYS = {
- 'password', 'secret', 'token', 'key', 'api_key',
- 'access_token', 'refresh_token', 'jwt', 'jwt_secret',
- 'csrf_token', 'session', 'cookie',
- 'credential', 'auth', 'authorization',
- 'ssn', 'credit_card', 'cvv', 'private_key',
- 'encryption_key', 'signing_key'
+ 'password',
+ 'secret',
+ 'token',
+ 'key',
+ 'api_key',
+ 'access_token',
+ 'refresh_token',
+ 'jwt',
+ 'jwt_secret',
+ 'csrf_token',
+ 'session',
+ 'cookie',
+ 'credential',
+ 'auth',
+ 'authorization',
+ 'ssn',
+ 'credit_card',
+ 'cvv',
+ 'private_key',
+ 'encryption_key',
+ 'signing_key',
}
+
def should_redact(key: str) -> bool:
if not isinstance(key, str):
return False
key_lower = key.lower()
return any(s in key_lower for s in SENSITIVE_KEYS)
+
def redact_value(obj: Any) -> Any:
if isinstance(obj, dict):
return {
@@ -175,20 +195,22 @@ def _sanitize_for_dump(data: Any) -> Any:
if cleaned.isalnum():
return '[REDACTED-TOKEN]'
return obj
+
return redact_value(data)
-def dump_memory_to_file(path: Optional[str] = None) -> str:
+
+def dump_memory_to_file(path: str | None = None) -> str:
if not database.memory_only:
raise RuntimeError('Memory dump is only available in memory-only mode')
dump_dir, stem = _split_dir_and_stem(path)
os.makedirs(dump_dir, exist_ok=True)
- ts = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%SZ')
+ ts = datetime.now(UTC).strftime('%Y%m%dT%H%M%SZ')
dump_path = os.path.join(dump_dir, f'{stem}-{ts}.bin')
raw_data = database.db.dump_data()
sanitized_data = _sanitize_for_dump(_to_jsonable(raw_data))
payload = {
'version': 1,
- 'created_at': datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z'),
+ 'created_at': datetime.now(UTC).isoformat().replace('+00:00', 'Z'),
'sanitized': True,
'note': 'Sensitive fields (passwords, tokens, secrets) have been redacted',
'data': sanitized_data,
@@ -200,7 +222,8 @@ def dump_memory_to_file(path: Optional[str] = None) -> str:
f.write(blob)
return dump_path
-def restore_memory_from_file(path: Optional[str] = None) -> dict:
+
+def restore_memory_from_file(path: str | None = None) -> dict:
if not database.memory_only:
raise RuntimeError('Memory restore is only available in memory-only mode')
dump_path = path or DEFAULT_DUMP_PATH
@@ -215,34 +238,41 @@ def restore_memory_from_file(path: Optional[str] = None) -> dict:
data = _from_jsonable(payload.get('data', {}))
database.db.load_data(data)
try:
- from utils.database import user_collection
- from utils import password_util as _pw
import os as _os
+
+ from utils import password_util as _pw
+ from utils.database import user_collection
+
admin = user_collection.find_one({'username': 'admin'})
if admin is not None and not isinstance(admin.get('password'), (bytes, bytearray)):
pwd = _os.getenv('DOORMAN_ADMIN_PASSWORD')
if not pwd:
raise RuntimeError('DOORMAN_ADMIN_PASSWORD must be set in environment')
- user_collection.update_one({'username': 'admin'}, {'$set': {'password': _pw.hash_password(pwd)}})
+ user_collection.update_one(
+ {'username': 'admin'}, {'$set': {'password': _pw.hash_password(pwd)}}
+ )
except Exception:
pass
return {'version': payload.get('version', 1), 'created_at': payload.get('created_at')}
-def find_latest_dump_path(path_hint: Optional[str] = None) -> Optional[str]:
+
+def find_latest_dump_path(path_hint: str | None = None) -> str | None:
"""Return the most recent dump file path based on a hint.
- If `path_hint` is a file and exists, return it.
- If `path_hint` is a directory, search for .bin files and pick the newest.
- If no hint or not found, try DEFAULT_DUMP_PATH, or its directory for .bin files.
"""
- def newest_bin_in_dir(d: str, stem: Optional[str] = None) -> Optional[str]:
+
+ def newest_bin_in_dir(d: str, stem: str | None = None) -> str | None:
try:
if not os.path.isdir(d):
return None
candidates = [
os.path.join(d, f)
for f in os.listdir(d)
- if f.lower().endswith('.bin') and os.path.isfile(os.path.join(d, f))
+ if f.lower().endswith('.bin')
+ and os.path.isfile(os.path.join(d, f))
and (stem is None or f.startswith(stem + '-'))
]
if not candidates:
diff --git a/backend-services/utils/metrics_util.py b/backend-services/utils/metrics_util.py
index ad109bb..5af0b96 100644
--- a/backend-services/utils/metrics_util.py
+++ b/backend-services/utils/metrics_util.py
@@ -4,12 +4,13 @@ Records count, status code distribution, and response time stats, with per-minut
"""
from __future__ import annotations
+
+import json
+import os
import time
from collections import defaultdict, deque
from dataclasses import dataclass, field
-from typing import Deque, Dict, List, Optional
-import json
-import os
+
@dataclass
class MinuteBucket:
@@ -22,13 +23,21 @@ class MinuteBucket:
upstream_timeouts: int = 0
retries: int = 0
- status_counts: Dict[int, int] = field(default_factory=dict)
- api_counts: Dict[str, int] = field(default_factory=dict)
- api_error_counts: Dict[str, int] = field(default_factory=dict)
- user_counts: Dict[str, int] = field(default_factory=dict)
- latencies: Deque[float] = field(default_factory=deque)
+ status_counts: dict[int, int] = field(default_factory=dict)
+ api_counts: dict[str, int] = field(default_factory=dict)
+ api_error_counts: dict[str, int] = field(default_factory=dict)
+ user_counts: dict[str, int] = field(default_factory=dict)
+ latencies: deque[float] = field(default_factory=deque)
- def add(self, ms: float, status: int, username: Optional[str], api_key: Optional[str], bytes_in: int = 0, bytes_out: int = 0) -> None:
+ def add(
+ self,
+ ms: float,
+ status: int,
+ username: str | None,
+ api_key: str | None,
+ bytes_in: int = 0,
+ bytes_out: int = 0,
+ ) -> None:
self.count += 1
if status >= 400:
self.error_count += 1
@@ -68,7 +77,7 @@ class MinuteBucket:
except Exception:
pass
- def to_dict(self) -> Dict:
+ def to_dict(self) -> dict:
return {
'start_ts': self.start_ts,
'count': self.count,
@@ -85,7 +94,7 @@ class MinuteBucket:
}
@staticmethod
- def from_dict(d: Dict) -> 'MinuteBucket':
+ def from_dict(d: dict) -> MinuteBucket:
mb = MinuteBucket(
start_ts=int(d.get('start_ts', 0)),
count=int(d.get('count', 0)),
@@ -105,6 +114,7 @@ class MinuteBucket:
pass
return mb
+
class MetricsStore:
def __init__(self, max_minutes: int = 60 * 24 * 30):
self.total_requests: int = 0
@@ -113,10 +123,10 @@ class MetricsStore:
self.total_bytes_out: int = 0
self.total_upstream_timeouts: int = 0
self.total_retries: int = 0
- self.status_counts: Dict[int, int] = defaultdict(int)
- self.username_counts: Dict[str, int] = defaultdict(int)
- self.api_counts: Dict[str, int] = defaultdict(int)
- self._buckets: Deque[MinuteBucket] = deque()
+ self.status_counts: dict[int, int] = defaultdict(int)
+ self.username_counts: dict[str, int] = defaultdict(int)
+ self.api_counts: dict[str, int] = defaultdict(int)
+ self._buckets: deque[MinuteBucket] = deque()
self._max_minutes = max_minutes
@staticmethod
@@ -134,7 +144,15 @@ class MetricsStore:
self._buckets.popleft()
return mb
- def record(self, status: int, duration_ms: float, username: Optional[str] = None, api_key: Optional[str] = None, bytes_in: int = 0, bytes_out: int = 0) -> None:
+ def record(
+ self,
+ status: int,
+ duration_ms: float,
+ username: str | None = None,
+ api_key: str | None = None,
+ bytes_in: int = 0,
+ bytes_out: int = 0,
+ ) -> None:
now = time.time()
minute_start = self._minute_floor(now)
bucket = self._ensure_bucket(minute_start)
@@ -152,7 +170,7 @@ class MetricsStore:
if api_key:
self.api_counts[api_key] += 1
- def record_retry(self, api_key: Optional[str] = None) -> None:
+ def record_retry(self, api_key: str | None = None) -> None:
now = time.time()
minute_start = self._minute_floor(now)
bucket = self._ensure_bucket(minute_start)
@@ -162,7 +180,7 @@ class MetricsStore:
except Exception:
pass
- def record_upstream_timeout(self, api_key: Optional[str] = None) -> None:
+ def record_upstream_timeout(self, api_key: str | None = None) -> None:
now = time.time()
minute_start = self._minute_floor(now)
bucket = self._ensure_bucket(minute_start)
@@ -172,27 +190,24 @@ class MetricsStore:
except Exception:
pass
- def snapshot(self, range_key: str, group: str = 'minute', sort: str = 'asc') -> Dict:
-
- range_to_minutes = {
- '1h': 60,
- '24h': 60 * 24,
- '7d': 60 * 24 * 7,
- '30d': 60 * 24 * 30,
- }
+ def snapshot(self, range_key: str, group: str = 'minute', sort: str = 'asc') -> dict:
+ range_to_minutes = {'1h': 60, '24h': 60 * 24, '7d': 60 * 24 * 7, '30d': 60 * 24 * 30}
minutes = range_to_minutes.get(range_key, 60 * 24)
- buckets: List[MinuteBucket] = list(self._buckets)[-minutes:]
+ buckets: list[MinuteBucket] = list(self._buckets)[-minutes:]
series = []
if group == 'day':
from collections import defaultdict
- day_map: Dict[int, Dict[str, float]] = defaultdict(lambda: {
- 'count': 0,
- 'error_count': 0,
- 'total_ms': 0.0,
- 'bytes_in': 0,
- 'bytes_out': 0,
- })
+
+ day_map: dict[int, dict[str, float]] = defaultdict(
+ lambda: {
+ 'count': 0,
+ 'error_count': 0,
+ 'total_ms': 0.0,
+ 'bytes_in': 0,
+ 'bytes_out': 0,
+ }
+ )
for b in buckets:
day_ts = int((b.start_ts // 86400) * 86400)
d = day_map[day_ts]
@@ -203,15 +218,19 @@ class MetricsStore:
d['bytes_out'] += b.bytes_out
for day_ts, d in day_map.items():
avg_ms = (d['total_ms'] / d['count']) if d['count'] else 0.0
- series.append({
- 'timestamp': day_ts,
- 'count': int(d['count']),
- 'error_count': int(d['error_count']),
- 'avg_ms': avg_ms,
- 'bytes_in': int(d['bytes_in']),
- 'bytes_out': int(d['bytes_out']),
- 'error_rate': (int(d['error_count']) / int(d['count'])) if d['count'] else 0.0,
- })
+ series.append(
+ {
+ 'timestamp': day_ts,
+ 'count': int(d['count']),
+ 'error_count': int(d['error_count']),
+ 'avg_ms': avg_ms,
+ 'bytes_in': int(d['bytes_in']),
+ 'bytes_out': int(d['bytes_out']),
+ 'error_rate': (int(d['error_count']) / int(d['count']))
+ if d['count']
+ else 0.0,
+ }
+ )
else:
for b in buckets:
avg_ms = (b.total_ms / b.count) if b.count else 0.0
@@ -224,20 +243,22 @@ class MetricsStore:
p95 = float(arr[k])
except Exception:
p95 = 0.0
- series.append({
- 'timestamp': b.start_ts,
- 'count': b.count,
- 'error_count': b.error_count,
- 'avg_ms': avg_ms,
- 'p95_ms': p95,
- 'bytes_in': b.bytes_in,
- 'bytes_out': b.bytes_out,
- 'error_rate': (b.error_count / b.count) if b.count else 0.0,
- 'upstream_timeouts': b.upstream_timeouts,
- 'retries': b.retries,
- })
+ series.append(
+ {
+ 'timestamp': b.start_ts,
+ 'count': b.count,
+ 'error_count': b.error_count,
+ 'avg_ms': avg_ms,
+ 'p95_ms': p95,
+ 'bytes_in': b.bytes_in,
+ 'bytes_out': b.bytes_out,
+ 'error_rate': (b.error_count / b.count) if b.count else 0.0,
+ 'upstream_timeouts': b.upstream_timeouts,
+ 'retries': b.retries,
+ }
+ )
- reverse = (str(sort).lower() == 'desc')
+ reverse = str(sort).lower() == 'desc'
try:
series.sort(key=lambda x: x.get('timestamp', 0), reverse=reverse)
except Exception:
@@ -255,11 +276,13 @@ class MetricsStore:
'total_retries': self.total_retries,
'status_counts': status,
'series': series,
- 'top_users': sorted(self.username_counts.items(), key=lambda kv: kv[1], reverse=True)[:10],
+ 'top_users': sorted(self.username_counts.items(), key=lambda kv: kv[1], reverse=True)[
+ :10
+ ],
'top_apis': sorted(self.api_counts.items(), key=lambda kv: kv[1], reverse=True)[:10],
}
- def to_dict(self) -> Dict:
+ def to_dict(self) -> dict:
return {
'total_requests': int(self.total_requests),
'total_ms': float(self.total_ms),
@@ -271,7 +294,7 @@ class MetricsStore:
'buckets': [b.to_dict() for b in list(self._buckets)],
}
- def load_dict(self, data: Dict) -> None:
+ def load_dict(self, data: dict) -> None:
try:
self.total_requests = int(data.get('total_requests', 0))
self.total_ms = float(data.get('total_ms', 0.0))
@@ -308,11 +331,12 @@ class MetricsStore:
try:
if not os.path.exists(path):
return
- with open(path, 'r', encoding='utf-8') as f:
+ with open(path, encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, dict):
self.load_dict(data)
except Exception:
pass
+
metrics_store = MetricsStore()
diff --git a/backend-services/utils/paging_util.py b/backend-services/utils/paging_util.py
index 7d9ac3f..7f9dd6d 100644
--- a/backend-services/utils/paging_util.py
+++ b/backend-services/utils/paging_util.py
@@ -2,6 +2,7 @@ import os
from utils.constants import Defaults
+
def max_page_size() -> int:
try:
env = os.getenv(Defaults.MAX_PAGE_SIZE_ENV)
@@ -11,6 +12,7 @@ def max_page_size() -> int:
except Exception:
return Defaults.MAX_PAGE_SIZE_DEFAULT
+
def validate_page_params(page: int, page_size: int) -> tuple[int, int]:
p = int(page)
ps = int(page_size)
@@ -22,4 +24,3 @@ def validate_page_params(page: int, page_size: int) -> tuple[int, int]:
if ps > m:
raise ValueError(f'page_size must be <= {m}')
return p, ps
-
diff --git a/backend-services/utils/password_util.py b/backend-services/utils/password_util.py
index d51f021..fdc4680 100644
--- a/backend-services/utils/password_util.py
+++ b/backend-services/utils/password_util.py
@@ -6,14 +6,17 @@ See https://github.com/pypeople-dev/doorman for more information
import bcrypt
+
def hash_password(password: str):
hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt())
return hashed_password
+
def verify_password(password: str, hashed_password: str):
password = password.encode('utf-8')
return bcrypt.checkpw(password, hashed_password)
+
def is_secure_password(password: str):
if len(password) < 16:
return False
@@ -25,4 +28,4 @@ def is_secure_password(password: str):
return False
if not any(c in '!@#$%^&*()-_=+[]{};:,.<>?/' for c in password):
return False
- return True
\ No newline at end of file
+ return True
diff --git a/backend-services/utils/quota_tracker.py b/backend-services/utils/quota_tracker.py
index ccc108f..a30697f 100644
--- a/backend-services/utils/quota_tracker.py
+++ b/backend-services/utils/quota_tracker.py
@@ -5,13 +5,12 @@ Tracks usage quotas (monthly, daily) for users and APIs.
Supports quota limits, rollover, and exhaustion detection.
"""
-import time
import logging
-from datetime import datetime, timedelta
-from typing import Optional, Dict, List
from dataclasses import dataclass
-from models.rate_limit_models import QuotaUsage, QuotaType, generate_quota_key
-from utils.redis_client import get_redis_client, RedisClient
+from datetime import datetime, timedelta
+
+from models.rate_limit_models import QuotaType, QuotaUsage, generate_quota_key
+from utils.redis_client import RedisClient, get_redis_client
logger = logging.getLogger(__name__)
@@ -19,6 +18,7 @@ logger = logging.getLogger(__name__)
@dataclass
class QuotaCheckResult:
"""Result of quota check"""
+
allowed: bool
current_usage: int
limit: int
@@ -32,7 +32,7 @@ class QuotaCheckResult:
class QuotaTracker:
"""
Quota tracker for managing usage quotas
-
+
Features:
- Monthly and daily quota tracking
- Automatic reset at period boundaries
@@ -40,49 +40,45 @@ class QuotaTracker:
- Quota exhaustion detection
- Historical usage tracking
"""
-
- def __init__(self, redis_client: Optional[RedisClient] = None):
+
+ def __init__(self, redis_client: RedisClient | None = None):
"""
Initialize quota tracker
-
+
Args:
redis_client: Redis client instance
"""
self.redis = redis_client or get_redis_client()
-
+
def check_quota(
- self,
- user_id: str,
- quota_type: QuotaType,
- limit: int,
- period: str = 'month'
+ self, user_id: str, quota_type: QuotaType, limit: int, period: str = 'month'
) -> QuotaCheckResult:
"""
Check if user has quota available
-
+
Args:
user_id: User identifier
quota_type: Type of quota (requests, bandwidth, etc.)
limit: Quota limit
period: 'month' or 'day'
-
+
Returns:
QuotaCheckResult with current status
"""
# Get current period key
period_key = self._get_period_key(period)
quota_key = generate_quota_key(user_id, quota_type, period_key)
-
+
try:
# Read usage and reset keys stored separately
- usage = self.redis.get(f"{quota_key}:usage")
- reset_at_raw = self.redis.get(f"{quota_key}:reset_at")
-
+ usage = self.redis.get(f'{quota_key}:usage')
+ reset_at_raw = self.redis.get(f'{quota_key}:reset_at')
+
if usage is None:
current_usage = 0
else:
current_usage = int(usage)
-
+
if reset_at_raw:
try:
reset_at = datetime.fromisoformat(reset_at_raw)
@@ -92,26 +88,26 @@ class QuotaTracker:
reset_at = self._get_next_reset(period)
# Initialize reset_at for future checks
try:
- self.redis.set(f"{quota_key}:reset_at", reset_at.isoformat())
+ self.redis.set(f'{quota_key}:reset_at', reset_at.isoformat())
except Exception:
pass
-
+
# Reset if period elapsed
if datetime.now() >= reset_at:
current_usage = 0
reset_at = self._get_next_reset(period)
try:
- self.redis.set(f"{quota_key}:usage", 0)
- self.redis.set(f"{quota_key}:reset_at", reset_at.isoformat())
+ self.redis.set(f'{quota_key}:usage', 0)
+ self.redis.set(f'{quota_key}:reset_at', reset_at.isoformat())
except Exception:
pass
-
+
remaining = max(0, limit - current_usage)
percentage_used = (current_usage / limit * 100) if limit > 0 else 0
is_warning = percentage_used >= 80
is_critical = percentage_used >= 95
allowed = current_usage < limit
-
+
# Attach derived exhaustion flag for test expectations
result = QuotaCheckResult(
allowed=allowed,
@@ -121,17 +117,17 @@ class QuotaTracker:
reset_at=reset_at,
percentage_used=percentage_used,
is_warning=is_warning,
- is_critical=is_critical
+ is_critical=is_critical,
)
# Inject attribute expected by tests
try:
- setattr(result, 'is_exhausted', not allowed)
+ result.is_exhausted = not allowed
except Exception:
pass
return result
-
+
except Exception as e:
- logger.error(f"Quota check error for {user_id}: {e}")
+ logger.error(f'Quota check error for {user_id}: {e}')
# Graceful degradation: allow on error
res = QuotaCheckResult(
allowed=True,
@@ -139,190 +135,173 @@ class QuotaTracker:
limit=limit,
remaining=limit,
reset_at=self._get_next_reset(period),
- percentage_used=0.0
+ percentage_used=0.0,
)
try:
- setattr(res, 'is_exhausted', False)
+ res.is_exhausted = False
except Exception:
pass
return res
-
+
def increment_quota(
- self,
- user_id: str,
- quota_type: QuotaType,
- amount: int = 1,
- period: str = 'month'
+ self, user_id: str, quota_type: QuotaType, amount: int = 1, period: str = 'month'
) -> int:
"""
Increment quota usage
-
+
Args:
user_id: User identifier
quota_type: Type of quota
amount: Amount to increment
period: 'month' or 'day'
-
+
Returns:
New usage value
"""
period_key = self._get_period_key(period)
quota_key = generate_quota_key(user_id, quota_type, period_key)
-
+
try:
# Increment usage (string amount tolerated by Redis mock in tests)
- new_usage = self.redis.incr(f"{quota_key}:usage", amount)
-
+ new_usage = self.redis.incr(f'{quota_key}:usage', amount)
+
# Ensure reset_at is set
- if not self.redis.exists(f"{quota_key}:reset_at"):
+ if not self.redis.exists(f'{quota_key}:reset_at'):
reset_at = self._get_next_reset(period)
try:
- self.redis.set(f"{quota_key}:reset_at", reset_at.isoformat())
+ self.redis.set(f'{quota_key}:reset_at', reset_at.isoformat())
except Exception:
pass
-
+
return new_usage
-
+
except Exception as e:
- logger.error(f"Error incrementing quota for {user_id}: {e}")
+ logger.error(f'Error incrementing quota for {user_id}: {e}')
return 0
-
+
def get_quota_usage(
- self,
- user_id: str,
- quota_type: QuotaType,
- limit: int,
- period: str = 'month'
+ self, user_id: str, quota_type: QuotaType, limit: int, period: str = 'month'
) -> QuotaUsage:
"""
Get current quota usage
-
+
Args:
user_id: User identifier
quota_type: Type of quota
limit: Quota limit
period: 'month' or 'day'
-
+
Returns:
QuotaUsage object
"""
period_key = self._get_period_key(period)
quota_key = generate_quota_key(user_id, quota_type, period_key)
-
+
try:
usage_data = self.redis.hmget(quota_key, ['usage', 'reset_at'])
-
+
if usage_data[0] is None:
current_usage = 0
reset_at = self._get_next_reset(period)
else:
current_usage = int(usage_data[0])
reset_at = datetime.fromisoformat(usage_data[1])
-
+
# Check if needs reset
if datetime.now() >= reset_at:
current_usage = 0
reset_at = self._get_next_reset(period)
-
+
return QuotaUsage(
key=quota_key,
quota_type=quota_type,
current_usage=current_usage,
limit=limit,
- reset_at=reset_at
+ reset_at=reset_at,
)
-
+
except Exception as e:
- logger.error(f"Error getting quota usage for {user_id}: {e}")
+ logger.error(f'Error getting quota usage for {user_id}: {e}')
return QuotaUsage(
key=quota_key,
quota_type=quota_type,
current_usage=0,
limit=limit,
- reset_at=self._get_next_reset(period)
+ reset_at=self._get_next_reset(period),
)
-
- def reset_quota(
- self,
- user_id: str,
- quota_type: QuotaType,
- period: str = 'month'
- ) -> bool:
+
+ def reset_quota(self, user_id: str, quota_type: QuotaType, period: str = 'month') -> bool:
"""
Reset quota for user (admin function)
-
+
Args:
user_id: User identifier
quota_type: Type of quota
period: 'month' or 'day'
-
+
Returns:
True if successful
"""
try:
period_key = self._get_period_key(period)
quota_key = generate_quota_key(user_id, quota_type, period_key)
-
+
# Delete quota key
self.redis.delete(quota_key)
- logger.info(f"Reset quota for user {user_id}")
+ logger.info(f'Reset quota for user {user_id}')
return True
-
+
except Exception as e:
- logger.error(f"Error resetting quota: {e}")
+ logger.error(f'Error resetting quota: {e}')
return False
-
- def get_all_quotas(
- self,
- user_id: str,
- limits: Dict[QuotaType, int]
- ) -> List[QuotaUsage]:
+
+ def get_all_quotas(self, user_id: str, limits: dict[QuotaType, int]) -> list[QuotaUsage]:
"""
Get all quota usages for a user
-
+
Args:
user_id: User identifier
limits: Dictionary of quota type to limit
-
+
Returns:
List of QuotaUsage objects
"""
quotas = []
-
+
for quota_type, limit in limits.items():
usage = self.get_quota_usage(user_id, quota_type, limit)
quotas.append(usage)
-
+
return quotas
-
+
def check_and_increment(
self,
user_id: str,
quota_type: QuotaType,
limit: int,
amount: int = 1,
- period: str = 'month'
+ period: str = 'month',
) -> QuotaCheckResult:
"""
Check quota and increment if allowed (atomic operation)
-
+
Args:
user_id: User identifier
quota_type: Type of quota
limit: Quota limit
amount: Amount to increment
period: 'month' or 'day'
-
+
Returns:
QuotaCheckResult
"""
# First check if quota is available
check_result = self.check_quota(user_id, quota_type, limit, period)
-
+
if check_result.allowed:
# Increment quota
new_usage = self.increment_quota(user_id, quota_type, amount, period)
-
+
# Update result with new values
check_result.current_usage = new_usage
check_result.remaining = max(0, limit - new_usage)
@@ -330,40 +309,40 @@ class QuotaTracker:
check_result.is_warning = check_result.percentage_used >= 80
check_result.is_critical = check_result.percentage_used >= 95
check_result.allowed = new_usage <= limit
-
+
return check_result
-
+
def _get_period_key(self, period: str) -> str:
"""
Get period key for current time
-
+
Args:
period: 'month' or 'day'
-
+
Returns:
Period key (e.g., '2025-12' for month, '2025-12-02' for day)
"""
now = datetime.now()
-
+
if period == 'month':
return now.strftime('%Y-%m')
elif period == 'day':
return now.strftime('%Y-%m-%d')
else:
- raise ValueError(f"Invalid period: {period}")
-
+ raise ValueError(f'Invalid period: {period}')
+
def _get_next_reset(self, period: str) -> datetime:
"""
Get next reset time for period
-
+
Args:
period: 'month' or 'day'
-
+
Returns:
Next reset datetime
"""
now = datetime.now()
-
+
if period == 'month':
# Next month, first day, midnight
if now.month == 12:
@@ -375,61 +354,53 @@ class QuotaTracker:
tomorrow = now + timedelta(days=1)
return datetime(tomorrow.year, tomorrow.month, tomorrow.day, 0, 0, 0)
else:
- raise ValueError(f"Invalid period: {period}")
-
+ raise ValueError(f'Invalid period: {period}')
+
def _initialize_quota(self, quota_key: str, reset_at: datetime):
"""
Initialize quota in Redis
-
+
Args:
quota_key: Redis key for quota
reset_at: Reset datetime
"""
try:
- self.redis.hmset(quota_key, {
- 'usage': 0,
- 'reset_at': reset_at.isoformat()
- })
-
+ self.redis.hmset(quota_key, {'usage': 0, 'reset_at': reset_at.isoformat()})
+
# Set TTL to slightly after reset time
ttl_seconds = int((reset_at - datetime.now()).total_seconds()) + 3600
self.redis.expire(quota_key, ttl_seconds)
-
+
except Exception as e:
- logger.error(f"Error initializing quota: {e}")
-
- def get_quota_history(
- self,
- user_id: str,
- quota_type: QuotaType,
- months: int = 6
- ) -> List[Dict]:
+ logger.error(f'Error initializing quota: {e}')
+
+ def get_quota_history(self, user_id: str, quota_type: QuotaType, months: int = 6) -> list[dict]:
"""
Get historical quota usage (placeholder for future implementation)
-
+
Args:
user_id: User identifier
quota_type: Type of quota
months: Number of months to retrieve
-
+
Returns:
List of historical usage data
"""
# TODO: Implement historical tracking
# This would query MongoDB or time-series database
- logger.warning("Quota history not yet implemented")
+ logger.warning('Quota history not yet implemented')
return []
# Global quota tracker instance
-_quota_tracker: Optional[QuotaTracker] = None
+_quota_tracker: QuotaTracker | None = None
def get_quota_tracker() -> QuotaTracker:
"""Get or create global quota tracker instance"""
global _quota_tracker
-
+
if _quota_tracker is None:
_quota_tracker = QuotaTracker()
-
+
return _quota_tracker
diff --git a/backend-services/utils/rate_limit_simulator.py b/backend-services/utils/rate_limit_simulator.py
index f86f32c..b85ddd0 100644
--- a/backend-services/utils/rate_limit_simulator.py
+++ b/backend-services/utils/rate_limit_simulator.py
@@ -6,13 +6,12 @@ and preview the impact of rule changes.
"""
import logging
-from typing import List, Dict, Optional
+import random
from dataclasses import dataclass
from datetime import datetime, timedelta
-import random
from models.rate_limit_models import RateLimitRule, RuleType, TimeWindow
-from utils.rate_limiter import RateLimiter, RateLimitResult
+from utils.rate_limiter import RateLimiter
logger = logging.getLogger(__name__)
@@ -20,6 +19,7 @@ logger = logging.getLogger(__name__)
@dataclass
class SimulationRequest:
"""Simulated request"""
+
timestamp: datetime
user_id: str
endpoint: str
@@ -29,6 +29,7 @@ class SimulationRequest:
@dataclass
class SimulationResult:
"""Result of simulation"""
+
total_requests: int
allowed_requests: int
blocked_requests: int
@@ -36,118 +37,123 @@ class SimulationResult:
success_rate: float
average_remaining: float
peak_usage: int
- requests_by_second: Dict[int, int]
+ requests_by_second: dict[int, int]
class RateLimitSimulator:
"""
Simulate rate limiting scenarios without real traffic
"""
-
+
def __init__(self):
"""Initialize simulator"""
self.rate_limiter = RateLimiter()
-
+
def generate_requests(
- self,
- num_requests: int,
- duration_seconds: int,
- pattern: str = "uniform"
- ) -> List[SimulationRequest]:
+ self, num_requests: int, duration_seconds: int, pattern: str = 'uniform'
+ ) -> list[SimulationRequest]:
"""
Generate simulated requests
-
+
Args:
num_requests: Number of requests to generate
duration_seconds: Duration over which to spread requests
pattern: Distribution pattern (uniform, burst, spike, gradual)
-
+
Returns:
List of simulated requests
"""
requests = []
start_time = datetime.now()
-
- if pattern == "uniform":
+
+ if pattern == 'uniform':
# Evenly distributed
interval = duration_seconds / num_requests
for i in range(num_requests):
timestamp = start_time + timedelta(seconds=i * interval)
- requests.append(SimulationRequest(
- timestamp=timestamp,
- user_id="sim_user",
- endpoint="/api/test",
- ip="192.168.1.1"
- ))
-
- elif pattern == "burst":
+ requests.append(
+ SimulationRequest(
+ timestamp=timestamp,
+ user_id='sim_user',
+ endpoint='/api/test',
+ ip='192.168.1.1',
+ )
+ )
+
+ elif pattern == 'burst':
# All requests in first 10% of duration
burst_duration = duration_seconds * 0.1
interval = burst_duration / num_requests
for i in range(num_requests):
timestamp = start_time + timedelta(seconds=i * interval)
- requests.append(SimulationRequest(
- timestamp=timestamp,
- user_id="sim_user",
- endpoint="/api/test",
- ip="192.168.1.1"
- ))
-
- elif pattern == "spike":
+ requests.append(
+ SimulationRequest(
+ timestamp=timestamp,
+ user_id='sim_user',
+ endpoint='/api/test',
+ ip='192.168.1.1',
+ )
+ )
+
+ elif pattern == 'spike':
# Spike in the middle
spike_start = duration_seconds * 0.4
spike_duration = duration_seconds * 0.2
interval = spike_duration / num_requests
for i in range(num_requests):
timestamp = start_time + timedelta(seconds=spike_start + i * interval)
- requests.append(SimulationRequest(
- timestamp=timestamp,
- user_id="sim_user",
- endpoint="/api/test",
- ip="192.168.1.1"
- ))
-
- elif pattern == "gradual":
+ requests.append(
+ SimulationRequest(
+ timestamp=timestamp,
+ user_id='sim_user',
+ endpoint='/api/test',
+ ip='192.168.1.1',
+ )
+ )
+
+ elif pattern == 'gradual':
# Gradually increasing rate
for i in range(num_requests):
# Quadratic distribution (more requests toward end)
progress = (i / num_requests) ** 2
timestamp = start_time + timedelta(seconds=progress * duration_seconds)
- requests.append(SimulationRequest(
- timestamp=timestamp,
- user_id="sim_user",
- endpoint="/api/test",
- ip="192.168.1.1"
- ))
-
- elif pattern == "random":
+ requests.append(
+ SimulationRequest(
+ timestamp=timestamp,
+ user_id='sim_user',
+ endpoint='/api/test',
+ ip='192.168.1.1',
+ )
+ )
+
+ elif pattern == 'random':
# Random distribution
for i in range(num_requests):
random_offset = random.uniform(0, duration_seconds)
timestamp = start_time + timedelta(seconds=random_offset)
- requests.append(SimulationRequest(
- timestamp=timestamp,
- user_id="sim_user",
- endpoint="/api/test",
- ip="192.168.1.1"
- ))
+ requests.append(
+ SimulationRequest(
+ timestamp=timestamp,
+ user_id='sim_user',
+ endpoint='/api/test',
+ ip='192.168.1.1',
+ )
+ )
# Sort by timestamp
requests.sort(key=lambda r: r.timestamp)
-
+
return requests
-
+
def simulate_rule(
- self,
- rule: RateLimitRule,
- requests: List[SimulationRequest]
+ self, rule: RateLimitRule, requests: list[SimulationRequest]
) -> SimulationResult:
"""
Simulate rate limit rule against requests
-
+
Args:
rule: Rate limit rule to test
requests: List of simulated requests
-
+
Returns:
Simulation result with statistics
"""
@@ -156,25 +162,26 @@ class RateLimitSimulator:
burst_used = 0
remaining_values = []
requests_by_second = {}
-
+
# Track usage by second
for request in requests:
second = int(request.timestamp.timestamp())
requests_by_second[second] = requests_by_second.get(second, 0) + 1
-
+
# Simulate each request
for request in requests:
# In real scenario, would check actual Redis counters
# For simulation, we'll use simplified logic
-
+
# Calculate current window usage
- window_start = request.timestamp - timedelta(seconds=self._get_window_seconds(rule.time_window))
+ window_start = request.timestamp - timedelta(
+ seconds=self._get_window_seconds(rule.time_window)
+ )
window_requests = [
- r for r in requests
- if window_start <= r.timestamp <= request.timestamp
+ r for r in requests if window_start <= r.timestamp <= request.timestamp
]
current_usage = len(window_requests)
-
+
# Check if within limit
if current_usage <= rule.limit:
allowed += 1
@@ -189,13 +196,13 @@ class RateLimitSimulator:
else:
blocked += 1
remaining_values.append(0)
-
+
# Calculate statistics
total = len(requests)
success_rate = (allowed / total * 100) if total > 0 else 0
avg_remaining = sum(remaining_values) / len(remaining_values) if remaining_values else 0
peak_usage = max(requests_by_second.values()) if requests_by_second else 0
-
+
return SimulationResult(
total_requests=total,
allowed_requests=allowed,
@@ -204,9 +211,9 @@ class RateLimitSimulator:
success_rate=success_rate,
average_remaining=avg_remaining,
peak_usage=peak_usage,
- requests_by_second=requests_by_second
+ requests_by_second=requests_by_second,
)
-
+
def _get_window_seconds(self, window: TimeWindow) -> int:
"""Get window duration in seconds"""
window_map = {
@@ -214,120 +221,108 @@ class RateLimitSimulator:
TimeWindow.MINUTE: 60,
TimeWindow.HOUR: 3600,
TimeWindow.DAY: 86400,
- TimeWindow.MONTH: 2592000 # 30 days
+ TimeWindow.MONTH: 2592000, # 30 days
}
return window_map.get(window, 60)
-
+
def compare_rules(
- self,
- rules: List[RateLimitRule],
- requests: List[SimulationRequest]
- ) -> Dict[str, SimulationResult]:
+ self, rules: list[RateLimitRule], requests: list[SimulationRequest]
+ ) -> dict[str, SimulationResult]:
"""
Compare multiple rules against same request pattern
-
+
Args:
rules: List of rules to compare
requests: Simulated requests
-
+
Returns:
Dictionary mapping rule_id to simulation result
"""
results = {}
-
+
for rule in rules:
result = self.simulate_rule(rule, requests)
results[rule.rule_id] = result
-
+
return results
-
+
def preview_rule_change(
self,
current_rule: RateLimitRule,
new_rule: RateLimitRule,
- historical_pattern: str = "uniform",
- duration_minutes: int = 60
- ) -> Dict[str, SimulationResult]:
+ historical_pattern: str = 'uniform',
+ duration_minutes: int = 60,
+ ) -> dict[str, SimulationResult]:
"""
Preview impact of changing a rule
-
+
Args:
current_rule: Current rule configuration
new_rule: Proposed new rule configuration
historical_pattern: Traffic pattern to simulate
duration_minutes: Duration to simulate
-
+
Returns:
Comparison of current vs new rule performance
"""
# Estimate request volume based on current limit
estimated_requests = int(current_rule.limit * 1.5) # 150% of limit
-
+
# Generate requests
requests = self.generate_requests(
num_requests=estimated_requests,
duration_seconds=duration_minutes * 60,
- pattern=historical_pattern
+ pattern=historical_pattern,
)
-
+
# Compare rules
return self.compare_rules([current_rule, new_rule], requests)
-
+
def test_burst_effectiveness(
- self,
- base_limit: int,
- burst_allowances: List[int],
- spike_intensity: float = 2.0
- ) -> Dict[int, SimulationResult]:
+ self, base_limit: int, burst_allowances: list[int], spike_intensity: float = 2.0
+ ) -> dict[int, SimulationResult]:
"""
Test effectiveness of different burst allowances
-
+
Args:
base_limit: Base rate limit
burst_allowances: List of burst allowances to test
spike_intensity: Multiplier for spike (2.0 = 2x normal rate)
-
+
Returns:
Results for each burst allowance
"""
# Generate spike pattern
- normal_requests = base_limit
spike_requests = int(base_limit * spike_intensity)
-
+
requests = self.generate_requests(
- num_requests=spike_requests,
- duration_seconds=60,
- pattern="spike"
+ num_requests=spike_requests, duration_seconds=60, pattern='spike'
)
-
+
results = {}
-
+
for burst in burst_allowances:
rule = RateLimitRule(
- rule_id=f"burst_{burst}",
+ rule_id=f'burst_{burst}',
rule_type=RuleType.PER_USER,
time_window=TimeWindow.MINUTE,
limit=base_limit,
- burst_allowance=burst
+ burst_allowance=burst,
)
-
+
result = self.simulate_rule(rule, requests)
results[burst] = result
-
+
return results
-
- def generate_report(
- self,
- rule: RateLimitRule,
- result: SimulationResult
- ) -> str:
+
+ def generate_report(self, rule: RateLimitRule, result: SimulationResult) -> str:
"""
Generate human-readable simulation report
-
+
Args:
rule: Rule that was simulated
result: Simulation result
-
+
Returns:
Formatted report string
"""
@@ -347,7 +342,7 @@ Simulation Results:
Allowed: {result.allowed_requests} ({result.success_rate:.1f}%)
Blocked: {result.blocked_requests}
Burst Used: {result.burst_used_count}
-
+
Performance Metrics:
Success Rate: {result.success_rate:.1f}%
Average Remaining: {result.average_remaining:.1f}
@@ -355,58 +350,56 @@ Performance Metrics:
Recommendation:
"""
-
+
# Add recommendations
if result.success_rate < 90:
- report += " ⚠️ Consider increasing limit or burst allowance\n"
+ report += ' ⚠️ Consider increasing limit or burst allowance\n'
elif result.success_rate > 99 and result.average_remaining > rule.limit * 0.5:
- report += " ℹ️ Limit may be too high, consider reducing\n"
+ report += ' ℹ️ Limit may be too high, consider reducing\n'
else:
- report += " ✅ Rule configuration appears appropriate\n"
-
+ report += ' ✅ Rule configuration appears appropriate\n'
+
if result.burst_used_count > 0:
- burst_percentage = (result.burst_used_count / result.allowed_requests * 100)
- report += f" ℹ️ {burst_percentage:.1f}% of requests used burst tokens\n"
-
+ burst_percentage = result.burst_used_count / result.allowed_requests * 100
+ report += f' ℹ️ {burst_percentage:.1f}% of requests used burst tokens\n'
+
return report
-
+
def run_scenario(
self,
scenario_name: str,
rule: RateLimitRule,
- pattern: str = "uniform",
- duration_minutes: int = 5
- ) -> Dict:
+ pattern: str = 'uniform',
+ duration_minutes: int = 5,
+ ) -> dict:
"""
Run a named simulation scenario
-
+
Args:
scenario_name: Name of the scenario
rule: Rule to test
pattern: Traffic pattern
duration_minutes: Duration to simulate
-
+
Returns:
Scenario results with report
"""
# Generate requests based on rule limit
num_requests = rule.limit * 2 # 2x the limit
-
+
requests = self.generate_requests(
- num_requests=num_requests,
- duration_seconds=duration_minutes * 60,
- pattern=pattern
+ num_requests=num_requests, duration_seconds=duration_minutes * 60, pattern=pattern
)
-
+
result = self.simulate_rule(rule, requests)
report = self.generate_report(rule, result)
-
+
return {
'scenario_name': scenario_name,
'rule': rule,
'pattern': pattern,
'result': result,
- 'report': report
+ 'report': report,
}
@@ -414,40 +407,36 @@ Recommendation:
# HELPER FUNCTIONS
# ============================================================================
+
def quick_simulate(
- limit: int,
- time_window: str = "minute",
- burst: int = 0,
- pattern: str = "uniform"
+ limit: int, time_window: str = 'minute', burst: int = 0, pattern: str = 'uniform'
) -> str:
"""
Quick simulation helper
-
+
Args:
limit: Rate limit
time_window: Time window (second, minute, hour, day)
burst: Burst allowance
pattern: Traffic pattern
-
+
Returns:
Simulation report
"""
simulator = RateLimitSimulator()
-
+
rule = RateLimitRule(
- rule_id="quick_sim",
+ rule_id='quick_sim',
rule_type=RuleType.PER_USER,
time_window=TimeWindow(time_window),
limit=limit,
- burst_allowance=burst
+ burst_allowance=burst,
)
-
+
requests = simulator.generate_requests(
- num_requests=limit * 2,
- duration_seconds=60,
- pattern=pattern
+ num_requests=limit * 2, duration_seconds=60, pattern=pattern
)
-
+
result = simulator.simulate_rule(rule, requests)
return simulator.generate_report(rule, result)
@@ -456,45 +445,43 @@ def quick_simulate(
# EXAMPLE USAGE
# ============================================================================
-if __name__ == "__main__":
+if __name__ == '__main__':
# Example: Test different burst allowances
simulator = RateLimitSimulator()
-
- print("Testing burst effectiveness...")
+
+ print('Testing burst effectiveness...')
results = simulator.test_burst_effectiveness(
- base_limit=100,
- burst_allowances=[0, 20, 50, 100],
- spike_intensity=2.0
+ base_limit=100, burst_allowances=[0, 20, 50, 100], spike_intensity=2.0
)
-
+
for burst, result in results.items():
- print(f"\nBurst Allowance: {burst}")
- print(f" Success Rate: {result.success_rate:.1f}%")
- print(f" Burst Used: {result.burst_used_count}")
-
+ print(f'\nBurst Allowance: {burst}')
+ print(f' Success Rate: {result.success_rate:.1f}%')
+ print(f' Burst Used: {result.burst_used_count}')
+
# Example: Preview rule change
- print("\n" + "="*50)
- print("Previewing rule change...")
-
+ print('\n' + '=' * 50)
+ print('Previewing rule change...')
+
current = RateLimitRule(
- rule_id="current",
+ rule_id='current',
rule_type=RuleType.PER_USER,
time_window=TimeWindow.MINUTE,
limit=100,
- burst_allowance=20
+ burst_allowance=20,
)
-
+
proposed = RateLimitRule(
- rule_id="proposed",
+ rule_id='proposed',
rule_type=RuleType.PER_USER,
time_window=TimeWindow.MINUTE,
limit=150,
- burst_allowance=30
+ burst_allowance=30,
)
-
+
comparison = simulator.preview_rule_change(current, proposed)
-
+
for rule_id, result in comparison.items():
- print(f"\n{rule_id.upper()}:")
- print(f" Success Rate: {result.success_rate:.1f}%")
- print(f" Blocked: {result.blocked_requests}")
+ print(f'\n{rule_id.upper()}:')
+ print(f' Success Rate: {result.success_rate:.1f}%')
+ print(f' Blocked: {result.blocked_requests}')
diff --git a/backend-services/utils/rate_limiter.py b/backend-services/utils/rate_limiter.py
index 287ba8b..cdb4892 100644
--- a/backend-services/utils/rate_limiter.py
+++ b/backend-services/utils/rate_limiter.py
@@ -5,19 +5,18 @@ Implements token bucket and sliding window algorithms for rate limiting.
Supports distributed rate limiting across multiple server instances using Redis.
"""
-import time
import logging
-from typing import Optional, Tuple
+import time
from dataclasses import dataclass
+
from models.rate_limit_models import (
- RateLimitRule,
RateLimitCounter,
RateLimitInfo,
- TimeWindow,
+ RateLimitRule,
+ generate_redis_key,
get_time_window_seconds,
- generate_redis_key
)
-from utils.redis_client import get_redis_client, RedisClient
+from utils.redis_client import RedisClient, get_redis_client
logger = logging.getLogger(__name__)
@@ -25,13 +24,14 @@ logger = logging.getLogger(__name__)
@dataclass
class RateLimitResult:
"""Result of rate limit check"""
+
allowed: bool
limit: int
remaining: int
reset_at: int
- retry_after: Optional[int] = None
+ retry_after: int | None = None
burst_remaining: int = 0
-
+
def to_info(self) -> RateLimitInfo:
"""Convert to RateLimitInfo"""
return RateLimitInfo(
@@ -39,43 +39,39 @@ class RateLimitResult:
remaining=self.remaining,
reset_at=self.reset_at,
retry_after=self.retry_after,
- burst_remaining=self.burst_remaining
+ burst_remaining=self.burst_remaining,
)
class RateLimiter:
"""
Rate limiter with token bucket and sliding window algorithms
-
+
Features:
- Token bucket for burst handling
- Sliding window for accurate rate limiting
- Distributed locking for multi-instance support
- Graceful degradation if Redis is unavailable
"""
-
- def __init__(self, redis_client: Optional[RedisClient] = None):
+
+ def __init__(self, redis_client: RedisClient | None = None):
"""
Initialize rate limiter
-
+
Args:
redis_client: Redis client instance (creates default if None)
"""
self.redis = redis_client or get_redis_client()
self._fallback_mode = False
-
- def check_rate_limit(
- self,
- rule: RateLimitRule,
- identifier: str
- ) -> RateLimitResult:
+
+ def check_rate_limit(self, rule: RateLimitRule, identifier: str) -> RateLimitResult:
"""
Check if request is allowed under rate limit rule
-
+
Args:
rule: Rate limit rule to apply
identifier: Unique identifier (user ID, API name, IP, etc.)
-
+
Returns:
RateLimitResult with allow/deny decision
"""
@@ -85,46 +81,44 @@ class RateLimiter:
allowed=True,
limit=rule.limit,
remaining=rule.limit,
- reset_at=int(time.time()) + get_time_window_seconds(rule.time_window)
+ reset_at=int(time.time()) + get_time_window_seconds(rule.time_window),
)
-
+
# Use sliding window algorithm
return self._check_sliding_window(rule, identifier)
-
- def _check_sliding_window(
- self,
- rule: RateLimitRule,
- identifier: str
- ) -> RateLimitResult:
+
+ def _check_sliding_window(self, rule: RateLimitRule, identifier: str) -> RateLimitResult:
"""
Check rate limit using sliding window counter algorithm
-
+
This is more accurate than fixed window and prevents boundary issues.
-
+
Algorithm:
1. Get current and previous window counts
2. Calculate weighted count based on time elapsed in current window
3. Check if weighted count exceeds limit
4. If allowed, increment current window counter
-
+
Args:
rule: Rate limit rule
identifier: Unique identifier
-
+
Returns:
RateLimitResult
"""
now = time.time()
window_size = get_time_window_seconds(rule.time_window)
-
+
# Current window timestamp and key
current_window = int(now / window_size) * window_size
- current_key = generate_redis_key(rule.rule_type, identifier, rule.time_window, current_window)
-
+ current_key = generate_redis_key(
+ rule.rule_type, identifier, rule.time_window, current_window
+ )
+
try:
# Use only current window counter for deterministic behavior in unit tests
current_count = int(self.redis.get(current_key) or 0)
-
+
# Limit exceeded?
if current_count >= rule.limit:
reset_at = current_window + window_size
@@ -134,21 +128,21 @@ class RateLimiter:
limit=rule.limit,
remaining=0,
reset_at=int(reset_at),
- retry_after=retry_after
+ retry_after=retry_after,
)
-
+
# Burst allowance tracking (does not affect allow when under limit)
burst_remaining = rule.burst_allowance
if rule.burst_allowance > 0:
- burst_key = f"{current_key}:burst"
+ burst_key = f'{current_key}:burst'
burst_count = int(self.redis.get(burst_key) or 0)
burst_remaining = max(0, rule.burst_allowance - burst_count)
-
+
# Increment counter and set TTL, but report remaining based on pre-increment value
new_count = self.redis.incr(current_key)
if new_count == 1:
self.redis.expire(current_key, window_size * 2)
-
+
remaining = max(0, rule.limit - current_count)
reset_at = current_window + window_size
return RateLimitResult(
@@ -156,53 +150,49 @@ class RateLimiter:
limit=rule.limit,
remaining=remaining,
reset_at=int(reset_at),
- burst_remaining=burst_remaining
+ burst_remaining=burst_remaining,
)
-
+
except Exception as e:
- logger.error(f"Rate limit check error: {e}")
+ logger.error(f'Rate limit check error: {e}')
# Graceful degradation: allow request on error
return RateLimitResult(
allowed=True,
limit=rule.limit,
remaining=rule.limit,
- reset_at=int(now) + window_size
+ reset_at=int(now) + window_size,
)
-
- def check_token_bucket(
- self,
- rule: RateLimitRule,
- identifier: str
- ) -> RateLimitResult:
+
+ def check_token_bucket(self, rule: RateLimitRule, identifier: str) -> RateLimitResult:
"""
Check rate limit using token bucket algorithm
-
+
Token bucket allows bursts while maintaining average rate.
-
+
Algorithm:
1. Calculate tokens to add based on time elapsed
2. Add tokens to bucket (up to limit)
3. Check if enough tokens available
4. If yes, consume token and allow request
-
+
Args:
rule: Rate limit rule
identifier: Unique identifier
-
+
Returns:
RateLimitResult
"""
now = time.time()
window_size = get_time_window_seconds(rule.time_window)
refill_rate = rule.limit / window_size # Tokens per second
-
+
# Generate Redis key for bucket
- bucket_key = f"bucket:{rule.rule_type.value}:{identifier}:{rule.time_window.value}"
-
+ bucket_key = f'bucket:{rule.rule_type.value}:{identifier}:{rule.time_window.value}'
+
try:
# Get current bucket state
bucket_data = self.redis.hmget(bucket_key, ['tokens', 'last_refill'])
-
+
if bucket_data[0] is None:
# Initialize bucket
tokens = float(rule.limit)
@@ -210,119 +200,106 @@ class RateLimiter:
else:
tokens = float(bucket_data[0])
last_refill = float(bucket_data[1])
-
+
# Calculate tokens to add
elapsed = now - last_refill
tokens_to_add = elapsed * refill_rate
tokens = min(rule.limit, tokens + tokens_to_add)
-
+
# Check if request is allowed
if tokens >= 1.0:
# Consume token
tokens -= 1.0
-
+
# Update bucket state
- self.redis.hmset(bucket_key, {
- 'tokens': tokens,
- 'last_refill': now
- })
+ self.redis.hmset(bucket_key, {'tokens': tokens, 'last_refill': now})
self.redis.expire(bucket_key, window_size * 2)
-
+
# Calculate reset time (when bucket will be full)
time_to_full = (rule.limit - tokens) / refill_rate
reset_at = int(now + time_to_full)
-
+
return RateLimitResult(
- allowed=True,
- limit=rule.limit,
- remaining=int(tokens),
- reset_at=reset_at
+ allowed=True, limit=rule.limit, remaining=int(tokens), reset_at=reset_at
)
else:
# Not enough tokens
time_to_token = (1.0 - tokens) / refill_rate
retry_after = int(time_to_token) + 1
reset_at = int(now + time_to_token)
-
+
return RateLimitResult(
allowed=False,
limit=rule.limit,
remaining=0,
reset_at=reset_at,
- retry_after=retry_after
+ retry_after=retry_after,
)
-
+
except Exception as e:
- logger.error(f"Token bucket check error: {e}")
+ logger.error(f'Token bucket check error: {e}')
# Graceful degradation
return RateLimitResult(
allowed=True,
limit=rule.limit,
remaining=rule.limit,
- reset_at=int(now) + window_size
+ reset_at=int(now) + window_size,
)
-
- def check_hybrid(
- self,
- rule: RateLimitRule,
- identifier: str
- ) -> RateLimitResult:
+
+ def check_hybrid(self, rule: RateLimitRule, identifier: str) -> RateLimitResult:
"""
Check rate limit using hybrid approach (sliding window + token bucket)
-
+
This combines accuracy of sliding window with burst handling of token bucket.
-
+
Algorithm:
1. Check sliding window (accurate rate limit)
2. If allowed, check token bucket (burst handling)
3. Both must pass for request to be allowed
-
+
Args:
rule: Rate limit rule
identifier: Unique identifier
-
+
Returns:
RateLimitResult
"""
# First check sliding window
sliding_result = self._check_sliding_window(rule, identifier)
-
+
if not sliding_result.allowed:
return sliding_result
-
+
# If sliding window allows, check token bucket for burst
if rule.burst_allowance > 0:
bucket_result = self.check_token_bucket(rule, identifier)
-
+
if not bucket_result.allowed:
# Use burst tokens if available
return self._use_burst_tokens(rule, identifier, sliding_result)
-
+
return sliding_result
-
+
def _use_burst_tokens(
- self,
- rule: RateLimitRule,
- identifier: str,
- sliding_result: RateLimitResult
+ self, rule: RateLimitRule, identifier: str, sliding_result: RateLimitResult
) -> RateLimitResult:
"""
Try to use burst tokens when normal tokens are exhausted
-
+
Args:
rule: Rate limit rule
identifier: Unique identifier
sliding_result: Result from sliding window check
-
+
Returns:
RateLimitResult
"""
now = time.time()
window_size = get_time_window_seconds(rule.time_window)
current_window = int(now / window_size) * window_size
-
- burst_key = f"burst:{rule.rule_type.value}:{identifier}:{current_window}"
-
+
+ burst_key = f'burst:{rule.rule_type.value}:{identifier}:{current_window}'
+
try:
# Get current burst usage (tolerate mocks that provide multiple side-effect values)
burst_count = int(self.redis.get(burst_key) or 0)
@@ -332,22 +309,22 @@ class RateLimiter:
burst_count = int(second)
except Exception:
pass
-
+
if burst_count < rule.burst_allowance:
# Burst tokens available
new_burst_count = self.redis.incr(burst_key)
-
+
if new_burst_count == 1:
self.redis.expire(burst_key, window_size * 2)
-
+
burst_remaining = rule.burst_allowance - new_burst_count
-
+
return RateLimitResult(
allowed=True,
limit=rule.limit,
remaining=sliding_result.remaining,
reset_at=sliding_result.reset_at,
- burst_remaining=burst_remaining
+ burst_remaining=burst_remaining,
)
else:
# No burst tokens available
@@ -357,22 +334,22 @@ class RateLimiter:
remaining=0,
reset_at=sliding_result.reset_at,
retry_after=sliding_result.retry_after,
- burst_remaining=0
+ burst_remaining=0,
)
-
+
except Exception as e:
- logger.error(f"Burst token check error: {e}")
+ logger.error(f'Burst token check error: {e}')
# On error, allow with sliding window result
return sliding_result
-
+
def reset_limit(self, rule: RateLimitRule, identifier: str) -> bool:
"""
Reset rate limit for identifier (admin function)
-
+
Args:
rule: Rate limit rule
identifier: Unique identifier
-
+
Returns:
True if successful
"""
@@ -380,54 +357,47 @@ class RateLimiter:
now = time.time()
window_size = get_time_window_seconds(rule.time_window)
current_window = int(now / window_size) * window_size
-
+
# Delete all related keys
keys_to_delete = [
generate_redis_key(rule.rule_type, identifier, rule.time_window, current_window),
- generate_redis_key(rule.rule_type, identifier, rule.time_window, current_window - window_size),
- f"bucket:{rule.rule_type.value}:{identifier}:{rule.time_window.value}",
- f"burst:{rule.rule_type.value}:{identifier}:{current_window}"
+ generate_redis_key(
+ rule.rule_type, identifier, rule.time_window, current_window - window_size
+ ),
+ f'bucket:{rule.rule_type.value}:{identifier}:{rule.time_window.value}',
+ f'burst:{rule.rule_type.value}:{identifier}:{current_window}',
]
-
+
self.redis.delete(*keys_to_delete)
- logger.info(f"Reset rate limit for {identifier}")
+ logger.info(f'Reset rate limit for {identifier}')
return True
-
+
except Exception as e:
- logger.error(f"Error resetting rate limit: {e}")
+ logger.error(f'Error resetting rate limit: {e}')
return False
-
- def get_current_usage(
- self,
- rule: RateLimitRule,
- identifier: str
- ) -> RateLimitCounter:
+
+ def get_current_usage(self, rule: RateLimitRule, identifier: str) -> RateLimitCounter:
"""
Get current usage for identifier
-
+
Args:
rule: Rate limit rule
identifier: Unique identifier
-
+
Returns:
RateLimitCounter with current state
"""
now = time.time()
window_size = get_time_window_seconds(rule.time_window)
current_window = int(now / window_size) * window_size
-
- key = generate_redis_key(
- rule.rule_type,
- identifier,
- rule.time_window,
- current_window
- )
-
+
+ key = generate_redis_key(rule.rule_type, identifier, rule.time_window, current_window)
+
try:
count = int(self.redis.get(key) or 0)
- burst_key = f"{key}:burst"
+ burst_key = f'{key}:burst'
burst_count = int(self.redis.get(burst_key) or 0)
-
+
return RateLimitCounter(
key=key,
window_start=current_window,
@@ -435,29 +405,29 @@ class RateLimiter:
count=count,
limit=rule.limit,
burst_count=burst_count,
- burst_limit=rule.burst_allowance
+ burst_limit=rule.burst_allowance,
)
-
+
except Exception as e:
- logger.error(f"Error getting current usage: {e}")
+ logger.error(f'Error getting current usage: {e}')
return RateLimitCounter(
key=key,
window_start=current_window,
window_size=window_size,
count=0,
- limit=rule.limit
+ limit=rule.limit,
)
# Global rate limiter instance
-_rate_limiter: Optional[RateLimiter] = None
+_rate_limiter: RateLimiter | None = None
def get_rate_limiter() -> RateLimiter:
"""Get or create global rate limiter instance"""
global _rate_limiter
-
+
if _rate_limiter is None:
_rate_limiter = RateLimiter()
-
+
return _rate_limiter
diff --git a/backend-services/utils/redis_client.py b/backend-services/utils/redis_client.py
index bb605d3..01d8fbd 100644
--- a/backend-services/utils/redis_client.py
+++ b/backend-services/utils/redis_client.py
@@ -5,11 +5,11 @@ Provides a robust Redis client with connection pooling, error handling,
and graceful degradation for rate limiting operations.
"""
-import redis
import logging
-from typing import Optional, Any, Dict
from contextlib import contextmanager
-import time
+from typing import Any
+
+import redis
logger = logging.getLogger(__name__)
@@ -17,28 +17,28 @@ logger = logging.getLogger(__name__)
class RedisClient:
"""
Redis client wrapper with connection pooling and error handling
-
+
Features:
- Connection pooling for performance
- Automatic reconnection on failure
- Graceful degradation (returns None on errors)
- Pipeline support for batch operations
"""
-
+
def __init__(
self,
host: str = 'localhost',
port: int = 6379,
- password: Optional[str] = None,
+ password: str | None = None,
db: int = 0,
max_connections: int = 50,
socket_timeout: int = 5,
socket_connect_timeout: int = 5,
- decode_responses: bool = True
+ decode_responses: bool = True,
):
"""
Initialize Redis client
-
+
Args:
host: Redis server host
port: Redis server port
@@ -53,7 +53,7 @@ class RedisClient:
self.port = port
self.password = password
self.db = db
-
+
# Create connection pool
self.pool = redis.ConnectionPool(
host=host,
@@ -63,56 +63,56 @@ class RedisClient:
max_connections=max_connections,
socket_timeout=socket_timeout,
socket_connect_timeout=socket_connect_timeout,
- decode_responses=decode_responses
+ decode_responses=decode_responses,
)
-
+
# Create Redis client
self.client = redis.Redis(connection_pool=self.pool)
-
+
# Test connection
self._test_connection()
-
+
def _test_connection(self) -> bool:
"""Test Redis connection"""
try:
self.client.ping()
- logger.info(f"Redis connection successful: {self.host}:{self.port}")
+ logger.info(f'Redis connection successful: {self.host}:{self.port}')
return True
except redis.ConnectionError as e:
- logger.error(f"Redis connection failed: {e}")
+ logger.error(f'Redis connection failed: {e}')
return False
except Exception as e:
- logger.error(f"Unexpected Redis error: {e}")
+ logger.error(f'Unexpected Redis error: {e}')
return False
-
- def get(self, key: str) -> Optional[str]:
+
+ def get(self, key: str) -> str | None:
"""
Get value by key
-
+
Args:
key: Redis key
-
+
Returns:
Value or None if not found or error
"""
try:
return self.client.get(key)
except Exception as e:
- logger.error(f"Redis GET error for key {key}: {e}")
+ logger.error(f'Redis GET error for key {key}: {e}')
return None
-
+
def set(
self,
key: str,
value: Any,
- ex: Optional[int] = None,
- px: Optional[int] = None,
+ ex: int | None = None,
+ px: int | None = None,
nx: bool = False,
- xx: bool = False
+ xx: bool = False,
) -> bool:
"""
Set key-value pair
-
+
Args:
key: Redis key
value: Value to store
@@ -120,7 +120,7 @@ class RedisClient:
px: Expiration time in milliseconds
nx: Only set if key doesn't exist
xx: Only set if key exists
-
+
Returns:
True if successful, False otherwise
"""
@@ -128,134 +128,134 @@ class RedisClient:
result = self.client.set(key, value, ex=ex, px=px, nx=nx, xx=xx)
return bool(result)
except Exception as e:
- logger.error(f"Redis SET error for key {key}: {e}")
+ logger.error(f'Redis SET error for key {key}: {e}')
return False
-
- def incr(self, key: str, amount: int = 1) -> Optional[int]:
+
+ def incr(self, key: str, amount: int = 1) -> int | None:
"""
Increment counter atomically
-
+
Args:
key: Redis key
amount: Amount to increment
-
+
Returns:
New value or None on error
"""
try:
return self.client.incr(key, amount)
except Exception as e:
- logger.error(f"Redis INCR error for key {key}: {e}")
+ logger.error(f'Redis INCR error for key {key}: {e}')
return None
-
- def decr(self, key: str, amount: int = 1) -> Optional[int]:
+
+ def decr(self, key: str, amount: int = 1) -> int | None:
"""
Decrement counter atomically
-
+
Args:
key: Redis key
amount: Amount to decrement
-
+
Returns:
New value or None on error
"""
try:
return self.client.decr(key, amount)
except Exception as e:
- logger.error(f"Redis DECR error for key {key}: {e}")
+ logger.error(f'Redis DECR error for key {key}: {e}')
return None
-
+
def expire(self, key: str, seconds: int) -> bool:
"""
Set expiration time for key
-
+
Args:
key: Redis key
seconds: Expiration time in seconds
-
+
Returns:
True if successful, False otherwise
"""
try:
return bool(self.client.expire(key, seconds))
except Exception as e:
- logger.error(f"Redis EXPIRE error for key {key}: {e}")
+ logger.error(f'Redis EXPIRE error for key {key}: {e}')
return False
-
- def ttl(self, key: str) -> Optional[int]:
+
+ def ttl(self, key: str) -> int | None:
"""
Get time to live for key
-
+
Args:
key: Redis key
-
+
Returns:
TTL in seconds, -1 if no expiry, -2 if key doesn't exist, None on error
"""
try:
return self.client.ttl(key)
except Exception as e:
- logger.error(f"Redis TTL error for key {key}: {e}")
+ logger.error(f'Redis TTL error for key {key}: {e}')
return None
-
+
def delete(self, *keys: str) -> int:
"""
Delete one or more keys
-
+
Args:
keys: Keys to delete
-
+
Returns:
Number of keys deleted
"""
try:
return self.client.delete(*keys)
except Exception as e:
- logger.error(f"Redis DELETE error: {e}")
+ logger.error(f'Redis DELETE error: {e}')
return 0
-
+
def exists(self, *keys: str) -> int:
"""
Check if keys exist
-
+
Args:
keys: Keys to check
-
+
Returns:
Number of existing keys
"""
try:
return self.client.exists(*keys)
except Exception as e:
- logger.error(f"Redis EXISTS error: {e}")
+ logger.error(f'Redis EXISTS error: {e}')
return 0
-
- def hget(self, name: str, key: str) -> Optional[str]:
+
+ def hget(self, name: str, key: str) -> str | None:
"""
Get hash field value
-
+
Args:
name: Hash name
key: Field key
-
+
Returns:
Field value or None
"""
try:
return self.client.hget(name, key)
except Exception as e:
- logger.error(f"Redis HGET error for {name}:{key}: {e}")
+ logger.error(f'Redis HGET error for {name}:{key}: {e}')
return None
-
+
def hset(self, name: str, key: str, value: Any) -> bool:
"""
Set hash field value
-
+
Args:
name: Hash name
key: Field key
value: Field value
-
+
Returns:
True if successful
"""
@@ -263,34 +263,34 @@ class RedisClient:
self.client.hset(name, key, value)
return True
except Exception as e:
- logger.error(f"Redis HSET error for {name}:{key}: {e}")
+ logger.error(f'Redis HSET error for {name}:{key}: {e}')
return False
-
- def hmget(self, name: str, keys: list) -> Optional[list]:
+
+ def hmget(self, name: str, keys: list) -> list | None:
"""
Get multiple hash field values
-
+
Args:
name: Hash name
keys: List of field keys
-
+
Returns:
List of values or None
"""
try:
return self.client.hmget(name, keys)
except Exception as e:
- logger.error(f"Redis HMGET error for {name}: {e}")
+ logger.error(f'Redis HMGET error for {name}: {e}')
return None
-
- def hmset(self, name: str, mapping: Dict[str, Any]) -> bool:
+
+ def hmset(self, name: str, mapping: dict[str, Any]) -> bool:
"""
Set multiple hash field values
-
+
Args:
name: Hash name
mapping: Dictionary of field:value pairs
-
+
Returns:
True if successful
"""
@@ -298,20 +298,20 @@ class RedisClient:
self.client.hset(name, mapping=mapping)
return True
except Exception as e:
- logger.error(f"Redis HMSET error for {name}: {e}")
+ logger.error(f'Redis HMSET error for {name}: {e}')
return False
-
+
@contextmanager
def pipeline(self, transaction: bool = True):
"""
Context manager for Redis pipeline
-
+
Args:
transaction: Use MULTI/EXEC transaction
-
+
Yields:
Pipeline object
-
+
Example:
with redis_client.pipeline() as pipe:
pipe.incr('key1')
@@ -322,112 +322,104 @@ class RedisClient:
try:
yield pipe
except Exception as e:
- logger.error(f"Redis pipeline error: {e}")
+ logger.error(f'Redis pipeline error: {e}')
raise
finally:
pipe.reset()
-
- def zadd(self, name: str, mapping: Dict[str, float]) -> int:
+
+ def zadd(self, name: str, mapping: dict[str, float]) -> int:
"""
Add members to sorted set
-
+
Args:
name: Sorted set name
mapping: Dictionary of member:score pairs
-
+
Returns:
Number of members added
"""
try:
return self.client.zadd(name, mapping)
except Exception as e:
- logger.error(f"Redis ZADD error for {name}: {e}")
+ logger.error(f'Redis ZADD error for {name}: {e}')
return 0
-
+
def zremrangebyscore(self, name: str, min_score: float, max_score: float) -> int:
"""
Remove members from sorted set by score range
-
+
Args:
name: Sorted set name
min_score: Minimum score
max_score: Maximum score
-
+
Returns:
Number of members removed
"""
try:
return self.client.zremrangebyscore(name, min_score, max_score)
except Exception as e:
- logger.error(f"Redis ZREMRANGEBYSCORE error for {name}: {e}")
+ logger.error(f'Redis ZREMRANGEBYSCORE error for {name}: {e}')
return 0
-
+
def zcount(self, name: str, min_score: float, max_score: float) -> int:
"""
Count members in sorted set by score range
-
+
Args:
name: Sorted set name
min_score: Minimum score
max_score: Maximum score
-
+
Returns:
Number of members in range
"""
try:
return self.client.zcount(name, min_score, max_score)
except Exception as e:
- logger.error(f"Redis ZCOUNT error for {name}: {e}")
+ logger.error(f'Redis ZCOUNT error for {name}: {e}')
return 0
-
+
def close(self):
"""Close Redis connection pool"""
try:
self.pool.disconnect()
- logger.info("Redis connection pool closed")
+ logger.info('Redis connection pool closed')
except Exception as e:
- logger.error(f"Error closing Redis pool: {e}")
+ logger.error(f'Error closing Redis pool: {e}')
# Global Redis client instance
-_redis_client: Optional[RedisClient] = None
+_redis_client: RedisClient | None = None
def get_redis_client(
- host: str = 'localhost',
- port: int = 6379,
- password: Optional[str] = None,
- db: int = 0
+ host: str = 'localhost', port: int = 6379, password: str | None = None, db: int = 0
) -> RedisClient:
"""
Get or create global Redis client instance
-
+
Args:
host: Redis server host
port: Redis server port
password: Redis password
db: Redis database number
-
+
Returns:
RedisClient instance
"""
global _redis_client
-
+
if _redis_client is None:
- _redis_client = RedisClient(
- host=host,
- port=port,
- password=password,
- db=db
- )
-
+ _redis_client = RedisClient(host=host, port=port, password=password, db=db)
+
return _redis_client
def close_redis_client():
"""Close global Redis client"""
global _redis_client
-
+
if _redis_client is not None:
_redis_client.close()
_redis_client = None
diff --git a/backend-services/utils/response_util.py b/backend-services/utils/response_util.py
index d4c8d89..7d90c77 100644
--- a/backend-services/utils/response_util.py
+++ b/backend-services/utils/response_util.py
@@ -1,11 +1,13 @@
-from fastapi.responses import JSONResponse, Response
-import os
import logging
+import os
+
+from fastapi.responses import JSONResponse, Response
from models.response_model import ResponseModel
logger = logging.getLogger('doorman.gateway')
+
def _normalize_headers(hdrs: dict | None) -> dict | None:
try:
if not hdrs:
@@ -20,15 +22,13 @@ def _normalize_headers(hdrs: dict | None) -> dict | None:
except Exception:
return hdrs
+
def _envelope(content: dict, status_code: int) -> dict:
- return {
- 'status_code': status_code,
- **content
- }
+ return {'status_code': status_code, **content}
+
def _add_token_compat(enveloped: dict, payload: dict):
try:
-
if isinstance(payload, dict):
for key in ('access_token', 'refresh_token'):
if key in payload:
@@ -36,6 +36,7 @@ def _add_token_compat(enveloped: dict, payload: dict):
except Exception:
pass
+
def respond_rest(model):
"""Return a REST JSONResponse using the normalized envelope logic.
@@ -47,6 +48,7 @@ def respond_rest(model):
rm = model
return process_rest_response(rm)
+
def process_rest_response(response):
try:
strict = os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true'
@@ -62,12 +64,20 @@ def process_rest_response(response):
_add_token_compat(content, response.response)
http_status = 200
elif response.message:
- content = {'message': response.message} if not strict else _envelope({'message': response.message}, response.status_code)
+ content = (
+ {'message': response.message}
+ if not strict
+ else _envelope({'message': response.message}, response.status_code)
+ )
http_status = response.status_code if not strict else 200
else:
content = {} if not strict else _envelope({}, response.status_code)
http_status = response.status_code if not strict else 200
- resp = JSONResponse(content=content, status_code=http_status, headers=_normalize_headers(response.response_headers))
+ resp = JSONResponse(
+ content=content,
+ status_code=http_status,
+ headers=_normalize_headers(response.response_headers),
+ )
try:
# Ensure Content-Length is set for downstream metrics/bandwidth accounting
blen = len(getattr(resp, 'body', b'') or b'')
@@ -92,7 +102,11 @@ def process_rest_response(response):
content = err_payload if not strict else _envelope(err_payload, response.status_code)
http_status = response.status_code if not strict else 200
- resp = JSONResponse(content=content, status_code=http_status, headers=_normalize_headers(response.response_headers))
+ resp = JSONResponse(
+ content=content,
+ status_code=http_status,
+ headers=_normalize_headers(response.response_headers),
+ )
try:
blen = len(getattr(resp, 'body', b'') or b'')
if blen > 0:
@@ -103,11 +117,14 @@ def process_rest_response(response):
return resp
except Exception as e:
logger.error(f'An error occurred while processing the response: {e}')
- return JSONResponse(content={'error_message': 'Unable to process response'}, status_code=500)
+ return JSONResponse(
+ content={'error_message': 'Unable to process response'}, status_code=500
+ )
+
def process_soap_response(response):
try:
- strict = os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true'
+ os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true'
if response.status_code == 200:
if getattr(response, 'soap_envelope', None):
soap_response = response.soap_envelope
@@ -136,6 +153,7 @@ def process_soap_response(response):
error_response = 'Unable to process SOAP response'
return Response(content=error_response, status_code=500, media_type='application/xml')
+
def process_response(response, type):
response = ResponseModel(**response)
if type == 'rest':
@@ -146,9 +164,17 @@ def process_response(response, type):
try:
strict = os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true'
if response.status_code == 200:
- content = response.response if not strict else _envelope({'response': response.response}, response.status_code)
+ content = (
+ response.response
+ if not strict
+ else _envelope({'response': response.response}, response.status_code)
+ )
code = response.status_code if not strict else 200
- resp = JSONResponse(content=content, status_code=code, headers=_normalize_headers(response.response_headers))
+ resp = JSONResponse(
+ content=content,
+ status_code=code,
+ headers=_normalize_headers(response.response_headers),
+ )
try:
blen = len(getattr(resp, 'body', b'') or b'')
if blen > 0:
@@ -158,11 +184,18 @@ def process_response(response, type):
pass
return resp
else:
- content = {'error_code': response.error_code, 'error_message': response.error_message}
+ content = {
+ 'error_code': response.error_code,
+ 'error_message': response.error_message,
+ }
if strict:
content = _envelope(content, response.status_code)
code = response.status_code if not strict else 200
- resp = JSONResponse(content=content, status_code=code, headers=_normalize_headers(response.response_headers))
+ resp = JSONResponse(
+ content=content,
+ status_code=code,
+ headers=_normalize_headers(response.response_headers),
+ )
try:
blen = len(getattr(resp, 'body', b'') or b'')
if blen > 0:
@@ -173,23 +206,42 @@ def process_response(response, type):
return resp
except Exception as e:
logger.error(f'An error occurred while processing the GraphQL response: {e}')
- return JSONResponse(content={'error': 'Unable to process GraphQL response'}, status_code=500)
+ return JSONResponse(
+ content={'error': 'Unable to process GraphQL response'}, status_code=500
+ )
elif type == 'grpc':
try:
strict = os.getenv('STRICT_RESPONSE_ENVELOPE', 'false').lower() == 'true'
if response.status_code == 200:
- content = response.response if not strict else _envelope({'response': response.response}, response.status_code)
+ content = (
+ response.response
+ if not strict
+ else _envelope({'response': response.response}, response.status_code)
+ )
code = response.status_code if not strict else 200
- return JSONResponse(content=content, status_code=code, headers=_normalize_headers(response.response_headers))
+ return JSONResponse(
+ content=content,
+ status_code=code,
+ headers=_normalize_headers(response.response_headers),
+ )
else:
- content = {'error_code': response.error_code, 'error_message': response.error_message}
+ content = {
+ 'error_code': response.error_code,
+ 'error_message': response.error_message,
+ }
if strict:
content = _envelope(content, response.status_code)
code = response.status_code if not strict else 200
- return JSONResponse(content=content, status_code=code, headers=_normalize_headers(response.response_headers))
+ return JSONResponse(
+ content=content,
+ status_code=code,
+ headers=_normalize_headers(response.response_headers),
+ )
except Exception as e:
logger.error(f'An error occurred while processing the gRPC response: {e}')
- return JSONResponse(content={'error': 'Unable to process gRPC response'}, status_code=500)
+ return JSONResponse(
+ content={'error': 'Unable to process gRPC response'}, status_code=500
+ )
else:
logger.error(f'Unhandled response type: {type}')
return JSONResponse(content={'error': 'Unhandled response type'}, status_code=500)
diff --git a/backend-services/utils/role_util.py b/backend-services/utils/role_util.py
index 8389c10..0f3af70 100644
--- a/backend-services/utils/role_util.py
+++ b/backend-services/utils/role_util.py
@@ -5,6 +5,7 @@ See https://github.com/pypeople-dev/doorman for more information
"""
import logging
+
from fastapi import HTTPException
from utils.database import role_collection, user_collection
@@ -12,6 +13,7 @@ from utils.doorman_cache_util import doorman_cache
logger = logging.getLogger('doorman.gateway')
+
def _strip_id(r):
try:
if r and r.get('_id'):
@@ -20,6 +22,7 @@ def _strip_id(r):
pass
return r
+
async def is_admin_role(role_name: str) -> bool:
"""Return True if the given role is the admin role.
@@ -34,7 +37,6 @@ async def is_admin_role(role_name: str) -> bool:
if role:
doorman_cache.set_cache('role_cache', role_name, role)
if not role:
-
rn = (role_name or '').strip().lower()
return rn in ('admin', 'platform admin')
if role.get('platform_admin') is True:
@@ -44,6 +46,7 @@ async def is_admin_role(role_name: str) -> bool:
except Exception:
return False
+
async def is_admin_user(username: str) -> bool:
"""Return True if the user has the admin role."""
try:
@@ -59,12 +62,15 @@ async def is_admin_user(username: str) -> bool:
except Exception:
return False
+
async def is_platform_admin_role(role_name: str) -> bool:
return await is_admin_role(role_name)
+
async def is_platform_admin_user(username: str) -> bool:
return await is_admin_user(username)
+
async def validate_platform_role(role_name, action):
"""
Get the platform roles from the cache or database.
@@ -75,7 +81,8 @@ async def validate_platform_role(role_name, action):
role = role_collection.find_one({'role_name': role_name})
if not role:
raise HTTPException(status_code=404, detail='Role not found')
- if role.get('_id'): del role['_id']
+ if role.get('_id'):
+ del role['_id']
doorman_cache.set_cache('role_cache', role_name, role)
if action == 'manage_users' and role.get('manage_users'):
return True
@@ -114,6 +121,7 @@ async def validate_platform_role(role_name, action):
logger.error(f'validate_platform_role error: {e}')
raise HTTPException(status_code=500, detail='Internal Server Error')
+
async def platform_role_required_bool(username, action):
try:
user = doorman_cache.get_cache('user_cache', username)
@@ -121,15 +129,17 @@ async def platform_role_required_bool(username, action):
user = user_collection.find_one({'username': username})
if not user:
raise HTTPException(status_code=404, detail='User not found')
- if user.get('_id'): del user['_id']
- if user.get('password'): del user['password']
+ if user.get('_id'):
+ del user['_id']
+ if user.get('password'):
+ del user['password']
doorman_cache.set_cache('user_cache', username, user)
if not user:
raise HTTPException(status_code=404, detail='User not found')
if not await validate_platform_role(user.get('role'), action):
raise HTTPException(status_code=403, detail='You do not have the correct role for this')
return True
- except HTTPException as e:
+ except HTTPException:
return False
except Exception as e:
logger.error(f'Unexpected error: {e}')
diff --git a/backend-services/utils/routing_util.py b/backend-services/utils/routing_util.py
index 8d2e0d5..1a39d55 100644
--- a/backend-services/utils/routing_util.py
+++ b/backend-services/utils/routing_util.py
@@ -1,14 +1,14 @@
-from typing import Optional, Dict
import logging
-from utils.doorman_cache_util import doorman_cache
-from utils.database_async import routing_collection
-from utils.async_db import db_find_one
from utils import api_util
+from utils.async_db import db_find_one
+from utils.database_async import routing_collection
+from utils.doorman_cache_util import doorman_cache
logger = logging.getLogger('doorman.gateway')
-async def get_client_routing(client_key: str) -> Optional[Dict]:
+
+async def get_client_routing(client_key: str) -> dict | None:
"""Get the routing information for a specific client.
Args:
@@ -23,14 +23,16 @@ async def get_client_routing(client_key: str) -> Optional[Dict]:
client_routing = await db_find_one(routing_collection, {'client_key': client_key})
if not client_routing:
return None
- if client_routing.get('_id'): del client_routing['_id']
+ if client_routing.get('_id'):
+ del client_routing['_id']
doorman_cache.set_cache('client_routing_cache', client_key, client_routing)
return client_routing
except Exception as e:
logger.error(f'Error in get_client_routing: {e}')
return None
-async def get_routing_info(client_key: str) -> Optional[str]:
+
+async def get_routing_info(client_key: str) -> str | None:
"""Get next upstream server for client using round-robin.
Args:
@@ -50,7 +52,10 @@ async def get_routing_info(client_key: str) -> Optional[str]:
doorman_cache.set_cache('client_routing_cache', client_key, routing)
return server
-async def pick_upstream_server(api: Dict, method: str, endpoint_uri: str, client_key: Optional[str]) -> Optional[str]:
+
+async def pick_upstream_server(
+ api: dict, method: str, endpoint_uri: str, client_key: str | None
+) -> str | None:
"""Resolve upstream server with precedence: Routing (1) > Endpoint (2) > API (3).
- Routing: client-specific routing list with round-robin in the routing doc/cache.
@@ -70,10 +75,12 @@ async def pick_upstream_server(api: Dict, method: str, endpoint_uri: str, client
if endpoint:
ep_servers = endpoint.get('endpoint_servers') or []
if isinstance(ep_servers, list) and len(ep_servers) > 0:
- idx_key = endpoint.get('endpoint_id') or f"{api.get('api_id')}:{method}:{endpoint_uri}"
+ idx_key = endpoint.get('endpoint_id') or f'{api.get("api_id")}:{method}:{endpoint_uri}'
server_index = doorman_cache.get_cache('endpoint_server_cache', idx_key) or 0
server = ep_servers[server_index % len(ep_servers)]
- doorman_cache.set_cache('endpoint_server_cache', idx_key, (server_index + 1) % len(ep_servers))
+ doorman_cache.set_cache(
+ 'endpoint_server_cache', idx_key, (server_index + 1) % len(ep_servers)
+ )
return server
api_servers = api.get('api_servers') or []
@@ -81,7 +88,9 @@ async def pick_upstream_server(api: Dict, method: str, endpoint_uri: str, client
idx_key = api.get('api_id')
server_index = doorman_cache.get_cache('endpoint_server_cache', idx_key) or 0
server = api_servers[server_index % len(api_servers)]
- doorman_cache.set_cache('endpoint_server_cache', idx_key, (server_index + 1) % len(api_servers))
+ doorman_cache.set_cache(
+ 'endpoint_server_cache', idx_key, (server_index + 1) % len(api_servers)
+ )
return server
return None
diff --git a/backend-services/utils/sanitize_util.py b/backend-services/utils/sanitize_util.py
index ac4076c..0f7cfcd 100644
--- a/backend-services/utils/sanitize_util.py
+++ b/backend-services/utils/sanitize_util.py
@@ -4,9 +4,9 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-import re
import html
-from typing import Optional
+import re
+
def sanitize_html(text: str, allow_tags: bool = False) -> str:
"""
@@ -20,7 +20,8 @@ def sanitize_html(text: str, allow_tags: bool = False) -> str:
text = re.sub(r'<[^>]+>', '', text)
return html.escape(text)
-def sanitize_input(text: str, max_length: Optional[int] = None) -> str:
+
+def sanitize_input(text: str, max_length: int | None = None) -> str:
"""
Comprehensive input sanitization for user-provided text.
"""
@@ -33,26 +34,22 @@ def sanitize_input(text: str, max_length: Optional[int] = None) -> str:
sanitized = sanitized[:max_length]
return sanitized
+
def sanitize_url(url: str) -> str:
"""
Sanitize URL to prevent javascript: and data: URL attacks.
"""
if not url or not isinstance(url, str):
- return ""
+ return ''
url = url.strip()
- dangerous_schemes = [
- 'javascript:',
- 'data:',
- 'vbscript:',
- 'file:',
- 'about:'
- ]
+ dangerous_schemes = ['javascript:', 'data:', 'vbscript:', 'file:', 'about:']
url_lower = url.lower()
for scheme in dangerous_schemes:
if url_lower.startswith(scheme):
- return ""
+ return ''
return url
+
def strip_control_characters(text: str) -> str:
"""
Remove control characters that might cause issues.
@@ -61,6 +58,7 @@ def strip_control_characters(text: str) -> str:
return text
return ''.join(char for char in text if ord(char) >= 32 or char in '\n\r\t')
+
def sanitize_username(username: str) -> str:
"""
Sanitize username to only allow safe characters.
@@ -71,6 +69,7 @@ def sanitize_username(username: str) -> str:
safe_pattern = re.compile(r'[^a-zA-Z0-9_\-\.@]')
return safe_pattern.sub('', username)
+
def sanitize_api_name(name: str) -> str:
"""
Sanitize API name to prevent injection attacks.
diff --git a/backend-services/utils/security_settings_util.py b/backend-services/utils/security_settings_util.py
index 2d8026e..acd4a23 100644
--- a/backend-services/utils/security_settings_util.py
+++ b/backend-services/utils/security_settings_util.py
@@ -3,19 +3,19 @@ Utilities to manage security-related settings and schedule auto-save of memory d
"""
import asyncio
+import logging
import os
from pathlib import Path
-from typing import Optional, Dict, Any
-import logging
+from typing import Any
from .database import database, db
from .memory_dump_util import dump_memory_to_file
logger = logging.getLogger('doorman.gateway')
-_CACHE: Dict[str, Any] = {}
-_AUTO_TASK: Optional[asyncio.Task] = None
-_STOP_EVENT: Optional[asyncio.Event] = None
+_CACHE: dict[str, Any] = {}
+_AUTO_TASK: asyncio.Task | None = None
+_STOP_EVENT: asyncio.Event | None = None
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
_GEN_DIR = _PROJECT_ROOT / 'generated'
@@ -34,31 +34,35 @@ DEFAULTS = {
SETTINGS_FILE = os.getenv('SECURITY_SETTINGS_FILE', str(_GEN_DIR / 'security_settings.json'))
+
def _get_collection():
return db.settings if not database.memory_only else database.db.settings
-def _merge_settings(doc: Dict[str, Any]) -> Dict[str, Any]:
+
+def _merge_settings(doc: dict[str, Any]) -> dict[str, Any]:
merged = DEFAULTS.copy()
if doc:
merged.update({k: v for k, v in doc.items() if v is not None})
return merged
-def get_cached_settings() -> Dict[str, Any]:
+
+def get_cached_settings() -> dict[str, Any]:
global _CACHE
if not _CACHE:
-
_CACHE = DEFAULTS.copy()
return _CACHE
-def _load_from_file() -> Optional[Dict[str, Any]]:
+
+def _load_from_file() -> dict[str, Any] | None:
try:
if not os.path.exists(SETTINGS_FILE):
return None
- with open(SETTINGS_FILE, 'r', encoding='utf-8') as f:
+ with open(SETTINGS_FILE, encoding='utf-8') as f:
data = f.read().strip()
if not data:
return None
import json
+
obj = json.loads(data)
if isinstance(obj, dict):
@@ -67,39 +71,41 @@ def _load_from_file() -> Optional[Dict[str, Any]]:
logger.error('Failed to read settings file %s: %s', SETTINGS_FILE, e)
return None
-def _save_to_file(settings: Dict[str, Any]) -> None:
+
+def _save_to_file(settings: dict[str, Any]) -> None:
try:
os.makedirs(os.path.dirname(SETTINGS_FILE) or '.', exist_ok=True)
import json
+
with open(SETTINGS_FILE, 'w', encoding='utf-8') as f:
f.write(json.dumps(settings, separators=(',', ':')))
except Exception as e:
logger.error('Failed to write settings file %s: %s', SETTINGS_FILE, e)
-async def load_settings() -> Dict[str, Any]:
+
+async def load_settings() -> dict[str, Any]:
coll = _get_collection()
doc = coll.find_one({'type': 'security_settings'})
if not doc and database.memory_only:
file_obj = _load_from_file()
if file_obj:
-
try:
to_set = _merge_settings(file_obj)
coll.update_one({'type': 'security_settings'}, {'$set': to_set})
doc = to_set
except Exception:
-
doc = file_obj
settings = _merge_settings(doc or {})
_CACHE.update(settings)
return settings
-async def save_settings(partial: Dict[str, Any]) -> Dict[str, Any]:
+
+async def save_settings(partial: dict[str, Any]) -> dict[str, Any]:
coll = _get_collection()
current = _merge_settings(coll.find_one({'type': 'security_settings'}) or {})
current.update({k: v for k, v in partial.items() if v is not None})
- result = coll.update_one({'type': 'security_settings'}, {'$set': current},)
+ result = coll.update_one({'type': 'security_settings'}, {'$set': current})
try:
modified = getattr(result, 'modified_count', 0)
@@ -113,6 +119,7 @@ async def save_settings(partial: Dict[str, Any]) -> Dict[str, Any]:
await restart_auto_save_task()
return current
+
async def _auto_save_loop(stop_event: asyncio.Event):
while not stop_event.is_set():
try:
@@ -127,12 +134,13 @@ async def _auto_save_loop(stop_event: asyncio.Event):
logger.error('Auto-save memory dump failed: %s', e)
await asyncio.wait_for(stop_event.wait(), timeout=max(freq, 60) if freq > 0 else 60)
- except asyncio.TimeoutError:
+ except TimeoutError:
continue
except Exception as e:
logger.error('Auto-save loop error: %s', e)
await asyncio.sleep(60)
+
async def start_auto_save_task():
global _AUTO_TASK, _STOP_EVENT
if _AUTO_TASK and not _AUTO_TASK.done():
@@ -141,6 +149,7 @@ async def start_auto_save_task():
_AUTO_TASK = asyncio.create_task(_auto_save_loop(_STOP_EVENT))
logger.info('Security auto-save task started')
+
async def stop_auto_save_task():
global _AUTO_TASK, _STOP_EVENT
if _STOP_EVENT:
@@ -153,6 +162,7 @@ async def stop_auto_save_task():
_AUTO_TASK = None
_STOP_EVENT = None
+
async def restart_auto_save_task():
await stop_auto_save_task()
await start_auto_save_task()
diff --git a/backend-services/utils/subscription_util.py b/backend-services/utils/subscription_util.py
index fa4a174..2cdeaa4 100644
--- a/backend-services/utils/subscription_util.py
+++ b/backend-services/utils/subscription_util.py
@@ -4,17 +4,19 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/doorman for more information
"""
-from fastapi import HTTPException, Depends, Request
-from jose import jwt, JWTError
import logging
-from utils.doorman_cache_util import doorman_cache
+from fastapi import HTTPException, Request
+from jose import JWTError
+
+from utils.async_db import db_find_one
+from utils.auth_util import auth_required
from utils.database_async import subscriptions_collection
-from utils.async_db import db_find_one, db_update_one
-from utils.auth_util import SECRET_KEY, ALGORITHM, auth_required
+from utils.doorman_cache_util import doorman_cache
logger = logging.getLogger('doorman.gateway')
+
async def subscription_required(request: Request):
try:
payload = await auth_required(request)
@@ -24,11 +26,11 @@ async def subscription_required(request: Request):
full_path = request.url.path
if full_path.startswith('/api/rest/'):
prefix = '/api/rest/'
- path = full_path[len(prefix):]
+ path = full_path[len(prefix) :]
api_and_version = '/'.join(path.split('/')[:2])
elif full_path.startswith('/api/soap/'):
prefix = '/api/soap/'
- path = full_path[len(prefix):]
+ path = full_path[len(prefix) :]
api_and_version = '/'.join(path.split('/')[:2])
elif full_path.startswith('/api/graphql/'):
api_name = full_path.replace('/api/graphql/', '')
@@ -45,8 +47,14 @@ async def subscription_required(request: Request):
api_and_version = '/'.join(segs[2:4])
else:
api_and_version = '/'.join(segs[:2])
- user_subscriptions = doorman_cache.get_cache('user_subscription_cache', username) or await db_find_one(subscriptions_collection, {'username': username})
- subscriptions = user_subscriptions.get('apis') if user_subscriptions and 'apis' in user_subscriptions else None
+ user_subscriptions = doorman_cache.get_cache(
+ 'user_subscription_cache', username
+ ) or await db_find_one(subscriptions_collection, {'username': username})
+ subscriptions = (
+ user_subscriptions.get('apis')
+ if user_subscriptions and 'apis' in user_subscriptions
+ else None
+ )
if not subscriptions or api_and_version not in subscriptions:
logger.info(f'User {username} attempted access to {api_and_version}')
raise HTTPException(status_code=403, detail='You are not subscribed to this resource')
diff --git a/backend-services/utils/validation_util.py b/backend-services/utils/validation_util.py
index bd45da3..e8ef861 100644
--- a/backend-services/utils/validation_util.py
+++ b/backend-services/utils/validation_util.py
@@ -4,26 +4,31 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-from typing import Dict, Any, Optional, Callable
-from fastapi import HTTPException
-import json
import re
-from datetime import datetime
import uuid
+from collections.abc import Callable
+from datetime import datetime
+from typing import Any
+
+from fastapi import HTTPException
+
try:
from defusedxml import ElementTree as ET
+
_DEFUSED = True
except Exception:
import xml.etree.ElementTree as ET
+
_DEFUSED = False
-from graphql import parse, GraphQLError
import grpc
+from graphql import GraphQLError, parse
from models.field_validation_model import FieldValidation
from models.validation_schema_model import ValidationSchema
-from utils.doorman_cache_util import doorman_cache
-from utils.database_async import endpoint_validation_collection
from utils.async_db import db_find_one
+from utils.database_async import endpoint_validation_collection
+from utils.doorman_cache_util import doorman_cache
+
class ValidationError(Exception):
def __init__(self, message: str, field_path: str):
@@ -31,6 +36,7 @@ class ValidationError(Exception):
self.field_path = field_path
super().__init__(self.message)
+
class ValidationUtil:
def __init__(self):
self.type_validators = {
@@ -38,16 +44,16 @@ class ValidationUtil:
'number': self._validate_number,
'boolean': self._validate_boolean,
'array': self._validate_array,
- 'object': self._validate_object
+ 'object': self._validate_object,
}
self.format_validators = {
'email': self._validate_email,
'url': self._validate_url,
'date': self._validate_date,
'datetime': self._validate_datetime,
- 'uuid': self._validate_uuid
+ 'uuid': self._validate_uuid,
}
- self.custom_validators: Dict[str, Callable] = {}
+ self.custom_validators: dict[str, Callable] = {}
# When defusedxml is unavailable, apply a basic pre-parse guard against DOCTYPE/ENTITY.
def _reject_unsafe_xml(self, xml_text: str) -> None:
@@ -62,10 +68,12 @@ class ValidationUtil:
if ' None:
+ def register_custom_validator(
+ self, name: str, validator: Callable[[Any, FieldValidation], None]
+ ) -> None:
self.custom_validators[name] = validator
- async def get_validation_schema(self, endpoint_id: str) -> Optional[ValidationSchema]:
+ async def get_validation_schema(self, endpoint_id: str) -> ValidationSchema | None:
"""Return the ValidationSchema for an endpoint_id if configured.
Looks up the in-memory cache first, then falls back to the DB collection.
@@ -75,7 +83,9 @@ class ValidationUtil:
"""
validation_doc = doorman_cache.get_cache('endpoint_validation_cache', endpoint_id)
if not validation_doc:
- validation_doc = await db_find_one(endpoint_validation_collection, {'endpoint_id': endpoint_id})
+ validation_doc = await db_find_one(
+ endpoint_validation_collection, {'endpoint_id': endpoint_id}
+ )
if validation_doc:
try:
vdoc = dict(validation_doc)
@@ -91,14 +101,20 @@ class ValidationUtil:
raw = validation_doc.get('validation_schema')
if not raw:
return None
- mapping = raw.get('validation_schema') if isinstance(raw, dict) and 'validation_schema' in raw else raw
+ mapping = (
+ raw.get('validation_schema')
+ if isinstance(raw, dict) and 'validation_schema' in raw
+ else raw
+ )
if not isinstance(mapping, dict):
return None
schema = ValidationSchema(validation_schema=mapping)
self._validate_schema_paths(schema.validation_schema)
return schema
- def _validate_schema_paths(self, schema: Dict[str, FieldValidation], parent_path: str = '') -> None:
+ def _validate_schema_paths(
+ self, schema: dict[str, FieldValidation], parent_path: str = ''
+ ) -> None:
for field_path, validation in schema.items():
full_path = f'{parent_path}.{field_path}' if parent_path else field_path
if not self._is_valid_field_path(full_path):
@@ -160,7 +176,9 @@ class ValidationUtil:
if field_validation.required and field_path not in value:
raise ValidationError(f'Required field {field_path} is missing', path)
if field_path in value:
- self._validate_value(value[field_path], field_validation, f'{path}.{field_path}')
+ self._validate_value(
+ value[field_path], field_validation, f'{path}.{field_path}'
+ )
def _validate_value(self, value: Any, validation: FieldValidation, field_path: str) -> None:
if validation.required and value is None:
@@ -175,7 +193,6 @@ class ValidationUtil:
try:
self.custom_validators[validation.custom_validator](value, validation)
except ValidationError as e:
-
raise ValidationError(e.message, field_path)
def _validate_email(self, value: str, validation: FieldValidation, path: str) -> None:
@@ -184,29 +201,32 @@ class ValidationUtil:
raise ValidationError('Invalid email format', path)
def _validate_url(self, value: str, validation: FieldValidation, path: str) -> None:
- url_pattern = r'^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)$'
+ url_pattern = (
+ r'^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}'
+ r'\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)$'
+ )
if not re.match(url_pattern, value):
raise ValidationError('Invalid URL format', path)
def _validate_date(self, value: str, validation: FieldValidation, path: str) -> None:
try:
datetime.strptime(value, '%Y-%m-%d')
- except ValueError:
- raise ValidationError('Invalid date format (YYYY-MM-DD)', path)
+ except ValueError as e:
+ raise ValidationError('Invalid date format (YYYY-MM-DD)', path) from e
def _validate_datetime(self, value: str, validation: FieldValidation, path: str) -> None:
try:
datetime.fromisoformat(value.replace('Z', '+00:00'))
- except ValueError:
- raise ValidationError('Invalid datetime format (ISO 8601)', path)
+ except ValueError as e:
+ raise ValidationError('Invalid datetime format (ISO 8601)', path) from e
def _validate_uuid(self, value: str, validation: FieldValidation, path: str) -> None:
try:
uuid.UUID(value)
- except ValueError:
- raise ValidationError('Invalid UUID format', path)
+ except ValueError as e:
+ raise ValidationError('Invalid UUID format', path) from e
- async def validate_rest_request(self, endpoint_id: str, request_data: Dict[str, Any]) -> None:
+ async def validate_rest_request(self, endpoint_id: str, request_data: dict[str, Any]) -> None:
schema = await self.get_validation_schema(endpoint_id)
if not schema:
return
@@ -216,8 +236,11 @@ class ValidationUtil:
self._validate_value(value, validation, field_path)
except ValidationError as e:
import logging
- logging.getLogger('doorman.gateway').error(f'Validation failed for {field_path}: {e}')
- raise HTTPException(status_code=400, detail=str(e))
+
+ logging.getLogger('doorman.gateway').error(
+ f'Validation failed for {field_path}: {e}'
+ )
+ raise HTTPException(status_code=400, detail=str(e)) from e
async def validate_soap_request(self, endpoint_id: str, soap_envelope: str) -> None:
schema = await self.get_validation_schema(endpoint_id)
@@ -235,9 +258,10 @@ class ValidationUtil:
value = self._get_nested_value(request_data, field_path)
self._validate_value(value, validation, field_path)
except ValidationError as e:
- raise HTTPException(status_code=400, detail=str(e))
- except ET.ParseError:
- raise HTTPException(status_code=400, detail='Invalid SOAP envelope')
+ raise HTTPException(status_code=400, detail=str(e)) from e
+ except ET.ParseError as e:
+ raise HTTPException(status_code=400, detail='Invalid SOAP envelope') from e
+
async def validate_grpc_request(self, endpoint_id: str, request: Any) -> None:
schema = await self.get_validation_schema(endpoint_id)
if not schema:
@@ -248,9 +272,11 @@ class ValidationUtil:
value = self._get_nested_value(request_data, field_path)
self._validate_value(value, validation, field_path)
except ValidationError as e:
- raise grpc.RpcError(grpc.StatusCode.INVALID_ARGUMENT, str(e))
+ raise grpc.RpcError(grpc.StatusCode.INVALID_ARGUMENT, str(e)) from e
- async def validate_graphql_request(self, endpoint_id: str, query: str, variables: Dict[str, Any]) -> None:
+ async def validate_graphql_request(
+ self, endpoint_id: str, query: str, variables: dict[str, Any]
+ ) -> None:
schema = await self.get_validation_schema(endpoint_id)
if not schema:
return
@@ -261,19 +287,22 @@ class ValidationUtil:
for field_path, validation in schema.validation_schema.items():
if field_path.startswith(operation_name):
try:
- value = self._get_nested_value(variables, field_path[len(operation_name)+1:])
+ value = self._get_nested_value(
+ variables, field_path[len(operation_name) + 1 :]
+ )
self._validate_value(value, validation, field_path)
except ValidationError as e:
- raise HTTPException(status_code=400, detail=str(e))
+ raise HTTPException(status_code=400, detail=str(e)) from e
except GraphQLError as e:
- raise HTTPException(status_code=400, detail=str(e))
+ raise HTTPException(status_code=400, detail=str(e)) from e
except Exception as e:
- raise HTTPException(status_code=400, detail=str(e))
- def _extract_operation_name(self, query: str) -> Optional[str]:
+ raise HTTPException(status_code=400, detail=str(e)) from e
+
+ def _extract_operation_name(self, query: str) -> str | None:
match = re.search(r'(?:query|mutation)\s+(\w+)', query)
return match.group(1) if match else None
- def _get_nested_value(self, data: Dict[str, Any], field_path: str) -> Any:
+ def _get_nested_value(self, data: dict[str, Any], field_path: str) -> Any:
parts = field_path.split('.')
current = data
for part in parts:
@@ -298,7 +327,7 @@ class ValidationUtil:
return tag.split('}', 1)[1]
return tag
- def _xml_to_dict(self, element: Any) -> Dict[str, Any]:
+ def _xml_to_dict(self, element: Any) -> dict[str, Any]:
result = {}
for child in element:
key = self._strip_ns(child.tag)
@@ -308,7 +337,7 @@ class ValidationUtil:
result[key] = child.text
return result
- def _protobuf_to_dict(self, message: Any) -> Dict[str, Any]:
+ def _protobuf_to_dict(self, message: Any) -> dict[str, Any]:
result = {}
for field in message.DESCRIPTOR.fields:
value = getattr(message, field.name)
@@ -321,4 +350,5 @@ class ValidationUtil:
result[field.name] = value
return result
+
validation_util = ValidationUtil()
diff --git a/backend-services/utils/vault_encryption_util.py b/backend-services/utils/vault_encryption_util.py
index ac78642..e1ba370 100644
--- a/backend-services/utils/vault_encryption_util.py
+++ b/backend-services/utils/vault_encryption_util.py
@@ -4,14 +4,15 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/apidoorman/doorman for more information
"""
-import os
+import base64
import hashlib
+import logging
+import os
+
from cryptography.fernet import Fernet
+from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
-from cryptography.hazmat.backends import default_backend
-import base64
-import logging
logger = logging.getLogger('doorman.gateway')
@@ -19,31 +20,31 @@ logger = logging.getLogger('doorman.gateway')
def _derive_key_from_components(email: str, username: str, vault_key: str) -> bytes:
"""
Derive a Fernet-compatible encryption key from email, username, and vault key.
-
+
Args:
email: User's email address
username: User's username
vault_key: VAULT_KEY from environment
-
+
Returns:
bytes: 32-byte key suitable for Fernet encryption
"""
# Combine all components to create a unique salt
- combined = f"{email}:{username}:{vault_key}"
+ combined = f'{email}:{username}:{vault_key}'
salt = hashlib.sha256(combined.encode()).digest()
-
+
# Use PBKDF2 to derive a key
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
- backend=default_backend()
+ backend=default_backend(),
)
-
+
# Derive key from the vault_key
key = kdf.derive(vault_key.encode())
-
+
# Encode to base64 for Fernet compatibility
return base64.urlsafe_b64encode(key)
@@ -51,15 +52,15 @@ def _derive_key_from_components(email: str, username: str, vault_key: str) -> by
def encrypt_vault_value(value: str, email: str, username: str) -> str:
"""
Encrypt a vault value using email, username, and VAULT_KEY from environment.
-
+
Args:
value: The plaintext value to encrypt
email: User's email address
username: User's username
-
+
Returns:
str: Base64-encoded encrypted value
-
+
Raises:
RuntimeError: If VAULT_KEY is not configured
ValueError: If encryption fails
@@ -67,36 +68,36 @@ def encrypt_vault_value(value: str, email: str, username: str) -> str:
vault_key = os.getenv('VAULT_KEY')
if not vault_key:
raise RuntimeError('VAULT_KEY is not configured in environment variables')
-
+
try:
# Derive encryption key
encryption_key = _derive_key_from_components(email, username, vault_key)
-
+
# Create Fernet cipher
cipher = Fernet(encryption_key)
-
+
# Encrypt the value
encrypted_bytes = cipher.encrypt(value.encode('utf-8'))
-
+
# Return as base64 string
return encrypted_bytes.decode('utf-8')
except Exception as e:
logger.error(f'Encryption failed: {str(e)}')
- raise ValueError(f'Failed to encrypt vault value: {str(e)}')
+ raise ValueError(f'Failed to encrypt vault value: {str(e)}') from e
def decrypt_vault_value(encrypted_value: str, email: str, username: str) -> str:
"""
Decrypt a vault value using email, username, and VAULT_KEY from environment.
-
+
Args:
encrypted_value: The base64-encoded encrypted value
email: User's email address
username: User's username
-
+
Returns:
str: Decrypted plaintext value
-
+
Raises:
RuntimeError: If VAULT_KEY is not configured
ValueError: If decryption fails
@@ -104,28 +105,28 @@ def decrypt_vault_value(encrypted_value: str, email: str, username: str) -> str:
vault_key = os.getenv('VAULT_KEY')
if not vault_key:
raise RuntimeError('VAULT_KEY is not configured in environment variables')
-
+
try:
# Derive encryption key
encryption_key = _derive_key_from_components(email, username, vault_key)
-
+
# Create Fernet cipher
cipher = Fernet(encryption_key)
-
+
# Decrypt the value
decrypted_bytes = cipher.decrypt(encrypted_value.encode('utf-8'))
-
+
# Return as string
return decrypted_bytes.decode('utf-8')
except Exception as e:
logger.error(f'Decryption failed: {str(e)}')
- raise ValueError(f'Failed to decrypt vault value: {str(e)}')
+ raise ValueError(f'Failed to decrypt vault value: {str(e)}') from e
def is_vault_configured() -> bool:
"""
Check if VAULT_KEY is configured in environment variables.
-
+
Returns:
bool: True if VAULT_KEY is set, False otherwise
"""
diff --git a/pyproject.toml b/pyproject.toml
index 00db760..17ab294 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,6 +7,7 @@ select = [
"E", "F", "I", "UP", "B", "W"
]
ignore = [
+ "B904" # Exception chaining is a style preference, not a functional requirement
]
[tool.ruff.lint.isort]