From c96cfa456cd32a513727ba38814c6fdd1edd48cd Mon Sep 17 00:00:00 2001 From: seniorswe Date: Sun, 14 Dec 2025 21:16:24 -0500 Subject: [PATCH] bug fix --- .../middleware/rate_limit_middleware.py | 47 +++++++------------ backend-services/routes/quota_routes.py | 22 +++++---- .../routes/rate_limit_rule_routes.py | 24 ++++++++-- 3 files changed, 51 insertions(+), 42 deletions(-) diff --git a/backend-services/middleware/rate_limit_middleware.py b/backend-services/middleware/rate_limit_middleware.py index 6202a6a..b5e2ceb 100644 --- a/backend-services/middleware/rate_limit_middleware.py +++ b/backend-services/middleware/rate_limit_middleware.py @@ -242,27 +242,14 @@ class RateLimitMiddleware(BaseHTTPMiddleware): return None async def _default_get_rules( - self, - request: Request, - user_id: str | None, - api_name: str | None, - endpoint_uri: str, - ip_address: str, + self, request: Request, user_id: str | None, api_name: str | None, endpoint_uri: str, ip_address: str ) -> list[RateLimitRule]: """ - Default function to get applicable rules - - Priority order: - 1. If user has tier assigned → Use tier limits ONLY - 2. If user has NO tier → Use per-user rate limit rules - 3. Fall back to global rules - - In production, this should query MongoDB for rules. - This is a placeholder that returns default rules. + Default implementation to get applicable rules Args: - request: Incoming request - user_id: User ID + request: Request object + user_id: User identifier api_name: API name endpoint_uri: Endpoint URI ip_address: IP address @@ -314,21 +301,19 @@ class RateLimitMiddleware(BaseHTTPMiddleware): ) ) - # Always add global rule as fallback - rules.append( - RateLimitRule( - rule_id='default_global', - rule_type=RuleType.GLOBAL, - time_window=TimeWindow.MINUTE, - limit=1000, - priority=0, - enabled=True, - description='Global rate limit', + # Add global rule as fallback if no rules were loaded + if not rules: + rules.append( + RateLimitRule( + rule_id='default_global', + rule_type=RuleType.GLOBAL, + time_window=TimeWindow.MINUTE, + limit=1000, + priority=0, + enabled=True, + description='Global rate limit', + ) ) - ) - - # Sort by priority (highest first) - rules.sort(key=lambda r: r.priority, reverse=True) return rules diff --git a/backend-services/routes/quota_routes.py b/backend-services/routes/quota_routes.py index 0b00f8f..5845eda 100644 --- a/backend-services/routes/quota_routes.py +++ b/backend-services/routes/quota_routes.py @@ -9,13 +9,14 @@ import logging from typing import Any, Dict, List from datetime import datetime -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel from models.rate_limit_models import QuotaType from services.tier_service import TierService, get_tier_service from utils.database_async import async_database from utils.quota_tracker import QuotaTracker, get_quota_tracker +from utils.auth_util import auth_required logger = logging.getLogger(__name__) @@ -80,15 +81,20 @@ async def get_tier_service_dep() -> TierService: return get_tier_service(db) -def get_current_user_id() -> str: +async def get_current_user_id(payload: dict = Depends(auth_required)) -> str: """ - Get current user ID from request context - - In production, this should extract from JWT token or session. - For now, returns a placeholder. + Get current user ID from JWT token + + Args: + payload: JWT payload from auth_required dependency + + Returns: + str: Username from JWT 'sub' claim """ - # TODO: Extract from auth middleware - return 'current_user' + username = payload.get('sub') + if not username: + raise HTTPException(status_code=401, detail='Invalid token: missing user ID') + return username # ============================================================================ diff --git a/backend-services/routes/rate_limit_rule_routes.py b/backend-services/routes/rate_limit_rule_routes.py index 5fa4bf4..d82b78e 100644 --- a/backend-services/routes/rate_limit_rule_routes.py +++ b/backend-services/routes/rate_limit_rule_routes.py @@ -13,6 +13,7 @@ from pydantic import BaseModel, Field from models.rate_limit_models import RateLimitRule, RuleType, TimeWindow from services.rate_limit_rule_service import RateLimitRuleService, get_rate_limit_rule_service from utils.database_async import async_database +from utils.auth_util import auth_required logger = logging.getLogger(__name__) @@ -399,8 +400,27 @@ async def get_rule_statistics(rule_service: RateLimitRuleService = Depends(get_r # ============================================================================ +async def get_current_user_id(payload: dict = Depends(auth_required)) -> str: + """ + Get current user ID from JWT token + + Args: + payload: JWT payload from auth_required dependency + + Returns: + str: Username from JWT 'sub' claim + """ + username = payload.get('sub') + if not username: + raise HTTPException(status_code=401, detail='Invalid token: missing user ID') + return username + + @rate_limit_rule_router.get('/status', response_model=Dict[str, Any]) -async def get_rate_limit_status(rule_service: RateLimitRuleService = Depends(get_rule_service_dep)): +async def get_rate_limit_status( + user_id: str = Depends(get_current_user_id), + rule_service: RateLimitRuleService = Depends(get_rule_service_dep), +): """ Get current rate limit status for the authenticated user @@ -408,8 +428,6 @@ async def get_rate_limit_status(rule_service: RateLimitRuleService = Depends(get This is a user-facing endpoint showing their current limits. """ try: - # TODO: Get user_id from auth middleware - user_id = 'current_user' # Get applicable rules for user rules = await rule_service.get_applicable_rules(user_id=user_id)