mirror of
https://github.com/apidoorman/doorman.git
synced 2026-02-09 11:07:05 -06:00
bug fix
This commit is contained in:
@@ -242,27 +242,14 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
return None
|
||||
|
||||
async def _default_get_rules(
|
||||
self,
|
||||
request: Request,
|
||||
user_id: str | None,
|
||||
api_name: str | None,
|
||||
endpoint_uri: str,
|
||||
ip_address: str,
|
||||
self, request: Request, user_id: str | None, api_name: str | None, endpoint_uri: str, ip_address: str
|
||||
) -> list[RateLimitRule]:
|
||||
"""
|
||||
Default function to get applicable rules
|
||||
|
||||
Priority order:
|
||||
1. If user has tier assigned → Use tier limits ONLY
|
||||
2. If user has NO tier → Use per-user rate limit rules
|
||||
3. Fall back to global rules
|
||||
|
||||
In production, this should query MongoDB for rules.
|
||||
This is a placeholder that returns default rules.
|
||||
Default implementation to get applicable rules
|
||||
|
||||
Args:
|
||||
request: Incoming request
|
||||
user_id: User ID
|
||||
request: Request object
|
||||
user_id: User identifier
|
||||
api_name: API name
|
||||
endpoint_uri: Endpoint URI
|
||||
ip_address: IP address
|
||||
@@ -314,21 +301,19 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
)
|
||||
|
||||
# Always add global rule as fallback
|
||||
rules.append(
|
||||
RateLimitRule(
|
||||
rule_id='default_global',
|
||||
rule_type=RuleType.GLOBAL,
|
||||
time_window=TimeWindow.MINUTE,
|
||||
limit=1000,
|
||||
priority=0,
|
||||
enabled=True,
|
||||
description='Global rate limit',
|
||||
# Add global rule as fallback if no rules were loaded
|
||||
if not rules:
|
||||
rules.append(
|
||||
RateLimitRule(
|
||||
rule_id='default_global',
|
||||
rule_type=RuleType.GLOBAL,
|
||||
time_window=TimeWindow.MINUTE,
|
||||
limit=1000,
|
||||
priority=0,
|
||||
enabled=True,
|
||||
description='Global rate limit',
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by priority (highest first)
|
||||
rules.sort(key=lambda r: r.priority, reverse=True)
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
@@ -9,13 +9,14 @@ import logging
|
||||
from typing import Any, Dict, List
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.rate_limit_models import QuotaType
|
||||
from services.tier_service import TierService, get_tier_service
|
||||
from utils.database_async import async_database
|
||||
from utils.quota_tracker import QuotaTracker, get_quota_tracker
|
||||
from utils.auth_util import auth_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -80,15 +81,20 @@ async def get_tier_service_dep() -> TierService:
|
||||
return get_tier_service(db)
|
||||
|
||||
|
||||
def get_current_user_id() -> str:
|
||||
async def get_current_user_id(payload: dict = Depends(auth_required)) -> str:
|
||||
"""
|
||||
Get current user ID from request context
|
||||
|
||||
In production, this should extract from JWT token or session.
|
||||
For now, returns a placeholder.
|
||||
Get current user ID from JWT token
|
||||
|
||||
Args:
|
||||
payload: JWT payload from auth_required dependency
|
||||
|
||||
Returns:
|
||||
str: Username from JWT 'sub' claim
|
||||
"""
|
||||
# TODO: Extract from auth middleware
|
||||
return 'current_user'
|
||||
username = payload.get('sub')
|
||||
if not username:
|
||||
raise HTTPException(status_code=401, detail='Invalid token: missing user ID')
|
||||
return username
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -13,6 +13,7 @@ from pydantic import BaseModel, Field
|
||||
from models.rate_limit_models import RateLimitRule, RuleType, TimeWindow
|
||||
from services.rate_limit_rule_service import RateLimitRuleService, get_rate_limit_rule_service
|
||||
from utils.database_async import async_database
|
||||
from utils.auth_util import auth_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -399,8 +400,27 @@ async def get_rule_statistics(rule_service: RateLimitRuleService = Depends(get_r
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def get_current_user_id(payload: dict = Depends(auth_required)) -> str:
|
||||
"""
|
||||
Get current user ID from JWT token
|
||||
|
||||
Args:
|
||||
payload: JWT payload from auth_required dependency
|
||||
|
||||
Returns:
|
||||
str: Username from JWT 'sub' claim
|
||||
"""
|
||||
username = payload.get('sub')
|
||||
if not username:
|
||||
raise HTTPException(status_code=401, detail='Invalid token: missing user ID')
|
||||
return username
|
||||
|
||||
|
||||
@rate_limit_rule_router.get('/status', response_model=Dict[str, Any])
|
||||
async def get_rate_limit_status(rule_service: RateLimitRuleService = Depends(get_rule_service_dep)):
|
||||
async def get_rate_limit_status(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
rule_service: RateLimitRuleService = Depends(get_rule_service_dep),
|
||||
):
|
||||
"""
|
||||
Get current rate limit status for the authenticated user
|
||||
|
||||
@@ -408,8 +428,6 @@ async def get_rate_limit_status(rule_service: RateLimitRuleService = Depends(get
|
||||
This is a user-facing endpoint showing their current limits.
|
||||
"""
|
||||
try:
|
||||
# TODO: Get user_id from auth middleware
|
||||
user_id = 'current_user'
|
||||
|
||||
# Get applicable rules for user
|
||||
rules = await rule_service.get_applicable_rules(user_id=user_id)
|
||||
|
||||
Reference in New Issue
Block a user