This commit is contained in:
seniorswe
2025-12-14 21:16:24 -05:00
parent 662fcd966d
commit c96cfa456c
3 changed files with 51 additions and 42 deletions

View File

@@ -242,27 +242,14 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
return None
async def _default_get_rules(
self,
request: Request,
user_id: str | None,
api_name: str | None,
endpoint_uri: str,
ip_address: str,
self, request: Request, user_id: str | None, api_name: str | None, endpoint_uri: str, ip_address: str
) -> list[RateLimitRule]:
"""
Default function to get applicable rules
Priority order:
1. If user has tier assigned → Use tier limits ONLY
2. If user has NO tier → Use per-user rate limit rules
3. Fall back to global rules
In production, this should query MongoDB for rules.
This is a placeholder that returns default rules.
Default implementation to get applicable rules
Args:
request: Incoming request
user_id: User ID
request: Request object
user_id: User identifier
api_name: API name
endpoint_uri: Endpoint URI
ip_address: IP address
@@ -314,21 +301,19 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
)
)
# Always add global rule as fallback
rules.append(
RateLimitRule(
rule_id='default_global',
rule_type=RuleType.GLOBAL,
time_window=TimeWindow.MINUTE,
limit=1000,
priority=0,
enabled=True,
description='Global rate limit',
# Add global rule as fallback if no rules were loaded
if not rules:
rules.append(
RateLimitRule(
rule_id='default_global',
rule_type=RuleType.GLOBAL,
time_window=TimeWindow.MINUTE,
limit=1000,
priority=0,
enabled=True,
description='Global rate limit',
)
)
)
# Sort by priority (highest first)
rules.sort(key=lambda r: r.priority, reverse=True)
return rules

View File

@@ -9,13 +9,14 @@ import logging
from typing import Any, Dict, List
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from models.rate_limit_models import QuotaType
from services.tier_service import TierService, get_tier_service
from utils.database_async import async_database
from utils.quota_tracker import QuotaTracker, get_quota_tracker
from utils.auth_util import auth_required
logger = logging.getLogger(__name__)
@@ -80,15 +81,20 @@ async def get_tier_service_dep() -> TierService:
return get_tier_service(db)
def get_current_user_id() -> str:
async def get_current_user_id(payload: dict = Depends(auth_required)) -> str:
"""
Get current user ID from request context
In production, this should extract from JWT token or session.
For now, returns a placeholder.
Get current user ID from JWT token
Args:
payload: JWT payload from auth_required dependency
Returns:
str: Username from JWT 'sub' claim
"""
# TODO: Extract from auth middleware
return 'current_user'
username = payload.get('sub')
if not username:
raise HTTPException(status_code=401, detail='Invalid token: missing user ID')
return username
# ============================================================================

View File

@@ -13,6 +13,7 @@ from pydantic import BaseModel, Field
from models.rate_limit_models import RateLimitRule, RuleType, TimeWindow
from services.rate_limit_rule_service import RateLimitRuleService, get_rate_limit_rule_service
from utils.database_async import async_database
from utils.auth_util import auth_required
logger = logging.getLogger(__name__)
@@ -399,8 +400,27 @@ async def get_rule_statistics(rule_service: RateLimitRuleService = Depends(get_r
# ============================================================================
async def get_current_user_id(payload: dict = Depends(auth_required)) -> str:
"""
Get current user ID from JWT token
Args:
payload: JWT payload from auth_required dependency
Returns:
str: Username from JWT 'sub' claim
"""
username = payload.get('sub')
if not username:
raise HTTPException(status_code=401, detail='Invalid token: missing user ID')
return username
@rate_limit_rule_router.get('/status', response_model=Dict[str, Any])
async def get_rate_limit_status(rule_service: RateLimitRuleService = Depends(get_rule_service_dep)):
async def get_rate_limit_status(
user_id: str = Depends(get_current_user_id),
rule_service: RateLimitRuleService = Depends(get_rule_service_dep),
):
"""
Get current rate limit status for the authenticated user
@@ -408,8 +428,6 @@ async def get_rate_limit_status(rule_service: RateLimitRuleService = Depends(get
This is a user-facing endpoint showing their current limits.
"""
try:
# TODO: Get user_id from auth middleware
user_id = 'current_user'
# Get applicable rules for user
rules = await rule_service.get_applicable_rules(user_id=user_id)