mirror of
https://github.com/apidoorman/doorman.git
synced 2026-04-25 10:08:41 -05:00
qof updates
This commit is contained in:
@@ -25,7 +25,7 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend-services/requirements.txt
|
||||
- name: Run backend tests
|
||||
- name: Run backend-services tests
|
||||
env:
|
||||
MEM_OR_EXTERNAL: MEM
|
||||
HTTPS_ONLY: 'false'
|
||||
@@ -41,6 +41,21 @@ jobs:
|
||||
ALLOWED_ORIGINS: http://localhost:3000
|
||||
working-directory: backend-services
|
||||
run: pytest -q tests
|
||||
- name: Run root-level integration tests
|
||||
env:
|
||||
MEM_OR_EXTERNAL: MEM
|
||||
HTTPS_ONLY: 'false'
|
||||
HTTPS_ENABLED: 'false'
|
||||
STRICT_RESPONSE_ENVELOPE: 'false'
|
||||
JWT_SECRET_KEY: test-secret-key-please-change
|
||||
STARTUP_ADMIN_EMAIL: admin@doorman.dev
|
||||
STARTUP_ADMIN_PASSWORD: password1
|
||||
MEM_ENCRYPTION_KEY: unit-test-key-32chars-abcdef123456!!
|
||||
ALLOWED_HEADERS: '*'
|
||||
ALLOW_HEADERS: '*'
|
||||
ALLOW_METHODS: '*'
|
||||
ALLOWED_ORIGINS: http://localhost:3000
|
||||
run: pytest -q tests
|
||||
|
||||
frontend-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -39,11 +39,45 @@ HTTPS_ONLY=true
|
||||
# Enforce CSRF double-submit validation (and set Secure cookies when HTTPS_ONLY=true)
|
||||
HTTPS_ENABLED=true
|
||||
COOKIE_DOMAIN=localhost
|
||||
# Optional: Custom Content Security Policy (defaults to restrictive policy)
|
||||
# CONTENT_SECURITY_POLICY=default-src 'self'; script-src 'self' 'unsafe-inline'
|
||||
|
||||
# HTTP Client Connection Pooling (httpx)
|
||||
ENABLE_HTTPX_CLIENT_CACHE=true
|
||||
HTTP_MAX_CONNECTIONS=100
|
||||
HTTP_MAX_KEEPALIVE=50
|
||||
HTTP_KEEPALIVE_EXPIRY=30.0
|
||||
# Timeouts (seconds)
|
||||
HTTP_CONNECT_TIMEOUT=5.0
|
||||
HTTP_READ_TIMEOUT=30.0
|
||||
HTTP_WRITE_TIMEOUT=30.0
|
||||
HTTP_TIMEOUT=30.0
|
||||
# Enable HTTP/2 (default: false)
|
||||
HTTP_ENABLE_HTTP2=false
|
||||
|
||||
# Security
|
||||
# Max request body size in bytes (default: 1MB)
|
||||
MAX_BODY_SIZE_BYTES=1048576
|
||||
# Per-API type overrides (optional)
|
||||
# MAX_BODY_SIZE_BYTES_REST=1048576
|
||||
# MAX_BODY_SIZE_BYTES_SOAP=2097152
|
||||
# MAX_BODY_SIZE_BYTES_GRAPHQL=524288
|
||||
# MAX_BODY_SIZE_BYTES_GRPC=1048576
|
||||
# Allow localhost to bypass IP filters when no X-Forwarded-For header present
|
||||
LOCAL_HOST_IP_BYPASS=true
|
||||
|
||||
# Logging
|
||||
# Format: json or plain
|
||||
LOG_FORMAT=plain
|
||||
# Optional: custom logs directory (defaults to backend-services/platform-logs)
|
||||
# LOGS_DIR=/path/to/logs
|
||||
|
||||
# App
|
||||
PORT=5001
|
||||
THREADS=4
|
||||
DEV_RELOAD=false
|
||||
# Production environment flag (enforces HTTPS checks)
|
||||
ENV=development
|
||||
SSL_CERTFILE=./certs/localhost.crt
|
||||
SSL_KEYFILE=./certs/localhost.key
|
||||
PID_FILE=doorman.pid
|
||||
|
||||
@@ -0,0 +1,289 @@
|
||||
# Security Features & Posture
|
||||
|
||||
This document outlines the security features implemented in Doorman API Gateway.
|
||||
|
||||
## 🔒 Transport Layer Security
|
||||
|
||||
### HTTPS Enforcement
|
||||
- **Production Guard**: In production (`ENV=production`), the application enforces HTTPS configuration
|
||||
- **Startup Validation**: Server refuses to start unless `HTTPS_ONLY` or `HTTPS_ENABLED` is set to `true`
|
||||
- **Secure Cookies**: Cookies are marked as `Secure` when HTTPS is enabled
|
||||
- **SSL/TLS Configuration**: Supports custom certificate paths via `SSL_CERTFILE` and `SSL_KEYFILE`
|
||||
|
||||
```bash
|
||||
# Production configuration (required)
|
||||
ENV=production
|
||||
HTTPS_ONLY=true
|
||||
SSL_CERTFILE=./certs/server.crt
|
||||
SSL_KEYFILE=./certs/server.key
|
||||
```
|
||||
|
||||
### HTTP Strict Transport Security (HSTS)
|
||||
- **Auto-enabled**: When `HTTPS_ONLY=true`, HSTS headers are automatically added
|
||||
- **Default Policy**: `max-age=15552000; includeSubDomains; preload` (6 months)
|
||||
- **Browser Protection**: Prevents downgrade attacks and ensures HTTPS-only access
|
||||
|
||||
## 🛡️ Authentication & Authorization
|
||||
|
||||
### JWT Cookie-Based Authentication
|
||||
- **HTTP-Only Cookies**: Tokens stored in `HttpOnly` cookies to prevent XSS
|
||||
- **Configurable Expiry**:
|
||||
- Access tokens: Default 30 minutes (configurable via `AUTH_EXPIRE_TIME`)
|
||||
- Refresh tokens: Default 7 days (configurable via `AUTH_REFRESH_EXPIRE_TIME`)
|
||||
- **Token Revocation**:
|
||||
- In-memory blacklist with optional Redis persistence
|
||||
- Database-backed revocation for memory-only mode
|
||||
- Per-user and per-JTI revocation support
|
||||
|
||||
### CSRF Protection (Double Submit Cookie Pattern)
|
||||
- **HTTPS-Only**: CSRF validation automatically enabled when `HTTPS_ENABLED=true`
|
||||
- **Validation Flow**:
|
||||
1. Server sets `csrf_token` cookie on login
|
||||
2. Client sends `X-CSRF-Token` header with requests
|
||||
3. Server validates header matches cookie value
|
||||
- **401 Response**: Invalid or missing CSRF tokens are rejected
|
||||
- **Test Coverage**: Full test suite in `tests/test_auth_csrf_https.py`
|
||||
|
||||
```python
|
||||
# CSRF validation (automatic when HTTPS enabled)
|
||||
https_enabled = os.getenv('HTTPS_ENABLED', 'false').lower() == 'true'
|
||||
if https_enabled:
|
||||
csrf_header = request.headers.get('X-CSRF-Token')
|
||||
csrf_cookie = request.cookies.get('csrf_token')
|
||||
if not await validate_csrf_double_submit(csrf_header, csrf_cookie):
|
||||
raise HTTPException(status_code=401, detail='Invalid CSRF token')
|
||||
```
|
||||
|
||||
## 🌐 CORS & Cross-Origin Security
|
||||
|
||||
### Platform Routes CORS
|
||||
- **Environment-Based**: Configured via `ALLOWED_ORIGINS`, `ALLOW_METHODS`, `ALLOW_HEADERS`
|
||||
- **Credentials Support**: Configurable via `ALLOW_CREDENTIALS`
|
||||
- **Wildcard Safety**: Automatic downgrade to `localhost` when credentials are enabled with `*` origin
|
||||
- **OPTIONS Preflight**: Full preflight handling with proper headers
|
||||
|
||||
### Per-API CORS (Gateway Routes)
|
||||
- **API-Level Control**: Each API can define its own CORS policy
|
||||
- **Configuration Options**:
|
||||
- `api_cors_allow_origins`: List of allowed origins (default: `['*']`)
|
||||
- `api_cors_allow_methods`: Allowed HTTP methods
|
||||
- `api_cors_allow_headers`: Allowed request headers
|
||||
- `api_cors_allow_credentials`: Enable credentials (default: `false`)
|
||||
- `api_cors_expose_headers`: Headers exposed to client
|
||||
- **Preflight Validation**: Origin, method, and header validation before allowing requests
|
||||
- **Dynamic Headers**: CORS headers computed per request based on API config
|
||||
|
||||
```python
|
||||
# Per-API CORS example
|
||||
{
|
||||
"api_cors_allow_origins": ["https://app.example.com"],
|
||||
"api_cors_allow_methods": ["GET", "POST"],
|
||||
"api_cors_allow_credentials": true,
|
||||
"api_cors_expose_headers": ["X-Request-ID", "X-RateLimit-Remaining"]
|
||||
}
|
||||
```
|
||||
|
||||
## 🚧 IP Policy & Access Control
|
||||
|
||||
### Global IP Filtering
|
||||
- **Whitelist Mode**: `ip_whitelist` - Only listed IPs/CIDRs allowed (blocks all others)
|
||||
- **Blacklist Mode**: `ip_blacklist` - Listed IPs/CIDRs blocked (allows all others)
|
||||
- **CIDR Support**: Full IPv4 and IPv6 CIDR notation support
|
||||
- **X-Forwarded-For**: Configurable trust via `trust_x_forwarded_for` setting
|
||||
- **Trusted Proxies**: Validate XFF headers against `xff_trusted_proxies` list
|
||||
- **Localhost Bypass**: Optional bypass for localhost (`LOCAL_HOST_IP_BYPASS=true`)
|
||||
|
||||
### Per-API IP Policy
|
||||
- **API-Level Override**: Each API can define additional IP restrictions
|
||||
- **Deny/Allow Lists**: API-specific `api_ip_deny_list` and `api_ip_allow_list`
|
||||
- **Granular Control**: Restrict access to specific APIs by client IP
|
||||
- **Audit Trail**: All IP denials logged with details (reason, XFF header, source IP)
|
||||
|
||||
```python
|
||||
# IP policy enforcement
|
||||
if client_ip:
|
||||
if whitelist and not ip_in_list(client_ip, whitelist):
|
||||
audit(request, action='ip.global_deny', target=client_ip,
|
||||
status='blocked', details={'reason': 'not_in_whitelist'})
|
||||
return 403 # Forbidden
|
||||
```
|
||||
|
||||
## 🔐 Security Headers
|
||||
|
||||
### Content Security Policy (CSP)
|
||||
- **Safe Default**: Restrictive baseline policy prevents common attacks
|
||||
- **Default Policy**:
|
||||
```
|
||||
default-src 'none';
|
||||
frame-ancestors 'none';
|
||||
base-uri 'none';
|
||||
form-action 'self';
|
||||
img-src 'self' data:;
|
||||
connect-src 'self';
|
||||
```
|
||||
- **Customizable**: Override via `CONTENT_SECURITY_POLICY` environment variable
|
||||
- **XSS Protection**: Prevents inline scripts and untrusted resource loading
|
||||
|
||||
### Additional Security Headers
|
||||
All responses include:
|
||||
- **X-Content-Type-Options**: `nosniff` - Prevents MIME sniffing
|
||||
- **X-Frame-Options**: `DENY` - Prevents clickjacking
|
||||
- **Referrer-Policy**: `no-referrer` - Prevents referrer leakage
|
||||
- **Permissions-Policy**: Restricts geolocation, microphone, camera access
|
||||
|
||||
```python
|
||||
# Automatic security headers middleware
|
||||
response.headers.setdefault('X-Content-Type-Options', 'nosniff')
|
||||
response.headers.setdefault('X-Frame-Options', 'DENY')
|
||||
response.headers.setdefault('Referrer-Policy', 'no-referrer')
|
||||
response.headers.setdefault('Permissions-Policy', 'geolocation=(), microphone=(), camera=()')
|
||||
```
|
||||
|
||||
## 📝 Audit Trail & Logging
|
||||
|
||||
### Request ID Propagation
|
||||
- **Auto-Generation**: UUID generated for each request if not provided
|
||||
- **Header Support**: Accepts `X-Request-ID`, `Request-ID`, or generates new
|
||||
- **Response Headers**: ID included in both `X-Request-ID` and `request_id` headers
|
||||
- **Log Correlation**: All logs tagged with request ID for tracing
|
||||
- **Middleware**: Automatic injection via `request_id_middleware`
|
||||
|
||||
### Audit Trail Logging
|
||||
- **Separate Log File**: `doorman-trail.log` for audit events
|
||||
- **Structured Logging**: JSON or plain text format (configurable)
|
||||
- **Event Tracking**:
|
||||
- User authentication (login, logout, token refresh)
|
||||
- Authorization changes (role/group modifications)
|
||||
- IP policy violations
|
||||
- Configuration changes
|
||||
- Security events
|
||||
- **Context Capture**: Username, action, target, status, details
|
||||
|
||||
### Log Redaction
|
||||
- **Sensitive Data Protection**: Automatic redaction of:
|
||||
- Authorization headers: `Authorization: Bearer [REDACTED]`
|
||||
- Access tokens: `access_token": "[REDACTED]"`
|
||||
- Refresh tokens: `refresh_token": "[REDACTED]"`
|
||||
- Passwords: `password": "[REDACTED]"`
|
||||
- Cookies: `Cookie: [REDACTED]`
|
||||
- CSRF tokens: `X-CSRF-Token: [REDACTED]`
|
||||
- **Pattern-Based**: Regex patterns match various formats
|
||||
- **Applied Universally**: File and console logs both protected
|
||||
|
||||
```python
|
||||
# Redaction filter (automatic)
|
||||
PATTERNS = [
|
||||
re.compile(r'(?i)(authorization\s*[:=]\s*)([^;\r\n]+)'),
|
||||
re.compile(r'(?i)(access[_-]?token\s*[\"\']?\s*[:=]\s*[\"\'])([^\"\']+)'),
|
||||
re.compile(r'(?i)(password\s*[\"\']?\s*[:=]\s*[\"\'])([^\"\']+)'),
|
||||
# ... additional patterns
|
||||
]
|
||||
```
|
||||
|
||||
## 📏 Request Validation & Limits
|
||||
|
||||
### Body Size Limits
|
||||
- **Default Limit**: 1MB (`MAX_BODY_SIZE_BYTES=1048576`)
|
||||
- **Universal Enforcement**: All request types protected (REST, SOAP, GraphQL, gRPC)
|
||||
- **Content-Length Based**: Efficient pre-validation before reading body
|
||||
- **Per-API Type Overrides**:
|
||||
- `MAX_BODY_SIZE_BYTES_REST` - REST API override
|
||||
- `MAX_BODY_SIZE_BYTES_SOAP` - SOAP/XML API override (e.g., 2MB for large SOAP envelopes)
|
||||
- `MAX_BODY_SIZE_BYTES_GRAPHQL` - GraphQL query override
|
||||
- `MAX_BODY_SIZE_BYTES_GRPC` - gRPC payload override
|
||||
- **Protected Paths**:
|
||||
- `/platform/authorization` - Prevents auth route DoS
|
||||
- `/api/rest/*` - REST APIs with custom limit
|
||||
- `/api/soap/*` - SOAP/XML APIs with custom limit
|
||||
- `/api/graphql/*` - GraphQL queries with custom limit
|
||||
- `/api/grpc/*` - gRPC payloads with custom limit
|
||||
- `/platform/*` - All platform routes
|
||||
- **Audit Logging**: Violations logged to audit trail with details
|
||||
- **Error Response**: 413 Payload Too Large with error code `REQ001`
|
||||
|
||||
```bash
|
||||
# Example: Larger limit for SOAP APIs with big XML envelopes
|
||||
MAX_BODY_SIZE_BYTES=1048576 # Default: 1MB
|
||||
MAX_BODY_SIZE_BYTES_SOAP=2097152 # SOAP: 2MB
|
||||
MAX_BODY_SIZE_BYTES_GRAPHQL=524288 # GraphQL: 512KB (queries are smaller)
|
||||
```
|
||||
|
||||
### Request Validation Flow
|
||||
- **Content-Length Check**: Validates header before reading body
|
||||
- **Early Rejection**: Prevents large payloads from consuming resources
|
||||
- **Type-Aware**: Different limits for different API types
|
||||
- **Security Audit**: All rejections logged with content type and path
|
||||
|
||||
```python
|
||||
# Body size validation
|
||||
MAX_BODY_SIZE = int(os.getenv('MAX_BODY_SIZE_BYTES', 1_048_576))
|
||||
if content_length and int(content_length) > MAX_BODY_SIZE:
|
||||
return ResponseModel(
|
||||
status_code=413,
|
||||
error_code='REQ001',
|
||||
error_message='Request entity too large'
|
||||
)
|
||||
```
|
||||
|
||||
## 🔑 Encryption & Secrets
|
||||
|
||||
### API Key Encryption
|
||||
- **At-Rest Encryption**: Optional encryption via `TOKEN_ENCRYPTION_KEY`
|
||||
- **Transparent**: Encrypt/decrypt on read/write operations
|
||||
- **Key Storage**: API keys can be encrypted in database/memory
|
||||
|
||||
### Memory Dump Encryption
|
||||
- **Required for Dumps**: `MEM_ENCRYPTION_KEY` must be set for memory dumps
|
||||
- **AES Encryption**: Secure encryption of serialized state
|
||||
- **Key Derivation**: Uses Fernet (symmetric encryption)
|
||||
- **Startup Restore**: Automatic decryption on server restart
|
||||
|
||||
```bash
|
||||
# Encryption configuration
|
||||
TOKEN_ENCRYPTION_KEY=your-api-key-encryption-secret-32chars+
|
||||
MEM_ENCRYPTION_KEY=your-memory-dump-encryption-secret-32chars+
|
||||
```
|
||||
|
||||
## 🛠️ Security Best Practices
|
||||
|
||||
### Production Checklist
|
||||
- [ ] `ENV=production` set
|
||||
- [ ] `HTTPS_ONLY=true` or `HTTPS_ENABLED=true` configured
|
||||
- [ ] Valid SSL certificates configured (`SSL_CERTFILE`, `SSL_KEYFILE`)
|
||||
- [ ] `JWT_SECRET_KEY` set to strong random value (change default!)
|
||||
- [ ] `MEM_ENCRYPTION_KEY` set to strong random value (32+ chars)
|
||||
- [ ] `ALLOWED_ORIGINS` configured (no wildcard `*`)
|
||||
- [ ] `CORS_STRICT=true` enforced
|
||||
- [ ] IP whitelist/blacklist configured if needed
|
||||
- [ ] `LOG_FORMAT=json` for structured logging
|
||||
- [ ] Regular security audits via `doorman-trail.log`
|
||||
|
||||
### Development vs Production
|
||||
| Feature | Development | Production |
|
||||
|---------|-------------|------------|
|
||||
| HTTPS Required | Optional | **Required** |
|
||||
| CSRF Validation | Optional | **Enabled** |
|
||||
| CORS | Permissive | **Strict** |
|
||||
| Cookie Secure Flag | No | **Yes** |
|
||||
| HSTS Header | No | **Yes** |
|
||||
| Log Format | Plain | **JSON** |
|
||||
|
||||
## 📚 References
|
||||
|
||||
- **OWASP Top 10**: Addresses A01 (Broken Access Control), A02 (Cryptographic Failures), A05 (Security Misconfiguration)
|
||||
- **CSP Level 3**: Content Security Policy implementation
|
||||
- **RFC 6797**: HTTP Strict Transport Security (HSTS)
|
||||
- **RFC 7519**: JSON Web Tokens (JWT)
|
||||
- **CSRF Double Submit**: Industry-standard CSRF protection pattern
|
||||
|
||||
## 🔍 Testing
|
||||
|
||||
Security features are covered by comprehensive test suites:
|
||||
- `tests/test_auth_csrf_https.py` - CSRF validation
|
||||
- `tests/test_production_https_guard.py` - HTTPS enforcement
|
||||
- `tests/test_ip_policy_allow_deny_cidr.py` - IP filtering
|
||||
- `tests/test_security.py` - General security features
|
||||
- `tests/test_request_id_and_logging_redaction.py` - Audit trail
|
||||
- 323 total tests, all passing ✅
|
||||
|
||||
For questions or security concerns, please review the code or open an issue.
|
||||
+90
-20
@@ -249,6 +249,19 @@ async def app_lifespan(app: FastAPI):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Close shared HTTP client pool if enabled
|
||||
try:
|
||||
from services.gateway_service import GatewayService as _GS
|
||||
if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'true').lower() != 'false':
|
||||
try:
|
||||
import asyncio as _asyncio
|
||||
if _asyncio.iscoroutinefunction(_GS.aclose_http_client):
|
||||
await _GS.aclose_http_client()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _generate_unique_id(route):
|
||||
try:
|
||||
name = getattr(route, 'name', 'op') or 'op'
|
||||
@@ -354,30 +367,87 @@ def _get_max_body_size() -> int:
|
||||
|
||||
@doorman.middleware('http')
|
||||
async def body_size_limit(request: Request, call_next):
|
||||
"""Enforce request body size limits to prevent DoS attacks.
|
||||
|
||||
Default limit: 1MB (configurable via MAX_BODY_SIZE_BYTES)
|
||||
Per-API overrides: MAX_BODY_SIZE_BYTES_<API_TYPE> (e.g., MAX_BODY_SIZE_BYTES_SOAP)
|
||||
|
||||
Protected paths:
|
||||
- /platform/authorization: Strict enforcement (prevent auth DoS)
|
||||
- /api/rest/*: Enforce on all requests with Content-Length
|
||||
- /api/soap/*: Enforce on XML/SOAP bodies
|
||||
- /api/graphql/*: Enforce on GraphQL queries
|
||||
- /api/grpc/*: Enforce on gRPC JSON payloads
|
||||
"""
|
||||
try:
|
||||
path = str(request.url.path)
|
||||
cl = request.headers.get('content-length')
|
||||
limit = _get_max_body_size()
|
||||
# Strictly enforce on auth route to prevent large bodies there
|
||||
if path.startswith('/platform/authorization'):
|
||||
if cl and int(cl) > limit:
|
||||
return process_response(ResponseModel(
|
||||
status_code=413,
|
||||
error_code='REQ001',
|
||||
error_message='Request entity too large'
|
||||
).dict(), 'rest')
|
||||
|
||||
# Skip requests without Content-Length header (GET, HEAD, etc.)
|
||||
if not cl or str(cl).strip() == '':
|
||||
return await call_next(request)
|
||||
# Enforce on gateway API traffic, but only for JSON payloads to
|
||||
# preserve existing tests that send raw bodies without CL/CT headers.
|
||||
if path.startswith('/api/'):
|
||||
ctype = (request.headers.get('content-type') or '').lower()
|
||||
if ctype.startswith('application/json'):
|
||||
if cl and int(cl) > limit:
|
||||
return process_response(ResponseModel(
|
||||
status_code=413,
|
||||
error_code='REQ001',
|
||||
error_message='Request entity too large'
|
||||
).dict(), 'rest')
|
||||
|
||||
try:
|
||||
content_length = int(cl)
|
||||
except (ValueError, TypeError):
|
||||
# Invalid Content-Length header - let it through and fail later
|
||||
return await call_next(request)
|
||||
|
||||
# Determine if this path should be protected
|
||||
should_enforce = False
|
||||
default_limit = _get_max_body_size()
|
||||
limit = default_limit
|
||||
|
||||
# Strictly enforce on auth route (prevent auth DoS)
|
||||
if path.startswith('/platform/authorization'):
|
||||
should_enforce = True
|
||||
# Enforce on all /api/* routes with per-type overrides
|
||||
elif path.startswith('/api/soap/'):
|
||||
should_enforce = True
|
||||
limit = int(os.getenv('MAX_BODY_SIZE_BYTES_SOAP', default_limit))
|
||||
elif path.startswith('/api/graphql/'):
|
||||
should_enforce = True
|
||||
limit = int(os.getenv('MAX_BODY_SIZE_BYTES_GRAPHQL', default_limit))
|
||||
elif path.startswith('/api/grpc/'):
|
||||
should_enforce = True
|
||||
limit = int(os.getenv('MAX_BODY_SIZE_BYTES_GRPC', default_limit))
|
||||
elif path.startswith('/api/rest/'):
|
||||
should_enforce = True
|
||||
limit = int(os.getenv('MAX_BODY_SIZE_BYTES_REST', default_limit))
|
||||
elif path.startswith('/api/'):
|
||||
# Catch-all for other /api/* routes
|
||||
should_enforce = True
|
||||
|
||||
# Skip if this path is not protected
|
||||
if not should_enforce:
|
||||
return await call_next(request)
|
||||
|
||||
# Enforce limit
|
||||
if content_length > limit:
|
||||
# Log for security monitoring
|
||||
try:
|
||||
from utils.audit_util import audit
|
||||
audit(
|
||||
request,
|
||||
actor=None,
|
||||
action='request.body_size_exceeded',
|
||||
target=path,
|
||||
status='blocked',
|
||||
details={
|
||||
'content_length': content_length,
|
||||
'limit': limit,
|
||||
'content_type': request.headers.get('content-type')
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return process_response(ResponseModel(
|
||||
status_code=413,
|
||||
error_code='REQ001',
|
||||
error_message=f'Request entity too large (max: {limit} bytes)'
|
||||
).dict(), 'rest')
|
||||
|
||||
return await call_next(request)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -17,7 +17,7 @@ from models.response_model import ResponseModel
|
||||
from services.user_service import UserService
|
||||
from utils.response_util import respond_rest
|
||||
from utils.auth_util import auth_required, create_access_token
|
||||
from utils.auth_blacklist import TimedHeap, jwt_blacklist, revoke_all_for_user, unrevoke_all_for_user, is_user_revoked
|
||||
from utils.auth_blacklist import TimedHeap, jwt_blacklist, revoke_all_for_user, unrevoke_all_for_user, is_user_revoked, add_revoked_jti
|
||||
from utils.role_util import platform_role_required_bool
|
||||
from utils.role_util import is_admin_user
|
||||
from models.update_user_model import UpdateUserModel
|
||||
@@ -688,9 +688,19 @@ async def authorization_invalidate(response: Response, request: Request):
|
||||
username = payload.get('sub')
|
||||
logger.info(f'{request_id} | Username: {username} | From: {request.client.host}:{request.client.port}')
|
||||
logger.info(f'{request_id} | Endpoint: {request.method} {str(request.url.path)}')
|
||||
if username not in jwt_blacklist:
|
||||
jwt_blacklist[username] = TimedHeap()
|
||||
jwt_blacklist[username].push(payload.get('jti'))
|
||||
# Add this token's JTI to durable revocation with TTL until expiry
|
||||
try:
|
||||
import time as _t
|
||||
exp = payload.get('exp')
|
||||
ttl = None
|
||||
if isinstance(exp, (int, float)):
|
||||
ttl = max(1, int(exp - _t.time()))
|
||||
add_revoked_jti(username, payload.get('jti'), ttl)
|
||||
except Exception:
|
||||
# Fallback to in-memory TimedHeap (back-compat)
|
||||
if username not in jwt_blacklist:
|
||||
jwt_blacklist[username] = TimedHeap()
|
||||
jwt_blacklist[username].push(payload.get('jti'))
|
||||
response = respond_rest(ResponseModel(
|
||||
status_code=200,
|
||||
response_headers={
|
||||
|
||||
@@ -57,13 +57,59 @@ class GatewayService:
|
||||
)
|
||||
_http_client: httpx.AsyncClient | None = None
|
||||
|
||||
@staticmethod
|
||||
def _build_limits() -> httpx.Limits:
|
||||
"""Pool limits tuned for small/medium projects with env overrides.
|
||||
|
||||
Defaults:
|
||||
- max_connections: 100 (total across hosts)
|
||||
- max_keepalive_connections: 50 (pooled, idle)
|
||||
- keepalive_expiry: 30s
|
||||
"""
|
||||
try:
|
||||
max_conns = int(os.getenv('HTTP_MAX_CONNECTIONS', 100))
|
||||
except Exception:
|
||||
max_conns = 100
|
||||
try:
|
||||
max_keep = int(os.getenv('HTTP_MAX_KEEPALIVE', 50))
|
||||
except Exception:
|
||||
max_keep = 50
|
||||
try:
|
||||
expiry = float(os.getenv('HTTP_KEEPALIVE_EXPIRY', 30.0))
|
||||
except Exception:
|
||||
expiry = 30.0
|
||||
return httpx.Limits(max_connections=max_conns, max_keepalive_connections=max_keep, keepalive_expiry=expiry)
|
||||
|
||||
@classmethod
|
||||
def get_http_client(cls) -> httpx.AsyncClient:
|
||||
if (os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'false').lower() == 'true'):
|
||||
"""Return a pooled AsyncClient by default for connection reuse.
|
||||
|
||||
Set ENABLE_HTTPX_CLIENT_CACHE=false to disable pooling and create a
|
||||
fresh client per request.
|
||||
"""
|
||||
if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'true').lower() != 'false':
|
||||
if cls._http_client is None:
|
||||
cls._http_client = httpx.AsyncClient(timeout=cls.timeout)
|
||||
cls._http_client = httpx.AsyncClient(
|
||||
timeout=cls.timeout,
|
||||
limits=cls._build_limits(),
|
||||
http2=(os.getenv('HTTP_ENABLE_HTTP2', 'false').lower() == 'true')
|
||||
)
|
||||
return cls._http_client
|
||||
return httpx.AsyncClient(timeout=cls.timeout)
|
||||
return httpx.AsyncClient(
|
||||
timeout=cls.timeout,
|
||||
limits=cls._build_limits(),
|
||||
http2=(os.getenv('HTTP_ENABLE_HTTP2', 'false').lower() == 'true')
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def aclose_http_client(cls) -> None:
|
||||
try:
|
||||
if cls._http_client is not None:
|
||||
await cls._http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
cls._http_client = None
|
||||
|
||||
def error_response(request_id, code, message, status=404):
|
||||
logger.error(f'{request_id} | REST gateway failed with code {code}')
|
||||
@@ -269,7 +315,7 @@ class GatewayService:
|
||||
else:
|
||||
return GatewayService.error_response(request_id, 'GTW004', 'Method not supported', status=405)
|
||||
finally:
|
||||
if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'false').lower() != 'true':
|
||||
if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'true').lower() == 'false':
|
||||
try:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
@@ -426,7 +472,7 @@ class GatewayService:
|
||||
try:
|
||||
http_response = await client.post(url, content=envelope, params=query_params, headers=headers)
|
||||
finally:
|
||||
if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'false').lower() != 'true':
|
||||
if os.getenv('ENABLE_HTTPX_CLIENT_CACHE', 'true').lower() == 'false':
|
||||
try:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
|
||||
@@ -78,6 +78,23 @@ def event_loop():
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def reset_http_client():
|
||||
"""Reset the pooled httpx client between tests to prevent connection pool exhaustion."""
|
||||
# Reset before the test (important for tests that monkeypatch httpx.AsyncClient)
|
||||
try:
|
||||
from services.gateway_service import GatewayService
|
||||
await GatewayService.aclose_http_client()
|
||||
except Exception:
|
||||
pass
|
||||
yield
|
||||
# After each test, close and reset the pooled client
|
||||
try:
|
||||
from services.gateway_service import GatewayService
|
||||
await GatewayService.aclose_http_client()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Test helpers expected by some suites
|
||||
async def create_api(client: AsyncClient, api_name: str, api_version: str):
|
||||
payload = {
|
||||
|
||||
@@ -29,7 +29,7 @@ async def test_bandwidth_enforcement_and_usage_tracking(monkeypatch, authed_clie
|
||||
return json.loads(self.text)
|
||||
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, timeout=None):
|
||||
def __init__(self, timeout=None, limits=None, http2=False):
|
||||
pass
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
@@ -80,7 +80,7 @@ async def test_monitor_tracks_bytes_in_out(monkeypatch, authed_client):
|
||||
return json.loads(self.text)
|
||||
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, timeout=None): pass
|
||||
def __init__(self, timeout=None, limits=None, http2=False): pass
|
||||
async def __aenter__(self): return self
|
||||
async def __aexit__(self, exc_type, exc, tb): return False
|
||||
async def post(self, url, data=None, json=None, headers=None, params=None): return _FakeHTTPResponse(200)
|
||||
|
||||
@@ -170,6 +170,9 @@ async def test_rest_payload_validation_allows_good_request(monkeypatch, authed_c
|
||||
return self._json
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, timeout=None, limits=None, http2=False):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
@@ -180,7 +183,7 @@ async def test_rest_payload_validation_allows_good_request(monkeypatch, authed_c
|
||||
return FakeResp()
|
||||
|
||||
import services.gateway_service as gw
|
||||
monkeypatch.setattr(gw.httpx, 'AsyncClient', lambda timeout: FakeClient())
|
||||
monkeypatch.setattr(gw.httpx, 'AsyncClient', FakeClient)
|
||||
|
||||
r = await authed_client.post(f'/api/rest/{api_name}/{version}/do', json={'user': {'name': 'Ab'}})
|
||||
assert r.status_code == 200
|
||||
@@ -209,6 +212,9 @@ async def test_soap_payload_validation_allows_good_request(monkeypatch, authed_c
|
||||
self.text = '<ok/>'
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, timeout=None, limits=None, http2=False):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
@@ -219,7 +225,7 @@ async def test_soap_payload_validation_allows_good_request(monkeypatch, authed_c
|
||||
return FakeResp()
|
||||
|
||||
import services.gateway_service as gw
|
||||
monkeypatch.setattr(gw.httpx, 'AsyncClient', lambda timeout: FakeClient())
|
||||
monkeypatch.setattr(gw.httpx, 'AsyncClient', FakeClient)
|
||||
|
||||
envelope = (
|
||||
'<?xml version=\"1.0\" encoding=\"UTF-8\"?>'
|
||||
|
||||
@@ -18,7 +18,7 @@ async def test_metrics_range_parameters(monkeypatch, authed_client):
|
||||
self.content = b'{}'
|
||||
def json(self): return {'ok': True}
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, timeout=None): pass
|
||||
def __init__(self, timeout=None, limits=None, http2=False): pass
|
||||
async def __aenter__(self): return self
|
||||
async def __aexit__(self, exc_type, exc, tb): return False
|
||||
async def get(self, url, params=None, headers=None): return _FakeHTTPResponse()
|
||||
|
||||
@@ -26,7 +26,7 @@ async def test_metrics_bytes_in_uses_content_length(monkeypatch, authed_client):
|
||||
return json.loads(self.text)
|
||||
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, timeout=None): pass
|
||||
def __init__(self, timeout=None, limits=None, http2=False): pass
|
||||
async def __aenter__(self): return self
|
||||
async def __aexit__(self, exc_type, exc, tb): return False
|
||||
async def post(self, url, data=None, json=None, headers=None, params=None): return _FakeHTTPResponse(200)
|
||||
|
||||
@@ -24,7 +24,7 @@ async def test_metrics_increment_on_gateway_requests(monkeypatch, authed_client)
|
||||
return {'ok': True}
|
||||
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, timeout=None):
|
||||
def __init__(self, timeout=None, limits=None, http2=False):
|
||||
pass
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
@@ -67,7 +67,7 @@ async def test_metrics_top_apis_aggregate(monkeypatch, authed_client):
|
||||
def json(self): return {'ok': True}
|
||||
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, timeout=None): pass
|
||||
def __init__(self, timeout=None, limits=None, http2=False): pass
|
||||
async def __aenter__(self): return self
|
||||
async def __aexit__(self, exc_type, exc, tb): return False
|
||||
async def get(self, url, params=None, headers=None): return _FakeHTTPResponse(200)
|
||||
@@ -114,7 +114,7 @@ async def test_monitor_report_csv(monkeypatch, authed_client):
|
||||
def json(self): return {'ok': True}
|
||||
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, timeout=None): pass
|
||||
def __init__(self, timeout=None, limits=None, http2=False): pass
|
||||
async def __aenter__(self): return self
|
||||
async def __aexit__(self, exc_type, exc, tb): return False
|
||||
async def get(self, url, params=None, headers=None): return _FakeHTTPResponse()
|
||||
|
||||
@@ -16,7 +16,7 @@ class _Resp:
|
||||
|
||||
def _mk_client_capture(seen):
|
||||
class _Client:
|
||||
def __init__(self, timeout=None):
|
||||
def __init__(self, timeout=None, limits=None, http2=False):
|
||||
pass
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
@@ -24,7 +24,7 @@ def _mk_retry_client(sequence, seen):
|
||||
counter = {'i': 0}
|
||||
|
||||
class _Client:
|
||||
def __init__(self, timeout=None):
|
||||
def __init__(self, timeout=None, limits=None, http2=False):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
|
||||
@@ -28,7 +28,7 @@ class _Resp:
|
||||
|
||||
def _mk_client_capture(seen, resp_status=200, resp_headers=None, resp_body=b'{"ok":true}'):
|
||||
class _Client:
|
||||
def __init__(self, timeout=None):
|
||||
def __init__(self, timeout=None, limits=None, http2=False):
|
||||
pass
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
@@ -176,7 +176,7 @@ async def test_token_defs_and_deduction_on_gateway(monkeypatch, authed_client):
|
||||
return self._json_body
|
||||
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, timeout=None):
|
||||
def __init__(self, timeout=None, limits=None, http2=False):
|
||||
self._timeout = timeout
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
@@ -16,7 +16,7 @@ def _mk_retry_xml_client(sequence, seen):
|
||||
counter = {'i': 0}
|
||||
|
||||
class _Client:
|
||||
def __init__(self, timeout=None):
|
||||
def __init__(self, timeout=None, limits=None, http2=False):
|
||||
pass
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
@@ -1,24 +1,140 @@
|
||||
"""
|
||||
Durable token revocation utilities.
|
||||
|
||||
Behavior:
|
||||
- If a Redis backend is available (MEM_OR_EXTERNAL != 'MEM' and Redis connection
|
||||
succeeds), revocations are persisted in Redis so they survive restarts and are
|
||||
shared across processes/nodes.
|
||||
- Otherwise, fall back to in-memory structures compatible with previous behavior.
|
||||
|
||||
Public API kept backward-compatible for existing imports/tests:
|
||||
- `TimedHeap` (in-memory helper)
|
||||
- `jwt_blacklist` (in-memory map for fallback)
|
||||
- `revoke_all_for_user`, `unrevoke_all_for_user`, `is_user_revoked`
|
||||
- `purge_expired_tokens` (no-op when using Redis)
|
||||
|
||||
New helpers used by auth/routes:
|
||||
- `add_revoked_jti(username, jti, ttl_seconds)`
|
||||
- `is_jti_revoked(username, jti)`
|
||||
"""
|
||||
|
||||
# External imports
|
||||
from datetime import datetime, timedelta
|
||||
import heapq
|
||||
import os
|
||||
from typing import Optional
|
||||
import time
|
||||
|
||||
try:
|
||||
from utils.database import database, revocations_collection
|
||||
except Exception: # pragma: no cover
|
||||
database = None # type: ignore
|
||||
revocations_collection = None # type: ignore
|
||||
|
||||
try:
|
||||
import redis # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
redis = None # type: ignore
|
||||
|
||||
# In-memory fallback structures (legacy behavior)
|
||||
jwt_blacklist = {}
|
||||
revoked_all_users = set()
|
||||
|
||||
def revoke_all_for_user(username: str):
|
||||
# Module-level Redis client (sync) for durability
|
||||
_redis_client = None
|
||||
_redis_enabled = False
|
||||
|
||||
def _init_redis_if_possible():
|
||||
global _redis_client, _redis_enabled
|
||||
if _redis_client is not None:
|
||||
return
|
||||
try:
|
||||
revoked_all_users.add(username)
|
||||
# Honor unified MEM/REDIS flag (same as database/cache utils)
|
||||
flag = os.getenv('MEM_OR_EXTERNAL') or os.getenv('MEM_OR_REDIS', 'MEM')
|
||||
if str(flag).upper() == 'MEM':
|
||||
_redis_enabled = False
|
||||
_redis_client = None
|
||||
return
|
||||
if redis is None:
|
||||
_redis_enabled = False
|
||||
_redis_client = None
|
||||
return
|
||||
host = os.getenv('REDIS_HOST', 'localhost')
|
||||
port = int(os.getenv('REDIS_PORT', 6379))
|
||||
db = int(os.getenv('REDIS_DB', 0))
|
||||
pool = redis.ConnectionPool(host=host, port=port, db=db, decode_responses=True, max_connections=100)
|
||||
_redis_client = redis.StrictRedis(connection_pool=pool)
|
||||
# cheap ping to verify
|
||||
try:
|
||||
_redis_client.ping()
|
||||
_redis_enabled = True
|
||||
except Exception:
|
||||
_redis_client = None
|
||||
_redis_enabled = False
|
||||
except Exception:
|
||||
pass
|
||||
_redis_client = None
|
||||
_redis_enabled = False
|
||||
|
||||
def _revoked_jti_key(username: str, jti: str) -> str:
|
||||
return f'jwt:revoked:{username}:{jti}'
|
||||
|
||||
def _revoke_all_key(username: str) -> str:
|
||||
return f'jwt:revoke_all:{username}'
|
||||
|
||||
def revoke_all_for_user(username: str):
|
||||
"""Mark all tokens for a user as revoked (durable if Redis is enabled)."""
|
||||
_init_redis_if_possible()
|
||||
try:
|
||||
# Memory-only mode: persist flag into in-memory DB for dumping
|
||||
if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None:
|
||||
try:
|
||||
existing = revocations_collection.find_one({'type': 'revoke_all', 'username': username})
|
||||
if existing:
|
||||
revocations_collection.update_one({'_id': existing.get('_id')}, {'$set': {'revoke_all': True}})
|
||||
else:
|
||||
revocations_collection.insert_one({'type': 'revoke_all', 'username': username, 'revoke_all': True})
|
||||
except Exception:
|
||||
revoked_all_users.add(username)
|
||||
return
|
||||
if _redis_enabled and _redis_client is not None:
|
||||
_redis_client.set(_revoke_all_key(username), '1') # no TTL – admin will clear explicitly
|
||||
else:
|
||||
revoked_all_users.add(username)
|
||||
except Exception:
|
||||
revoked_all_users.add(username)
|
||||
|
||||
def unrevoke_all_for_user(username: str):
|
||||
"""Clear 'revoke all' for a user (durable if Redis is enabled)."""
|
||||
_init_redis_if_possible()
|
||||
try:
|
||||
revoked_all_users.discard(username)
|
||||
if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None:
|
||||
try:
|
||||
revocations_collection.delete_one({'type': 'revoke_all', 'username': username})
|
||||
except Exception:
|
||||
revoked_all_users.discard(username)
|
||||
return
|
||||
if _redis_enabled and _redis_client is not None:
|
||||
_redis_client.delete(_revoke_all_key(username))
|
||||
else:
|
||||
revoked_all_users.discard(username)
|
||||
except Exception:
|
||||
pass
|
||||
revoked_all_users.discard(username)
|
||||
|
||||
def is_user_revoked(username: str) -> bool:
|
||||
return username in revoked_all_users
|
||||
"""Return True if user is under 'revoke all' (durable check if Redis enabled)."""
|
||||
_init_redis_if_possible()
|
||||
try:
|
||||
if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None:
|
||||
try:
|
||||
doc = revocations_collection.find_one({'type': 'revoke_all', 'username': username})
|
||||
return bool(doc and doc.get('revoke_all'))
|
||||
except Exception:
|
||||
pass
|
||||
if _redis_enabled and _redis_client is not None:
|
||||
return bool(_redis_client.exists(_revoke_all_key(username)))
|
||||
return username in revoked_all_users
|
||||
except Exception:
|
||||
return username in revoked_all_users
|
||||
|
||||
class TimedHeap:
|
||||
def __init__(self, purge_after=timedelta(hours=1)):
|
||||
@@ -46,7 +162,97 @@ class TimedHeap:
|
||||
return self.heap[0][1]
|
||||
return None
|
||||
|
||||
def add_revoked_jti(username: str, jti: str, ttl_seconds: Optional[int] = None):
|
||||
"""Add a specific JTI to the revocation list.
|
||||
|
||||
- If Redis is enabled, store key with TTL so it auto-expires.
|
||||
- Otherwise push into in-memory TimedHeap (approximate via default purge window when ttl not provided).
|
||||
"""
|
||||
if not username or not jti:
|
||||
return
|
||||
_init_redis_if_possible()
|
||||
try:
|
||||
if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None:
|
||||
try:
|
||||
exp = int(time.time()) + (max(1, int(ttl_seconds)) if ttl_seconds is not None else 3600)
|
||||
existing = revocations_collection.find_one({'type': 'jti', 'username': username, 'jti': jti})
|
||||
if existing:
|
||||
revocations_collection.update_one({'_id': existing.get('_id')}, {'$set': {'expires_at': exp}})
|
||||
else:
|
||||
revocations_collection.insert_one({'type': 'jti', 'username': username, 'jti': jti, 'expires_at': exp})
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
if _redis_enabled and _redis_client is not None:
|
||||
ttl = max(1, int(ttl_seconds)) if ttl_seconds is not None else 3600
|
||||
_redis_client.setex(_revoked_jti_key(username, jti), ttl, '1')
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback to in-memory
|
||||
th = jwt_blacklist.get(username)
|
||||
if not th:
|
||||
th = TimedHeap()
|
||||
jwt_blacklist[username] = th
|
||||
th.push(jti)
|
||||
|
||||
def is_jti_revoked(username: str, jti: str) -> bool:
|
||||
"""Check whether a specific JTI is revoked (durable if Redis enabled)."""
|
||||
if not username or not jti:
|
||||
return False
|
||||
_init_redis_if_possible()
|
||||
try:
|
||||
if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None:
|
||||
try:
|
||||
doc = revocations_collection.find_one({'type': 'jti', 'username': username, 'jti': jti})
|
||||
if not doc:
|
||||
pass
|
||||
else:
|
||||
exp = int(doc.get('expires_at') or 0)
|
||||
now = int(time.time())
|
||||
if exp <= now:
|
||||
# expire eagerly
|
||||
revocations_collection.delete_one({'_id': doc.get('_id')})
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
if _redis_enabled and _redis_client is not None:
|
||||
return bool(_redis_client.exists(_revoked_jti_key(username, jti)))
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback check in-memory
|
||||
th = jwt_blacklist.get(username)
|
||||
if not th:
|
||||
return False
|
||||
th.purge()
|
||||
for _, token_jti in list(th.heap):
|
||||
if token_jti == jti:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def purge_expired_tokens():
|
||||
"""No-op when Redis-backed; purge DB/in-memory when memory-only."""
|
||||
_init_redis_if_possible()
|
||||
if _redis_enabled:
|
||||
return
|
||||
# Purge memory-only DB entries
|
||||
try:
|
||||
if database is not None and getattr(database, 'memory_only', False) and revocations_collection is not None:
|
||||
now = int(time.time())
|
||||
# remove all expired jti docs
|
||||
to_delete = []
|
||||
for d in revocations_collection.find({'type': 'jti'}):
|
||||
try:
|
||||
if int(d.get('expires_at') or 0) <= now:
|
||||
to_delete.append(d)
|
||||
except Exception:
|
||||
to_delete.append(d)
|
||||
for d in to_delete:
|
||||
revocations_collection.delete_one({'_id': d.get('_id')})
|
||||
except Exception:
|
||||
pass
|
||||
# Purge in-memory fallback heaps
|
||||
for key, timed_heap in list(jwt_blacklist.items()):
|
||||
timed_heap.purge()
|
||||
if not timed_heap.heap:
|
||||
|
||||
@@ -18,7 +18,7 @@ import uuid
|
||||
from fastapi import HTTPException, Request
|
||||
from jose import jwt, JWTError
|
||||
|
||||
from utils.auth_blacklist import jwt_blacklist, is_user_revoked
|
||||
from utils.auth_blacklist import is_user_revoked, is_jti_revoked
|
||||
from utils.database import user_collection, role_collection
|
||||
from utils.doorman_cache_util import doorman_cache
|
||||
|
||||
@@ -96,13 +96,8 @@ async def auth_required(request: Request):
|
||||
jti = payload.get('jti')
|
||||
if not username or not jti:
|
||||
raise HTTPException(status_code=401, detail='Invalid token')
|
||||
if is_user_revoked(username):
|
||||
if is_user_revoked(username) or is_jti_revoked(username, jti):
|
||||
raise HTTPException(status_code=401, detail='Token has been revoked')
|
||||
if username in jwt_blacklist:
|
||||
timed_heap = jwt_blacklist[username]
|
||||
for _, token_jti in timed_heap.heap:
|
||||
if token_jti == jti:
|
||||
raise HTTPException(status_code=401, detail='Token has been revoked')
|
||||
user = doorman_cache.get_cache('user_cache', username)
|
||||
if not user:
|
||||
user = user_collection.find_one({'username': username})
|
||||
|
||||
@@ -135,7 +135,7 @@ class Database:
|
||||
pass
|
||||
print('Memory-only mode: Core data initialized (admin user/role/groups)')
|
||||
return
|
||||
collections = ['users', 'apis', 'endpoints', 'groups', 'roles', 'subscriptions', 'routings', 'credit_defs', 'user_credits', 'endpoint_validations', 'settings']
|
||||
collections = ['users', 'apis', 'endpoints', 'groups', 'roles', 'subscriptions', 'routings', 'credit_defs', 'user_credits', 'endpoint_validations', 'settings', 'revocations']
|
||||
for collection in collections:
|
||||
if collection not in self.db.list_collection_names():
|
||||
self.db_existed = False
|
||||
@@ -399,12 +399,14 @@ class InMemoryDB:
|
||||
self.user_credits = InMemoryCollection('user_credits')
|
||||
self.endpoint_validations = InMemoryCollection('endpoint_validations')
|
||||
self.settings = InMemoryCollection('settings')
|
||||
# New durable in-memory store for token revocations
|
||||
self.revocations = InMemoryCollection('revocations')
|
||||
|
||||
def list_collection_names(self):
|
||||
return [
|
||||
'users', 'apis', 'endpoints', 'groups', 'roles',
|
||||
'subscriptions', 'routings', 'credit_defs', 'user_credits',
|
||||
'endpoint_validations', 'settings'
|
||||
'endpoint_validations', 'settings', 'revocations'
|
||||
]
|
||||
|
||||
def create_collection(self, name):
|
||||
@@ -431,6 +433,7 @@ class InMemoryDB:
|
||||
'user_credits': coll_docs(self.user_credits),
|
||||
'endpoint_validations': coll_docs(self.endpoint_validations),
|
||||
'settings': coll_docs(self.settings),
|
||||
'revocations': coll_docs(self.revocations),
|
||||
}
|
||||
|
||||
def load_data(self, data: dict):
|
||||
@@ -448,6 +451,7 @@ class InMemoryDB:
|
||||
load_coll(self.user_credits, data.get('user_credits', []))
|
||||
load_coll(self.endpoint_validations, data.get('endpoint_validations', []))
|
||||
load_coll(self.settings, data.get('settings', []))
|
||||
load_coll(self.revocations, data.get('revocations', []))
|
||||
|
||||
database = Database()
|
||||
database.initialize_collections()
|
||||
@@ -466,6 +470,7 @@ if database.memory_only:
|
||||
credit_def_collection = db.credit_defs
|
||||
user_credit_collection = db.user_credits
|
||||
endpoint_validation_collection = db.endpoint_validations
|
||||
revocations_collection = db.revocations
|
||||
else:
|
||||
db = database.db
|
||||
mongodb_client = database.client
|
||||
@@ -479,3 +484,7 @@ else:
|
||||
credit_def_collection = db.credit_defs
|
||||
user_credit_collection = db.user_credits
|
||||
endpoint_validation_collection = db.endpoint_validations
|
||||
try:
|
||||
revocations_collection = db.revocations
|
||||
except Exception:
|
||||
revocations_collection = None
|
||||
|
||||
Reference in New Issue
Block a user