mirror of
https://github.com/apidoorman/doorman.git
synced 2026-05-08 00:59:16 -05:00
Test fixes and improvements
This commit is contained in:
@@ -110,11 +110,14 @@ async def app_lifespan(app: FastAPI):
|
||||
path = getattr(route, 'path', '')
|
||||
if not path.startswith(('/platform', '/api')):
|
||||
continue
|
||||
# Skip non-documented and preflight-only routes
|
||||
include = getattr(route, 'include_in_schema', True)
|
||||
methods = set(getattr(route, 'methods', set()) or [])
|
||||
if not include or 'OPTIONS' in methods:
|
||||
continue
|
||||
if not getattr(route, 'description', None):
|
||||
problems.append(f'Route {path} missing description')
|
||||
|
||||
if not getattr(route, 'response_model', None):
|
||||
|
||||
problems.append(f'Route {path} missing response_model')
|
||||
if problems:
|
||||
gateway_logger.info('OpenAPI lint: \n' + '\n'.join(problems[:50]))
|
||||
@@ -184,11 +187,21 @@ async def app_lifespan(app: FastAPI):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _generate_unique_id(route):
|
||||
try:
|
||||
name = getattr(route, 'name', 'op') or 'op'
|
||||
path = getattr(route, 'path', '').replace('/', '_').replace('{', '').replace('}', '')
|
||||
methods = '_'.join(sorted(list(getattr(route, 'methods', []) or [])))
|
||||
return f"{name}_{methods}_{path}".lower()
|
||||
except Exception:
|
||||
return (getattr(route, 'name', 'op') or 'op').lower()
|
||||
|
||||
doorman = FastAPI(
|
||||
title='doorman',
|
||||
description="A lightweight API gateway for AI, REST, SOAP, GraphQL, gRPC, and WebSocket APIs — fully managed with built-in RESTful APIs for configuration and control. This is your application's gateway to the world.",
|
||||
version='1.0.0',
|
||||
lifespan=app_lifespan,
|
||||
generate_unique_id_function=_generate_unique_id,
|
||||
)
|
||||
|
||||
https_only = os.getenv('HTTPS_ONLY', 'false').lower() == 'true'
|
||||
|
||||
@@ -103,6 +103,7 @@ Response:
|
||||
|
||||
@gateway_router.api_route('/caches', methods=['DELETE'],
|
||||
description='Clear all caches',
|
||||
response_model=ResponseModel,
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
],
|
||||
@@ -131,6 +132,11 @@ async def clear_all_caches(request: Request):
|
||||
error_message='You do not have permission to clear caches'
|
||||
).dict(), 'rest')
|
||||
doorman_cache.clear_all_caches()
|
||||
try:
|
||||
from utils.limit_throttle_util import reset_counters as _reset_rate
|
||||
_reset_rate()
|
||||
except Exception:
|
||||
pass
|
||||
audit(request, actor=username, action='gateway.clear_caches', target='all', status='success', details=None)
|
||||
return process_response(ResponseModel(
|
||||
status_code=200,
|
||||
@@ -163,8 +169,8 @@ Response:
|
||||
|
||||
@gateway_router.api_route('/rest/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'],
|
||||
description='REST gateway endpoint',
|
||||
response_model=ResponseModel)
|
||||
|
||||
response_model=ResponseModel,
|
||||
include_in_schema=False)
|
||||
async def gateway(request: Request, path: str):
|
||||
request_id = str(uuid.uuid4())
|
||||
start_time = time.time() * 1000
|
||||
@@ -242,6 +248,27 @@ async def gateway(request: Request, path: str):
|
||||
end_time = time.time() * 1000
|
||||
logger.info(f'{request_id} | Total time: {str(end_time - start_time)}ms')
|
||||
|
||||
# Per-method wrappers with unique operation IDs for OpenAPI
|
||||
@gateway_router.get('/rest/{path:path}', description='REST gateway endpoint (GET)', response_model=ResponseModel, operation_id='rest_get')
|
||||
async def rest_get(request: Request, path: str):
|
||||
return await gateway(request, path)
|
||||
|
||||
@gateway_router.post('/rest/{path:path}', description='REST gateway endpoint (POST)', response_model=ResponseModel, operation_id='rest_post')
|
||||
async def rest_post(request: Request, path: str):
|
||||
return await gateway(request, path)
|
||||
|
||||
@gateway_router.put('/rest/{path:path}', description='REST gateway endpoint (PUT)', response_model=ResponseModel, operation_id='rest_put')
|
||||
async def rest_put(request: Request, path: str):
|
||||
return await gateway(request, path)
|
||||
|
||||
@gateway_router.patch('/rest/{path:path}', description='REST gateway endpoint (PATCH)', response_model=ResponseModel, operation_id='rest_patch')
|
||||
async def rest_patch(request: Request, path: str):
|
||||
return await gateway(request, path)
|
||||
|
||||
@gateway_router.delete('/rest/{path:path}', description='REST gateway endpoint (DELETE)', response_model=ResponseModel, operation_id='rest_delete')
|
||||
async def rest_delete(request: Request, path: str):
|
||||
return await gateway(request, path)
|
||||
|
||||
"""
|
||||
Endpoint
|
||||
|
||||
@@ -261,7 +288,7 @@ Response:
|
||||
"""
|
||||
|
||||
@gateway_router.api_route('/rest/{path:path}', methods=['OPTIONS'],
|
||||
description='REST gateway CORS preflight')
|
||||
description='REST gateway CORS preflight', include_in_schema=False)
|
||||
|
||||
async def rest_preflight(request: Request, path: str):
|
||||
request_id = str(uuid.uuid4())
|
||||
@@ -389,7 +416,7 @@ Response:
|
||||
"""
|
||||
|
||||
@gateway_router.api_route('/soap/{path:path}', methods=['OPTIONS'],
|
||||
description='SOAP gateway CORS preflight')
|
||||
description='SOAP gateway CORS preflight', include_in_schema=False)
|
||||
|
||||
async def soap_preflight(request: Request, path: str):
|
||||
request_id = str(uuid.uuid4())
|
||||
@@ -532,7 +559,7 @@ Response:
|
||||
"""
|
||||
|
||||
@gateway_router.api_route('/graphql/{path:path}', methods=['OPTIONS'],
|
||||
description='GraphQL gateway CORS preflight')
|
||||
description='GraphQL gateway CORS preflight', include_in_schema=False)
|
||||
|
||||
async def graphql_preflight(request: Request, path: str):
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
@@ -315,6 +315,7 @@ Response:
|
||||
|
||||
@logging_router.get('/logs/export',
|
||||
description='Export logs in various formats',
|
||||
response_model=ResponseModel,
|
||||
responses={
|
||||
200: {
|
||||
'description': 'Successful Response',
|
||||
@@ -413,6 +414,7 @@ Response:
|
||||
|
||||
@logging_router.get('/logs/download',
|
||||
description='Download logs as file',
|
||||
include_in_schema=False,
|
||||
responses={
|
||||
200: {
|
||||
'description': 'File download',
|
||||
|
||||
@@ -4,6 +4,7 @@ Routes to expose gateway metrics to the web client.
|
||||
|
||||
# External imports
|
||||
from fastapi import APIRouter, Request
|
||||
from pydantic import BaseModel
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
@@ -22,6 +23,16 @@ from utils.health_check_util import check_mongodb, check_redis
|
||||
from utils.doorman_cache_util import doorman_cache
|
||||
from utils.database import database
|
||||
|
||||
class LivenessResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
class ReadinessResponse(BaseModel):
|
||||
status: str
|
||||
mongodb: bool | None = None
|
||||
redis: bool | None = None
|
||||
mode: str | None = None
|
||||
cache_backend: str | None = None
|
||||
|
||||
monitor_router = APIRouter()
|
||||
logger = logging.getLogger('doorman.gateway')
|
||||
|
||||
@@ -90,8 +101,8 @@ Response:
|
||||
|
||||
|
||||
@monitor_router.get('/monitor/liveness',
|
||||
description='Kubernetes liveness probe endpoint (no auth)')
|
||||
|
||||
description='Kubernetes liveness probe endpoint (no auth)',
|
||||
response_model=LivenessResponse)
|
||||
async def liveness():
|
||||
return {'status': 'alive'}
|
||||
|
||||
@@ -106,8 +117,8 @@ Response:
|
||||
|
||||
|
||||
@monitor_router.get('/monitor/readiness',
|
||||
description='Kubernetes readiness probe endpoint (no auth)')
|
||||
|
||||
description='Kubernetes readiness probe endpoint (no auth)',
|
||||
response_model=ReadinessResponse)
|
||||
async def readiness():
|
||||
try:
|
||||
mongo_ok = await check_mongodb()
|
||||
@@ -135,7 +146,8 @@ Response:
|
||||
|
||||
|
||||
@monitor_router.get('/monitor/report',
|
||||
description='Generate a CSV report for a date range (requires manage_gateway)')
|
||||
description='Generate a CSV report for a date range (requires manage_gateway)',
|
||||
include_in_schema=False)
|
||||
|
||||
async def generate_report(request: Request, start: str, end: str):
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
@@ -51,6 +51,21 @@ async def authed_client():
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# Ensure admin cannot hit bandwidth/rate/throttle limits in unit tests
|
||||
try:
|
||||
await client.put('/platform/user/admin', json={
|
||||
'bandwidth_limit_bytes': 0,
|
||||
'bandwidth_limit_window': 'day',
|
||||
'rate_limit_duration': 1000000,
|
||||
'rate_limit_duration_type': 'second',
|
||||
'throttle_duration': 1000000,
|
||||
'throttle_duration_type': 'second',
|
||||
'throttle_queue_limit': 1000000,
|
||||
'throttle_wait_duration': 0,
|
||||
'throttle_wait_duration_type': 'second'
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
# External imports
|
||||
import json
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bandwidth_enforcement_and_usage_tracking(monkeypatch, authed_client):
|
||||
# Set a small bandwidth limit for admin user
|
||||
# Set low limit then restore afterwards to avoid polluting other tests
|
||||
try:
|
||||
upd = await authed_client.put('/platform/user/admin', json={'bandwidth_limit_bytes': 80, 'bandwidth_limit_window': 'second'})
|
||||
assert upd.status_code in (200, 204)
|
||||
except AssertionError:
|
||||
# Best effort cleanup if update failed
|
||||
await authed_client.put('/platform/user/admin', json={'bandwidth_limit_bytes': None})
|
||||
raise
|
||||
|
||||
# Create a small API and endpoint and subscribe admin
|
||||
from conftest import create_api, create_endpoint, subscribe_self
|
||||
name, ver = 'bwapi', 'v1'
|
||||
await create_api(authed_client, name, ver)
|
||||
await create_endpoint(authed_client, name, ver, 'POST', '/p')
|
||||
await subscribe_self(authed_client, name, ver)
|
||||
|
||||
# Mock upstream to avoid real network and produce a small response
|
||||
import services.gateway_service as gs
|
||||
|
||||
class _FakeHTTPResponse:
|
||||
def __init__(self, status_code=200, body=b'{"ok":true}'):
|
||||
self.status_code = status_code
|
||||
self.headers = {'Content-Type': 'application/json', 'Content-Length': str(len(body))}
|
||||
self.text = body.decode('utf-8')
|
||||
self.content = body
|
||||
def json(self):
|
||||
return json.loads(self.text)
|
||||
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, timeout=None):
|
||||
pass
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
async def post(self, url, data=None, json=None, headers=None, params=None):
|
||||
return _FakeHTTPResponse(200)
|
||||
|
||||
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
|
||||
|
||||
# Prepare JSON body around ~60 bytes
|
||||
payload = {'data': 'x' * 50}
|
||||
body = json.dumps(payload).encode('utf-8')
|
||||
assert len(body) >= 50
|
||||
|
||||
# First request should succeed (under limit)
|
||||
r1 = await authed_client.post(f'/api/rest/{name}/{ver}/p', json=payload)
|
||||
assert r1.status_code == 200
|
||||
|
||||
# Second request should be blocked (pre-request enforcement)
|
||||
r2 = await authed_client.post(f'/api/rest/{name}/{ver}/p', json=payload)
|
||||
assert r2.status_code == 429
|
||||
|
||||
# User endpoint should show usage and reset time
|
||||
u = await authed_client.get('/platform/user/admin')
|
||||
assert u.status_code == 200
|
||||
data = u.json().get('response') or u.json()
|
||||
assert int(data.get('bandwidth_usage_bytes', 0)) > 0
|
||||
assert int(data.get('bandwidth_resets_at', 0)) > 0
|
||||
|
||||
# Restore admin bandwidth settings to avoid 429s in subsequent tests (0 disables)
|
||||
await authed_client.put('/platform/user/admin', json={'bandwidth_limit_bytes': 0})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_tracks_bytes_in_out(monkeypatch, authed_client):
|
||||
from conftest import create_api, create_endpoint, subscribe_self
|
||||
name, ver = 'bwmon', 'v1'
|
||||
await create_api(authed_client, name, ver)
|
||||
await create_endpoint(authed_client, name, ver, 'POST', '/echo')
|
||||
await subscribe_self(authed_client, name, ver)
|
||||
|
||||
# Mock upstream response with known bytes length
|
||||
import services.gateway_service as gs
|
||||
|
||||
resp_body = b'{"ok":true,"pad":"' + b'y' * 20 + b'"}'
|
||||
|
||||
class _FakeHTTPResponse:
|
||||
def __init__(self, status_code=200, body=resp_body):
|
||||
self.status_code = status_code
|
||||
self.headers = {'Content-Type': 'application/json', 'Content-Length': str(len(body))}
|
||||
self.text = body.decode('utf-8')
|
||||
self.content = body
|
||||
def json(self):
|
||||
return json.loads(self.text)
|
||||
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, timeout=None): pass
|
||||
async def __aenter__(self): return self
|
||||
async def __aexit__(self, exc_type, exc, tb): return False
|
||||
async def post(self, url, data=None, json=None, headers=None, params=None): return _FakeHTTPResponse(200)
|
||||
|
||||
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
|
||||
|
||||
# Snapshot totals before
|
||||
m0 = await authed_client.get('/platform/monitor/metrics')
|
||||
assert m0.status_code == 200
|
||||
j0 = m0.json().get('response') or m0.json()
|
||||
tin0 = int(j0.get('total_bytes_in', 0))
|
||||
tout0 = int(j0.get('total_bytes_out', 0))
|
||||
|
||||
# Send two POSTs with known input size
|
||||
payload = {'pad': 'z' * 30}
|
||||
body_len = len(json.dumps(payload).encode('utf-8'))
|
||||
r1 = await authed_client.post(f'/api/rest/{name}/{ver}/echo', json=payload)
|
||||
r2 = await authed_client.post(f'/api/rest/{name}/{ver}/echo', json=payload)
|
||||
assert r1.status_code == 200 and r2.status_code == 200
|
||||
|
||||
# Snapshot after
|
||||
m1 = await authed_client.get('/platform/monitor/metrics')
|
||||
j1 = m1.json().get('response') or m1.json()
|
||||
tin1 = int(j1.get('total_bytes_in', 0))
|
||||
tout1 = int(j1.get('total_bytes_out', 0))
|
||||
|
||||
# Validate deltas meet or exceed expected totals (some wrappers may add small overhead)
|
||||
assert tin1 - tin0 >= body_len * 2
|
||||
assert tout1 - tout0 >= len(resp_body) * 2
|
||||
@@ -98,3 +98,12 @@ async def limit_and_throttle(request: Request):
|
||||
throttle_wait *= duration_to_seconds(throttle_wait_duration)
|
||||
dynamic_wait = throttle_wait * (throttle_count - throttle_limit)
|
||||
await asyncio.sleep(dynamic_wait)
|
||||
|
||||
def reset_counters():
|
||||
"""Reset in-memory rate/throttle counters (used by tests and cache clears).
|
||||
Has no effect when a real Redis client is configured.
|
||||
"""
|
||||
try:
|
||||
_fallback_counter._store.clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user