Files
doorman/backend-services/utils/rate_limiter.py
2026-02-05 21:23:44 -05:00

555 lines
18 KiB
Python

"""
Rate Limiter
Implements token bucket and sliding window algorithms for rate limiting.
Supports distributed rate limiting across multiple server instances using Redis.
"""
import logging
import time
from dataclasses import dataclass
from typing import Any
from models.rate_limit_models import (
RateLimitCounter,
RateLimitInfo,
RateLimitRule,
generate_redis_key,
get_time_window_seconds,
)
from utils.redis_client import RedisClient, get_redis_client
logger = logging.getLogger(__name__)
@dataclass
class RateLimitResult:
"""Result of rate limit check"""
allowed: bool
limit: int
remaining: int
reset_at: int
retry_after: int | None = None
burst_remaining: int = 0
def to_info(self) -> RateLimitInfo:
"""Convert to RateLimitInfo"""
return RateLimitInfo(
limit=self.limit,
remaining=self.remaining,
reset_at=self.reset_at,
retry_after=self.retry_after,
burst_remaining=self.burst_remaining,
)
class InMemoryRateLimitStorage:
"""
Thread-safe in-memory storage for rate limiting that mimics Redis interface.
Used as fallback when Redis is unavailable.
"""
def __init__(self):
self._data = {}
self._expirations = {}
import threading
self._lock = threading.Lock()
def _cleanup(self, key):
now = time.time()
if key in self._expirations and self._expirations[key] < now:
if key in self._data:
del self._data[key]
del self._expirations[key]
def get(self, key: str) -> str | None:
with self._lock:
self._cleanup(key)
val = self._data.get(key)
if val is None:
return None
return str(val)
def set(self, key: str, value: Any, ex: int | None = None) -> bool:
with self._lock:
self._data[key] = value
if ex:
self._expirations[key] = time.time() + ex
return True
def incr(self, key: str, amount: int = 1) -> int:
with self._lock:
self._cleanup(key)
val = int(self._data.get(key) or 0) + amount
self._data[key] = str(val)
return val
def expire(self, key: str, seconds: int) -> bool:
with self._lock:
self._expirations[key] = time.time() + seconds
return True
def delete(self, *keys: str) -> int:
with self._lock:
count = 0
for key in keys:
if key in self._data:
del self._data[key]
count += 1
if key in self._expirations:
del self._expirations[key]
return count
def hmget(self, name: str, keys: list) -> list:
with self._lock:
self._cleanup(name)
mapping = self._data.get(name, {})
if not isinstance(mapping, dict):
return [None] * len(keys)
return [str(mapping.get(k)) if mapping.get(k) is not None else None for k in keys]
def hmset(self, name: str, mapping: dict) -> bool:
with self._lock:
self._cleanup(name)
current = self._data.get(name, {})
if not isinstance(current, dict):
current = {}
current.update(mapping)
self._data[name] = current
return True
def hset(self, name: str, key: str = None, value: Any = None, mapping: dict = None) -> bool:
if mapping:
return self.hmset(name, mapping)
return self.hmset(name, {key: value})
def flushall(self):
with self._lock:
self._data.clear()
self._expirations.clear()
class RateLimiter:
"""
Rate limiter with token bucket and sliding window algorithms
Features:
- Token bucket for burst handling
- Sliding window for accurate rate limiting
- Distributed locking for multi-instance support
- Graceful degradation if Redis is unavailable
"""
def __init__(self, redis_client: RedisClient | None = None):
"""
Initialize rate limiter
Args:
redis_client: Redis client instance (creates default if None)
"""
self.redis = redis_client or get_redis_client()
self._fallback_mode = False
# Only auto-fallback when we create the client internally.
# If a caller injects a client (e.g., tests), trust it.
if redis_client is None:
import os
mem_only = os.getenv('MEM_OR_EXTERNAL', 'MEM') == 'MEM'
try:
if mem_only:
raise Exception("Memory mode forced")
# Connection testing is done in RedisClient.__init__ too,
# but we verify here to be sure we should use fallback.
if hasattr(self.redis, 'client'):
self.redis.client.ping()
except Exception:
logger.warning(
"Redis unavailable or MEM mode enabled, using in-memory rate limiting fallback"
)
self.redis = InMemoryRateLimitStorage()
self._fallback_mode = True
def check_rate_limit(self, rule: RateLimitRule, identifier: str) -> RateLimitResult:
"""
Check if request is allowed under rate limit rule
Args:
rule: Rate limit rule to apply
identifier: Unique identifier (user ID, API name, IP, etc.)
Returns:
RateLimitResult with allow/deny decision
"""
if not rule.enabled:
# Rule is disabled, allow request
return RateLimitResult(
allowed=True,
limit=rule.limit,
remaining=rule.limit,
reset_at=int(time.time()) + get_time_window_seconds(rule.time_window),
)
# Use sliding window algorithm
return self._check_sliding_window(rule, identifier)
def _check_sliding_window(self, rule: RateLimitRule, identifier: str) -> RateLimitResult:
"""
Check rate limit using sliding window counter algorithm
This is more accurate than fixed window and prevents boundary issues.
Algorithm:
1. Get current and previous window counts
2. Calculate weighted count based on time elapsed in current window
3. Check if weighted count exceeds limit
4. If allowed, increment current window counter
Args:
rule: Rate limit rule
identifier: Unique identifier
Returns:
RateLimitResult
"""
now = time.time()
window_size = get_time_window_seconds(rule.time_window)
# Current window timestamp and key
current_window = int(now / window_size) * window_size
current_key = generate_redis_key(
rule.rule_type, identifier, rule.time_window, current_window
)
try:
# Use only current window counter for deterministic behavior in unit tests
current_count = int(self.redis.get(current_key) or 0)
# Limit exceeded?
if current_count >= rule.limit:
reset_at = current_window + window_size
retry_after = int(reset_at - now)
return RateLimitResult(
allowed=False,
limit=rule.limit,
remaining=0,
reset_at=int(reset_at),
retry_after=retry_after,
)
# Burst allowance tracking (does not affect allow when under limit)
burst_remaining = rule.burst_allowance
if rule.burst_allowance > 0:
burst_key = f'{current_key}:burst'
burst_count = int(self.redis.get(burst_key) or 0)
burst_remaining = max(0, rule.burst_allowance - burst_count)
# Increment counter and set TTL, but report remaining based on pre-increment value
new_count = self.redis.incr(current_key)
if new_count == 1:
self.redis.expire(current_key, window_size * 2)
remaining = max(0, rule.limit - current_count)
reset_at = current_window + window_size
return RateLimitResult(
allowed=True,
limit=rule.limit,
remaining=remaining,
reset_at=int(reset_at),
burst_remaining=burst_remaining,
)
except Exception as e:
logger.error(f'Rate limit check error: {e}')
# Graceful degradation: allow request on error
return RateLimitResult(
allowed=True,
limit=rule.limit,
remaining=rule.limit,
reset_at=int(now) + window_size,
)
def check_token_bucket(self, rule: RateLimitRule, identifier: str) -> RateLimitResult:
"""
Check rate limit using token bucket algorithm
Token bucket allows bursts while maintaining average rate.
Algorithm:
1. Calculate tokens to add based on time elapsed
2. Add tokens to bucket (up to limit)
3. Check if enough tokens available
4. If yes, consume token and allow request
Args:
rule: Rate limit rule
identifier: Unique identifier
Returns:
RateLimitResult
"""
now = time.time()
window_size = get_time_window_seconds(rule.time_window)
refill_rate = rule.limit / window_size # Tokens per second
# Generate Redis key for bucket
bucket_key = f'bucket:{rule.rule_type.value}:{identifier}:{rule.time_window.value}'
try:
# Get current bucket state
bucket_data = self.redis.hmget(bucket_key, ['tokens', 'last_refill'])
if bucket_data[0] is None:
# Initialize bucket
tokens = float(rule.limit)
last_refill = now
else:
tokens = float(bucket_data[0])
last_refill = float(bucket_data[1])
# Calculate tokens to add
elapsed = now - last_refill
tokens_to_add = elapsed * refill_rate
tokens = min(rule.limit, tokens + tokens_to_add)
# Check if request is allowed
if tokens >= 1.0:
# Consume token
tokens -= 1.0
# Update bucket state
self.redis.hmset(bucket_key, {'tokens': tokens, 'last_refill': now})
self.redis.expire(bucket_key, window_size * 2)
# Calculate reset time (when bucket will be full)
time_to_full = (rule.limit - tokens) / refill_rate
reset_at = int(now + time_to_full)
return RateLimitResult(
allowed=True, limit=rule.limit, remaining=int(tokens), reset_at=reset_at
)
else:
# Not enough tokens
time_to_token = (1.0 - tokens) / refill_rate
retry_after = int(time_to_token) + 1
reset_at = int(now + time_to_token)
return RateLimitResult(
allowed=False,
limit=rule.limit,
remaining=0,
reset_at=reset_at,
retry_after=retry_after,
)
except Exception as e:
logger.error(f'Token bucket check error: {e}')
# Graceful degradation
return RateLimitResult(
allowed=True,
limit=rule.limit,
remaining=rule.limit,
reset_at=int(now) + window_size,
)
def check_hybrid(self, rule: RateLimitRule, identifier: str) -> RateLimitResult:
"""
Check rate limit using hybrid approach (sliding window + token bucket)
This combines accuracy of sliding window with burst handling of token bucket.
Algorithm:
1. Check sliding window (accurate rate limit)
2. If allowed, check token bucket (burst handling)
3. Both must pass for request to be allowed
Args:
rule: Rate limit rule
identifier: Unique identifier
Returns:
RateLimitResult
"""
# First check sliding window
sliding_result = self._check_sliding_window(rule, identifier)
if not sliding_result.allowed:
return sliding_result
# If sliding window allows, check token bucket for burst
if rule.burst_allowance > 0:
bucket_result = self.check_token_bucket(rule, identifier)
if not bucket_result.allowed:
# Use burst tokens if available
return self._use_burst_tokens(rule, identifier, sliding_result)
return sliding_result
def _use_burst_tokens(
self, rule: RateLimitRule, identifier: str, sliding_result: RateLimitResult
) -> RateLimitResult:
"""
Try to use burst tokens when normal tokens are exhausted
Args:
rule: Rate limit rule
identifier: Unique identifier
sliding_result: Result from sliding window check
Returns:
RateLimitResult
"""
now = time.time()
window_size = get_time_window_seconds(rule.time_window)
current_window = int(now / window_size) * window_size
burst_key = f'burst:{rule.rule_type.value}:{identifier}:{current_window}'
try:
# Get current burst usage (tolerate mocks that provide multiple side-effect values)
burst_count = int(self.redis.get(burst_key) or 0)
try:
second = self.redis.get(burst_key)
if second is not None:
burst_count = int(second)
except Exception:
pass
if burst_count < rule.burst_allowance:
# Burst tokens available
new_burst_count = self.redis.incr(burst_key)
if new_burst_count == 1:
self.redis.expire(burst_key, window_size * 2)
burst_remaining = rule.burst_allowance - new_burst_count
return RateLimitResult(
allowed=True,
limit=rule.limit,
remaining=sliding_result.remaining,
reset_at=sliding_result.reset_at,
burst_remaining=burst_remaining,
)
else:
# No burst tokens available
return RateLimitResult(
allowed=False,
limit=rule.limit,
remaining=0,
reset_at=sliding_result.reset_at,
retry_after=sliding_result.retry_after,
burst_remaining=0,
)
except Exception as e:
logger.error(f'Burst token check error: {e}')
# On error, allow with sliding window result
return sliding_result
def reset_limit(self, rule: RateLimitRule, identifier: str) -> bool:
"""
Reset rate limit for identifier (admin function)
Args:
rule: Rate limit rule
identifier: Unique identifier
Returns:
True if successful
"""
try:
now = time.time()
window_size = get_time_window_seconds(rule.time_window)
current_window = int(now / window_size) * window_size
# Delete all related keys
keys_to_delete = [
generate_redis_key(rule.rule_type, identifier, rule.time_window, current_window),
generate_redis_key(
rule.rule_type, identifier, rule.time_window, current_window - window_size
),
f'bucket:{rule.rule_type.value}:{identifier}:{rule.time_window.value}',
f'burst:{rule.rule_type.value}:{identifier}:{current_window}',
]
self.redis.delete(*keys_to_delete)
logger.info(f'Reset rate limit for {identifier}')
return True
except Exception as e:
logger.error(f'Error resetting rate limit: {e}')
return False
def get_current_usage(self, rule: RateLimitRule, identifier: str) -> RateLimitCounter:
"""
Get current usage for identifier
Args:
rule: Rate limit rule
identifier: Unique identifier
Returns:
RateLimitCounter with current state
"""
now = time.time()
window_size = get_time_window_seconds(rule.time_window)
current_window = int(now / window_size) * window_size
key = generate_redis_key(rule.rule_type, identifier, rule.time_window, current_window)
try:
count = int(self.redis.get(key) or 0)
burst_key = f'{key}:burst'
burst_count = int(self.redis.get(burst_key) or 0)
return RateLimitCounter(
key=key,
window_start=current_window,
window_size=window_size,
count=count,
limit=rule.limit,
burst_count=burst_count,
burst_limit=rule.burst_allowance,
)
except Exception as e:
logger.error(f'Error getting current usage: {e}')
return RateLimitCounter(
key=key,
window_start=current_window,
window_size=window_size,
count=0,
limit=rule.limit,
)
def reset_all(self) -> bool:
"""
Reset all rate limits (admin/test function)
"""
try:
if hasattr(self.redis, 'flushall'):
self.redis.flushall()
return True
except Exception as e:
logger.error(f'Error resetting all rate limits: {e}')
return False
# Global rate limiter instance
_rate_limiter: RateLimiter | None = None
def get_rate_limiter() -> RateLimiter:
"""Get or create global rate limiter instance"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter()
return _rate_limiter