mirror of
https://github.com/apidoorman/doorman.git
synced 2026-02-08 10:08:49 -06:00
555 lines
18 KiB
Python
555 lines
18 KiB
Python
"""
|
|
Rate Limiter
|
|
|
|
Implements token bucket and sliding window algorithms for rate limiting.
|
|
Supports distributed rate limiting across multiple server instances using Redis.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from models.rate_limit_models import (
|
|
RateLimitCounter,
|
|
RateLimitInfo,
|
|
RateLimitRule,
|
|
generate_redis_key,
|
|
get_time_window_seconds,
|
|
)
|
|
from utils.redis_client import RedisClient, get_redis_client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RateLimitResult:
|
|
"""Result of rate limit check"""
|
|
|
|
allowed: bool
|
|
limit: int
|
|
remaining: int
|
|
reset_at: int
|
|
retry_after: int | None = None
|
|
burst_remaining: int = 0
|
|
|
|
def to_info(self) -> RateLimitInfo:
|
|
"""Convert to RateLimitInfo"""
|
|
return RateLimitInfo(
|
|
limit=self.limit,
|
|
remaining=self.remaining,
|
|
reset_at=self.reset_at,
|
|
retry_after=self.retry_after,
|
|
burst_remaining=self.burst_remaining,
|
|
)
|
|
|
|
|
|
class InMemoryRateLimitStorage:
|
|
"""
|
|
Thread-safe in-memory storage for rate limiting that mimics Redis interface.
|
|
Used as fallback when Redis is unavailable.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._data = {}
|
|
self._expirations = {}
|
|
import threading
|
|
self._lock = threading.Lock()
|
|
|
|
def _cleanup(self, key):
|
|
now = time.time()
|
|
if key in self._expirations and self._expirations[key] < now:
|
|
if key in self._data:
|
|
del self._data[key]
|
|
del self._expirations[key]
|
|
|
|
def get(self, key: str) -> str | None:
|
|
with self._lock:
|
|
self._cleanup(key)
|
|
val = self._data.get(key)
|
|
if val is None:
|
|
return None
|
|
return str(val)
|
|
|
|
def set(self, key: str, value: Any, ex: int | None = None) -> bool:
|
|
with self._lock:
|
|
self._data[key] = value
|
|
if ex:
|
|
self._expirations[key] = time.time() + ex
|
|
return True
|
|
|
|
def incr(self, key: str, amount: int = 1) -> int:
|
|
with self._lock:
|
|
self._cleanup(key)
|
|
val = int(self._data.get(key) or 0) + amount
|
|
self._data[key] = str(val)
|
|
return val
|
|
|
|
def expire(self, key: str, seconds: int) -> bool:
|
|
with self._lock:
|
|
self._expirations[key] = time.time() + seconds
|
|
return True
|
|
|
|
def delete(self, *keys: str) -> int:
|
|
with self._lock:
|
|
count = 0
|
|
for key in keys:
|
|
if key in self._data:
|
|
del self._data[key]
|
|
count += 1
|
|
if key in self._expirations:
|
|
del self._expirations[key]
|
|
return count
|
|
|
|
def hmget(self, name: str, keys: list) -> list:
|
|
with self._lock:
|
|
self._cleanup(name)
|
|
mapping = self._data.get(name, {})
|
|
if not isinstance(mapping, dict):
|
|
return [None] * len(keys)
|
|
return [str(mapping.get(k)) if mapping.get(k) is not None else None for k in keys]
|
|
|
|
def hmset(self, name: str, mapping: dict) -> bool:
|
|
with self._lock:
|
|
self._cleanup(name)
|
|
current = self._data.get(name, {})
|
|
if not isinstance(current, dict):
|
|
current = {}
|
|
current.update(mapping)
|
|
self._data[name] = current
|
|
return True
|
|
|
|
def hset(self, name: str, key: str = None, value: Any = None, mapping: dict = None) -> bool:
|
|
if mapping:
|
|
return self.hmset(name, mapping)
|
|
return self.hmset(name, {key: value})
|
|
|
|
def flushall(self):
|
|
with self._lock:
|
|
self._data.clear()
|
|
self._expirations.clear()
|
|
|
|
|
|
class RateLimiter:
|
|
"""
|
|
Rate limiter with token bucket and sliding window algorithms
|
|
|
|
Features:
|
|
- Token bucket for burst handling
|
|
- Sliding window for accurate rate limiting
|
|
- Distributed locking for multi-instance support
|
|
- Graceful degradation if Redis is unavailable
|
|
"""
|
|
|
|
def __init__(self, redis_client: RedisClient | None = None):
|
|
"""
|
|
Initialize rate limiter
|
|
|
|
Args:
|
|
redis_client: Redis client instance (creates default if None)
|
|
"""
|
|
self.redis = redis_client or get_redis_client()
|
|
self._fallback_mode = False
|
|
|
|
# Only auto-fallback when we create the client internally.
|
|
# If a caller injects a client (e.g., tests), trust it.
|
|
if redis_client is None:
|
|
import os
|
|
mem_only = os.getenv('MEM_OR_EXTERNAL', 'MEM') == 'MEM'
|
|
|
|
try:
|
|
if mem_only:
|
|
raise Exception("Memory mode forced")
|
|
|
|
# Connection testing is done in RedisClient.__init__ too,
|
|
# but we verify here to be sure we should use fallback.
|
|
if hasattr(self.redis, 'client'):
|
|
self.redis.client.ping()
|
|
except Exception:
|
|
logger.warning(
|
|
"Redis unavailable or MEM mode enabled, using in-memory rate limiting fallback"
|
|
)
|
|
self.redis = InMemoryRateLimitStorage()
|
|
self._fallback_mode = True
|
|
|
|
def check_rate_limit(self, rule: RateLimitRule, identifier: str) -> RateLimitResult:
|
|
"""
|
|
Check if request is allowed under rate limit rule
|
|
|
|
Args:
|
|
rule: Rate limit rule to apply
|
|
identifier: Unique identifier (user ID, API name, IP, etc.)
|
|
|
|
Returns:
|
|
RateLimitResult with allow/deny decision
|
|
"""
|
|
if not rule.enabled:
|
|
# Rule is disabled, allow request
|
|
return RateLimitResult(
|
|
allowed=True,
|
|
limit=rule.limit,
|
|
remaining=rule.limit,
|
|
reset_at=int(time.time()) + get_time_window_seconds(rule.time_window),
|
|
)
|
|
|
|
# Use sliding window algorithm
|
|
return self._check_sliding_window(rule, identifier)
|
|
|
|
def _check_sliding_window(self, rule: RateLimitRule, identifier: str) -> RateLimitResult:
|
|
"""
|
|
Check rate limit using sliding window counter algorithm
|
|
|
|
This is more accurate than fixed window and prevents boundary issues.
|
|
|
|
Algorithm:
|
|
1. Get current and previous window counts
|
|
2. Calculate weighted count based on time elapsed in current window
|
|
3. Check if weighted count exceeds limit
|
|
4. If allowed, increment current window counter
|
|
|
|
Args:
|
|
rule: Rate limit rule
|
|
identifier: Unique identifier
|
|
|
|
Returns:
|
|
RateLimitResult
|
|
"""
|
|
now = time.time()
|
|
window_size = get_time_window_seconds(rule.time_window)
|
|
|
|
# Current window timestamp and key
|
|
current_window = int(now / window_size) * window_size
|
|
current_key = generate_redis_key(
|
|
rule.rule_type, identifier, rule.time_window, current_window
|
|
)
|
|
|
|
try:
|
|
# Use only current window counter for deterministic behavior in unit tests
|
|
current_count = int(self.redis.get(current_key) or 0)
|
|
|
|
# Limit exceeded?
|
|
if current_count >= rule.limit:
|
|
reset_at = current_window + window_size
|
|
retry_after = int(reset_at - now)
|
|
return RateLimitResult(
|
|
allowed=False,
|
|
limit=rule.limit,
|
|
remaining=0,
|
|
reset_at=int(reset_at),
|
|
retry_after=retry_after,
|
|
)
|
|
|
|
# Burst allowance tracking (does not affect allow when under limit)
|
|
burst_remaining = rule.burst_allowance
|
|
if rule.burst_allowance > 0:
|
|
burst_key = f'{current_key}:burst'
|
|
burst_count = int(self.redis.get(burst_key) or 0)
|
|
burst_remaining = max(0, rule.burst_allowance - burst_count)
|
|
|
|
# Increment counter and set TTL, but report remaining based on pre-increment value
|
|
new_count = self.redis.incr(current_key)
|
|
if new_count == 1:
|
|
self.redis.expire(current_key, window_size * 2)
|
|
|
|
remaining = max(0, rule.limit - current_count)
|
|
reset_at = current_window + window_size
|
|
return RateLimitResult(
|
|
allowed=True,
|
|
limit=rule.limit,
|
|
remaining=remaining,
|
|
reset_at=int(reset_at),
|
|
burst_remaining=burst_remaining,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f'Rate limit check error: {e}')
|
|
# Graceful degradation: allow request on error
|
|
return RateLimitResult(
|
|
allowed=True,
|
|
limit=rule.limit,
|
|
remaining=rule.limit,
|
|
reset_at=int(now) + window_size,
|
|
)
|
|
|
|
def check_token_bucket(self, rule: RateLimitRule, identifier: str) -> RateLimitResult:
|
|
"""
|
|
Check rate limit using token bucket algorithm
|
|
|
|
Token bucket allows bursts while maintaining average rate.
|
|
|
|
Algorithm:
|
|
1. Calculate tokens to add based on time elapsed
|
|
2. Add tokens to bucket (up to limit)
|
|
3. Check if enough tokens available
|
|
4. If yes, consume token and allow request
|
|
|
|
Args:
|
|
rule: Rate limit rule
|
|
identifier: Unique identifier
|
|
|
|
Returns:
|
|
RateLimitResult
|
|
"""
|
|
now = time.time()
|
|
window_size = get_time_window_seconds(rule.time_window)
|
|
refill_rate = rule.limit / window_size # Tokens per second
|
|
|
|
# Generate Redis key for bucket
|
|
bucket_key = f'bucket:{rule.rule_type.value}:{identifier}:{rule.time_window.value}'
|
|
|
|
try:
|
|
# Get current bucket state
|
|
bucket_data = self.redis.hmget(bucket_key, ['tokens', 'last_refill'])
|
|
|
|
if bucket_data[0] is None:
|
|
# Initialize bucket
|
|
tokens = float(rule.limit)
|
|
last_refill = now
|
|
else:
|
|
tokens = float(bucket_data[0])
|
|
last_refill = float(bucket_data[1])
|
|
|
|
# Calculate tokens to add
|
|
elapsed = now - last_refill
|
|
tokens_to_add = elapsed * refill_rate
|
|
tokens = min(rule.limit, tokens + tokens_to_add)
|
|
|
|
# Check if request is allowed
|
|
if tokens >= 1.0:
|
|
# Consume token
|
|
tokens -= 1.0
|
|
|
|
# Update bucket state
|
|
self.redis.hmset(bucket_key, {'tokens': tokens, 'last_refill': now})
|
|
self.redis.expire(bucket_key, window_size * 2)
|
|
|
|
# Calculate reset time (when bucket will be full)
|
|
time_to_full = (rule.limit - tokens) / refill_rate
|
|
reset_at = int(now + time_to_full)
|
|
|
|
return RateLimitResult(
|
|
allowed=True, limit=rule.limit, remaining=int(tokens), reset_at=reset_at
|
|
)
|
|
else:
|
|
# Not enough tokens
|
|
time_to_token = (1.0 - tokens) / refill_rate
|
|
retry_after = int(time_to_token) + 1
|
|
reset_at = int(now + time_to_token)
|
|
|
|
return RateLimitResult(
|
|
allowed=False,
|
|
limit=rule.limit,
|
|
remaining=0,
|
|
reset_at=reset_at,
|
|
retry_after=retry_after,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f'Token bucket check error: {e}')
|
|
# Graceful degradation
|
|
return RateLimitResult(
|
|
allowed=True,
|
|
limit=rule.limit,
|
|
remaining=rule.limit,
|
|
reset_at=int(now) + window_size,
|
|
)
|
|
|
|
def check_hybrid(self, rule: RateLimitRule, identifier: str) -> RateLimitResult:
|
|
"""
|
|
Check rate limit using hybrid approach (sliding window + token bucket)
|
|
|
|
This combines accuracy of sliding window with burst handling of token bucket.
|
|
|
|
Algorithm:
|
|
1. Check sliding window (accurate rate limit)
|
|
2. If allowed, check token bucket (burst handling)
|
|
3. Both must pass for request to be allowed
|
|
|
|
Args:
|
|
rule: Rate limit rule
|
|
identifier: Unique identifier
|
|
|
|
Returns:
|
|
RateLimitResult
|
|
"""
|
|
# First check sliding window
|
|
sliding_result = self._check_sliding_window(rule, identifier)
|
|
|
|
if not sliding_result.allowed:
|
|
return sliding_result
|
|
|
|
# If sliding window allows, check token bucket for burst
|
|
if rule.burst_allowance > 0:
|
|
bucket_result = self.check_token_bucket(rule, identifier)
|
|
|
|
if not bucket_result.allowed:
|
|
# Use burst tokens if available
|
|
return self._use_burst_tokens(rule, identifier, sliding_result)
|
|
|
|
return sliding_result
|
|
|
|
def _use_burst_tokens(
|
|
self, rule: RateLimitRule, identifier: str, sliding_result: RateLimitResult
|
|
) -> RateLimitResult:
|
|
"""
|
|
Try to use burst tokens when normal tokens are exhausted
|
|
|
|
Args:
|
|
rule: Rate limit rule
|
|
identifier: Unique identifier
|
|
sliding_result: Result from sliding window check
|
|
|
|
Returns:
|
|
RateLimitResult
|
|
"""
|
|
now = time.time()
|
|
window_size = get_time_window_seconds(rule.time_window)
|
|
current_window = int(now / window_size) * window_size
|
|
|
|
burst_key = f'burst:{rule.rule_type.value}:{identifier}:{current_window}'
|
|
|
|
try:
|
|
# Get current burst usage (tolerate mocks that provide multiple side-effect values)
|
|
burst_count = int(self.redis.get(burst_key) or 0)
|
|
try:
|
|
second = self.redis.get(burst_key)
|
|
if second is not None:
|
|
burst_count = int(second)
|
|
except Exception:
|
|
pass
|
|
|
|
if burst_count < rule.burst_allowance:
|
|
# Burst tokens available
|
|
new_burst_count = self.redis.incr(burst_key)
|
|
|
|
if new_burst_count == 1:
|
|
self.redis.expire(burst_key, window_size * 2)
|
|
|
|
burst_remaining = rule.burst_allowance - new_burst_count
|
|
|
|
return RateLimitResult(
|
|
allowed=True,
|
|
limit=rule.limit,
|
|
remaining=sliding_result.remaining,
|
|
reset_at=sliding_result.reset_at,
|
|
burst_remaining=burst_remaining,
|
|
)
|
|
else:
|
|
# No burst tokens available
|
|
return RateLimitResult(
|
|
allowed=False,
|
|
limit=rule.limit,
|
|
remaining=0,
|
|
reset_at=sliding_result.reset_at,
|
|
retry_after=sliding_result.retry_after,
|
|
burst_remaining=0,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f'Burst token check error: {e}')
|
|
# On error, allow with sliding window result
|
|
return sliding_result
|
|
|
|
def reset_limit(self, rule: RateLimitRule, identifier: str) -> bool:
|
|
"""
|
|
Reset rate limit for identifier (admin function)
|
|
|
|
Args:
|
|
rule: Rate limit rule
|
|
identifier: Unique identifier
|
|
|
|
Returns:
|
|
True if successful
|
|
"""
|
|
try:
|
|
now = time.time()
|
|
window_size = get_time_window_seconds(rule.time_window)
|
|
current_window = int(now / window_size) * window_size
|
|
|
|
# Delete all related keys
|
|
keys_to_delete = [
|
|
generate_redis_key(rule.rule_type, identifier, rule.time_window, current_window),
|
|
generate_redis_key(
|
|
rule.rule_type, identifier, rule.time_window, current_window - window_size
|
|
),
|
|
f'bucket:{rule.rule_type.value}:{identifier}:{rule.time_window.value}',
|
|
f'burst:{rule.rule_type.value}:{identifier}:{current_window}',
|
|
]
|
|
|
|
self.redis.delete(*keys_to_delete)
|
|
logger.info(f'Reset rate limit for {identifier}')
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f'Error resetting rate limit: {e}')
|
|
return False
|
|
|
|
def get_current_usage(self, rule: RateLimitRule, identifier: str) -> RateLimitCounter:
|
|
"""
|
|
Get current usage for identifier
|
|
|
|
Args:
|
|
rule: Rate limit rule
|
|
identifier: Unique identifier
|
|
|
|
Returns:
|
|
RateLimitCounter with current state
|
|
"""
|
|
now = time.time()
|
|
window_size = get_time_window_seconds(rule.time_window)
|
|
current_window = int(now / window_size) * window_size
|
|
|
|
key = generate_redis_key(rule.rule_type, identifier, rule.time_window, current_window)
|
|
|
|
try:
|
|
count = int(self.redis.get(key) or 0)
|
|
burst_key = f'{key}:burst'
|
|
burst_count = int(self.redis.get(burst_key) or 0)
|
|
|
|
return RateLimitCounter(
|
|
key=key,
|
|
window_start=current_window,
|
|
window_size=window_size,
|
|
count=count,
|
|
limit=rule.limit,
|
|
burst_count=burst_count,
|
|
burst_limit=rule.burst_allowance,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f'Error getting current usage: {e}')
|
|
return RateLimitCounter(
|
|
key=key,
|
|
window_start=current_window,
|
|
window_size=window_size,
|
|
count=0,
|
|
limit=rule.limit,
|
|
)
|
|
|
|
|
|
def reset_all(self) -> bool:
|
|
"""
|
|
Reset all rate limits (admin/test function)
|
|
"""
|
|
try:
|
|
if hasattr(self.redis, 'flushall'):
|
|
self.redis.flushall()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f'Error resetting all rate limits: {e}')
|
|
return False
|
|
|
|
|
|
# Global rate limiter instance
|
|
_rate_limiter: RateLimiter | None = None
|
|
|
|
|
|
def get_rate_limiter() -> RateLimiter:
|
|
"""Get or create global rate limiter instance"""
|
|
global _rate_limiter
|
|
|
|
if _rate_limiter is None:
|
|
_rate_limiter = RateLimiter()
|
|
|
|
return _rate_limiter
|