mirror of
https://github.com/apidoorman/doorman.git
synced 2026-05-04 07:09:53 -05:00
100 lines
3.8 KiB
Python
100 lines
3.8 KiB
Python
from fastapi import Request, HTTPException
|
|
from utils.auth_util import auth_required
|
|
from utils.database import user_collection
|
|
from utils.doorman_cache_util import doorman_cache
|
|
import asyncio
|
|
import time
|
|
|
|
|
|
class InMemoryWindowCounter:
|
|
"""Simple in-memory counter with TTL semantics to mimic required Redis ops.
|
|
Not distributed; process-local only. Used as fallback when Redis is unavailable.
|
|
"""
|
|
def __init__(self):
|
|
self._store = {}
|
|
|
|
async def incr(self, key: str) -> int:
|
|
now = int(time.time())
|
|
entry = self._store.get(key)
|
|
if entry and entry['expires_at'] > now:
|
|
entry['count'] += 1
|
|
else:
|
|
# Set a short default TTL; caller should extend via expire()
|
|
entry = {'count': 1, 'expires_at': now + 1}
|
|
self._store[key] = entry
|
|
return entry['count']
|
|
|
|
async def expire(self, key: str, ttl_seconds: int) -> None:
|
|
now = int(time.time())
|
|
entry = self._store.get(key)
|
|
if entry:
|
|
entry['expires_at'] = now + int(ttl_seconds)
|
|
self._store[key] = entry
|
|
|
|
|
|
_fallback_counter = InMemoryWindowCounter()
|
|
|
|
def duration_to_seconds(duration: str) -> int:
|
|
mapping = {
|
|
"second": 1,
|
|
"minute": 60,
|
|
"hour": 3600,
|
|
"day": 86400,
|
|
"week": 604800,
|
|
"month": 2592000,
|
|
"year": 31536000
|
|
}
|
|
if not duration:
|
|
return 60
|
|
if duration.endswith("s"):
|
|
duration = duration[:-1]
|
|
return mapping.get(duration.lower(), 60)
|
|
|
|
async def limit_and_throttle(request: Request):
|
|
payload = await auth_required(request)
|
|
username = payload.get("sub")
|
|
redis_client = getattr(request.app.state, 'redis', None)
|
|
user = doorman_cache.get_cache("user_cache", username)
|
|
if not user:
|
|
user = await user_collection.find_one({"username": username})
|
|
rate = int(user.get("rate_limit_duration") or 1)
|
|
duration = user.get("rate_limit_duration_type", "minute")
|
|
window = duration_to_seconds(duration)
|
|
now = int(time.time())
|
|
key = f"rate_limit:{username}:{now // window}"
|
|
try:
|
|
client = redis_client or _fallback_counter
|
|
count = await client.incr(key)
|
|
if count == 1:
|
|
await client.expire(key, window)
|
|
except Exception:
|
|
# Fallback to in-memory counter on any client error
|
|
count = await _fallback_counter.incr(key)
|
|
if count == 1:
|
|
await _fallback_counter.expire(key, window)
|
|
if count > rate:
|
|
raise HTTPException(status_code=429, detail="Rate limit exceeded")
|
|
throttle_limit = int(user.get("throttle_duration") or 5)
|
|
throttle_duration = user.get("throttle_duration_type", "second")
|
|
throttle_window = duration_to_seconds(throttle_duration)
|
|
throttle_key = f"throttle_limit:{username}:{now // throttle_window}"
|
|
try:
|
|
client = redis_client or _fallback_counter
|
|
throttle_count = await client.incr(throttle_key)
|
|
if throttle_count == 1:
|
|
await client.expire(throttle_key, throttle_window)
|
|
except Exception:
|
|
throttle_count = await _fallback_counter.incr(throttle_key)
|
|
if throttle_count == 1:
|
|
await _fallback_counter.expire(throttle_key, throttle_window)
|
|
throttle_queue_limit = int(user.get("throttle_queue_limit") or 10)
|
|
if throttle_count > throttle_queue_limit:
|
|
raise HTTPException(status_code=429, detail="Throttle queue limit exceeded")
|
|
if throttle_count > throttle_limit:
|
|
throttle_wait = float(user.get("throttle_wait_duration", 0.5) or 0.5)
|
|
throttle_wait_duration = user.get("throttle_wait_duration_type", "second")
|
|
if throttle_wait_duration != "second":
|
|
throttle_wait *= duration_to_seconds(throttle_wait_duration)
|
|
dynamic_wait = throttle_wait * (throttle_count - throttle_limit)
|
|
await asyncio.sleep(dynamic_wait)
|