mirror of
https://github.com/apidoorman/doorman.git
synced 2026-05-12 11:58:25 -05:00
code cleanup
This commit is contained in:
@@ -59,8 +59,8 @@ REDIS_DB=0
|
||||
|
||||
# Memory Dump Config (memory-only mode)
|
||||
# Base path/stem for encrypted in-memory database dumps (.bin). Timestamp is appended.
|
||||
# Example produces files like generated/memory_dump-YYYYMMDDTHHMMSSZ.bin
|
||||
MEM_DUMP_PATH=generated/memory_dump.bin
|
||||
# Example produces files like backend-services/generated/memory_dump-YYYYMMDDTHHMMSSZ.bin
|
||||
MEM_DUMP_PATH=backend-services/generated/memory_dump.bin
|
||||
|
||||
# Authorization Config
|
||||
JWT_SECRET_KEY=please-change-me # REQUIRED: app will fail to start without this
|
||||
|
||||
+22
-156
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from datetime import datetime, timedelta
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from fastapi import FastAPI, Request
|
||||
@@ -31,16 +30,13 @@ import uuid
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# Compatibility guard: ensure aiohttp is a Python 3.13–compatible version before
|
||||
# downstream modules import it (e.g., gateway_service). This avoids a cryptic
|
||||
# regex error inside older aiohttp builds on 3.13.
|
||||
try:
|
||||
if sys.version_info >= (3, 13):
|
||||
try:
|
||||
from importlib.metadata import version, PackageNotFoundError # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
version = None # type: ignore
|
||||
PackageNotFoundError = Exception # type: ignore
|
||||
from importlib.metadata import version, PackageNotFoundError
|
||||
except Exception:
|
||||
version = None
|
||||
PackageNotFoundError = Exception
|
||||
if version is not None:
|
||||
try:
|
||||
v = version('aiohttp')
|
||||
@@ -59,7 +55,6 @@ try:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from utils.cache_manager_util import cache_manager
|
||||
from utils.auth_blacklist import purge_expired_tokens
|
||||
@@ -95,7 +90,6 @@ from utils.ip_policy_util import _get_client_ip as _policy_get_client_ip, _ip_in
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Normalize generated/ location and migrate any legacy files
|
||||
try:
|
||||
_migrate_generated_directory()
|
||||
except Exception:
|
||||
@@ -117,7 +111,6 @@ def _migrate_generated_directory() -> None:
|
||||
if src == dst:
|
||||
return
|
||||
if not src.exists() or not src.is_dir():
|
||||
# Nothing to migrate; ensure dst exists
|
||||
dst.mkdir(exist_ok=True)
|
||||
gateway_logger.info(f"Generated dir: {dst} (no migration needed)")
|
||||
return
|
||||
@@ -138,7 +131,6 @@ def _migrate_generated_directory() -> None:
|
||||
except Exception:
|
||||
continue
|
||||
moved_count += 1
|
||||
# Attempt to remove the now-empty src tree
|
||||
try:
|
||||
shutil.rmtree(src)
|
||||
except Exception:
|
||||
@@ -154,18 +146,16 @@ async def validate_database_connections():
|
||||
"""Validate database connections on startup with retry logic"""
|
||||
gateway_logger.info("Validating database connections...")
|
||||
|
||||
# Test MongoDB
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
from utils.database import user_collection
|
||||
# Simple query to verify connection
|
||||
await user_collection.find_one({})
|
||||
gateway_logger.info("✓ MongoDB connection verified")
|
||||
break
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
|
||||
wait = 2 ** attempt
|
||||
gateway_logger.warning(
|
||||
f"MongoDB connection attempt {attempt + 1}/{max_retries} failed: {e}"
|
||||
)
|
||||
@@ -177,7 +167,6 @@ async def validate_database_connections():
|
||||
f"Cannot connect to MongoDB: {e}"
|
||||
) from e
|
||||
|
||||
# Test Redis (if configured)
|
||||
redis_host = os.getenv('REDIS_HOST')
|
||||
mem_or_external = os.getenv('MEM_OR_EXTERNAL', 'MEM')
|
||||
|
||||
@@ -249,10 +238,8 @@ async def app_lifespan(app: FastAPI):
|
||||
if not os.getenv('JWT_SECRET_KEY'):
|
||||
raise RuntimeError('JWT_SECRET_KEY is not configured. Set it before starting the server.')
|
||||
|
||||
# Production environment validation
|
||||
try:
|
||||
if os.getenv('ENV', '').lower() == 'production':
|
||||
# Validate HTTPS
|
||||
https_only = os.getenv('HTTPS_ONLY', 'false').lower() == 'true'
|
||||
https_enabled = os.getenv('HTTPS_ENABLED', 'false').lower() == 'true'
|
||||
if not (https_only or https_enabled):
|
||||
@@ -260,7 +247,6 @@ async def app_lifespan(app: FastAPI):
|
||||
'In production (ENV=production), you must enable HTTPS_ONLY or HTTPS_ENABLED to enforce Secure cookies.'
|
||||
)
|
||||
|
||||
# Validate SSL certificates exist before starting
|
||||
if https_only or https_enabled:
|
||||
cert = os.getenv('SSL_CERTFILE')
|
||||
key = os.getenv('SSL_KEYFILE')
|
||||
@@ -273,7 +259,6 @@ async def app_lifespan(app: FastAPI):
|
||||
if key and not os.path.exists(key):
|
||||
raise RuntimeError(f'SSL private key not found: {key}')
|
||||
|
||||
# Validate JWT secret is not default
|
||||
jwt_secret = os.getenv('JWT_SECRET_KEY', '')
|
||||
if jwt_secret in ('please-change-me', 'test-secret-key', 'test-secret-key-please-change', ''):
|
||||
raise RuntimeError(
|
||||
@@ -281,7 +266,6 @@ async def app_lifespan(app: FastAPI):
|
||||
'Generate a strong random secret (32+ characters).'
|
||||
)
|
||||
|
||||
# Validate Redis for HA deployments (shared token revocation and rate limiting)
|
||||
mem_or_external = os.getenv('MEM_OR_EXTERNAL', 'MEM').upper()
|
||||
if mem_or_external == 'MEM':
|
||||
num_threads = int(os.getenv('THREADS', 1))
|
||||
@@ -296,7 +280,6 @@ async def app_lifespan(app: FastAPI):
|
||||
'Single-node only. For multi-node HA, use REDIS or EXTERNAL mode.'
|
||||
)
|
||||
else:
|
||||
# Verify Redis is actually configured
|
||||
redis_host = os.getenv('REDIS_HOST')
|
||||
if not redis_host:
|
||||
raise RuntimeError(
|
||||
@@ -304,7 +287,6 @@ async def app_lifespan(app: FastAPI):
|
||||
'Redis is essential for shared token revocation and rate limiting in HA deployments.'
|
||||
)
|
||||
|
||||
# Validate CORS security
|
||||
if os.getenv('CORS_STRICT', 'false').lower() != 'true':
|
||||
raise RuntimeError(
|
||||
'In production (ENV=production), CORS_STRICT must be true. '
|
||||
@@ -318,7 +300,6 @@ async def app_lifespan(app: FastAPI):
|
||||
'Set ALLOWED_ORIGINS to specific domain(s): https://yourdomain.com'
|
||||
)
|
||||
|
||||
# Validate TOKEN_ENCRYPTION_KEY for API key encryption
|
||||
token_encryption_key = os.getenv('TOKEN_ENCRYPTION_KEY', '')
|
||||
if not token_encryption_key or len(token_encryption_key) < 32:
|
||||
gateway_logger.warning(
|
||||
@@ -326,7 +307,6 @@ async def app_lifespan(app: FastAPI):
|
||||
'API keys will not be encrypted at rest. Highly recommended for production security.'
|
||||
)
|
||||
|
||||
# Validate encryption keys if memory dumps are used
|
||||
if mem_or_external == 'MEM':
|
||||
mem_encryption_key = os.getenv('MEM_ENCRYPTION_KEY', '')
|
||||
if not mem_encryption_key or len(mem_encryption_key) < 32:
|
||||
@@ -336,41 +316,39 @@ async def app_lifespan(app: FastAPI):
|
||||
'Generate a strong random key: openssl rand -hex 32'
|
||||
)
|
||||
except Exception as e:
|
||||
# Re-raise all RuntimeErrors (validation failures should stop startup)
|
||||
raise
|
||||
|
||||
# Configure Redis connection with authentication
|
||||
mem_or_external = os.getenv('MEM_OR_EXTERNAL', 'MEM').upper()
|
||||
redis_host = os.getenv('REDIS_HOST')
|
||||
redis_port = os.getenv('REDIS_PORT')
|
||||
redis_db = os.getenv('REDIS_DB')
|
||||
redis_password = os.getenv('REDIS_PASSWORD', '')
|
||||
|
||||
# Warn if Redis is used without authentication in production/HA modes
|
||||
mem_or_external = os.getenv('MEM_OR_EXTERNAL', 'MEM').upper()
|
||||
if mem_or_external in ('REDIS', 'EXTERNAL') and not redis_password:
|
||||
gateway_logger.warning(
|
||||
'Redis password not set; connection may be unauthenticated. '
|
||||
'Set REDIS_PASSWORD environment variable to secure Redis access.'
|
||||
)
|
||||
|
||||
# Build Redis URL with authentication if password is provided
|
||||
if redis_password:
|
||||
redis_url = f'redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}'
|
||||
if mem_or_external in ('REDIS', 'EXTERNAL'):
|
||||
if not redis_password:
|
||||
gateway_logger.warning(
|
||||
'Redis password not set; connection may be unauthenticated. '
|
||||
'Set REDIS_PASSWORD environment variable to secure Redis access.'
|
||||
)
|
||||
host = redis_host or 'localhost'
|
||||
port = redis_port or '6379'
|
||||
db = redis_db or '0'
|
||||
if redis_password:
|
||||
redis_url = f'redis://:{redis_password}@{host}:{port}/{db}'
|
||||
else:
|
||||
redis_url = f'redis://{host}:{port}/{db}'
|
||||
app.state.redis = Redis.from_url(redis_url, decode_responses=True)
|
||||
else:
|
||||
redis_url = f'redis://{redis_host}:{redis_port}/{redis_db}'
|
||||
|
||||
app.state.redis = Redis.from_url(redis_url, decode_responses=True)
|
||||
app.state.redis = None
|
||||
|
||||
app.state._purger_task = asyncio.create_task(automatic_purger(1800))
|
||||
|
||||
# Restore persisted metrics (if available)
|
||||
METRICS_FILE = os.path.join(LOGS_DIR, 'metrics.json')
|
||||
try:
|
||||
metrics_store.load_from_file(METRICS_FILE)
|
||||
except Exception as e:
|
||||
gateway_logger.debug(f'Metrics restore skipped: {e}')
|
||||
|
||||
# Start periodic metrics saver
|
||||
async def _metrics_autosave(interval_s: int = 60):
|
||||
while True:
|
||||
try:
|
||||
@@ -396,14 +374,12 @@ async def app_lifespan(app: FastAPI):
|
||||
if bool(settings.get('trust_x_forwarded_for')) and not (settings.get('xff_trusted_proxies') or []):
|
||||
gateway_logger.warning('Security: trust_x_forwarded_for enabled but xff_trusted_proxies is empty; header spoofing risk. Configure trusted proxy IPs/CIDRs.')
|
||||
|
||||
# Production validation: enforce trusted proxy configuration
|
||||
if os.getenv('ENV', '').lower() == 'production':
|
||||
raise RuntimeError(
|
||||
'Production deployment with trust_x_forwarded_for requires xff_trusted_proxies '
|
||||
'to prevent IP spoofing. Configure trusted proxy IPs/CIDRs via /platform/security endpoint.'
|
||||
)
|
||||
except Exception as e:
|
||||
# Re-raise RuntimeErrors (production validation failures should stop startup)
|
||||
if isinstance(e, RuntimeError):
|
||||
raise
|
||||
gateway_logger.debug(f'Startup security checks skipped: {e}')
|
||||
@@ -466,7 +442,6 @@ async def app_lifespan(app: FastAPI):
|
||||
|
||||
pass
|
||||
|
||||
# SIGHUP handler for configuration hot reload
|
||||
try:
|
||||
if hasattr(signal, 'SIGHUP'):
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -475,10 +450,8 @@ async def app_lifespan(app: FastAPI):
|
||||
try:
|
||||
gateway_logger.info('SIGHUP received: reloading configuration...')
|
||||
|
||||
# Reload hot config
|
||||
hot_config.reload()
|
||||
|
||||
# Update log level if changed
|
||||
log_level = hot_config.get('LOG_LEVEL', 'INFO')
|
||||
try:
|
||||
numeric_level = getattr(logging, log_level.upper(), logging.INFO)
|
||||
@@ -494,7 +467,6 @@ async def app_lifespan(app: FastAPI):
|
||||
loop.add_signal_handler(signal.SIGHUP, lambda: asyncio.create_task(_sighup_reload()))
|
||||
gateway_logger.info('SIGHUP handler registered for configuration hot reload')
|
||||
except (NotImplementedError, AttributeError):
|
||||
# Windows doesn't support SIGHUP
|
||||
gateway_logger.debug('SIGHUP not supported on this platform')
|
||||
|
||||
try:
|
||||
@@ -537,7 +509,6 @@ async def app_lifespan(app: FastAPI):
|
||||
except Exception as e:
|
||||
gateway_logger.error(f"Error closing HTTP client: {e}")
|
||||
|
||||
# Persist metrics on shutdown
|
||||
try:
|
||||
METRICS_FILE = os.path.join(LOGS_DIR, 'metrics.json')
|
||||
metrics_store.save_to_file(METRICS_FILE)
|
||||
@@ -545,7 +516,6 @@ async def app_lifespan(app: FastAPI):
|
||||
pass
|
||||
|
||||
gateway_logger.info("Graceful shutdown complete")
|
||||
# Stop autosave task
|
||||
try:
|
||||
t = getattr(app.state, '_metrics_save_task', None)
|
||||
if t:
|
||||
@@ -553,7 +523,6 @@ async def app_lifespan(app: FastAPI):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Close shared HTTP client pool if enabled
|
||||
try:
|
||||
from services.gateway_service import GatewayService as _GS
|
||||
if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'true').lower() != 'false':
|
||||
@@ -586,8 +555,6 @@ doorman = FastAPI(
|
||||
https_only = os.getenv('HTTPS_ONLY', 'false').lower() == 'true'
|
||||
domain = os.getenv('COOKIE_DOMAIN', 'localhost')
|
||||
|
||||
# Replace global CORSMiddleware with path-aware CORS handling:
|
||||
# - Platform routes (/platform/*): preserve env-based behavior for now
|
||||
# - API gateway routes (/api/*): CORS controlled per-API in gateway routes/services
|
||||
|
||||
def _env_cors_config():
|
||||
@@ -627,8 +594,6 @@ def _env_cors_config():
|
||||
|
||||
@doorman.middleware('http')
|
||||
async def platform_cors(request: Request, call_next):
|
||||
# When ASGI-level CORS is disabled (e.g., Python 3.13 CI toggle), handle
|
||||
# platform CORS at the request middleware layer for compatibility.
|
||||
try:
|
||||
if os.getenv('DISABLE_PLATFORM_CORS_ASGI', 'false').lower() in ('1','true','yes','on'):
|
||||
path = str(request.url.path)
|
||||
@@ -641,7 +606,6 @@ async def platform_cors(request: Request, call_next):
|
||||
or ('*' in (cfg.get('origins') or []) and not strict)
|
||||
)
|
||||
|
||||
# Handle preflight
|
||||
if request.method.upper() == 'OPTIONS':
|
||||
from fastapi.responses import Response as _Resp
|
||||
headers = {}
|
||||
@@ -651,13 +615,11 @@ async def platform_cors(request: Request, call_next):
|
||||
headers['Access-Control-Allow-Methods'] = ', '.join(cfg['methods'])
|
||||
headers['Access-Control-Allow-Headers'] = ', '.join(cfg['headers'])
|
||||
headers['Access-Control-Allow-Credentials'] = 'true' if cfg['credentials'] else 'false'
|
||||
# Preserve inbound request id if present
|
||||
rid = request.headers.get('x-request-id') or request.headers.get('X-Request-ID')
|
||||
if rid:
|
||||
headers['request_id'] = rid
|
||||
return _Resp(status_code=204, headers=headers)
|
||||
|
||||
# Normal request path: call downstream then inject headers
|
||||
response = await call_next(request)
|
||||
try:
|
||||
response.headers.setdefault('Access-Control-Allow-Credentials', 'true' if cfg['credentials'] else 'false')
|
||||
@@ -668,12 +630,9 @@ async def platform_cors(request: Request, call_next):
|
||||
pass
|
||||
return response
|
||||
except Exception:
|
||||
# Fall back to app
|
||||
pass
|
||||
# Default path: let ASGI-level middleware handle CORS
|
||||
return await call_next(request)
|
||||
|
||||
# Body size limit middleware (protects against both Content-Length and Transfer-Encoding: chunked)
|
||||
MAX_BODY_SIZE = int(os.getenv('MAX_BODY_SIZE_BYTES', 1_048_576))
|
||||
|
||||
def _get_max_body_size() -> int:
|
||||
@@ -699,7 +658,6 @@ class LimitedStreamReader:
|
||||
self.over_limit = False
|
||||
|
||||
async def __call__(self):
|
||||
# If already over the limit, immediately end the request body for the app
|
||||
if self.over_limit:
|
||||
return {'type': 'http.request', 'body': b'', 'more_body': False}
|
||||
|
||||
@@ -710,7 +668,6 @@ class LimitedStreamReader:
|
||||
self.bytes_received += len(body)
|
||||
|
||||
if self.bytes_received > self.max_size:
|
||||
# Mark as over-limit and end the request body stream for the app
|
||||
self.over_limit = True
|
||||
return {'type': 'http.request', 'body': b'', 'more_body': False}
|
||||
|
||||
@@ -735,16 +692,10 @@ async def body_size_limit(request: Request, call_next):
|
||||
- /api/grpc/*: Enforce on gRPC JSON payloads
|
||||
"""
|
||||
try:
|
||||
# Allow hard-disable for environments where ASGI/transport interactions
|
||||
# are problematic (e.g., CI, certain Python/Starlette combos)
|
||||
if os.getenv('DISABLE_BODY_SIZE_LIMIT', 'false').lower() in ('1','true','yes','on'):
|
||||
return await call_next(request)
|
||||
path = str(request.url.path)
|
||||
|
||||
# Note: We no longer bypass general /platform/* routes here.
|
||||
# Enforcement applies to platform routes too (tests expect protection).
|
||||
# Allow excluding known-safe platform paths from size enforcement to
|
||||
# avoid transport/middleware edge-cases on certain runtimes.
|
||||
try:
|
||||
raw_excludes = os.getenv('BODY_LIMIT_EXCLUDE_PATHS', '')
|
||||
if raw_excludes:
|
||||
@@ -754,28 +705,18 @@ async def body_size_limit(request: Request, call_next):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Hard-coded bypass for platform monitor endpoints to avoid transport
|
||||
# edge-cases on some Starlette/AnyIO/Python combos (esp. 3.13) where
|
||||
# upstream short-circuits (e.g., IP filter) can surface as transport
|
||||
# errors. Size enforcement is not relevant for these probes.
|
||||
if path.startswith('/platform/monitor/'):
|
||||
return await call_next(request)
|
||||
|
||||
# Skip enforcement for security settings to avoid Python 3.13 + Starlette
|
||||
# middleware interaction bug where endpoint returns successfully but
|
||||
# call_next raises anyio.EndOfStream. Auth + RBAC already protect this.
|
||||
if path == '/platform/security/settings':
|
||||
try:
|
||||
return await call_next(request)
|
||||
except Exception as e:
|
||||
# On Python 3.13, anyio.EndOfStream can be raised even when endpoint
|
||||
# completes successfully. Catch and return generic 500 to unblock tests.
|
||||
msg = str(e)
|
||||
if 'EndOfStream' in msg or 'No response returned' in msg:
|
||||
try:
|
||||
from models.response_model import ResponseModel as _RM
|
||||
from utils.response_util import process_response as _pr
|
||||
# Return a different error code so tests can identify this edge case
|
||||
return _pr(_RM(
|
||||
status_code=200,
|
||||
message='Settings updated (middleware bypass)'
|
||||
@@ -784,15 +725,12 @@ async def body_size_limit(request: Request, call_next):
|
||||
pass
|
||||
raise
|
||||
|
||||
# Determine if this path should be protected
|
||||
should_enforce = False
|
||||
default_limit = _get_max_body_size()
|
||||
limit = default_limit
|
||||
|
||||
# Strictly enforce on auth route (prevent auth DoS)
|
||||
if path.startswith('/platform/authorization'):
|
||||
should_enforce = True
|
||||
# Enforce on all /api/* routes with per-type overrides
|
||||
elif path.startswith('/api/soap/'):
|
||||
should_enforce = True
|
||||
limit = int(os.getenv('MAX_BODY_SIZE_BYTES_SOAP', default_limit))
|
||||
@@ -806,17 +744,13 @@ async def body_size_limit(request: Request, call_next):
|
||||
should_enforce = True
|
||||
limit = int(os.getenv('MAX_BODY_SIZE_BYTES_REST', default_limit))
|
||||
elif path.startswith('/api/'):
|
||||
# Catch-all for other /api/* routes
|
||||
should_enforce = True
|
||||
elif path.startswith('/platform/'):
|
||||
# Protect all platform routes (tests expect platform routes are protected)
|
||||
should_enforce = True
|
||||
|
||||
# Skip if this path is not protected
|
||||
if not should_enforce:
|
||||
return await call_next(request)
|
||||
|
||||
# Check Content-Length header first (fast path for non-chunked requests)
|
||||
cl = request.headers.get('content-length')
|
||||
transfer_encoding = request.headers.get('transfer-encoding', '').lower()
|
||||
|
||||
@@ -824,7 +758,6 @@ async def body_size_limit(request: Request, call_next):
|
||||
try:
|
||||
content_length = int(cl)
|
||||
if content_length > limit:
|
||||
# Log for security monitoring
|
||||
try:
|
||||
from utils.audit_util import audit
|
||||
audit(
|
||||
@@ -849,43 +782,22 @@ async def body_size_limit(request: Request, call_next):
|
||||
error_message=f'Request entity too large (max: {limit} bytes)'
|
||||
).dict(), 'rest')
|
||||
except (ValueError, TypeError):
|
||||
# Invalid Content-Length header - treat as potentially malicious
|
||||
pass
|
||||
|
||||
# Handle Transfer-Encoding: chunked or missing Content-Length
|
||||
# Wrap the receive channel with size-limited reader
|
||||
if 'chunked' in transfer_encoding or not cl:
|
||||
# Optional hardening: If both chunked and Content-Length are present,
|
||||
# block immediately only when STRICT_CHUNKED_CL=true (off by default).
|
||||
# Always block when both chunked transfer and Content-Length appear
|
||||
# for mutating methods to prevent CL spoofing and ensure chunked
|
||||
# precedence. This avoids handler-level parsing of large bodies.
|
||||
# For chunked requests, ignore any Content-Length and rely on
|
||||
# streaming enforcement when wrapping is allowed. When wrapping is
|
||||
# disabled (e.g., via env or platform path), enforcement falls back
|
||||
# to Content-Length checks only.
|
||||
# Check if method typically has a body
|
||||
if request.method in ('POST', 'PUT', 'PATCH'):
|
||||
# On some Starlette/AnyIO versions (notably with Python 3.13),
|
||||
# swapping the low-level receive callable can cause middleware
|
||||
# stacks to raise "No response returned". To stay compatible,
|
||||
# only wrap streaming receive for API routes; for platform
|
||||
# routes rely on Content-Length enforcement above.
|
||||
wrap_allowed = True
|
||||
try:
|
||||
env_flag = os.getenv('DISABLE_PLATFORM_CHUNKED_WRAP')
|
||||
if isinstance(env_flag, str) and env_flag.strip() != '':
|
||||
if env_flag.strip().lower() in ('1','true','yes','on'):
|
||||
wrap_allowed = False
|
||||
# Always allow streaming enforcement on login endpoint to
|
||||
# guarantee size limits even under platform wrapping toggles
|
||||
if str(path) == '/platform/authorization':
|
||||
wrap_allowed = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if wrap_allowed:
|
||||
# Replace request receive with limited reader (safe on API routes)
|
||||
original_receive = request.receive
|
||||
limited_reader = LimitedStreamReader(original_receive, limit)
|
||||
request._receive = limited_reader
|
||||
@@ -893,10 +805,8 @@ async def body_size_limit(request: Request, call_next):
|
||||
try:
|
||||
response = await call_next(request)
|
||||
|
||||
# Check if limit was exceeded during streaming (only if we wrapped)
|
||||
try:
|
||||
if wrap_allowed and (limited_reader.over_limit or limited_reader.bytes_received > limit):
|
||||
# Log for security monitoring
|
||||
try:
|
||||
from utils.audit_util import audit
|
||||
audit(
|
||||
@@ -925,7 +835,6 @@ async def body_size_limit(request: Request, call_next):
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
# If stream reading failed due to size limit (only if wrapped), return 413
|
||||
try:
|
||||
if wrap_allowed and (limited_reader.over_limit or limited_reader.bytes_received > limit):
|
||||
return process_response(ResponseModel(
|
||||
@@ -939,10 +848,6 @@ async def body_size_limit(request: Request, call_next):
|
||||
|
||||
return await call_next(request)
|
||||
except Exception as e:
|
||||
# Be defensive: certain Starlette/AnyIO edge-cases (esp. on Python 3.13)
|
||||
# can raise EndOfStream/"No response returned" from deeper middleware
|
||||
# stacks. Propagating leaves the client hanging in tests. Instead,
|
||||
# return a well-formed 500 so the pipeline completes deterministically.
|
||||
try:
|
||||
from models.response_model import ResponseModel as _RM
|
||||
from utils.response_util import process_response as _pr
|
||||
@@ -953,16 +858,13 @@ async def body_size_limit(request: Request, call_next):
|
||||
msg = str(e)
|
||||
gateway_logger.error(f'Body size limit middleware error: {msg}', exc_info=True)
|
||||
|
||||
# Only swallow known transport errors; otherwise, re-raise
|
||||
swallow = False
|
||||
try:
|
||||
# RuntimeError("No response returned.") from Starlette
|
||||
if isinstance(e, RuntimeError) and 'No response returned' in msg:
|
||||
swallow = True
|
||||
else:
|
||||
# anyio.EndOfStream
|
||||
try:
|
||||
import anyio # type: ignore
|
||||
import anyio
|
||||
if isinstance(e, getattr(anyio, 'EndOfStream', tuple())):
|
||||
swallow = True
|
||||
except Exception:
|
||||
@@ -980,10 +882,8 @@ async def body_size_limit(request: Request, call_next):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: re-raise unknown exceptions
|
||||
raise
|
||||
|
||||
|
||||
class PlatformCORSMiddleware:
|
||||
"""ASGI-level CORS for /platform/* routes to avoid BaseHTTPMiddleware pitfalls.
|
||||
|
||||
@@ -994,7 +894,6 @@ class PlatformCORSMiddleware:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
# Allow bypass via env if needed for CI stability
|
||||
try:
|
||||
if os.getenv('DISABLE_PLATFORM_CORS_ASGI', 'false').lower() in ('1','true','yes','on'):
|
||||
return await self.app(scope, receive, send)
|
||||
@@ -1008,7 +907,6 @@ class PlatformCORSMiddleware:
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
cfg = _env_cors_config()
|
||||
# Decode headers into a dict (lower-cased)
|
||||
hdrs = {}
|
||||
try:
|
||||
for k, v in (scope.get('headers') or []):
|
||||
@@ -1026,7 +924,6 @@ class PlatformCORSMiddleware:
|
||||
headers.append((b'access-control-allow-methods', ', '.join(cfg['methods']).encode('latin1')))
|
||||
headers.append((b'access-control-allow-headers', ', '.join(cfg['headers']).encode('latin1')))
|
||||
headers.append((b'access-control-allow-credentials', b'true' if cfg['credentials'] else b'false'))
|
||||
# Preserve incoming request id if present
|
||||
rid = hdrs.get('x-request-id')
|
||||
if rid:
|
||||
headers.append((b'request_id', rid.encode('latin1')))
|
||||
@@ -1049,13 +946,10 @@ class PlatformCORSMiddleware:
|
||||
|
||||
return await self.app(scope, receive, send_wrapper)
|
||||
except Exception:
|
||||
# In case of unexpected error, fall back to the underlying app
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Register the ASGI middleware as outermost to reduce interaction issues
|
||||
doorman.add_middleware(PlatformCORSMiddleware)
|
||||
|
||||
# Request ID middleware: accept incoming X-Request-ID or generate one.
|
||||
@doorman.middleware('http')
|
||||
async def request_id_middleware(request: Request, call_next):
|
||||
try:
|
||||
@@ -1072,7 +966,6 @@ async def request_id_middleware(request: Request, call_next):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Set correlation ID for async task tracking
|
||||
try:
|
||||
from utils.correlation_util import set_correlation_id
|
||||
set_correlation_id(rid)
|
||||
@@ -1088,7 +981,6 @@ async def request_id_middleware(request: Request, call_next):
|
||||
pass
|
||||
response = await call_next(request)
|
||||
try:
|
||||
# Always preserve/propagate the inbound Request ID
|
||||
response.headers['X-Request-ID'] = rid
|
||||
response.headers['request_id'] = rid
|
||||
except Exception as e:
|
||||
@@ -1098,7 +990,6 @@ async def request_id_middleware(request: Request, call_next):
|
||||
gateway_logger.error(f'Request ID middleware error: {str(e)}', exc_info=True)
|
||||
raise
|
||||
|
||||
# Security headers (including HSTS when HTTPS is used)
|
||||
@doorman.middleware('http')
|
||||
async def security_headers(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
@@ -1136,10 +1027,8 @@ to console so production environments (e.g., ECS/EKS/Lambda) still capture logs.
|
||||
Respects LOG_FORMAT=json|plain.
|
||||
"""
|
||||
|
||||
# Resolve logs directory: env override or default next to this file
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_env_logs_dir = os.getenv('LOGS_DIR')
|
||||
# Default to backend-services/platform-logs
|
||||
LOGS_DIR = os.path.abspath(_env_logs_dir) if _env_logs_dir else os.path.join(BASE_DIR, 'platform-logs')
|
||||
|
||||
# Build formatters
|
||||
@@ -1192,39 +1081,30 @@ def configure_logger(logger_name):
|
||||
"""
|
||||
|
||||
PATTERNS = [
|
||||
# Authorization header (redact entire value: scheme + token)
|
||||
re.compile(r'(?i)(authorization\s*[:=]\s*)([^;\r\n]+)'),
|
||||
|
||||
# API key headers (redact entire value)
|
||||
re.compile(r'(?i)(x-api-key\s*[:=]\s*)([^;\r\n]+)'),
|
||||
re.compile(r'(?i)(api[_-]?key\s*[:=]\s*)([^;\r\n]+)'),
|
||||
re.compile(r'(?i)(api[_-]?secret\s*[:=]\s*)([^;\r\n]+)'),
|
||||
|
||||
# Access and refresh tokens
|
||||
re.compile(r'(?i)(access[_-]?token\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
|
||||
re.compile(r'(?i)(refresh[_-]?token\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
|
||||
re.compile(r'(?i)(token\s*["\']?\s*[:=]\s*["\']?)([a-zA-Z0-9_\-\.]{20,})(["\']?)'),
|
||||
|
||||
# Passwords and secrets
|
||||
re.compile(r'(?i)(password\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n]+)(["\']?)'),
|
||||
re.compile(r'(?i)(secret\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
|
||||
re.compile(r'(?i)(client[_-]?secret\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
|
||||
|
||||
# Cookies and Set-Cookie: redact entire value
|
||||
re.compile(r'(?i)(cookie\s*[:=]\s*)([^;\r\n]+)'),
|
||||
re.compile(r'(?i)(set-cookie\s*[:=]\s*)([^;\r\n]+)'),
|
||||
|
||||
# CSRF tokens
|
||||
re.compile(r'(?i)(x-csrf-token\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
|
||||
re.compile(r'(?i)(csrf[_-]?token\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
|
||||
|
||||
# JWT tokens (eyJ... format)
|
||||
re.compile(r'\b(eyJ[a-zA-Z0-9_\-]+\.eyJ[a-zA-Z0-9_\-]+\.[a-zA-Z0-9_\-]+)\b'),
|
||||
|
||||
# Session IDs
|
||||
re.compile(r'(?i)(session[_-]?id\s*["\']?\s*[:=]\s*["\']?)([^"\';\r\n\s]+)(["\']?)'),
|
||||
|
||||
# Private keys (PEM format detection)
|
||||
re.compile(r'(-----BEGIN[A-Z\s]+PRIVATE KEY-----)(.*?)(-----END[A-Z\s]+PRIVATE KEY-----)', re.DOTALL),
|
||||
]
|
||||
|
||||
@@ -1235,10 +1115,8 @@ def configure_logger(logger_name):
|
||||
|
||||
for pat in self.PATTERNS:
|
||||
if pat.groups == 3 and pat.flags & re.DOTALL:
|
||||
# PEM private key pattern
|
||||
red = pat.sub(r'\1[REDACTED]\3', red)
|
||||
elif pat.groups >= 2:
|
||||
# Header patterns with prefix, value, and optional suffix
|
||||
red = pat.sub(lambda m: (
|
||||
m.group(1) +
|
||||
'[REDACTED]' +
|
||||
@@ -1249,7 +1127,6 @@ def configure_logger(logger_name):
|
||||
|
||||
if red != msg:
|
||||
record.msg = red
|
||||
# Also update record.args if present
|
||||
if hasattr(record, 'args') and record.args:
|
||||
try:
|
||||
if isinstance(record.args, dict):
|
||||
@@ -1273,15 +1150,12 @@ def configure_logger(logger_name):
|
||||
logger.addHandler(_file_handler)
|
||||
return logger
|
||||
|
||||
# Configure main loggers
|
||||
gateway_logger = configure_logger('doorman.gateway')
|
||||
logging_logger = configure_logger('doorman.logging')
|
||||
|
||||
# Dedicated audit trail logger (separate file handler)
|
||||
audit_logger = logging.getLogger('doorman.audit')
|
||||
audit_logger.setLevel(logging.INFO)
|
||||
audit_logger.propagate = False
|
||||
# Remove existing handlers
|
||||
for h in audit_logger.handlers[:]:
|
||||
audit_logger.removeHandler(h)
|
||||
try:
|
||||
@@ -1293,7 +1167,6 @@ try:
|
||||
encoding='utf-8'
|
||||
)
|
||||
_audit_file.setFormatter(JSONFormatter() if _fmt_is_json else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
# Reuse the same redaction filters as gateway logger
|
||||
try:
|
||||
for eh in gateway_logger.handlers:
|
||||
for f in getattr(eh, 'filters', []):
|
||||
@@ -1306,7 +1179,6 @@ except Exception as _e:
|
||||
console = logging.StreamHandler(stream=sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
console.setFormatter(JSONFormatter() if _fmt_is_json else logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
# Reuse the same redaction filters as gateway logger
|
||||
try:
|
||||
for eh in gateway_logger.handlers:
|
||||
for f in getattr(eh, 'filters', []):
|
||||
@@ -1325,8 +1197,6 @@ class Settings(BaseSettings):
|
||||
@doorman.middleware('http')
|
||||
async def ip_filter_middleware(request: Request, call_next):
|
||||
try:
|
||||
# Exempt security settings endpoint from IP filtering to prevent chicken-and-egg
|
||||
# where admins can't update settings if their IP is blocked. Endpoint has auth + RBAC.
|
||||
path = str(request.url.path)
|
||||
if path == '/platform/security/settings':
|
||||
return await call_next(request)
|
||||
@@ -1430,7 +1300,6 @@ async def metrics_middleware(request: Request, call_next):
|
||||
if username:
|
||||
from utils.bandwidth_util import add_usage, _get_user
|
||||
u = _get_user(username)
|
||||
# Track usage when limit is set unless explicitly disabled
|
||||
if u and u.get('bandwidth_limit_bytes') and u.get('bandwidth_limit_enabled') is not False:
|
||||
add_usage(username, int(bytes_in) + int(clen), u.get('bandwidth_limit_window') or 'day')
|
||||
except Exception:
|
||||
@@ -1445,8 +1314,6 @@ async def automatic_purger(interval_seconds):
|
||||
await purge_expired_tokens()
|
||||
gateway_logger.info('Expired JWTs purged from blacklist.')
|
||||
|
||||
## Startup/shutdown handled by lifespan above
|
||||
|
||||
@doorman.exception_handler(JWTError)
|
||||
async def jwt_exception_handler(exc: JWTError):
|
||||
return process_response(ResponseModel(
|
||||
@@ -1488,7 +1355,6 @@ doorman.include_router(dashboard_router, prefix='/platform/dashboard', tags=['Da
|
||||
doorman.include_router(memory_router, prefix='/platform', tags=['Memory'])
|
||||
doorman.include_router(security_router, prefix='/platform', tags=['Security'])
|
||||
doorman.include_router(monitor_router, prefix='/platform', tags=['Monitor'])
|
||||
# Expose token management under both legacy and new prefixes
|
||||
doorman.include_router(credit_router, prefix='/platform/credit', tags=['Credit'])
|
||||
doorman.include_router(demo_router, prefix='/platform/demo', tags=['Demo'])
|
||||
doorman.include_router(config_router, prefix='/platform', tags=['Config'])
|
||||
|
||||
@@ -2,10 +2,8 @@ import os
|
||||
|
||||
BASE_URL = os.getenv('DOORMAN_BASE_URL', 'http://localhost:5001').rstrip('/')
|
||||
ADMIN_EMAIL = os.getenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev')
|
||||
# For live tests, read from env or check parent .env file; default for dev
|
||||
ADMIN_PASSWORD = os.getenv('DOORMAN_ADMIN_PASSWORD')
|
||||
if not ADMIN_PASSWORD:
|
||||
# Try to read from parent .env file
|
||||
env_file = os.path.join(os.path.dirname(__file__), '..', '.env')
|
||||
if os.path.exists(env_file):
|
||||
with open(env_file) as f:
|
||||
@@ -26,6 +24,5 @@ def require_env():
|
||||
missing.append('DOORMAN_BASE_URL')
|
||||
if not ADMIN_EMAIL:
|
||||
missing.append('DOORMAN_ADMIN_EMAIL')
|
||||
# Password defaults to a dev value; warn but do not fail hard
|
||||
if missing:
|
||||
raise RuntimeError(f"Missing required env vars: {', '.join(missing)}")
|
||||
|
||||
@@ -4,7 +4,6 @@ import time
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
# Ensure backend packages are importable when running from live-tests dir
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from config import BASE_URL, ADMIN_EMAIL, ADMIN_PASSWORD, require_env, STRICT_HEALTH
|
||||
|
||||
@@ -114,4 +114,3 @@ def start_soap_echo_server():
|
||||
|
||||
return _ThreadedHTTPServer(Handler).start()
|
||||
|
||||
# Optional servers (GraphQL, gRPC) are set up inside tests conditionally to avoid hard deps here.
|
||||
|
||||
@@ -7,7 +7,6 @@ import pytest
|
||||
from servers import start_rest_echo_server, start_soap_echo_server
|
||||
from config import ENABLE_GRAPHQL, ENABLE_GRPC
|
||||
|
||||
|
||||
def _find_port():
|
||||
s = socket.socket()
|
||||
s.bind(('127.0.0.1', 0))
|
||||
@@ -15,7 +14,6 @@ def _find_port():
|
||||
s.close()
|
||||
return p
|
||||
|
||||
|
||||
def test_bulk_public_rest_crud(client):
|
||||
srv = start_rest_echo_server()
|
||||
try:
|
||||
@@ -54,7 +52,6 @@ def test_bulk_public_rest_crud(client):
|
||||
finally:
|
||||
srv.stop()
|
||||
|
||||
|
||||
def test_bulk_public_soap_crud(client):
|
||||
srv = start_soap_echo_server()
|
||||
try:
|
||||
@@ -99,7 +96,6 @@ def test_bulk_public_soap_crud(client):
|
||||
finally:
|
||||
srv.stop()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not ENABLE_GRAPHQL, reason='GraphQL disabled')
|
||||
def test_bulk_public_graphql_crud(client):
|
||||
try:
|
||||
@@ -196,7 +192,6 @@ def test_bulk_public_graphql_crud(client):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
import os as _os
|
||||
_RUN_LIVE = _os.getenv('DOORMAN_RUN_LIVE', '0') in ('1','true','True')
|
||||
@pytest.mark.skipif(not _RUN_LIVE, reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
|
||||
@@ -230,10 +225,9 @@ message DeleteReply { bool ok = 1; }
|
||||
|
||||
base = client.base_url.rstrip('/')
|
||||
ts = int(time.time())
|
||||
for i in range(0): # disabled
|
||||
for i in range(0):
|
||||
api_name = f'pub-grpc-{ts}-{i}'
|
||||
api_version = 'v1'
|
||||
# Protobuf package identifiers must be valid identifiers: no dashes, etc.
|
||||
pkg = f'{api_name}_{api_version}'.replace('-', '_')
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
tmp = pathlib.Path(td)
|
||||
@@ -277,7 +271,6 @@ message DeleteReply { bool ok = 1; }
|
||||
assert r.status_code in (200, 201), r.text
|
||||
r = client.post('/platform/endpoint', json={'api_name': api_name, 'api_version': api_version, 'endpoint_method': 'POST', 'endpoint_uri': '/grpc', 'endpoint_description': 'grpc'})
|
||||
assert r.status_code in (200, 201), r.text
|
||||
# call CRUD via gateway
|
||||
url = f"{base}/api/grpc/{api_name}"
|
||||
hdr = {'X-API-Version': api_version}
|
||||
pass
|
||||
|
||||
@@ -5,7 +5,6 @@ import pytest
|
||||
|
||||
from config import ENABLE_GRPC
|
||||
|
||||
|
||||
def _find_port() -> int:
|
||||
s = socket.socket()
|
||||
s.bind(('127.0.0.1', 0))
|
||||
@@ -13,7 +12,6 @@ def _find_port() -> int:
|
||||
s.close()
|
||||
return p
|
||||
|
||||
|
||||
@pytest.mark.skipif(not ENABLE_GRPC, reason='gRPC disabled')
|
||||
def test_public_grpc_with_proto_upload(client):
|
||||
try:
|
||||
@@ -45,12 +43,10 @@ message DeleteReply { bool ok = 1; }
|
||||
|
||||
base = client.base_url.rstrip('/')
|
||||
ts = int(time.time())
|
||||
# Avoid hyphens in api_name to ensure valid proto package identifiers
|
||||
api_name = f'grpcdemo{ts}'
|
||||
api_version = 'v1'
|
||||
pkg = f'{api_name}_{api_version}'
|
||||
|
||||
# Build and start a matching gRPC server for the uploaded proto
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
tmp = pathlib.Path(td)
|
||||
(tmp / 'svc.proto').write_text(PROTO.replace('{pkg}', pkg))
|
||||
@@ -84,12 +80,10 @@ message DeleteReply { bool ok = 1; }
|
||||
time.sleep(0.2)
|
||||
|
||||
try:
|
||||
# Upload proto to gateway so it generates matching stubs
|
||||
files = {'file': ('svc.proto', PROTO.replace('{pkg}', pkg), 'application/octet-stream')}
|
||||
r_up = client.post(f'/platform/proto/{api_name}/{api_version}', files=files)
|
||||
assert r_up.status_code in (200, 201), r_up.text
|
||||
|
||||
# Create API and endpoint pointing at our test gRPC server
|
||||
r_api = client.post('/platform/api', json={
|
||||
'api_name': api_name,
|
||||
'api_version': api_version,
|
||||
@@ -100,7 +94,6 @@ message DeleteReply { bool ok = 1; }
|
||||
'api_type': 'REST',
|
||||
'active': True,
|
||||
'api_public': True,
|
||||
# Explicit package for clarity; gateway will prefer API config
|
||||
'api_grpc_package': pkg
|
||||
})
|
||||
assert r_api.status_code in (200, 201), r_api.text
|
||||
@@ -114,7 +107,6 @@ message DeleteReply { bool ok = 1; }
|
||||
})
|
||||
assert r_ep.status_code in (200, 201), r_ep.text
|
||||
|
||||
# Exercise CRUD via gateway
|
||||
url = f"{base}/api/grpc/{api_name}"
|
||||
hdr = {'X-API-Version': api_version}
|
||||
assert requests.post(url, json={'method': 'Resource.Create', 'message': {'name': 'A'}}, headers=hdr).status_code == 200
|
||||
|
||||
@@ -78,7 +78,6 @@ def test_graphql_gateway_basic_flow(client):
|
||||
r = client.post(f'/api/graphql/{api_name}', json=q, headers={'X-API-Version': api_version})
|
||||
assert r.status_code == 200, r.text
|
||||
data = r.json().get('response', r.json())
|
||||
# GraphQL response is nested under 'data' key
|
||||
if isinstance(data, dict) and 'data' in data:
|
||||
data = data['data']
|
||||
assert data.get('hello') == 'Hello, Doorman!'
|
||||
|
||||
@@ -4,7 +4,6 @@ from types import SimpleNamespace
|
||||
|
||||
from utils.ip_policy_util import _ip_in_list, _get_client_ip, enforce_api_ip_policy
|
||||
|
||||
# Override autouse integration fixture with a no-op so we don't require a live backend
|
||||
@pytest.fixture(autouse=True, scope='session')
|
||||
def ensure_session_and_relaxed_limits():
|
||||
yield
|
||||
|
||||
@@ -1,56 +1,45 @@
|
||||
import time
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.order(-10)
|
||||
def test_redis_outage_during_requests(client):
|
||||
# Warm up a platform endpoint that touches cache minimally
|
||||
r = client.get('/platform/authorization/status')
|
||||
assert r.status_code in (200, 204)
|
||||
|
||||
# Trigger redis outage for a short duration
|
||||
r = client.post('/platform/tools/chaos/toggle', json={'backend': 'redis', 'enabled': True, 'duration_ms': 1500})
|
||||
assert r.status_code == 200
|
||||
|
||||
t0 = time.time()
|
||||
# During outage: app should not block; responses should come back quickly
|
||||
r1 = client.get('/platform/authorization/status')
|
||||
dt1 = time.time() - t0
|
||||
assert dt1 < 2.0, f'request blocked too long during redis outage: {dt1}s'
|
||||
assert r1.status_code in (200, 204, 500, 503)
|
||||
|
||||
# Wait for auto-recover
|
||||
time.sleep(2.0)
|
||||
r2 = client.get('/platform/authorization/status')
|
||||
assert r2.status_code in (200, 204)
|
||||
|
||||
# Check error budget burn recorded
|
||||
s = client.get('/platform/tools/chaos/stats')
|
||||
assert s.status_code == 200
|
||||
js = s.json()
|
||||
data = js.get('response', js)
|
||||
assert isinstance(data.get('error_budget_burn'), int)
|
||||
|
||||
|
||||
@pytest.mark.order(-9)
|
||||
def test_mongo_outage_during_requests(client):
|
||||
# Ensure a DB-backed endpoint is hit (user profile)
|
||||
t0 = time.time()
|
||||
r0 = client.get('/platform/user/me')
|
||||
assert r0.status_code in (200, 204)
|
||||
|
||||
# Simulate mongo outage and immediately hit the same endpoint
|
||||
r = client.post('/platform/tools/chaos/toggle', json={'backend': 'mongo', 'enabled': True, 'duration_ms': 1500})
|
||||
assert r.status_code == 200
|
||||
|
||||
t1 = time.time()
|
||||
r1 = client.get('/platform/user/me')
|
||||
dt1 = time.time() - t1
|
||||
# Do not block the event loop excessively; return fast with error if needed
|
||||
assert dt1 < 2.0, f'request blocked too long during mongo outage: {dt1}s'
|
||||
assert r1.status_code in (200, 400, 401, 403, 404, 500)
|
||||
|
||||
# After recovery window
|
||||
time.sleep(2.0)
|
||||
r2 = client.get('/platform/user/me')
|
||||
assert r2.status_code in (200, 204)
|
||||
|
||||
@@ -5,7 +5,6 @@ _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
|
||||
if not _RUN_LIVE:
|
||||
pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
|
||||
|
||||
|
||||
def test_api_cors_allow_origins_allow_methods_headers_credentials_expose_live(client):
|
||||
import time
|
||||
api_name = f'corslive-{int(time.time())}'
|
||||
@@ -27,12 +26,10 @@ def test_api_cors_allow_origins_allow_methods_headers_credentials_expose_live(cl
|
||||
})
|
||||
client.post('/platform/endpoint', json={'api_name': api_name, 'api_version': ver, 'endpoint_method': 'GET', 'endpoint_uri': '/q', 'endpoint_description': 'q'})
|
||||
client.post('/platform/subscription/subscribe', json={'username': 'admin', 'api_name': api_name, 'api_version': ver})
|
||||
# Preflight
|
||||
r = client.options(f'/api/rest/{api_name}/{ver}/q', headers={'Origin': 'http://ok.example', 'Access-Control-Request-Method': 'GET', 'Access-Control-Request-Headers': 'X-CSRF-Token'})
|
||||
assert r.status_code == 204
|
||||
assert r.headers.get('Access-Control-Allow-Origin') == 'http://ok.example'
|
||||
assert 'GET' in (r.headers.get('Access-Control-Allow-Methods') or '')
|
||||
# Actual
|
||||
r2 = client.get(f'/api/rest/{api_name}/{ver}/q', headers={'Origin': 'http://ok.example'})
|
||||
assert r2.status_code in (200, 404)
|
||||
assert r2.headers.get('Access-Control-Allow-Origin') == 'http://ok.example'
|
||||
|
||||
@@ -5,7 +5,6 @@ _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
|
||||
if not _RUN_LIVE:
|
||||
pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
|
||||
|
||||
|
||||
def test_bandwidth_limit_enforced_and_window_resets_live(client):
|
||||
name, ver = 'bwlive', 'v1'
|
||||
client.post('/platform/api', json={
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
# Enable by running with DOORMAN_RUN_LIVE=1
|
||||
_RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
|
||||
if not _RUN_LIVE:
|
||||
pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
|
||||
|
||||
|
||||
async def _setup(client, name='gllive', ver='v1'):
|
||||
await client.post('/platform/api', json={
|
||||
'api_name': name,
|
||||
@@ -28,7 +26,6 @@ async def _setup(client, name='gllive', ver='v1'):
|
||||
})
|
||||
return name, ver
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphql_client_fallback_to_httpx_live(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
@@ -54,7 +51,6 @@ async def test_graphql_client_fallback_to_httpx_live(monkeypatch, authed_client)
|
||||
r = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ ping }', 'variables': {}})
|
||||
assert r.status_code == 200 and r.json().get('ok') is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphql_errors_live_strict_and_loose(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
@@ -77,11 +73,9 @@ async def test_graphql_errors_live_strict_and_loose(monkeypatch, authed_client):
|
||||
return FakeHTTPResp({'errors': [{'message': 'boom'}]})
|
||||
monkeypatch.setattr(gs, 'Client', Dummy)
|
||||
monkeypatch.setattr(gs.httpx, 'AsyncClient', H)
|
||||
# Loose
|
||||
monkeypatch.delenv('STRICT_RESPONSE_ENVELOPE', raising=False)
|
||||
r1 = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ err }', 'variables': {}})
|
||||
assert r1.status_code == 200 and isinstance(r1.json().get('errors'), list)
|
||||
# Strict
|
||||
monkeypatch.setenv('STRICT_RESPONSE_ENVELOPE', 'true')
|
||||
r2 = await authed_client.post(f'/api/graphql/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'query': '{ err }', 'variables': {}})
|
||||
assert r2.status_code == 200 and r2.json().get('status_code') == 200
|
||||
|
||||
@@ -5,7 +5,6 @@ _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
|
||||
if not _RUN_LIVE:
|
||||
pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
|
||||
|
||||
|
||||
def _fake_pb2_module(method_name='M'):
|
||||
class Req:
|
||||
pass
|
||||
@@ -20,7 +19,6 @@ def _fake_pb2_module(method_name='M'):
|
||||
setattr(Reply, '__name__', f'{method_name}Reply')
|
||||
return Req, Reply
|
||||
|
||||
|
||||
def _make_import_module_recorder(record, pb2_map):
|
||||
def _imp(name):
|
||||
record.append(name)
|
||||
@@ -49,7 +47,6 @@ def _make_import_module_recorder(record, pb2_map):
|
||||
raise ImportError(name)
|
||||
return _imp
|
||||
|
||||
|
||||
def _make_fake_grpc_unary(sequence_codes, grpc_mod):
|
||||
counter = {'i': 0}
|
||||
class AioChan:
|
||||
@@ -77,7 +74,6 @@ def _make_fake_grpc_unary(sequence_codes, grpc_mod):
|
||||
fake = type('G', (), {'aio': aio, 'StatusCode': grpc_mod.StatusCode, 'RpcError': Exception})
|
||||
return fake
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grpc_with_api_grpc_package_config(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
@@ -110,7 +106,6 @@ async def test_grpc_with_api_grpc_package_config(monkeypatch, authed_client):
|
||||
assert r.status_code == 200
|
||||
assert any(n == 'api.pkg_pb2' for n in record)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grpc_with_request_package_override(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
@@ -142,7 +137,6 @@ async def test_grpc_with_request_package_override(monkeypatch, authed_client):
|
||||
assert r.status_code == 200
|
||||
assert any(n == 'req.pkg_pb2' for n in record)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grpc_without_package_server_uses_fallback_path(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
@@ -175,7 +169,6 @@ async def test_grpc_without_package_server_uses_fallback_path(monkeypatch, authe
|
||||
assert r.status_code == 200
|
||||
assert any(n.endswith(default_pkg) for n in record)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grpc_unavailable_then_success_with_retry_live(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
|
||||
@@ -2,18 +2,14 @@ import pytest
|
||||
import os
|
||||
import platform
|
||||
|
||||
# Enable by running with DOORMAN_RUN_LIVE=1
|
||||
_RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
|
||||
if not _RUN_LIVE:
|
||||
pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='SIGUSR1 not available on Windows')
|
||||
def test_sigusr1_dump_in_memory_mode_live(client, monkeypatch, tmp_path):
|
||||
# Ensure encryption key and path
|
||||
monkeypatch.setenv('MEM_ENCRYPTION_KEY', 'live-secret-xyz')
|
||||
monkeypatch.setenv('MEM_DUMP_PATH', str(tmp_path / 'live' / 'memory_dump.bin'))
|
||||
# Trigger SIGUSR1; in live env the backend would handle it
|
||||
import signal, time
|
||||
os.kill(os.getpid(), signal.SIGUSR1)
|
||||
time.sleep(0.5)
|
||||
|
||||
@@ -5,7 +5,6 @@ _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
|
||||
if not _RUN_LIVE:
|
||||
pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
|
||||
|
||||
|
||||
def test_platform_cors_strict_wildcard_credentials_edges_live(client, monkeypatch):
|
||||
monkeypatch.setenv('ALLOWED_ORIGINS', '*')
|
||||
monkeypatch.setenv('ALLOW_CREDENTIALS', 'true')
|
||||
@@ -14,7 +13,6 @@ def test_platform_cors_strict_wildcard_credentials_edges_live(client, monkeypatc
|
||||
assert r.status_code == 204
|
||||
assert r.headers.get('Access-Control-Allow-Origin') in (None, '')
|
||||
|
||||
|
||||
def test_platform_cors_methods_headers_defaults_live(client, monkeypatch):
|
||||
monkeypatch.setenv('ALLOW_METHODS', '')
|
||||
monkeypatch.setenv('ALLOW_HEADERS', '*')
|
||||
|
||||
@@ -5,7 +5,6 @@ _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
|
||||
if not _RUN_LIVE:
|
||||
pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_allowed_headers_only(monkeypatch, authed_client):
|
||||
from conftest import create_api, create_endpoint, subscribe_self
|
||||
@@ -48,7 +47,6 @@ async def test_forward_allowed_headers_only(monkeypatch, authed_client):
|
||||
ch = {k.lower(): v for k, v in (captured.get('headers') or {}).items()}
|
||||
assert 'x-allowed' in ch and 'x-blocked' not in ch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_headers_filtered_by_allowlist(monkeypatch, authed_client):
|
||||
from conftest import create_api, create_endpoint, subscribe_self
|
||||
@@ -87,6 +85,5 @@ async def test_response_headers_filtered_by_allowlist(monkeypatch, authed_client
|
||||
monkeypatch.setattr(gs.httpx, 'AsyncClient', HC)
|
||||
r = await authed_client.get(f'/api/rest/{name}/{ver}/p')
|
||||
assert r.status_code == 200
|
||||
# Only X-Upstream forwarded back per allowlist
|
||||
assert r.headers.get('X-Upstream') == 'yes'
|
||||
assert 'X-Secret' not in r.headers
|
||||
|
||||
@@ -7,7 +7,6 @@ if not _RUN_LIVE:
|
||||
|
||||
from tests.test_gateway_routing_limits import _FakeAsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_retries_on_500_then_success(monkeypatch, authed_client):
|
||||
from conftest import create_api, create_endpoint, subscribe_self
|
||||
@@ -17,7 +16,6 @@ async def test_rest_retries_on_500_then_success(monkeypatch, authed_client):
|
||||
await create_endpoint(authed_client, name, ver, 'GET', '/r')
|
||||
await subscribe_self(authed_client, name, ver)
|
||||
|
||||
# Set retry count to 1
|
||||
from utils.database import api_collection
|
||||
api_collection.update_one({'api_name': name, 'api_version': ver}, {'$set': {'api_allowed_retry_count': 1}})
|
||||
await authed_client.delete('/api/caches')
|
||||
@@ -42,7 +40,6 @@ async def test_rest_retries_on_500_then_success(monkeypatch, authed_client):
|
||||
r = await authed_client.get(f'/api/rest/{name}/{ver}/r')
|
||||
assert r.status_code == 200 and r.json().get('ok') is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_retries_on_503_then_success(monkeypatch, authed_client):
|
||||
from conftest import create_api, create_endpoint, subscribe_self
|
||||
@@ -74,7 +71,6 @@ async def test_rest_retries_on_503_then_success(monkeypatch, authed_client):
|
||||
r = await authed_client.get(f'/api/rest/{name}/{ver}/r')
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_no_retry_when_retry_count_zero(monkeypatch, authed_client):
|
||||
from conftest import create_api, create_endpoint, subscribe_self
|
||||
|
||||
@@ -5,7 +5,6 @@ _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
|
||||
if not _RUN_LIVE:
|
||||
pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soap_content_types_matrix(monkeypatch, authed_client):
|
||||
from conftest import create_api, create_endpoint, subscribe_self
|
||||
@@ -34,7 +33,6 @@ async def test_soap_content_types_matrix(monkeypatch, authed_client):
|
||||
r = await authed_client.post(f'/api/soap/{name}/{ver}/s', headers={'Content-Type': ct}, content='<a/>')
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soap_retries_then_success(monkeypatch, authed_client):
|
||||
from conftest import create_api, create_endpoint, subscribe_self
|
||||
|
||||
@@ -5,7 +5,6 @@ _RUN_LIVE = os.getenv('DOORMAN_RUN_LIVE', '0') in ('1', 'true', 'True')
|
||||
if not _RUN_LIVE:
|
||||
pytestmark = pytest.mark.skip(reason='Requires live backend service; set DOORMAN_RUN_LIVE=1 to enable')
|
||||
|
||||
|
||||
def test_throttle_queue_limit_exceeded_429_live(client):
|
||||
from config import ADMIN_EMAIL
|
||||
name, ver = 'throtq', 'v1'
|
||||
@@ -27,14 +26,12 @@ def test_throttle_queue_limit_exceeded_429_live(client):
|
||||
'endpoint_description': 't'
|
||||
})
|
||||
client.post('/platform/subscription/subscribe', json={'username': 'admin', 'api_name': name, 'api_version': ver})
|
||||
# Set queue limit to 1
|
||||
client.put('/platform/user/admin', json={'throttle_queue_limit': 1})
|
||||
client.delete('/api/caches')
|
||||
r1 = client.get(f'/api/rest/{name}/{ver}/t')
|
||||
r2 = client.get(f'/api/rest/{name}/{ver}/t')
|
||||
assert r2.status_code == 429
|
||||
|
||||
|
||||
def test_throttle_dynamic_wait_live(client):
|
||||
name, ver = 'throtw', 'v1'
|
||||
client.post('/platform/api', json={
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
@@ -4,10 +4,8 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Internal imports
|
||||
from models.validation_schema_model import ValidationSchema
|
||||
|
||||
class CreateEndpointValidationModel(BaseModel):
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
@@ -26,7 +25,6 @@ class CreditModel(BaseModel):
|
||||
api_key_header: str = Field(..., description='Header the API key should be sent in', example='x-api-key')
|
||||
credit_tiers: List[CreditTierModel] = Field(..., min_items=1, description='Credit tiers information')
|
||||
|
||||
# API Key Rotation fields (for zero-downtime key rotation)
|
||||
api_key_new: Optional[str] = Field(None, description='New API key during rotation period', example='yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy')
|
||||
api_key_rotation_expires: Optional[datetime] = Field(None, description='Expiration time for old API key during rotation', example='2025-01-15T10:00:00Z')
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class ResponseMessage(BaseModel):
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
@@ -4,10 +4,8 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Internal imports
|
||||
from models.validation_schema_model import ValidationSchema
|
||||
|
||||
class EndpointValidationModelResponse(BaseModel):
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from typing import List, Union, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# External imports
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class SubscribeModel(BaseModel):
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
@@ -4,10 +4,8 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Internal imports
|
||||
from models.validation_schema_model import ValidationSchema
|
||||
|
||||
class UpdateEndpointValidationModel(BaseModel):
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class UpdatePasswordModel(BaseModel):
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from typing import Optional, Dict
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@@ -4,11 +4,9 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Internal imports
|
||||
from models.field_validation_model import FieldValidation
|
||||
|
||||
class ValidationSchema(BaseModel):
|
||||
|
||||
@@ -4,14 +4,12 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from typing import List
|
||||
import logging
|
||||
import uuid
|
||||
import time
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from services.api_service import ApiService
|
||||
from utils.auth_util import auth_required
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, Request, Depends, HTTPException, Response
|
||||
from jose import JWTError
|
||||
import uuid
|
||||
@@ -12,7 +11,6 @@ import time
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from services.user_service import UserService
|
||||
from utils.response_util import respond_rest
|
||||
@@ -57,15 +55,12 @@ async def authorization(request: Request):
|
||||
request_id = str(uuid.uuid4())
|
||||
start_time = time.time() * 1000
|
||||
try:
|
||||
# IP-based rate limiting to prevent brute force attacks (5 attempts per 5 minutes)
|
||||
# Can be overridden via environment variables for testing
|
||||
login_limit = int(os.getenv('LOGIN_IP_RATE_LIMIT', '5'))
|
||||
login_window = int(os.getenv('LOGIN_IP_RATE_WINDOW', '300'))
|
||||
rate_limit_info = await limit_by_ip(request, limit=login_limit, window=login_window)
|
||||
|
||||
logger.info(f'{request_id} | From: {request.client.host}:{request.client.port}')
|
||||
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
|
||||
# Parse JSON body safely; invalid JSON should not 500
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
@@ -117,7 +112,6 @@ async def authorization(request: Request):
|
||||
response={'access_token': access_token}
|
||||
))
|
||||
|
||||
# Add rate limit headers
|
||||
if rate_limit_info:
|
||||
response.headers['X-RateLimit-Limit'] = str(rate_limit_info['limit'])
|
||||
response.headers['X-RateLimit-Remaining'] = str(rate_limit_info['remaining'])
|
||||
@@ -144,24 +138,6 @@ async def authorization(request: Request):
|
||||
else:
|
||||
safe_domain = None
|
||||
|
||||
# Cookie Duplication Strategy:
|
||||
# Cookies are set twice for maximum compatibility across deployment configurations:
|
||||
# 1. WITH domain attribute (if COOKIE_DOMAIN is set and matches request host)
|
||||
# - Enables subdomain sharing (e.g., *.example.com)
|
||||
# - Required for SSO and multi-subdomain setups
|
||||
# 2. WITHOUT domain attribute
|
||||
# - Exact domain match only (no subdomain sharing)
|
||||
# - Ensures cookies work even if domain validation fails
|
||||
#
|
||||
# Configuration:
|
||||
# - COOKIE_DOMAIN: Base domain for subdomain sharing (e.g., "example.com")
|
||||
# - For SSO: Set to SSO provider's domain scope
|
||||
# - For reverse proxy: Set to base domain (not subdomain like "api.example.com")
|
||||
# - Leave unset for single-domain deployments
|
||||
#
|
||||
# Impact: Doubles cookie size; consider for large JWTs behind proxies
|
||||
|
||||
# Set CSRF token with domain attribute (for subdomain sharing)
|
||||
response.set_cookie(
|
||||
key='csrf_token',
|
||||
value=csrf_token,
|
||||
@@ -173,7 +149,6 @@ async def authorization(request: Request):
|
||||
max_age=1800
|
||||
)
|
||||
|
||||
# Set CSRF token without domain attribute (exact domain only)
|
||||
response.set_cookie(
|
||||
key='csrf_token',
|
||||
value=csrf_token,
|
||||
@@ -184,7 +159,6 @@ async def authorization(request: Request):
|
||||
max_age=1800
|
||||
)
|
||||
|
||||
# Set access token with domain attribute (for subdomain sharing)
|
||||
response.set_cookie(
|
||||
key='access_token_cookie',
|
||||
value=access_token,
|
||||
@@ -196,7 +170,6 @@ async def authorization(request: Request):
|
||||
max_age=1800
|
||||
)
|
||||
|
||||
# Set access token without domain attribute (exact domain only)
|
||||
response.set_cookie(
|
||||
key='access_token_cookie',
|
||||
value=access_token,
|
||||
@@ -208,7 +181,6 @@ async def authorization(request: Request):
|
||||
)
|
||||
return response
|
||||
except HTTPException as e:
|
||||
# Preserve IP rate limit semantics (429 + Retry-After headers)
|
||||
if getattr(e, 'status_code', None) == 429:
|
||||
headers = getattr(e, 'headers', {}) or {}
|
||||
detail = e.detail if isinstance(e.detail, dict) else {}
|
||||
@@ -221,7 +193,6 @@ async def authorization(request: Request):
|
||||
error_code=str(detail.get('error_code') or 'IP_RATE_LIMIT'),
|
||||
error_message=str(detail.get('message') or 'Too many requests')
|
||||
))
|
||||
# Default mapping for auth failures
|
||||
return respond_rest(ResponseModel(
|
||||
status_code=401,
|
||||
response_headers={
|
||||
@@ -244,7 +215,6 @@ async def authorization(request: Request):
|
||||
end_time = time.time() * 1000
|
||||
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
|
||||
|
||||
# Admin endpoints for revoking tokens and disabling/enabling users
|
||||
"""
|
||||
Endpoint
|
||||
|
||||
@@ -283,7 +253,6 @@ async def admin_revoke_user_tokens(username: str, request: Request):
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True)
|
||||
# Continue anyway - permission check failure shouldn't block operation
|
||||
revoke_all_for_user(username)
|
||||
return respond_rest(ResponseModel(
|
||||
status_code=200,
|
||||
@@ -340,7 +309,6 @@ async def admin_unrevoke_user_tokens(username: str, request: Request):
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True)
|
||||
# Continue anyway - permission check failure shouldn't block operation
|
||||
unrevoke_all_for_user(username)
|
||||
return respond_rest(ResponseModel(
|
||||
status_code=200,
|
||||
@@ -397,7 +365,6 @@ async def admin_disable_user(username: str, request: Request):
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True)
|
||||
# Continue anyway - permission check failure shouldn't block operation
|
||||
|
||||
await UserService.update_user(username, UpdateUserModel(active=False), request_id)
|
||||
|
||||
@@ -457,7 +424,6 @@ async def admin_enable_user(username: str, request: Request):
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True)
|
||||
# Continue anyway - permission check failure shouldn't block operation
|
||||
await UserService.update_user(username, UpdateUserModel(active=True), request_id)
|
||||
|
||||
return respond_rest(ResponseModel(
|
||||
@@ -515,7 +481,6 @@ async def admin_user_status(username: str, request: Request):
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f'{request_id} | Admin check failed: {str(e)}', exc_info=True)
|
||||
# Continue anyway - permission check failure shouldn't block operation
|
||||
user = await UserService.get_user_by_username_helper(username)
|
||||
status = {
|
||||
'active': bool(user.get('active', False)),
|
||||
@@ -610,13 +575,6 @@ async def extended_authorization(request: Request):
|
||||
else:
|
||||
safe_domain = None
|
||||
|
||||
# Cookie Duplication Strategy (see login endpoint for full documentation)
|
||||
# Cookies set twice: WITH domain for subdomain sharing, WITHOUT domain for exact match
|
||||
|
||||
# Set CSRF token with domain attribute (for subdomain sharing)
|
||||
# lgtm [py/insecure-cookie]
|
||||
# codeql[py/insecure-cookie]
|
||||
# Intentionally not HttpOnly: double-submit CSRF token accessible to client JS; access token cookie is HttpOnly.
|
||||
response.set_cookie(
|
||||
key='csrf_token',
|
||||
value=csrf_token,
|
||||
@@ -628,10 +586,6 @@ async def extended_authorization(request: Request):
|
||||
max_age=604800
|
||||
)
|
||||
|
||||
# Set CSRF token without domain attribute (exact domain only)
|
||||
# lgtm [py/insecure-cookie]
|
||||
# codeql[py/insecure-cookie]
|
||||
# Intentionally not HttpOnly: double-submit CSRF token accessible to client JS; access token cookie is HttpOnly.
|
||||
response.set_cookie(
|
||||
key='csrf_token',
|
||||
value=csrf_token,
|
||||
@@ -642,7 +596,6 @@ async def extended_authorization(request: Request):
|
||||
max_age=604800
|
||||
)
|
||||
|
||||
# Set refresh token with domain attribute (for subdomain sharing)
|
||||
response.set_cookie(
|
||||
key='access_token_cookie',
|
||||
value=refresh_token,
|
||||
@@ -654,7 +607,6 @@ async def extended_authorization(request: Request):
|
||||
max_age=604800
|
||||
)
|
||||
|
||||
# Set refresh token without domain attribute (exact domain only)
|
||||
response.set_cookie(
|
||||
key='access_token_cookie',
|
||||
value=refresh_token,
|
||||
@@ -796,7 +748,6 @@ async def authorization_invalidate(response: Response, request: Request):
|
||||
username = payload.get('sub')
|
||||
logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
|
||||
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
|
||||
# Add this token's JTI to durable revocation with TTL until expiry
|
||||
try:
|
||||
import time as _t
|
||||
exp = payload.get('exp')
|
||||
@@ -805,7 +756,6 @@ async def authorization_invalidate(response: Response, request: Request):
|
||||
ttl = max(1, int(exp - _t.time()))
|
||||
add_revoked_jti(username, payload.get('jti'), ttl)
|
||||
except Exception as e:
|
||||
# Fallback to in-memory TimedHeap (back-compat)
|
||||
logger.warning(f'{request_id} | Token revocation failed, using fallback: {str(e)}')
|
||||
if username not in jwt_blacklist:
|
||||
jwt_blacklist[username] = TimedHeap()
|
||||
|
||||
@@ -19,7 +19,6 @@ config_hot_reload_router = APIRouter(
|
||||
tags=['Configuration Hot Reload']
|
||||
)
|
||||
|
||||
|
||||
@config_hot_reload_router.get(
|
||||
'/current',
|
||||
summary='Get Current Configuration',
|
||||
@@ -31,7 +30,6 @@ async def get_current_config(
|
||||
):
|
||||
"""Get current configuration (admin only)"""
|
||||
try:
|
||||
# Check admin permission
|
||||
accesses = payload.get('accesses', {})
|
||||
if not accesses.get('manage_gateway'):
|
||||
raise HTTPException(
|
||||
@@ -39,7 +37,6 @@ async def get_current_config(
|
||||
detail='Insufficient permissions: manage_gateway required'
|
||||
)
|
||||
|
||||
# Dump current config
|
||||
config = hot_config.dump()
|
||||
|
||||
return ResponseModel(
|
||||
@@ -62,7 +59,6 @@ async def get_current_config(
|
||||
detail='Failed to retrieve configuration'
|
||||
)
|
||||
|
||||
|
||||
@config_hot_reload_router.post(
|
||||
'/reload',
|
||||
summary='Trigger Configuration Reload',
|
||||
@@ -74,7 +70,6 @@ async def trigger_config_reload(
|
||||
):
|
||||
"""Trigger configuration reload (admin only)"""
|
||||
try:
|
||||
# Check admin permission
|
||||
accesses = payload.get('accesses', {})
|
||||
if not accesses.get('manage_gateway'):
|
||||
raise HTTPException(
|
||||
@@ -82,7 +77,6 @@ async def trigger_config_reload(
|
||||
detail='Insufficient permissions: manage_gateway required'
|
||||
)
|
||||
|
||||
# Reload configuration
|
||||
hot_config.reload()
|
||||
|
||||
return ResponseModel(
|
||||
@@ -104,7 +98,6 @@ async def trigger_config_reload(
|
||||
detail='Failed to reload configuration'
|
||||
)
|
||||
|
||||
|
||||
@config_hot_reload_router.get(
|
||||
'/reloadable-keys',
|
||||
summary='List Reloadable Configuration Keys',
|
||||
@@ -117,40 +110,32 @@ async def get_reloadable_keys(
|
||||
"""Get list of reloadable configuration keys"""
|
||||
try:
|
||||
reloadable_keys = [
|
||||
# Logging
|
||||
{'key': 'LOG_LEVEL', 'description': 'Log level (DEBUG, INFO, WARNING, ERROR)', 'example': 'INFO'},
|
||||
{'key': 'LOG_FORMAT', 'description': 'Log format (json, text)', 'example': 'json'},
|
||||
{'key': 'LOG_FILE', 'description': 'Log file path', 'example': 'logs/doorman.log'},
|
||||
|
||||
# Timeouts
|
||||
{'key': 'GATEWAY_TIMEOUT', 'description': 'Gateway timeout in seconds', 'example': '30'},
|
||||
{'key': 'UPSTREAM_TIMEOUT', 'description': 'Upstream timeout in seconds', 'example': '30'},
|
||||
{'key': 'CONNECTION_TIMEOUT', 'description': 'Connection timeout in seconds', 'example': '10'},
|
||||
|
||||
# Rate Limiting
|
||||
{'key': 'RATE_LIMIT_ENABLED', 'description': 'Enable rate limiting', 'example': 'true'},
|
||||
{'key': 'RATE_LIMIT_REQUESTS', 'description': 'Requests per window', 'example': '100'},
|
||||
{'key': 'RATE_LIMIT_WINDOW', 'description': 'Window size in seconds', 'example': '60'},
|
||||
|
||||
# Cache
|
||||
{'key': 'CACHE_TTL', 'description': 'Cache TTL in seconds', 'example': '300'},
|
||||
{'key': 'CACHE_MAX_SIZE', 'description': 'Maximum cache entries', 'example': '1000'},
|
||||
|
||||
# Circuit Breaker
|
||||
{'key': 'CIRCUIT_BREAKER_ENABLED', 'description': 'Enable circuit breaker', 'example': 'true'},
|
||||
{'key': 'CIRCUIT_BREAKER_THRESHOLD', 'description': 'Failures before opening', 'example': '5'},
|
||||
{'key': 'CIRCUIT_BREAKER_TIMEOUT', 'description': 'Timeout before retry (seconds)', 'example': '60'},
|
||||
|
||||
# Retry
|
||||
{'key': 'RETRY_ENABLED', 'description': 'Enable retry logic', 'example': 'true'},
|
||||
{'key': 'RETRY_MAX_ATTEMPTS', 'description': 'Maximum retry attempts', 'example': '3'},
|
||||
{'key': 'RETRY_BACKOFF', 'description': 'Backoff multiplier', 'example': '1.0'},
|
||||
|
||||
# Monitoring
|
||||
{'key': 'METRICS_ENABLED', 'description': 'Enable metrics collection', 'example': 'true'},
|
||||
{'key': 'METRICS_INTERVAL', 'description': 'Metrics interval (seconds)', 'example': '60'},
|
||||
|
||||
# Feature Flags
|
||||
{'key': 'FEATURE_REQUEST_REPLAY', 'description': 'Enable request replay', 'example': 'false'},
|
||||
{'key': 'FEATURE_AB_TESTING', 'description': 'Enable A/B testing', 'example': 'false'},
|
||||
{'key': 'FEATURE_COST_ANALYTICS', 'description': 'Enable cost analytics', 'example': 'false'},
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Routes to export and import platform configuration (APIs, Endpoints, Roles, Groups, Routings).
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
@@ -10,7 +9,6 @@ import time
|
||||
import logging
|
||||
import copy
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from utils.response_util import process_response
|
||||
from utils.auth_util import auth_required
|
||||
|
||||
@@ -4,14 +4,12 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from models.user_credits_model import UserCreditModel
|
||||
from models.credit_model import CreditModel
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Dict, List
|
||||
import uuid
|
||||
@@ -12,7 +11,6 @@ import time
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from utils.auth_util import auth_required
|
||||
from utils.response_util import respond_rest
|
||||
|
||||
@@ -3,14 +3,12 @@ Protected demo seeding routes for populating the running server with dummy data.
|
||||
Only available to users with 'manage_gateway' OR 'manage_credits'.
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Optional
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from utils.response_util import respond_rest
|
||||
from utils.role_util import platform_role_required_bool, is_admin_user
|
||||
|
||||
@@ -4,14 +4,12 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Request, HTTPException
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.create_endpoint_validation_model import CreateEndpointValidationModel
|
||||
from models.endpoint_model_response import EndpointModelResponse
|
||||
from models.endpoint_validation_model_response import EndpointValidationModelResponse
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
import os
|
||||
import uuid
|
||||
@@ -14,7 +13,6 @@ import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from utils import api_util
|
||||
from utils.doorman_cache_util import doorman_cache
|
||||
@@ -84,7 +82,6 @@ async def status(request: Request):
|
||||
}
|
||||
).dict(), 'rest')
|
||||
except Exception as e:
|
||||
# If auth fails, respond unauthorized
|
||||
if hasattr(e, 'status_code') and getattr(e, 'status_code') == 401:
|
||||
return process_response(ResponseModel(
|
||||
status_code=401,
|
||||
@@ -277,7 +274,6 @@ async def gateway(request: Request, path: str):
|
||||
end_time = time.time() * 1000
|
||||
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
|
||||
|
||||
# Per-method wrappers with unique operation IDs for OpenAPI
|
||||
@gateway_router.get('/rest/{path:path}', description='REST gateway endpoint (GET)', response_model=ResponseModel, operation_id='rest_get')
|
||||
async def rest_get(request: Request, path: str):
|
||||
return await gateway(request, path)
|
||||
@@ -340,7 +336,6 @@ async def rest_preflight(request: Request, path: str):
|
||||
from fastapi.responses import Response as StarletteResponse
|
||||
return StarletteResponse(status_code=204, headers={'request_id': request_id})
|
||||
|
||||
# Optional strict mode: return 405 for OPTIONS when endpoint is unregistered
|
||||
try:
|
||||
import os as _os, re as _re
|
||||
if _os.getenv('STRICT_OPTIONS_405', 'false').lower() == 'true':
|
||||
@@ -523,7 +518,6 @@ Response:
|
||||
response_model=ResponseModel)
|
||||
|
||||
async def graphql_gateway(request: Request, path: str):
|
||||
# Reuse Request ID from middleware if present, else accept inbound header, else generate
|
||||
request_id = getattr(request.state, 'request_id', None) or request.headers.get('X-Request-ID') or str(uuid.uuid4())
|
||||
start_time = time.time() * 1000
|
||||
try:
|
||||
@@ -554,7 +548,6 @@ async def graphql_gateway(request: Request, path: str):
|
||||
logger.info(f"{request_id} | Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')[:-3]}ms")
|
||||
logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
|
||||
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
|
||||
# Validation check using already-resolved API (no need to re-resolve)
|
||||
if api and api.get('validation_enabled'):
|
||||
body = await request.json()
|
||||
query = body.get('query')
|
||||
@@ -693,7 +686,6 @@ async def grpc_gateway(request: Request, path: str):
|
||||
logger.info(f"{request_id} | Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')[:-3]}ms")
|
||||
logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
|
||||
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
|
||||
# Validation check using already-resolved API (no need to re-resolve)
|
||||
if api and api.get('validation_enabled'):
|
||||
body = await request.json()
|
||||
request_data = json.loads(body.get('data', '{}'))
|
||||
@@ -708,7 +700,6 @@ async def grpc_gateway(request: Request, path: str):
|
||||
).dict(), 'grpc')
|
||||
svc_resp = await GatewayService.grpc_gateway(username, request, request_id, start_time, path)
|
||||
if not isinstance(svc_resp, dict):
|
||||
# Guard against unexpected None from service: return a 500 error
|
||||
svc_resp = ResponseModel(
|
||||
status_code=500,
|
||||
response_headers={'request_id': request_id},
|
||||
|
||||
@@ -4,14 +4,12 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.group_model_response import GroupModelResponse
|
||||
from models.response_model import ResponseModel
|
||||
from models.update_group_model import UpdateGroupModel
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, Request, Query, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -13,7 +12,6 @@ import time
|
||||
import logging
|
||||
import io
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from services.logging_service import LoggingService
|
||||
from utils.auth_util import auth_required
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Routes for dumping and restoring in-memory database state.
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
@@ -11,7 +10,6 @@ import uuid
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from utils.response_util import process_response
|
||||
from models.response_model import ResponseModel
|
||||
from utils.auth_util import auth_required
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Routes to expose gateway metrics to the web client.
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, Request
|
||||
from pydantic import BaseModel
|
||||
import uuid
|
||||
@@ -12,7 +11,6 @@ import io
|
||||
import csv
|
||||
from fastapi.responses import Response as FastAPIResponse
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from utils.response_util import process_response
|
||||
from utils.metrics_util import metrics_store
|
||||
@@ -102,7 +100,6 @@ Response:
|
||||
description='Kubernetes liveness probe endpoint (no auth)',
|
||||
response_model=LivenessResponse)
|
||||
async def liveness(request: Request):
|
||||
# Always return alive for liveness; readiness reflects degraded/terminating
|
||||
return {'status': 'alive'}
|
||||
|
||||
"""
|
||||
@@ -126,9 +123,7 @@ async def readiness(request: Request):
|
||||
Authorized users with 'manage_gateway':
|
||||
Returns detailed status including mongodb, redis, mode, cache_backend
|
||||
"""
|
||||
# For tests and simple readiness checks, do not return 503; reflect degraded state in body
|
||||
|
||||
# Check if caller is authorized for detailed status
|
||||
authorized = False
|
||||
try:
|
||||
payload = await auth_required(request)
|
||||
@@ -142,11 +137,9 @@ async def readiness(request: Request):
|
||||
redis_ok = await check_redis()
|
||||
ready = mongo_ok and redis_ok
|
||||
|
||||
# Minimal response for unauthenticated/unauthorized callers
|
||||
if not authorized:
|
||||
return {'status': 'ready' if ready else 'degraded'}
|
||||
|
||||
# Detailed response for authorized callers
|
||||
return {
|
||||
'status': 'ready' if ready else 'degraded',
|
||||
'mongodb': mongo_ok,
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package myapi_v1;
|
||||
message Hello { string name = 1; }
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
syntax = "proto3"; package psvc1_v1; service S { rpc M (R) returns (Q) {} } message R { string a = 1; } message Q { string b = 1; }
|
||||
@@ -0,0 +1 @@
|
||||
syntax = "proto3"; package psvc2_v1; service S { rpc M (R) returns (Q) {} } message R { string a = 1; } message Q { string b = 1; }
|
||||
@@ -0,0 +1,3 @@
|
||||
syntax = "proto3";
|
||||
package sample_v1;
|
||||
message Ping { string msg = 1; }
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, Depends, Request, UploadFile, File, HTTPException
|
||||
from werkzeug.utils import secure_filename
|
||||
from pathlib import Path
|
||||
@@ -17,7 +16,6 @@ from datetime import datetime
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from utils.auth_util import auth_required
|
||||
from utils.response_util import process_response
|
||||
@@ -27,35 +25,29 @@ from utils.role_util import platform_role_required_bool
|
||||
proto_router = APIRouter()
|
||||
logger = logging.getLogger('doorman.gateway')
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.absolute()
|
||||
PROJECT_ROOT = Path(__file__).parent.resolve()
|
||||
|
||||
def sanitize_filename(filename: str):
|
||||
"""Sanitize and validate filename with comprehensive security checks"""
|
||||
if not filename:
|
||||
raise ValueError('Empty filename provided')
|
||||
|
||||
# Check for path traversal attempts
|
||||
if '..' in filename:
|
||||
raise ValueError('Path traversal detected: .. not allowed in filename')
|
||||
|
||||
# Check for absolute paths
|
||||
if filename.startswith('/') or filename.startswith('\\'):
|
||||
raise ValueError('Absolute paths not allowed in filename')
|
||||
|
||||
# Check for drive letters (Windows)
|
||||
if len(filename) >= 2 and filename[1] == ':':
|
||||
raise ValueError('Drive letters not allowed in filename')
|
||||
|
||||
# Check length
|
||||
if len(filename) > 255:
|
||||
raise ValueError('Filename too long (max 255 characters)')
|
||||
|
||||
# Use werkzeug's secure_filename for additional sanitization
|
||||
sanitized = secure_filename(filename)
|
||||
if not sanitized:
|
||||
raise ValueError('Invalid filename after sanitization')
|
||||
|
||||
# Validate pattern (allow only alphanumeric, underscore, dash, dot)
|
||||
safe_pattern = re.compile(r'^[a-zA-Z0-9_\-\.]+$')
|
||||
if not safe_pattern.match(sanitized):
|
||||
raise ValueError('Filename contains invalid characters (use only letters, numbers, underscore, dash, dot)')
|
||||
@@ -76,29 +68,24 @@ def validate_path(base_path: Path, target_path: Path):
|
||||
|
||||
def validate_proto_content(content: bytes, max_size: int = 1024 * 1024) -> str:
|
||||
"""Validate proto file content for security and correctness"""
|
||||
# Check file size
|
||||
if len(content) > max_size:
|
||||
raise ValueError(f'File too large (max {max_size} bytes)')
|
||||
|
||||
# Check for null bytes (binary content)
|
||||
if b'\x00' in content:
|
||||
raise ValueError('Invalid proto file: binary content detected')
|
||||
|
||||
# Check if valid UTF-8
|
||||
try:
|
||||
content_str = content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError('Invalid proto file: not valid UTF-8')
|
||||
|
||||
# Check for basic proto syntax (must contain syntax or message declaration)
|
||||
if 'syntax' not in content_str and 'message' not in content_str and 'service' not in content_str:
|
||||
raise ValueError('Invalid proto file: missing proto syntax (syntax/message/service)')
|
||||
|
||||
# Check for suspicious patterns (shell injection attempts)
|
||||
suspicious_patterns = [
|
||||
r'`', # Backticks
|
||||
r'\$\(', # Command substitution
|
||||
r';\s*(?:rm|mv|cp|chmod|cat|wget|curl)', # Shell commands
|
||||
r'`',
|
||||
r'\$\(',
|
||||
r';\s*(?:rm|mv|cp|chmod|cat|wget|curl)',
|
||||
]
|
||||
for pattern in suspicious_patterns:
|
||||
if re.search(pattern, content_str):
|
||||
@@ -193,9 +180,8 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile =
|
||||
proto_path, generated_dir = get_safe_proto_path(api_name, api_version)
|
||||
content = await file.read()
|
||||
|
||||
# Validate content for security and correctness
|
||||
try:
|
||||
max_proto_size = int(os.getenv('MAX_PROTO_SIZE_BYTES', 1024 * 1024)) # 1MB default
|
||||
max_proto_size = int(os.getenv('MAX_PROTO_SIZE_BYTES', 1024 * 1024))
|
||||
proto_content = validate_proto_content(content, max_size=max_proto_size)
|
||||
except ValueError as e:
|
||||
return process_response(ResponseModel(
|
||||
@@ -230,11 +216,8 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile =
|
||||
raise ValueError('Invalid grpc file path')
|
||||
if pb2_grpc_file.exists():
|
||||
content = pb2_grpc_file.read_text()
|
||||
# Fix the import statement to use 'from generated import' instead of bare 'import'
|
||||
# Match pattern: import {module}_pb2 as {alias}
|
||||
import_pattern = rf'^import {safe_api_name}_{safe_api_version}_pb2 as (.+)$'
|
||||
logger.info(f'{request_id} | Applying import fix with pattern: {import_pattern}')
|
||||
# Show first 10 lines for debugging
|
||||
lines = content.split('\n')[:10]
|
||||
for i, line in enumerate(lines, 1):
|
||||
if 'import' in line and 'pb2' in line:
|
||||
@@ -244,7 +227,6 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile =
|
||||
logger.info(f'{request_id} | Import fix applied successfully')
|
||||
pb2_grpc_file.write_text(new_content)
|
||||
logger.info(f"{request_id} | Wrote fixed pb2_grpc at {pb2_grpc_file}")
|
||||
# Delete .pyc cache files so Python re-compiles from the fixed source
|
||||
pycache_dir = generated_dir / '__pycache__'
|
||||
if pycache_dir.exists():
|
||||
for pyc_file in pycache_dir.glob(f'{safe_api_name}_{safe_api_version}*.pyc'):
|
||||
@@ -253,7 +235,6 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile =
|
||||
logger.info(f'{request_id} | Deleted cache file: {pyc_file.name}')
|
||||
except Exception as e:
|
||||
logger.warning(f'{request_id} | Failed to delete cache file {pyc_file.name}: {e}')
|
||||
# Clear module from sys.modules cache so it gets reimported with fixed code
|
||||
import sys as sys_import
|
||||
pb2_module_name = f'{safe_api_name}_{safe_api_version}_pb2'
|
||||
pb2_grpc_module_name = f'{safe_api_name}_{safe_api_version}_pb2_grpc'
|
||||
@@ -265,7 +246,6 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile =
|
||||
logger.info(f'{request_id} | Cleared {pb2_grpc_module_name} from sys.modules')
|
||||
else:
|
||||
logger.warning(f'{request_id} | Import fix pattern did not match - no changes made')
|
||||
# Second pass: handle relative import form 'from . import X_pb2 as alias'
|
||||
try:
|
||||
rel_pattern = rf'^from \\. import {safe_api_name}_{safe_api_version}_pb2 as (.+)$'
|
||||
content2 = pb2_grpc_file.read_text()
|
||||
@@ -433,10 +413,9 @@ async def update_proto_file(api_name: str, api_version: str, request: Request, p
|
||||
).dict(), 'rest')
|
||||
proto_path, generated_dir = get_safe_proto_path(api_name, api_version)
|
||||
|
||||
# Read and validate content
|
||||
content = await proto_file.read()
|
||||
try:
|
||||
max_proto_size = int(os.getenv('MAX_PROTO_SIZE_BYTES', 1024 * 1024)) # 1MB default
|
||||
max_proto_size = int(os.getenv('MAX_PROTO_SIZE_BYTES', 1024 * 1024))
|
||||
proto_content = validate_proto_content(content, max_size=max_proto_size)
|
||||
except ValueError as e:
|
||||
return process_response(ResponseModel(
|
||||
@@ -446,7 +425,6 @@ async def update_proto_file(api_name: str, api_version: str, request: Request, p
|
||||
error_message=f'Invalid proto file: {str(e)}'
|
||||
).dict(), 'rest')
|
||||
|
||||
# Write validated content
|
||||
proto_path.write_text(proto_content)
|
||||
try:
|
||||
subprocess.run([
|
||||
|
||||
@@ -4,14 +4,12 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from models.role_model_response import RoleModelResponse
|
||||
from models.update_role_model import UpdateRoleModel
|
||||
|
||||
@@ -4,14 +4,12 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.create_routing_model import CreateRoutingModel
|
||||
from models.response_model import ResponseModel
|
||||
from models.routing_model_response import RoutingModelResponse
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Routes for managing security settings.
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Optional
|
||||
import os
|
||||
@@ -12,7 +11,6 @@ import uuid
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from models.security_settings_model import SecuritySettingsModel
|
||||
from utils.response_util import process_response
|
||||
|
||||
@@ -4,13 +4,11 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from services.subscription_service import SubscriptionService
|
||||
from utils.auth_util import auth_required
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Tools and diagnostics routes (e.g., CORS checker).
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
@@ -11,7 +10,6 @@ import uuid
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from utils.response_util import process_response
|
||||
from utils.auth_util import auth_required
|
||||
@@ -189,13 +187,11 @@ async def cors_check(request: Request, body: CorsCheckRequest):
|
||||
end_time = time.time() * 1000
|
||||
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
|
||||
|
||||
|
||||
class ChaosToggleRequest(BaseModel):
|
||||
backend: str = Field(..., description='Backend to toggle (redis|mongo)')
|
||||
enabled: bool = Field(..., description='Enable or disable outage simulation')
|
||||
duration_ms: Optional[int] = Field(default=None, description='Optional duration for outage before auto-disable')
|
||||
|
||||
|
||||
@tools_router.post('/chaos/toggle', description='Toggle simulated backend outages (redis|mongo)', response_model=ResponseModel)
|
||||
async def chaos_toggle(request: Request, body: ChaosToggleRequest):
|
||||
request_id = str(uuid.uuid4())
|
||||
@@ -239,7 +235,6 @@ async def chaos_toggle(request: Request, body: ChaosToggleRequest):
|
||||
end_time = time.time() * 1000
|
||||
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
|
||||
|
||||
|
||||
@tools_router.get('/chaos/stats', description='Get chaos simulation stats', response_model=ResponseModel)
|
||||
async def chaos_stats(request: Request):
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
@@ -4,14 +4,12 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from models.user_model_response import UserModelResponse
|
||||
from services.user_service import UserService
|
||||
|
||||
@@ -4,11 +4,9 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from models.update_api_model import UpdateApiModel
|
||||
from utils.database_async import api_collection
|
||||
|
||||
@@ -4,12 +4,10 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pymongo.errors import PyMongoError
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from models.credit_model import CreditModel
|
||||
from models.user_credits_model import UserCreditModel
|
||||
|
||||
@@ -4,14 +4,12 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
import uuid
|
||||
import logging
|
||||
import os
|
||||
import string as _string
|
||||
from pathlib import Path
|
||||
|
||||
# Internal imports
|
||||
from models.create_endpoint_validation_model import CreateEndpointValidationModel
|
||||
from models.response_model import ResponseModel
|
||||
from models.update_endpoint_model import UpdateEndpointModel
|
||||
@@ -81,7 +79,6 @@ class EndpointService:
|
||||
try:
|
||||
if data.endpoint_method.upper() == 'POST' and str(data.endpoint_uri).strip().lower() == '/grpc':
|
||||
from grpc_tools import protoc as _protoc
|
||||
# Sanitize module base to safe identifier
|
||||
api_name = data.api_name
|
||||
api_version = data.api_version
|
||||
module_base = f'{api_name}_{api_version}'.replace('-', '_')
|
||||
@@ -95,9 +92,6 @@ class EndpointService:
|
||||
proto_dir.mkdir(exist_ok=True)
|
||||
generated_dir.mkdir(exist_ok=True)
|
||||
proto_path = proto_dir / f'{module_base}.proto'
|
||||
# lgtm [py/path-injection]
|
||||
# codeql[py/path-injection]
|
||||
# Safe: filename is derived from sanitized identifier (letters/digits/underscore) under fixed base dir.
|
||||
if not proto_path.exists():
|
||||
proto_content = (
|
||||
'syntax = "proto3";\n'
|
||||
@@ -117,9 +111,6 @@ class EndpointService:
|
||||
'message DeleteRequest { int32 id = 1; }\n'
|
||||
'message DeleteReply { bool ok = 1; }\n'
|
||||
)
|
||||
# lgtm [py/path-injection]
|
||||
# codeql[py/path-injection]
|
||||
# Safe write: controlled path under project 'proto/' using sanitized module_base
|
||||
proto_path.write_text(proto_content, encoding='utf-8')
|
||||
code = _protoc.main([
|
||||
'protoc', f'--proto_path={str(proto_dir)}', f'--python_out={str(generated_dir)}', f'--grpc_python_out={str(generated_dir)}', str(proto_path)
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
@@ -22,23 +21,19 @@ import importlib
|
||||
import string
|
||||
from pathlib import Path
|
||||
|
||||
# Provide a shim for gql.Client so tests can monkeypatch `Client` even when gql
|
||||
# is not installed or used at runtime.
|
||||
try:
|
||||
from gql import Client as _GqlClient # type: ignore
|
||||
from gql import Client as _GqlClient
|
||||
def gql(q):
|
||||
return q
|
||||
except Exception: # pragma: no cover
|
||||
class _GqlClient: # type: ignore
|
||||
except Exception:
|
||||
class _GqlClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
def gql(q): # type: ignore
|
||||
def gql(q):
|
||||
return q
|
||||
|
||||
# Expose symbol name expected by tests
|
||||
Client = _GqlClient
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from utils import api_util, routing_util
|
||||
from utils import credit_util
|
||||
@@ -183,19 +178,15 @@ class GatewayService:
|
||||
ctype = ctype_raw.split(';', 1)[0].strip().lower()
|
||||
body = getattr(response, 'content', b'')
|
||||
|
||||
# Explicit JSON types
|
||||
if ctype in ('application/json', 'application/graphql+json') or 'application/graphql' in ctype:
|
||||
return json.loads(body)
|
||||
|
||||
# Explicit XML types
|
||||
if ctype in ('application/xml', 'text/xml'):
|
||||
return ET.fromstring(body)
|
||||
|
||||
# Known raw passthrough types
|
||||
if ctype in ('application/octet-stream', 'text/plain'):
|
||||
return body
|
||||
|
||||
# Unspecified/empty content-type: attempt best-effort parse
|
||||
if not ctype:
|
||||
try:
|
||||
return json.loads(body)
|
||||
@@ -205,13 +196,8 @@ class GatewayService:
|
||||
except Exception:
|
||||
return body
|
||||
|
||||
# Unknown but explicit type: do not guess; return raw bytes
|
||||
return body
|
||||
|
||||
# ========================================================================
|
||||
# Refactored Helper Methods (P2 #21 - Extract Duplicate Code)
|
||||
# ========================================================================
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_api_from_path(path: str, request_id: str):
|
||||
"""
|
||||
@@ -290,20 +276,16 @@ class GatewayService:
|
||||
if not api or not api.get('api_credits_enabled'):
|
||||
return
|
||||
|
||||
# Add system-level credit API key
|
||||
ai_token_headers = await credit_util.get_credit_api_header(api.get('api_credit_group'))
|
||||
if ai_token_headers:
|
||||
header_name = ai_token_headers[0]
|
||||
header_value = ai_token_headers[1]
|
||||
|
||||
# Handle key rotation: ai_token_headers[1] could be a list [old_key, new_key]
|
||||
if isinstance(header_value, list):
|
||||
# Use new key if available during rotation
|
||||
header_value = header_value[-1] if len(header_value) > 0 else header_value[0]
|
||||
|
||||
headers[header_name] = header_value
|
||||
|
||||
# Override with user-specific API key if available
|
||||
if username and not bool(api.get('api_public')):
|
||||
user_specific_api_key = await credit_util.get_user_api_key(
|
||||
api.get('api_credit_group'),
|
||||
@@ -334,45 +316,32 @@ class GatewayService:
|
||||
|
||||
for k, v in headers.items():
|
||||
try:
|
||||
# Convert key to lowercase and sanitize
|
||||
key = str(k).lower().strip()
|
||||
|
||||
# Skip empty keys
|
||||
if not key:
|
||||
continue
|
||||
|
||||
# Replace invalid characters with hyphens
|
||||
# Keep only alphanumeric, hyphens, underscores, and dots
|
||||
sanitized_key = ''.join(
|
||||
c if c.isalnum() or c in ('-', '_', '.') else '-'
|
||||
for c in key
|
||||
)
|
||||
|
||||
# Skip if sanitization resulted in empty key
|
||||
if not sanitized_key:
|
||||
continue
|
||||
|
||||
# Convert value to string and encode to ASCII
|
||||
value = str(v) if v is not None else ''
|
||||
|
||||
# Try to encode as ASCII to ensure compatibility
|
||||
try:
|
||||
value.encode('ascii')
|
||||
except UnicodeEncodeError:
|
||||
# If value contains non-ASCII, skip it
|
||||
continue
|
||||
|
||||
metadata_list.append((sanitized_key, value))
|
||||
except Exception:
|
||||
# Skip problematic headers silently
|
||||
continue
|
||||
|
||||
return metadata_list
|
||||
|
||||
# ========================================================================
|
||||
# gRPC Input Validation Helpers
|
||||
# ========================================================================
|
||||
|
||||
_IDENT_ALLOWED = set(string.ascii_letters + string.digits + "_")
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
@@ -381,9 +350,6 @@ class GatewayService:
|
||||
try:
|
||||
base_r = base.resolve()
|
||||
target_r = target.resolve()
|
||||
# lgtm [py/path-injection]
|
||||
# codeql[py/path-injection]
|
||||
# Safe check: resolving and enforcing target under base prevents traversal/absolute path escapes.
|
||||
return str(target_r).startswith(str(base_r))
|
||||
except Exception:
|
||||
return False
|
||||
@@ -442,10 +408,6 @@ class GatewayService:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# ========================================================================
|
||||
# Gateway Methods
|
||||
# ========================================================================
|
||||
|
||||
@staticmethod
|
||||
async def rest_gateway(username, request, request_id, start_time, path, url=None, method=None, retry=0):
|
||||
"""
|
||||
@@ -472,7 +434,6 @@ class GatewayService:
|
||||
if not endpoints:
|
||||
return GatewayService.error_response(request_id, 'GTW002', 'No endpoints found for the requested API')
|
||||
regex_pattern = re.compile(r'\{[^/]+\}')
|
||||
# Treat HEAD like GET for endpoint registration matching
|
||||
match_method = 'GET' if str(request.method).upper() == 'HEAD' else request.method
|
||||
composite = match_method + '/' + endpoint_uri
|
||||
if not any(re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in endpoints):
|
||||
@@ -491,7 +452,6 @@ class GatewayService:
|
||||
if not await credit_util.deduct_credit(api.get('api_credit_group'), username):
|
||||
return GatewayService.error_response(request_id, 'GTW008', 'User does not have any credits', status=401)
|
||||
else:
|
||||
# Recursive retry path: url/method provided, but we still need API context
|
||||
try:
|
||||
parts = [p for p in (path or '').split('/') if p]
|
||||
api_name_version = ''
|
||||
@@ -501,7 +461,6 @@ class GatewayService:
|
||||
endpoint_uri = '/'.join(parts[2:])
|
||||
api_key = doorman_cache.get_cache('api_id_cache', api_name_version)
|
||||
api = await api_util.get_api(api_key, api_name_version)
|
||||
# Do not mutate url/method or retry here; caller passed those
|
||||
except Exception:
|
||||
api = None
|
||||
endpoint_uri = ''
|
||||
@@ -510,7 +469,6 @@ class GatewayService:
|
||||
query_params = getattr(request, 'query_params', {})
|
||||
allowed_headers = api.get('api_allowed_headers') or [] if api else []
|
||||
headers = await get_headers(request, allowed_headers)
|
||||
# Propagate request ID to upstream for distributed tracing
|
||||
headers['X-Request-ID'] = request_id
|
||||
if api and api.get('api_credits_enabled'):
|
||||
ai_token_headers = await credit_util.get_credit_api_header(api.get('api_credit_group'))
|
||||
@@ -528,14 +486,11 @@ class GatewayService:
|
||||
swap_from = api.get('api_authorization_field_swap')
|
||||
source_val = None
|
||||
if swap_from:
|
||||
# Look up swap header among forwarded headers (case variants)
|
||||
for key_variant in (swap_from, str(swap_from).lower(), str(swap_from).title()):
|
||||
if key_variant in headers:
|
||||
source_val = headers.get(key_variant)
|
||||
break
|
||||
# Determine original Authorization from incoming request
|
||||
orig_auth = request.headers.get('Authorization') or request.headers.get('authorization')
|
||||
# Apply only when non-empty. Prefer swap header; otherwise preserve original Authorization.
|
||||
if source_val is not None and str(source_val).strip() != '':
|
||||
headers['Authorization'] = source_val
|
||||
elif orig_auth is not None and str(orig_auth).strip() != '':
|
||||
@@ -625,7 +580,6 @@ class GatewayService:
|
||||
try:
|
||||
response_content = http_response.json()
|
||||
except Exception as _e:
|
||||
# Upstream declared JSON but sent malformed body: map to 500 (GTW006)
|
||||
logger.error(f'{request_id} | REST upstream malformed JSON: {str(_e)}')
|
||||
return ResponseModel(
|
||||
status_code=500,
|
||||
@@ -636,7 +590,6 @@ class GatewayService:
|
||||
else:
|
||||
response_content = http_response.text
|
||||
backend_end_time = time.time() * 1000
|
||||
# Retries are handled by the HTTP helper
|
||||
if http_response.status_code == 404:
|
||||
return GatewayService.error_response(request_id, 'GTW005', 'Endpoint does not exist in backend service')
|
||||
logger.info(f'{request_id} | REST gateway status code: {http_response.status_code}')
|
||||
@@ -732,7 +685,6 @@ class GatewayService:
|
||||
if not await credit_util.deduct_credit(api.get('api_credit_group'), username):
|
||||
return GatewayService.error_response(request_id, 'GTW008', 'User does not have any credits', status=401)
|
||||
else:
|
||||
# Recursive call with url present; re-derive API context for headers/validation
|
||||
try:
|
||||
parts = [p for p in (path or '').split('/') if p]
|
||||
api_name_version = ''
|
||||
@@ -756,7 +708,6 @@ class GatewayService:
|
||||
content_type = 'text/xml; charset=utf-8'
|
||||
allowed_headers = api.get('api_allowed_headers') or [] if api else []
|
||||
headers = await get_headers(request, allowed_headers)
|
||||
# Propagate request ID to upstream for distributed tracing
|
||||
headers['X-Request-ID'] = request_id
|
||||
headers['Content-Type'] = content_type
|
||||
if 'SOAPAction' not in headers:
|
||||
@@ -771,7 +722,6 @@ class GatewayService:
|
||||
if key_variant in headers:
|
||||
source_val = headers.get(key_variant)
|
||||
break
|
||||
# Check original Authorization from incoming request
|
||||
orig_auth = request.headers.get('Authorization') or request.headers.get('authorization')
|
||||
if source_val is not None and str(source_val).strip() != '':
|
||||
headers['Authorization'] = source_val
|
||||
@@ -806,7 +756,6 @@ class GatewayService:
|
||||
response_content = http_response.text
|
||||
logger.info(f'{request_id} | SOAP gateway response: {response_content}')
|
||||
backend_end_time = time.time() * 1000
|
||||
# Retries handled by HTTP helper
|
||||
if http_response.status_code == 404:
|
||||
return GatewayService.error_response(request_id, 'GTW005', 'Endpoint does not exist in backend service')
|
||||
logger.info(f'{request_id} | SOAP gateway status code: {http_response.status_code}')
|
||||
@@ -886,7 +835,6 @@ class GatewayService:
|
||||
current_time = time.time() * 1000
|
||||
allowed_headers = api.get('api_allowed_headers') or []
|
||||
headers = await get_headers(request, allowed_headers)
|
||||
# Propagate request ID to upstream for distributed tracing
|
||||
headers['X-Request-ID'] = request_id
|
||||
headers['Content-Type'] = 'application/json'
|
||||
headers['Accept'] = 'application/json'
|
||||
@@ -907,7 +855,6 @@ class GatewayService:
|
||||
if key_variant in headers:
|
||||
source_val = headers.get(key_variant)
|
||||
break
|
||||
# Preserve original Authorization if swap header missing/empty
|
||||
orig_auth = request.headers.get('Authorization') or request.headers.get('authorization')
|
||||
if source_val is not None and str(source_val).strip() != '':
|
||||
headers['Authorization'] = source_val
|
||||
@@ -927,17 +874,14 @@ class GatewayService:
|
||||
except Exception as e:
|
||||
return GatewayService.error_response(request_id, 'GTW011', str(e), status=400)
|
||||
|
||||
# Try test-friendly Client path first (monkeypatchable for tests)
|
||||
# If Client has async context manager support, it's been monkeypatched by tests
|
||||
result = None
|
||||
if hasattr(Client, '__aenter__'):
|
||||
try:
|
||||
async with Client(transport=None, fetch_schema_from_transport=False) as session: # type: ignore
|
||||
result = await session.execute(gql(query), variable_values=variables) # type: ignore
|
||||
async with Client(transport=None, fetch_schema_from_transport=False) as session:
|
||||
result = await session.execute(gql(query), variable_values=variables)
|
||||
except Exception as _e:
|
||||
logger.debug(f'{request_id} | GraphQL Client execution failed; falling back to HTTP: {_e}')
|
||||
|
||||
# Default path: use httpx to call upstream GraphQL server
|
||||
if result is None:
|
||||
client_key = request.headers.get('client-key')
|
||||
server = await routing_util.pick_upstream_server(api, 'POST', '/graphql', client_key)
|
||||
@@ -956,7 +900,6 @@ class GatewayService:
|
||||
api_config=api,
|
||||
)
|
||||
except AttributeError:
|
||||
# Fallback for tests that stub AsyncClient without `.request`
|
||||
http_resp = await client.post(
|
||||
url,
|
||||
json={'query': query, 'variables': variables},
|
||||
@@ -1053,7 +996,6 @@ class GatewayService:
|
||||
logger.error(f'{request_id} | Invalid JSON in request body')
|
||||
return GatewayService.error_response(request_id, 'GTW011', 'Invalid JSON in request body', status=400)
|
||||
|
||||
# Validate method and optional package inputs early
|
||||
parsed = GatewayService._parse_and_validate_method(body.get('method'))
|
||||
if not parsed:
|
||||
return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.', status=400)
|
||||
@@ -1074,7 +1016,6 @@ class GatewayService:
|
||||
await validation_util.validate_grpc_request(endpoint_id, body.get('message'))
|
||||
except Exception as e:
|
||||
return GatewayService.error_response(request_id, 'GTW011', str(e), status=400)
|
||||
# Resolve and validate module base: API config > request override > default derived
|
||||
api_pkg_raw = None
|
||||
try:
|
||||
api_pkg_raw = (api.get('api_grpc_package') or '').strip() if api else None
|
||||
@@ -1085,7 +1026,6 @@ class GatewayService:
|
||||
pkg_override_valid = GatewayService._validate_package_name(pkg_override) if pkg_override else None
|
||||
default_base = f'{api_name}_{api_version}'.replace('-', '_')
|
||||
if not GatewayService._is_valid_identifier(default_base):
|
||||
# Final fallback: strip invalid chars from default
|
||||
default_base = ''.join(ch if ch in GatewayService._IDENT_ALLOWED else '_' for ch in default_base)
|
||||
module_base = (api_pkg or pkg_override_valid or default_base)
|
||||
try:
|
||||
@@ -1096,7 +1036,6 @@ class GatewayService:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Allow-list enforcement (package/service/method)
|
||||
try:
|
||||
allowed_pkgs = api.get('api_grpc_allowed_packages') if api else None
|
||||
allowed_svcs = api.get('api_grpc_allowed_services') if api else None
|
||||
@@ -1104,7 +1043,6 @@ class GatewayService:
|
||||
|
||||
service_name, method_name = _service_name_preview, _method_name_preview
|
||||
|
||||
# If allow-lists are configured, enforce them
|
||||
if allowed_pkgs and isinstance(allowed_pkgs, list):
|
||||
if module_base not in allowed_pkgs:
|
||||
return GatewayService.error_response(
|
||||
@@ -1122,7 +1060,6 @@ class GatewayService:
|
||||
request_id, 'GTW013', 'gRPC method not allowed', status=403
|
||||
)
|
||||
except Exception:
|
||||
# On any unexpected error, default to safe deny with 403
|
||||
return GatewayService.error_response(
|
||||
request_id, 'GTW013', 'gRPC target not allowed', status=403
|
||||
)
|
||||
@@ -1134,10 +1071,6 @@ class GatewayService:
|
||||
if not GatewayService._validate_under_base(project_root, proto_path):
|
||||
return GatewayService.error_response(request_id, 'GTW012', 'Invalid path for proto resolution', status=400)
|
||||
|
||||
# Ensure both project root (for 'from generated import ...') and
|
||||
# the generated directory itself (for 'import ..._pb2') are on sys.path.
|
||||
# The proto upload flow may rewrite imports to 'from generated import ...',
|
||||
# which requires the parent directory of 'generated' to be importable.
|
||||
generated_dir = project_root / 'generated'
|
||||
gen_dir_str = str(generated_dir)
|
||||
proj_root_str = str(project_root)
|
||||
@@ -1150,7 +1083,6 @@ class GatewayService:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try to import generated modules first (tests monkeypatch import_module)
|
||||
pb2 = None
|
||||
pb2_grpc = None
|
||||
try:
|
||||
@@ -1160,7 +1092,6 @@ class GatewayService:
|
||||
pb2 = importlib.import_module(pb2_name)
|
||||
pb2_grpc = importlib.import_module(pb2_grpc_name)
|
||||
except ModuleNotFoundError:
|
||||
# Fallback to 'generated.<name>' if rewrite expects package import
|
||||
gen_pb2_name = f'generated.{module_base}_pb2'
|
||||
gen_pb2_grpc_name = f'generated.{module_base}_pb2_grpc'
|
||||
pb2 = importlib.import_module(gen_pb2_name)
|
||||
@@ -1170,7 +1101,6 @@ class GatewayService:
|
||||
logger.warning(f"{request_id} | gRPC modules not found, will attempt proto generation: {str(mnf_exc)}")
|
||||
except ImportError as imp_exc:
|
||||
logger.error(f"{request_id} | ImportError loading gRPC modules (likely broken import in generated file): {str(imp_exc)}")
|
||||
# Clear the module from cache and return specific error
|
||||
mod_pb2 = f'{module_base}_pb2'
|
||||
mod_pb2_grpc = f'{module_base}_pb2_grpc'
|
||||
if mod_pb2 in sys.modules:
|
||||
@@ -1196,9 +1126,6 @@ class GatewayService:
|
||||
try:
|
||||
proto_dir.mkdir(exist_ok=True)
|
||||
try:
|
||||
# lgtm [py/path-injection]
|
||||
# codeql[py/path-injection]
|
||||
# Safe: path check/creation under fixed project directories; user input sanitized earlier
|
||||
logger.info(f"{request_id} | gRPC generated check: proto_path={proto_path} exists={proto_path.exists()} generated_dir={generated_dir} pb2={module_base}_pb2.py={ (generated_dir / (module_base + '_pb2.py')).exists() }")
|
||||
except Exception:
|
||||
pass
|
||||
@@ -1226,14 +1153,11 @@ class GatewayService:
|
||||
'message DeleteRequest { int32 id = 1; }\n'
|
||||
'message DeleteReply { bool ok = 1; }\n'
|
||||
)
|
||||
# lgtm [py/path-injection]
|
||||
# codeql[py/path-injection]
|
||||
# Safe write: sanitized module_base + fixed 'proto' base directory
|
||||
proto_path.write_text(proto_content, encoding='utf-8')
|
||||
generated_dir = project_root / 'generated'
|
||||
generated_dir.mkdir(exist_ok=True)
|
||||
try:
|
||||
from grpc_tools import protoc as _protoc # type: ignore
|
||||
from grpc_tools import protoc as _protoc
|
||||
code = _protoc.main([
|
||||
'protoc', f'--proto_path={str(proto_dir)}', f'--python_out={str(generated_dir)}', f'--grpc_python_out={str(generated_dir)}', str(proto_path)
|
||||
])
|
||||
@@ -1244,7 +1168,6 @@ class GatewayService:
|
||||
init_path.write_text('"""Generated gRPC code."""\n', encoding='utf-8')
|
||||
except Exception as ge:
|
||||
logger.error(f'{request_id} | On-demand proto generation failed: {ge}')
|
||||
# In test mode, allow fallback without generation
|
||||
if os.getenv('DOORMAN_TEST_MODE', '').lower() == 'true':
|
||||
pb2 = type('PB2', (), {})
|
||||
pb2_grpc = type('SVC', (), {})
|
||||
@@ -1274,22 +1197,17 @@ class GatewayService:
|
||||
if not await credit_util.deduct_credit(api.get('api_credit_group'), username):
|
||||
return GatewayService.error_response(request_id, 'GTW008', 'User does not have any credits', status=401)
|
||||
current_time = time.time() * 1000
|
||||
# Ensure api is available even in retry recursion
|
||||
try:
|
||||
if not url:
|
||||
# already resolved above
|
||||
pass
|
||||
else:
|
||||
# When called recursively with url present, rebuild api_path
|
||||
api_version = request.headers.get('X-API-Version', 'v1')
|
||||
if not api_name:
|
||||
# Derive from request path
|
||||
path_parts = (path or '').strip('/').split('/')
|
||||
api_name = path_parts[-1] if path_parts else None
|
||||
if api_name:
|
||||
api_path = f'{api_name}/{api_version}'
|
||||
api = doorman_cache.get_cache('api_cache', api_path) or await api_util.get_api(None, api_path)
|
||||
# Recompute module_base in recursive path
|
||||
try:
|
||||
api_pkg_raw = (api.get('api_grpc_package') or '').strip() if api else None
|
||||
except Exception:
|
||||
@@ -1301,7 +1219,6 @@ class GatewayService:
|
||||
if not GatewayService._is_valid_identifier(default_base):
|
||||
default_base = ''.join(ch if ch in GatewayService._IDENT_ALLOWED else '_' for ch in default_base)
|
||||
module_base = (api_pkg or pkg_override_valid or default_base)
|
||||
# Enforce allow-lists in recursive path as well
|
||||
try:
|
||||
allowed_pkgs = api.get('api_grpc_allowed_packages') if api else None
|
||||
allowed_svcs = api.get('api_grpc_allowed_services') if api else None
|
||||
@@ -1336,7 +1253,6 @@ class GatewayService:
|
||||
pass
|
||||
allowed_headers = (api or {}).get('api_allowed_headers') or []
|
||||
headers = await get_headers(request, allowed_headers)
|
||||
# Propagate request ID to upstream for distributed tracing
|
||||
headers['X-Request-ID'] = request_id
|
||||
try:
|
||||
body = await request.json()
|
||||
@@ -1352,15 +1268,12 @@ class GatewayService:
|
||||
if 'message' not in body:
|
||||
logger.error(f'{request_id} | Missing message in request body')
|
||||
return GatewayService.error_response(request_id, 'GTW011', 'Missing message in request body', status=400)
|
||||
# Validate method and (optional) package at this stage as well
|
||||
parsed_method = GatewayService._parse_and_validate_method(body.get('method'))
|
||||
if not parsed_method:
|
||||
return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.', status=400)
|
||||
# Validate package override if present
|
||||
pkg_override = (body.get('package') or '').strip() or None
|
||||
if pkg_override and GatewayService._validate_package_name(pkg_override) is None:
|
||||
return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC package. Use letters, digits, underscore only.', status=400)
|
||||
# Re-apply allow-list checks after validation in this path
|
||||
try:
|
||||
svc_name, mth_name = parsed_method
|
||||
allowed_pkgs = api.get('api_grpc_allowed_packages') if api else None
|
||||
@@ -1374,10 +1287,9 @@ class GatewayService:
|
||||
return GatewayService.error_response(request_id, 'GTW013', 'gRPC method not allowed', status=403)
|
||||
except Exception:
|
||||
return GatewayService.error_response(request_id, 'GTW013', 'gRPC target not allowed', status=403)
|
||||
# Preserve previously resolved module_base (api_grpc_package > request package > default)
|
||||
proto_rel = Path(module_base.replace('.', '/'))
|
||||
proto_filename = f'{proto_rel.name}.proto'
|
||||
|
||||
|
||||
try:
|
||||
endpoint_doc = await api_util.get_endpoint(api, 'POST', '/grpc')
|
||||
endpoint_id = endpoint_doc.get('endpoint_id') if endpoint_doc else None
|
||||
@@ -1386,11 +1298,10 @@ class GatewayService:
|
||||
except Exception as e:
|
||||
return GatewayService.error_response(request_id, 'GTW011', str(e), status=400)
|
||||
proto_path = (GatewayService._PROJECT_ROOT / 'proto' / proto_rel.with_suffix('.proto'))
|
||||
# Prefer modules imported earlier (tests may monkeypatch importlib)
|
||||
use_imported = False
|
||||
try:
|
||||
if 'pb2' in locals() and 'pb2_grpc' in locals():
|
||||
use_imported = (pb2 is not None and pb2_grpc is not None) # type: ignore[name-defined]
|
||||
use_imported = (pb2 is not None and pb2_grpc is not None)
|
||||
except Exception:
|
||||
use_imported = False
|
||||
module_name = module_base
|
||||
@@ -1410,15 +1321,11 @@ class GatewayService:
|
||||
pb2_module = None
|
||||
service_module = None
|
||||
if use_imported:
|
||||
pb2_module = pb2 # type: ignore[name-defined]
|
||||
service_module = pb2_grpc # type: ignore[name-defined]
|
||||
pb2_module = pb2
|
||||
service_module = pb2_grpc
|
||||
logger.info(f"{request_id} | Using imported gRPC modules for {module_name}")
|
||||
else:
|
||||
# lgtm [py/path-injection]
|
||||
# codeql[py/path-injection]
|
||||
# Safe existence check: proto_path built from sanitized package under fixed project root
|
||||
if not proto_path.exists():
|
||||
# In test mode, allow direct import via monkeypatched importlib
|
||||
if os.getenv('DOORMAN_TEST_MODE', '').lower() == 'true':
|
||||
try:
|
||||
pb2_module = importlib.import_module(f'{module_name}_pb2')
|
||||
@@ -1428,15 +1335,10 @@ class GatewayService:
|
||||
logger.error(f'{request_id} | Proto file not found: {str(proto_path)}')
|
||||
return GatewayService.error_response(request_id, 'GTW012', f'Proto file not found for API: {api_path}', status=404)
|
||||
if not use_imported:
|
||||
# Ensure generated files exist when not using imported modules
|
||||
pb2_path = package_dir / f"{parts[-1]}_pb2.py"
|
||||
pb2_grpc_path = package_dir / f"{parts[-1]}_pb2_grpc.py"
|
||||
# lgtm [py/path-injection]
|
||||
# codeql[py/path-injection]
|
||||
# Safe: filenames are derived from validated identifiers and constrained to generated dir
|
||||
if not (pb2_path.is_file() and pb2_grpc_path.is_file()):
|
||||
logger.error(f"{request_id} | Generated modules not found for '{module_name}' pb2={pb2_path} exists={pb2_path.is_file()} pb2_grpc={pb2_grpc_path} exists={pb2_grpc_path.is_file()}")
|
||||
# If upstream is HTTP-based, fall back to HTTP call
|
||||
if isinstance(url, str) and url.startswith(('http://', 'https://')):
|
||||
try:
|
||||
client = GatewayService.get_http_client()
|
||||
@@ -1464,7 +1366,6 @@ class GatewayService:
|
||||
return GatewayService.error_response(request_id, 'GTW012', f'Generated gRPC modules not found for package: {module_name}', status=404)
|
||||
if not use_imported:
|
||||
try:
|
||||
# Guard against unexpected module names by re-validating the module_name
|
||||
if GatewayService._validate_package_name(module_name) is None:
|
||||
return GatewayService.error_response(request_id, 'GTW012', 'Invalid gRPC module name', status=400)
|
||||
import_name_pb2 = f'{module_name}_pb2'
|
||||
@@ -1474,7 +1375,6 @@ class GatewayService:
|
||||
pb2_module = importlib.import_module(import_name_pb2)
|
||||
service_module = importlib.import_module(import_name_grpc)
|
||||
except ModuleNotFoundError:
|
||||
# Try the 'generated.' package path as a fallback
|
||||
alt_pb2 = f'generated.{module_name}_pb2'
|
||||
alt_grpc = f'generated.{module_name}_pb2_grpc'
|
||||
logger.info(f"{request_id} | Retrying import via generated package: {alt_pb2} and {alt_grpc}")
|
||||
@@ -1482,7 +1382,6 @@ class GatewayService:
|
||||
service_module = importlib.import_module(alt_grpc)
|
||||
except ImportError as e:
|
||||
logger.error(f'{request_id} | Failed to import gRPC module: {str(e)}', exc_info=True)
|
||||
# If upstream is HTTP-based, fall back to HTTP call
|
||||
if isinstance(url, str) and url.startswith(('http://', 'https://')):
|
||||
try:
|
||||
client = GatewayService.get_http_client()
|
||||
@@ -1512,7 +1411,6 @@ class GatewayService:
|
||||
if not parsed:
|
||||
return GatewayService.error_response(request_id, 'GTW011', 'Invalid gRPC method. Use Service.Method with alphanumerics/underscore.', status=400)
|
||||
service_name, method_name = parsed
|
||||
# If upstream is HTTP-based, fall back to HTTP call regardless of generated modules
|
||||
if isinstance(url, str) and url.startswith(("http://", "https://")):
|
||||
try:
|
||||
client = GatewayService.get_http_client()
|
||||
@@ -1538,21 +1436,18 @@ class GatewayService:
|
||||
response=(http_response.json() if http_response.headers.get('Content-Type','').startswith('application/json') else http_response.text)
|
||||
).dict()
|
||||
|
||||
# Defer type validation to attribute access below; avoid premature 500s in test stubs
|
||||
logger.info(f"{request_id} | Connecting to gRPC upstream: {url}")
|
||||
channel = grpc.aio.insecure_channel(url)
|
||||
try:
|
||||
await asyncio.wait_for(channel.channel_ready(), timeout=2.0)
|
||||
except Exception:
|
||||
pass
|
||||
# Resolve request/reply message classes from pb2_module
|
||||
request_class_name = f'{method_name}Request'
|
||||
reply_class_name = f'{method_name}Reply'
|
||||
|
||||
try:
|
||||
logger.info(f"{request_id} | Resolving message types: {request_class_name} and {reply_class_name} from pb2_module={getattr(pb2_module, '__name__', 'unknown')}")
|
||||
|
||||
# Verify pb2_module is not None
|
||||
if pb2_module is None:
|
||||
logger.error(f'{request_id} | pb2_module is None - cannot resolve message types')
|
||||
return GatewayService.error_response(
|
||||
@@ -1562,7 +1457,6 @@ class GatewayService:
|
||||
status=500
|
||||
)
|
||||
|
||||
# Get request and reply classes
|
||||
try:
|
||||
request_class = getattr(pb2_module, request_class_name)
|
||||
reply_class = getattr(pb2_module, reply_class_name)
|
||||
@@ -1575,7 +1469,6 @@ class GatewayService:
|
||||
status=500
|
||||
)
|
||||
|
||||
# Create request message instance
|
||||
try:
|
||||
request_message = request_class()
|
||||
logger.info(f'{request_id} | Successfully created request message of type {request_class_name}')
|
||||
@@ -1601,7 +1494,6 @@ class GatewayService:
|
||||
setattr(request_message, key, value)
|
||||
except Exception:
|
||||
pass
|
||||
# Retry policy configuration
|
||||
attempts = max(1, int(retry) + 1)
|
||||
env_max_retries = 0
|
||||
try:
|
||||
@@ -1619,7 +1511,6 @@ class GatewayService:
|
||||
except Exception:
|
||||
base_ms, max_ms = 100, 1000
|
||||
|
||||
# Determine idempotency: default true for unary/server-stream, false for client/bidi unless overridden
|
||||
stream_mode = str((body.get('stream') or body.get('streaming') or '')).lower()
|
||||
idempotent_override = body.get('idempotent')
|
||||
if idempotent_override is not None:
|
||||
@@ -1627,7 +1518,6 @@ class GatewayService:
|
||||
else:
|
||||
is_idempotent = not (stream_mode.startswith('client') or stream_mode.startswith('bidi') or stream_mode.startswith('bi'))
|
||||
|
||||
# Retryable gRPC status codes
|
||||
retryable = {
|
||||
grpc.StatusCode.UNAVAILABLE,
|
||||
grpc.StatusCode.DEADLINE_EXCEEDED,
|
||||
@@ -1645,7 +1535,6 @@ class GatewayService:
|
||||
pass
|
||||
for attempt in range(attempts):
|
||||
try:
|
||||
# Prefer direct unary call via channel for better error mapping
|
||||
full_method = f'/{module_base}.{service_name}/{method_name}'
|
||||
try:
|
||||
logger.info(f"{request_id} | gRPC attempt={attempt+1}/{attempts} calling {full_method}")
|
||||
@@ -1654,8 +1543,6 @@ class GatewayService:
|
||||
req_ser = getattr(request_message, 'SerializeToString', None)
|
||||
if not callable(req_ser):
|
||||
req_ser = (lambda _m: b'')
|
||||
# Choose streaming or unary based on request body hint (computed above)
|
||||
# Sanitize HTTP headers for gRPC metadata compatibility
|
||||
metadata_list = GatewayService._sanitize_grpc_metadata(headers or {})
|
||||
if stream_mode.startswith('server'):
|
||||
call = channel.unary_stream(
|
||||
@@ -1678,7 +1565,6 @@ class GatewayService:
|
||||
response = type('R', (), {'DESCRIPTOR': type('D', (), {'fields': []})(), 'ok': True, '_items': items})()
|
||||
got_response = True
|
||||
elif stream_mode.startswith('client'):
|
||||
# Client-streaming: send a stream of request messages, get single reply
|
||||
try:
|
||||
stream = channel.stream_unary(
|
||||
full_method,
|
||||
@@ -1702,7 +1588,6 @@ class GatewayService:
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# Fallback to base request_message
|
||||
msg = request_message
|
||||
yield msg
|
||||
except Exception:
|
||||
@@ -1725,7 +1610,6 @@ class GatewayService:
|
||||
response = await unary(request_message)
|
||||
got_response = True
|
||||
elif stream_mode.startswith('bidi') or stream_mode.startswith('bi'):
|
||||
# Bi-directional streaming: send stream, collect responses up to max_items
|
||||
try:
|
||||
bidi = channel.stream_stream(
|
||||
full_method,
|
||||
@@ -1806,21 +1690,17 @@ class GatewayService:
|
||||
except Exception:
|
||||
logger.info(f"{request_id} | gRPC primary call raised non-grpc exception")
|
||||
final_code_name = str(code.name) if getattr(code, 'name', None) else 'ERROR'
|
||||
# Backoff/retry only if idempotent and code is retryable and attempts remain
|
||||
if attempt < attempts - 1 and is_idempotent and code in retryable:
|
||||
retries_made += 1
|
||||
# Exponential backoff with jitter
|
||||
delay = min(max_ms, base_ms * (2 ** attempt)) / 1000.0
|
||||
jitter_factor = 1.0 + (random.random() * jitter - (jitter / 2.0))
|
||||
await asyncio.sleep(max(0.01, delay * jitter_factor))
|
||||
continue
|
||||
# Try alternative method path without package prefix
|
||||
try:
|
||||
alt_method = f'/{service_name}/{method_name}'
|
||||
req_ser = getattr(request_message, 'SerializeToString', None)
|
||||
if not callable(req_ser):
|
||||
req_ser = (lambda _m: b'')
|
||||
# reuse computed stream_mode
|
||||
if stream_mode.startswith('server'):
|
||||
call2 = channel.unary_stream(
|
||||
alt_method,
|
||||
@@ -1970,10 +1850,8 @@ class GatewayService:
|
||||
await asyncio.sleep(max(0.01, delay * jitter_factor))
|
||||
continue
|
||||
else:
|
||||
# Do not mask channel errors with stub fallback; propagate
|
||||
break
|
||||
if last_exc is not None:
|
||||
# Extract gRPC status code from exception
|
||||
code_name = 'UNKNOWN'
|
||||
code_obj = None
|
||||
try:
|
||||
@@ -1986,31 +1864,28 @@ class GatewayService:
|
||||
except Exception as code_extract_err:
|
||||
logger.warning(f"{request_id} | Failed to extract gRPC status code: {str(code_extract_err)}")
|
||||
|
||||
# Comprehensive gRPC status code to HTTP status code mapping
|
||||
# Based on: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
|
||||
status_map = {
|
||||
'OK': 200, # Success
|
||||
'CANCELLED': 499, # Client Closed Request
|
||||
'UNKNOWN': 500, # Internal Server Error
|
||||
'INVALID_ARGUMENT': 400, # Bad Request
|
||||
'DEADLINE_EXCEEDED': 504, # Gateway Timeout
|
||||
'NOT_FOUND': 404, # Not Found
|
||||
'ALREADY_EXISTS': 409, # Conflict
|
||||
'PERMISSION_DENIED': 403, # Forbidden
|
||||
'RESOURCE_EXHAUSTED': 429, # Too Many Requests
|
||||
'FAILED_PRECONDITION': 412, # Precondition Failed
|
||||
'ABORTED': 409, # Conflict
|
||||
'OUT_OF_RANGE': 400, # Bad Request
|
||||
'UNIMPLEMENTED': 501, # Not Implemented
|
||||
'INTERNAL': 500, # Internal Server Error
|
||||
'UNAVAILABLE': 503, # Service Unavailable
|
||||
'DATA_LOSS': 500, # Internal Server Error
|
||||
'UNAUTHENTICATED': 401, # Unauthorized
|
||||
'OK': 200,
|
||||
'CANCELLED': 499,
|
||||
'UNKNOWN': 500,
|
||||
'INVALID_ARGUMENT': 400,
|
||||
'DEADLINE_EXCEEDED': 504,
|
||||
'NOT_FOUND': 404,
|
||||
'ALREADY_EXISTS': 409,
|
||||
'PERMISSION_DENIED': 403,
|
||||
'RESOURCE_EXHAUSTED': 429,
|
||||
'FAILED_PRECONDITION': 412,
|
||||
'ABORTED': 409,
|
||||
'OUT_OF_RANGE': 400,
|
||||
'UNIMPLEMENTED': 501,
|
||||
'INTERNAL': 500,
|
||||
'UNAVAILABLE': 503,
|
||||
'DATA_LOSS': 500,
|
||||
'UNAUTHENTICATED': 401,
|
||||
}
|
||||
|
||||
http_status = status_map.get(code_name, 500)
|
||||
|
||||
# Extract error details from exception
|
||||
details = 'gRPC call failed'
|
||||
try:
|
||||
details_fn = getattr(last_exc, 'details', None)
|
||||
@@ -2041,7 +1916,6 @@ class GatewayService:
|
||||
error_code='GTW006',
|
||||
error_message=str(details)[:255]
|
||||
).dict()
|
||||
# If we somehow reach here without a response and no exception, log and fail predictably
|
||||
if not got_response and last_exc is None:
|
||||
try:
|
||||
logger.error(f"{request_id} | gRPC loop ended with no response and no exception; returning 500 UNKNOWN")
|
||||
@@ -2090,8 +1964,6 @@ class GatewayService:
|
||||
error_message='Gateway timeout'
|
||||
).dict()
|
||||
except Exception as e:
|
||||
# Catch-all exception handler for errors outside main gRPC call loop
|
||||
# Try to extract gRPC status code if this is a gRPC exception
|
||||
code_name = 'UNKNOWN'
|
||||
code_obj = None
|
||||
try:
|
||||
@@ -2099,10 +1971,8 @@ class GatewayService:
|
||||
if code_obj and hasattr(code_obj, 'name'):
|
||||
code_name = str(code_obj.name).upper()
|
||||
except Exception:
|
||||
# Not a gRPC exception, use exception type name
|
||||
code_name = type(e).__name__.upper()
|
||||
|
||||
# Use the same comprehensive status mapping
|
||||
status_map = {
|
||||
'OK': 200,
|
||||
'CANCELLED': 499,
|
||||
@@ -2125,7 +1995,6 @@ class GatewayService:
|
||||
|
||||
http_status = status_map.get(code_name, 500)
|
||||
|
||||
# Extract error details
|
||||
details = str(e)
|
||||
try:
|
||||
details_fn = getattr(e, 'details', None)
|
||||
|
||||
@@ -4,11 +4,9 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from models.update_group_model import UpdateGroupModel
|
||||
from utils.database import group_collection
|
||||
|
||||
@@ -4,7 +4,6 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
@@ -22,12 +21,10 @@ class LoggingService:
|
||||
if env_dir and str(env_dir).strip():
|
||||
self.log_directory = os.path.abspath(env_dir)
|
||||
else:
|
||||
# Match doorman.py default: <backend-services>/platform-logs
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
backend_root = os.path.normpath(os.path.join(base_dir, '..'))
|
||||
candidate = os.path.join(backend_root, 'platform-logs')
|
||||
self.log_directory = candidate if os.path.isdir(candidate) else os.path.join(backend_root, 'logs')
|
||||
# Include both gateway logs and audit trail
|
||||
self.log_file_patterns = ['doorman.log*', 'doorman-trail.log*']
|
||||
self.max_logs_per_request = 1000
|
||||
|
||||
@@ -282,13 +279,11 @@ class LoggingService:
|
||||
"""
|
||||
try:
|
||||
s = line.strip()
|
||||
# JSON format from doorman JSONFormatter
|
||||
if s.startswith('{') and s.endswith('}'):
|
||||
try:
|
||||
rec = json.loads(s)
|
||||
timestamp_str = rec.get('time') or rec.get('timestamp')
|
||||
try:
|
||||
# time is already an ISO-like string per formatter
|
||||
timestamp = timestamp_str or datetime.utcnow().isoformat()
|
||||
except Exception:
|
||||
timestamp = datetime.utcnow().isoformat()
|
||||
@@ -306,7 +301,6 @@ class LoggingService:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Plain text format: "ts - logger - level - message"
|
||||
parts = s.split(' - ', 3)
|
||||
if len(parts) < 4:
|
||||
return None
|
||||
|
||||
@@ -4,11 +4,9 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from models.update_role_model import UpdateRoleModel
|
||||
from utils.database_async import role_collection
|
||||
|
||||
@@ -4,12 +4,10 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from models.create_routing_model import CreateRoutingModel
|
||||
from models.update_routing_model import UpdateRoutingModel
|
||||
|
||||
@@ -4,10 +4,8 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
import logging
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from utils.database import subscriptions_collection, api_collection
|
||||
from utils.cache_manager_util import cache_manager
|
||||
|
||||
@@ -4,13 +4,11 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/apidoorman/doorman for more information
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from typing import List
|
||||
from fastapi import HTTPException
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
# Internal imports
|
||||
from models.response_model import ResponseModel
|
||||
from utils import password_util
|
||||
from utils.database_async import user_collection, subscriptions_collection, api_collection
|
||||
|
||||
@@ -9,7 +9,6 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
class _RedactFilter(logging.Filter):
|
||||
PATTERNS = [
|
||||
re.compile(r'(?i)(authorization\s*[:=]\s*)([^;\r\n]+)'),
|
||||
@@ -32,10 +31,8 @@ class _RedactFilter(logging.Filter):
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
def _ensure_logger(name: str):
|
||||
logger = logging.getLogger(name)
|
||||
# If the logger already has a handler with a filter, leave it alone
|
||||
for h in logger.handlers:
|
||||
if h.filters:
|
||||
return
|
||||
@@ -44,7 +41,6 @@ def _ensure_logger(name: str):
|
||||
h.addFilter(_RedactFilter())
|
||||
logger.addHandler(h)
|
||||
|
||||
|
||||
try:
|
||||
_ensure_logger('doorman.gateway')
|
||||
_ensure_logger('doorman.logging')
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import Dict, Any
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
# Async imports
|
||||
from utils.database_async import (
|
||||
user_collection as async_user_collection,
|
||||
api_collection as async_api_collection,
|
||||
@@ -18,7 +17,6 @@ from utils.database_async import (
|
||||
)
|
||||
from utils.doorman_cache_async import async_doorman_cache
|
||||
|
||||
# Sync imports for comparison
|
||||
from utils.database import (
|
||||
user_collection as sync_user_collection,
|
||||
api_collection as sync_api_collection
|
||||
@@ -27,20 +25,16 @@ from utils.doorman_cache_util import doorman_cache
|
||||
|
||||
router = APIRouter(prefix="/test/async", tags=["Async Testing"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def async_health_check() -> Dict[str, Any]:
|
||||
"""Test async database and cache health."""
|
||||
try:
|
||||
# Test async database
|
||||
if async_database.is_memory_only():
|
||||
db_status = "memory_only"
|
||||
else:
|
||||
# Try a simple query
|
||||
await async_user_collection.find_one({'username': 'admin'})
|
||||
db_status = "connected"
|
||||
|
||||
# Test async cache
|
||||
cache_operational = await async_doorman_cache.is_operational()
|
||||
cache_info = await async_doorman_cache.get_cache_info()
|
||||
|
||||
@@ -58,18 +52,15 @@ async def async_health_check() -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/performance/sync")
|
||||
async def test_sync_performance() -> Dict[str, Any]:
|
||||
"""Test SYNC (blocking) database operations - SLOW under load."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# These operations BLOCK the event loop
|
||||
user = sync_user_collection.find_one({'username': 'admin'})
|
||||
apis = list(sync_api_collection.find({}).limit(10))
|
||||
|
||||
# Cache operations also BLOCK
|
||||
cached_user = doorman_cache.get_cache('user_cache', 'admin')
|
||||
if not cached_user:
|
||||
doorman_cache.set_cache('user_cache', 'admin', user)
|
||||
@@ -86,25 +77,20 @@ async def test_sync_performance() -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Sync test failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/performance/async")
|
||||
async def test_async_performance() -> Dict[str, Any]:
|
||||
"""Test ASYNC (non-blocking) database operations - FAST under load."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# These operations are NON-BLOCKING
|
||||
user = await async_user_collection.find_one({'username': 'admin'})
|
||||
|
||||
if async_database.is_memory_only():
|
||||
# In memory mode, to_list is sync
|
||||
apis = async_api_collection.find({}).limit(10)
|
||||
apis = list(apis)
|
||||
else:
|
||||
# In MongoDB mode, to_list is async
|
||||
apis = await async_api_collection.find({}).limit(10).to_list(length=10)
|
||||
|
||||
# Cache operations also NON-BLOCKING
|
||||
cached_user = await async_doorman_cache.get_cache('user_cache', 'admin')
|
||||
if not cached_user:
|
||||
await async_doorman_cache.set_cache('user_cache', 'admin', user)
|
||||
@@ -121,14 +107,12 @@ async def test_async_performance() -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Async test failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/performance/parallel")
|
||||
async def test_parallel_performance() -> Dict[str, Any]:
|
||||
"""Test PARALLEL async operations - Maximum performance."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Execute multiple operations in PARALLEL
|
||||
user_task = async_user_collection.find_one({'username': 'admin'})
|
||||
|
||||
if async_database.is_memory_only():
|
||||
@@ -140,14 +124,12 @@ async def test_parallel_performance() -> Dict[str, Any]:
|
||||
|
||||
cache_task = async_doorman_cache.get_cache('user_cache', 'admin')
|
||||
|
||||
# Wait for all operations to complete in parallel
|
||||
user, apis, cached_user = await asyncio.gather(
|
||||
user_task,
|
||||
apis_task,
|
||||
cache_task
|
||||
)
|
||||
|
||||
# Cache if needed
|
||||
if not cached_user and user:
|
||||
await async_doorman_cache.set_cache('user_cache', 'admin', user)
|
||||
|
||||
@@ -163,7 +145,6 @@ async def test_parallel_performance() -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Parallel test failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/cache/test")
|
||||
async def test_cache_operations() -> Dict[str, Any]:
|
||||
"""Test async cache operations."""
|
||||
@@ -175,16 +156,12 @@ async def test_cache_operations() -> Dict[str, Any]:
|
||||
"role": "user"
|
||||
}
|
||||
|
||||
# Test set
|
||||
await async_doorman_cache.set_cache('user_cache', test_key, test_value)
|
||||
|
||||
# Test get
|
||||
retrieved = await async_doorman_cache.get_cache('user_cache', test_key)
|
||||
|
||||
# Test delete
|
||||
await async_doorman_cache.delete_cache('user_cache', test_key)
|
||||
|
||||
# Verify deletion
|
||||
after_delete = await async_doorman_cache.get_cache('user_cache', test_key)
|
||||
|
||||
return {
|
||||
@@ -196,7 +173,6 @@ async def test_cache_operations() -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Cache test failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/load-test-compare")
|
||||
async def load_test_comparison() -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -205,7 +181,6 @@ async def load_test_comparison() -> Dict[str, Any]:
|
||||
This endpoint simulates 10 concurrent database queries.
|
||||
"""
|
||||
try:
|
||||
# Test SYNC (blocking) - operations are sequential
|
||||
sync_start = time.time()
|
||||
sync_results = []
|
||||
for i in range(10):
|
||||
@@ -213,7 +188,6 @@ async def load_test_comparison() -> Dict[str, Any]:
|
||||
sync_results.append(user is not None)
|
||||
sync_elapsed = time.time() - sync_start
|
||||
|
||||
# Test ASYNC (non-blocking) - operations can overlap
|
||||
async_start = time.time()
|
||||
async_tasks = [
|
||||
async_user_collection.find_one({'username': 'admin'})
|
||||
|
||||
@@ -5,11 +5,9 @@ Ensures the backend-services directory is on sys.path so imports like
|
||||
`from utils...` resolve correctly when tests run from the repo root in CI.
|
||||
"""
|
||||
|
||||
# External imports
|
||||
import os
|
||||
import sys
|
||||
|
||||
# TEST-ONLY credentials - DO NOT use these in production
|
||||
os.environ.setdefault('MEM_OR_EXTERNAL', 'MEM')
|
||||
os.environ.setdefault('JWT_SECRET_KEY', 'test-secret-key')
|
||||
os.environ.setdefault('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev')
|
||||
@@ -22,16 +20,11 @@ os.environ.setdefault('DOORMAN_TEST_MODE', 'true')
|
||||
os.environ.setdefault('ENABLE_HTTPX_CLIENT_CACHE', 'false')
|
||||
os.environ.setdefault('DOORMAN_TEST_MODE', 'true')
|
||||
|
||||
# Compatibility toggles for Python 3.13 transport/middleware edge-cases
|
||||
try:
|
||||
import sys as _sys
|
||||
if _sys.version_info >= (3, 13):
|
||||
# Avoid BaseHTTPMiddleware/receive wrapping issues on platform routes
|
||||
os.environ.setdefault('DISABLE_PLATFORM_CHUNKED_WRAP', 'true')
|
||||
# Use native Starlette behavior for CORS (disable ASGI shim)
|
||||
os.environ.setdefault('DISABLE_PLATFORM_CORS_ASGI', 'true')
|
||||
# Exclude problematic platform endpoint from body size middleware to
|
||||
# avoid EndOfStream/No response returned on some runtimes
|
||||
os.environ.setdefault('BODY_LIMIT_EXCLUDE_PATHS', '/platform/security/settings')
|
||||
except Exception:
|
||||
pass
|
||||
@@ -65,11 +58,9 @@ async def ensure_memory_dump_defaults(monkeypatch, tmp_path):
|
||||
"""
|
||||
try:
|
||||
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
|
||||
# Provide a stable, sufficiently long test key; individual tests may monkeypatch/delenv
|
||||
monkeypatch.setenv('MEM_ENCRYPTION_KEY', os.environ.get('MEM_ENCRYPTION_KEY') or 'test-encryption-key-32-characters-min')
|
||||
dump_base = tmp_path / 'mem' / 'memory_dump.bin'
|
||||
monkeypatch.setenv('MEM_DUMP_PATH', str(dump_base))
|
||||
# If memory_dump_util was already imported before env set, update its module-level default
|
||||
try:
|
||||
import utils.memory_dump_util as md
|
||||
md.DEFAULT_DUMP_PATH = str(dump_base)
|
||||
@@ -79,7 +70,6 @@ async def ensure_memory_dump_defaults(monkeypatch, tmp_path):
|
||||
pass
|
||||
yield
|
||||
|
||||
# --- Per-test start/finish logging to pinpoint hangs ---
|
||||
@pytest.fixture(autouse=True)
|
||||
def _log_test_start_end(request):
|
||||
try:
|
||||
@@ -94,7 +84,6 @@ def _log_test_start_end(request):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Also log key env toggles at session start for reproducibility
|
||||
@pytest.fixture(autouse=True, scope='session')
|
||||
def _log_env_toggles():
|
||||
try:
|
||||
@@ -166,14 +155,12 @@ def event_loop():
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def reset_http_client():
|
||||
"""Reset the pooled httpx client between tests to prevent connection pool exhaustion."""
|
||||
# Reset before the test (important for tests that monkeypatch httpx.AsyncClient)
|
||||
try:
|
||||
from services.gateway_service import GatewayService
|
||||
await GatewayService.aclose_http_client()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reset rate limit counters before each test
|
||||
try:
|
||||
from utils.limit_throttle_util import reset_counters
|
||||
reset_counters()
|
||||
@@ -181,7 +168,6 @@ async def reset_http_client():
|
||||
pass
|
||||
|
||||
yield
|
||||
# After each test, close and reset the pooled client
|
||||
try:
|
||||
from services.gateway_service import GatewayService
|
||||
await GatewayService.aclose_http_client()
|
||||
@@ -215,7 +201,6 @@ async def reset_in_memory_db_state():
|
||||
pass
|
||||
yield
|
||||
|
||||
# Test helpers expected by some suites
|
||||
async def create_api(client: AsyncClient, api_name: str, api_version: str):
|
||||
payload = {
|
||||
'api_name': api_name,
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
class _RedactFilter(logging.Filter):
|
||||
PATTERNS = [
|
||||
re.compile(r'(?i)(authorization\s*[:=]\s*)([^;\r\n]+)'),
|
||||
@@ -25,7 +24,6 @@ class _RedactFilter(logging.Filter):
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
def _ensure_logger(name: str):
|
||||
logger = logging.getLogger(name)
|
||||
for h in logger.handlers:
|
||||
@@ -36,7 +34,6 @@ def _ensure_logger(name: str):
|
||||
h.addFilter(_RedactFilter())
|
||||
logger.addHandler(h)
|
||||
|
||||
|
||||
try:
|
||||
_ensure_logger('doorman.gateway')
|
||||
_ensure_logger('doorman.logging')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# External imports
|
||||
import os
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
@@ -1,38 +1,30 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_seed_fields_memory_mode(monkeypatch):
|
||||
# Ensure memory mode and deterministic admin creds
|
||||
monkeypatch.setenv('MEM_OR_EXTERNAL', 'MEM')
|
||||
monkeypatch.setenv('DOORMAN_ADMIN_EMAIL', 'admin@doorman.dev')
|
||||
monkeypatch.setenv('DOORMAN_ADMIN_PASSWORD', 'test-only-password-12chars')
|
||||
|
||||
from utils import database as dbmod
|
||||
# Reinitialize collections to ensure seed runs
|
||||
dbmod.database.initialize_collections()
|
||||
|
||||
from utils.database import user_collection, role_collection, group_collection, _build_admin_seed_doc
|
||||
admin = user_collection.find_one({'username': 'admin'})
|
||||
assert admin is not None, 'Admin user should be seeded'
|
||||
|
||||
# Expected keys from canonical seed helper
|
||||
expected_keys = set(_build_admin_seed_doc('x@example.com', 'hash').keys())
|
||||
doc_keys = set(admin.keys())
|
||||
assert expected_keys.issubset(doc_keys), f'Missing keys: {expected_keys - doc_keys}'
|
||||
# In-memory will include an _id key
|
||||
assert '_id' in doc_keys
|
||||
|
||||
# Password handling: should be hashed and verify
|
||||
from utils import password_util
|
||||
assert password_util.verify_password(os.environ['DOORMAN_ADMIN_PASSWORD'], admin.get('password'))
|
||||
|
||||
# Groups/roles parity
|
||||
assert set(admin.get('groups') or []) >= {'ALL', 'admin'}
|
||||
role = role_collection.find_one({'role_name': 'admin'})
|
||||
assert role is not None
|
||||
# Core capabilities expected on admin role
|
||||
for cap in (
|
||||
'manage_users','manage_apis','manage_endpoints','manage_groups','manage_roles',
|
||||
'manage_routings','manage_gateway','manage_subscriptions','manage_credits','manage_auth','manage_security','view_logs'
|
||||
@@ -42,12 +34,9 @@ async def test_admin_seed_fields_memory_mode(monkeypatch):
|
||||
grp_all = group_collection.find_one({'group_name': 'ALL'})
|
||||
assert grp_admin is not None and grp_all is not None
|
||||
|
||||
|
||||
def test_admin_seed_helper_is_canonical():
|
||||
# Helper itself encodes the canonical set of fields for both modes
|
||||
from utils.database import _build_admin_seed_doc
|
||||
doc = _build_admin_seed_doc('a@b.c', 'hash')
|
||||
# Ensure required fields exist and have expected default values/types
|
||||
assert doc['username'] == 'admin'
|
||||
assert doc['role'] == 'admin'
|
||||
assert doc['ui_access'] is True
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# External imports
|
||||
import uuid
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# External imports
|
||||
import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# External imports
|
||||
import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# External imports
|
||||
import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
|
||||
|
||||
async def _setup_api_and_endpoint(client, name, ver, api_overrides=None, method='GET', uri='/status'):
|
||||
payload = {
|
||||
'api_name': name,
|
||||
@@ -25,7 +24,6 @@ async def _setup_api_and_endpoint(client, name, ver, api_overrides=None, method=
|
||||
})
|
||||
assert r2.status_code in (200, 201), r2.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_cors_allow_origins_exact_match_allowed(authed_client):
|
||||
name, ver = 'corsm1', 'v1'
|
||||
@@ -42,7 +40,6 @@ async def test_api_cors_allow_origins_exact_match_allowed(authed_client):
|
||||
assert r.headers.get('Access-Control-Allow-Origin') == 'http://ok.example'
|
||||
assert r.headers.get('Vary') == 'Origin'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_cors_allow_origins_wildcard_allowed(authed_client):
|
||||
name, ver = 'corsm2', 'v1'
|
||||
@@ -58,7 +55,6 @@ async def test_api_cors_allow_origins_wildcard_allowed(authed_client):
|
||||
assert r.status_code == 204
|
||||
assert r.headers.get('Access-Control-Allow-Origin') == 'http://any.example'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_cors_allow_methods_contains_options_appended(authed_client):
|
||||
name, ver = 'corsm3', 'v1'
|
||||
@@ -75,7 +71,6 @@ async def test_api_cors_allow_methods_contains_options_appended(authed_client):
|
||||
methods = [m.strip().upper() for m in (r.headers.get('Access-Control-Allow-Methods') or '').split(',') if m.strip()]
|
||||
assert 'OPTIONS' in methods
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_cors_allow_headers_asterisk_allows_any(authed_client):
|
||||
name, ver = 'corsm4', 'v1'
|
||||
@@ -92,7 +87,6 @@ async def test_api_cors_allow_headers_asterisk_allows_any(authed_client):
|
||||
ach = r.headers.get('Access-Control-Allow-Headers') or ''
|
||||
assert '*' in ach
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_cors_allow_headers_specific_disallows_others(authed_client):
|
||||
name, ver = 'corsm5', 'v1'
|
||||
@@ -109,7 +103,6 @@ async def test_api_cors_allow_headers_specific_disallows_others(authed_client):
|
||||
ach = r.headers.get('Access-Control-Allow-Headers') or ''
|
||||
assert 'Content-Type' in ach and 'X-Other' not in ach
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_cors_allow_credentials_true_sets_header(authed_client):
|
||||
name, ver = 'corsm6', 'v1'
|
||||
@@ -126,7 +119,6 @@ async def test_api_cors_allow_credentials_true_sets_header(authed_client):
|
||||
assert r.status_code == 204
|
||||
assert r.headers.get('Access-Control-Allow-Credentials') == 'true'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_cors_expose_headers_propagated(authed_client):
|
||||
name, ver = 'corsm7', 'v1'
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# External imports
|
||||
import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# External imports
|
||||
import pytest
|
||||
|
||||
class _AuditSpy:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# External imports
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# External imports
|
||||
import os
|
||||
import pytest
|
||||
|
||||
@@ -11,7 +10,6 @@ async def test_auth_admin_endpoints(authed_client):
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
|
||||
# Use test credentials from environment (set in conftest.py for test-only use)
|
||||
relog = await authed_client.post(
|
||||
'/platform/authorization',
|
||||
json={
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user