mirror of
https://github.com/apidoorman/doorman.git
synced 2026-05-01 05:39:58 -05:00
bug fixes
This commit is contained in:
@@ -14,6 +14,7 @@ import logging
|
||||
import uuid
|
||||
import time
|
||||
from datetime import datetime
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
# Internal imports
|
||||
@@ -212,7 +213,7 @@ async def upload_proto_file(api_name: str, api_version: str, file: UploadFile =
|
||||
proto_path.write_text(proto_content)
|
||||
try:
|
||||
subprocess.run([
|
||||
'python', '-m', 'grpc_tools.protoc',
|
||||
sys.executable, '-m', 'grpc_tools.protoc',
|
||||
f'--proto_path={proto_path.parent}',
|
||||
f'--python_out={generated_dir}',
|
||||
f'--grpc_python_out={generated_dir}',
|
||||
@@ -408,7 +409,7 @@ async def update_proto_file(api_name: str, api_version: str, request: Request, p
|
||||
proto_path.write_text(proto_content)
|
||||
try:
|
||||
subprocess.run([
|
||||
'python', '-m', 'grpc_tools.protoc',
|
||||
sys.executable, '-m', 'grpc_tools.protoc',
|
||||
f'--proto_path={proto_path.parent}',
|
||||
f'--python_out={generated_dir}',
|
||||
f'--grpc_python_out={generated_dir}',
|
||||
|
||||
@@ -14,8 +14,9 @@ os.environ.setdefault('JWT_SECRET_KEY', 'test-secret-key')
|
||||
os.environ.setdefault('STARTUP_ADMIN_EMAIL', 'admin@doorman.dev')
|
||||
os.environ.setdefault('STARTUP_ADMIN_PASSWORD', 'password1')
|
||||
os.environ.setdefault('COOKIE_DOMAIN', 'testserver')
|
||||
os.environ.setdefault('LOGIN_IP_RATE_LIMIT', '1000') # High limit for tests
|
||||
os.environ.setdefault('LOGIN_IP_RATE_WINDOW', '60') # 1000 requests per minute for tests
|
||||
os.environ.setdefault('LOGIN_IP_RATE_LIMIT', '1000000')
|
||||
os.environ.setdefault('LOGIN_IP_RATE_WINDOW', '60')
|
||||
os.environ.setdefault('LOGIN_IP_RATE_DISABLED', 'true')
|
||||
|
||||
_HERE = os.path.dirname(__file__)
|
||||
_PROJECT_ROOT = os.path.abspath(os.path.join(_HERE, os.pardir))
|
||||
@@ -26,6 +27,13 @@ import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
import pytest
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from utils.database import database as _db
|
||||
_INITIAL_DB_SNAPSHOT: Optional[dict] = _db.db.dump_data() if getattr(_db, 'memory_only', True) else None
|
||||
except Exception:
|
||||
_INITIAL_DB_SNAPSHOT = None
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def authed_client():
|
||||
@@ -105,6 +113,33 @@ async def reset_http_client():
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True, scope='module')
|
||||
async def reset_in_memory_db_state():
|
||||
"""Restore in-memory DB and caches before each test to ensure isolation.
|
||||
|
||||
Prevents prior tests (e.g., password changes, user revocations, settings tweaks)
|
||||
from affecting later ones.
|
||||
"""
|
||||
try:
|
||||
if _INITIAL_DB_SNAPSHOT is not None:
|
||||
from utils.database import database as _db
|
||||
_db.db.load_data(_INITIAL_DB_SNAPSHOT)
|
||||
try:
|
||||
from utils.database import user_collection
|
||||
from utils import password_util as _pw
|
||||
pwd = os.environ.get('STARTUP_ADMIN_PASSWORD') or 'password1'
|
||||
user_collection.update_one({'username': 'admin'}, {'$set': {'password': _pw.hash_password(pwd)}})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from utils.doorman_cache_util import doorman_cache
|
||||
doorman_cache.clear_all()
|
||||
except Exception:
|
||||
pass
|
||||
yield
|
||||
|
||||
# Test helpers expected by some suites
|
||||
async def create_api(client: AsyncClient, api_name: str, api_version: str):
|
||||
payload = {
|
||||
|
||||
@@ -3,6 +3,7 @@ from fastapi import Request, HTTPException
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Internal imports
|
||||
from utils.auth_util import auth_required
|
||||
@@ -141,11 +142,17 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60):
|
||||
await limit_by_ip(request, limit=5, window=300)
|
||||
"""
|
||||
try:
|
||||
# Get client IP (trust X-Forwarded-For if from trusted proxy)
|
||||
if os.getenv('LOGIN_IP_RATE_DISABLED', 'false').lower() == 'true':
|
||||
now = int(time.time())
|
||||
return {
|
||||
'limit': limit,
|
||||
'remaining': limit,
|
||||
'reset': now + window,
|
||||
'window': window
|
||||
}
|
||||
client_ip = _get_client_ip(request, trust_xff=True)
|
||||
if not client_ip:
|
||||
logger.warning('Unable to determine client IP for rate limiting, allowing request')
|
||||
# Return default headers even if IP detection fails
|
||||
return {
|
||||
'limit': limit,
|
||||
'remaining': limit,
|
||||
@@ -153,33 +160,27 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60):
|
||||
'window': window
|
||||
}
|
||||
|
||||
# Create time-bucketed key (changes every window seconds)
|
||||
now = int(time.time())
|
||||
bucket = now // window
|
||||
key = f'ip_rate_limit:{client_ip}:{bucket}'
|
||||
|
||||
# Get Redis client or fallback to in-memory
|
||||
redis_client = getattr(request.app.state, 'redis', None)
|
||||
client = redis_client or _fallback_counter
|
||||
|
||||
# Increment counter
|
||||
try:
|
||||
count = await client.incr(key)
|
||||
if count == 1:
|
||||
await client.expire(key, window)
|
||||
except Exception as e:
|
||||
# Fallback to in-memory counter if Redis fails
|
||||
logger.warning(f'Redis failure in IP rate limiting, using fallback: {str(e)}')
|
||||
count = await _fallback_counter.incr(key)
|
||||
if count == 1:
|
||||
await _fallback_counter.expire(key, window)
|
||||
|
||||
# Calculate rate limit headers
|
||||
remaining = max(0, limit - count)
|
||||
reset_time = (bucket + 1) * window # Start of next window
|
||||
reset_time = (bucket + 1) * window
|
||||
retry_after = window - (now % window)
|
||||
|
||||
# Store rate limit info in request state for middleware
|
||||
rate_limit_info = {
|
||||
'limit': limit,
|
||||
'remaining': remaining,
|
||||
@@ -187,7 +188,6 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60):
|
||||
'window': window
|
||||
}
|
||||
|
||||
# Check if limit exceeded
|
||||
if count > limit:
|
||||
logger.warning(f'IP rate limit exceeded for {client_ip}: {count}/{limit} in {window}s')
|
||||
raise HTTPException(
|
||||
@@ -205,7 +205,6 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60):
|
||||
}
|
||||
)
|
||||
|
||||
# Log if approaching limit (80% threshold)
|
||||
if count > (limit * 0.8):
|
||||
logger.info(f'IP {client_ip} approaching rate limit: {count}/{limit}')
|
||||
|
||||
@@ -214,7 +213,6 @@ async def limit_by_ip(request: Request, limit: int = 10, window: int = 60):
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Don't block requests if rate limiting fails
|
||||
logger.error(f'IP rate limiting error: {str(e)}', exc_info=True)
|
||||
return {
|
||||
'limit': limit,
|
||||
|
||||
@@ -214,6 +214,16 @@ def restore_memory_from_file(path: Optional[str] = None) -> dict:
|
||||
|
||||
data = _from_jsonable(payload.get('data', {}))
|
||||
database.db.load_data(data)
|
||||
try:
|
||||
from utils.database import user_collection
|
||||
from utils import password_util as _pw
|
||||
import os as _os
|
||||
admin = user_collection.find_one({'username': 'admin'})
|
||||
if admin is not None and not isinstance(admin.get('password'), (bytes, bytearray)):
|
||||
pwd = _os.getenv('STARTUP_ADMIN_PASSWORD') or 'password1'
|
||||
user_collection.update_one({'username': 'admin'}, {'$set': {'password': _pw.hash_password(pwd)}})
|
||||
except Exception:
|
||||
pass
|
||||
return {'version': payload.get('version', 1), 'created_at': payload.get('created_at')}
|
||||
|
||||
def find_latest_dump_path(path_hint: Optional[str] = None) -> Optional[str]:
|
||||
|
||||
+2
-2
@@ -14,8 +14,9 @@ os.environ.setdefault('JWT_SECRET_KEY', 'test-secret-key')
|
||||
os.environ.setdefault('STARTUP_ADMIN_EMAIL', 'admin@doorman.dev')
|
||||
os.environ.setdefault('STARTUP_ADMIN_PASSWORD', 'password1')
|
||||
os.environ.setdefault('COOKIE_DOMAIN', 'testserver')
|
||||
os.environ.setdefault('LOGIN_IP_RATE_LIMIT', '1000000')
|
||||
os.environ.setdefault('LOGIN_IP_RATE_WINDOW', '60')
|
||||
|
||||
# Ensure backend-services is on sys.path for imports like `from doorman import doorman`
|
||||
_HERE = os.path.dirname(__file__)
|
||||
_BACKEND_DIR = os.path.abspath(os.path.join(_HERE, os.pardir, 'backend-services'))
|
||||
if _BACKEND_DIR not in sys.path:
|
||||
@@ -79,4 +80,3 @@ async def authed_client():
|
||||
except Exception:
|
||||
pass
|
||||
return client
|
||||
|
||||
|
||||
Reference in New Issue
Block a user