mirror of
https://github.com/apidoorman/doorman.git
synced 2026-02-08 18:18:46 -06:00
325 lines
11 KiB
Python
325 lines
11 KiB
Python
"""
|
|
Durable token revocation utilities.
|
|
|
|
**IMPORTANT: Process-local fallback - NOT safe for multi-worker deployments**
|
|
|
|
**Backend Priority:**
|
|
1. Redis (sync client) - REQUIRED for multi-worker/multi-node deployments
|
|
2. Memory-only MongoDB (revocations_collection) - Single-process only
|
|
3. In-memory fallback (jwt_blacklist, revoked_all_users) - Single-process only
|
|
|
|
**Behavior:**
|
|
- If Redis is configured (MEM_OR_EXTERNAL=REDIS) and connection succeeds:
|
|
Revocations are persisted in Redis (sync client) and survive restarts.
|
|
Shared across all workers/nodes in distributed deployments.
|
|
|
|
- If database.memory_only is True and revocations_collection exists:
|
|
Revocations stored in memory-only MongoDB for single-process persistence.
|
|
Included in memory dumps but NOT shared across workers.
|
|
|
|
- Otherwise:
|
|
Falls back to in-memory Python structures (jwt_blacklist, revoked_all_users).
|
|
Process-local only - NOT shared across workers.
|
|
|
|
**Multi-Worker Safety:**
|
|
Production deployments with THREADS>1 MUST configure Redis (MEM_OR_EXTERNAL=REDIS).
|
|
The in-memory and memory-only DB fallbacks are NOT safe for multi-worker setups
|
|
and will allow revoked tokens to remain valid on other workers.
|
|
|
|
**Note on Redis Client:**
|
|
This module uses a synchronous Redis client (_redis_client) because token
|
|
revocation checks occur in synchronous code paths. For async rate limiting,
|
|
see limit_throttle_util.py which uses the async Redis client (app.state.redis).
|
|
|
|
**Public API (backward-compatible):**
|
|
- `TimedHeap` (in-memory helper)
|
|
- `jwt_blacklist` (in-memory map for fallback)
|
|
- `revoke_all_for_user`, `unrevoke_all_for_user`, `is_user_revoked`
|
|
- `purge_expired_tokens` (no-op when using Redis)
|
|
- `add_revoked_jti(username, jti, ttl_seconds)`
|
|
- `is_jti_revoked(username, jti)`
|
|
|
|
**See Also:**
|
|
- doorman.py validate_token_revocation_config() for multi-worker validation
|
|
- doorman.py app_lifespan() for production Redis requirement enforcement
|
|
"""
|
|
|
|
import heapq
|
|
import os
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
|
|
try:
|
|
from utils.database import database, revocations_collection
|
|
except Exception:
|
|
database = None
|
|
revocations_collection = None
|
|
|
|
try:
|
|
import redis
|
|
except Exception:
|
|
redis = None
|
|
|
|
jwt_blacklist = {}
|
|
revoked_all_users = set()
|
|
|
|
_redis_client = None
|
|
_redis_enabled = False
|
|
|
|
|
|
def _init_redis_if_possible():
|
|
global _redis_client, _redis_enabled
|
|
if _redis_client is not None:
|
|
return
|
|
try:
|
|
flag = os.getenv('MEM_OR_EXTERNAL') or os.getenv('MEM_OR_REDIS', 'MEM')
|
|
if str(flag).upper() == 'MEM':
|
|
_redis_enabled = False
|
|
_redis_client = None
|
|
return
|
|
if redis is None:
|
|
_redis_enabled = False
|
|
_redis_client = None
|
|
return
|
|
host = os.getenv('REDIS_HOST', 'localhost')
|
|
port = int(os.getenv('REDIS_PORT', 6379))
|
|
db = int(os.getenv('REDIS_DB', 0))
|
|
pool = redis.ConnectionPool(
|
|
host=host, port=port, db=db, decode_responses=True, max_connections=100
|
|
)
|
|
_redis_client = redis.StrictRedis(connection_pool=pool)
|
|
try:
|
|
_redis_client.ping()
|
|
_redis_enabled = True
|
|
except Exception:
|
|
_redis_client = None
|
|
_redis_enabled = False
|
|
except Exception:
|
|
_redis_client = None
|
|
_redis_enabled = False
|
|
|
|
|
|
def _revoked_jti_key(username: str, jti: str) -> str:
|
|
return f'jwt:revoked:{username}:{jti}'
|
|
|
|
|
|
def _revoke_all_key(username: str) -> str:
|
|
return f'jwt:revoke_all:{username}'
|
|
|
|
|
|
def revoke_all_for_user(username: str):
|
|
"""Mark all tokens for a user as revoked (durable if Redis is enabled)."""
|
|
_init_redis_if_possible()
|
|
try:
|
|
if (
|
|
database is not None
|
|
and getattr(database, 'memory_only', False)
|
|
and revocations_collection is not None
|
|
):
|
|
try:
|
|
existing = revocations_collection.find_one(
|
|
{'type': 'revoke_all', 'username': username}
|
|
)
|
|
if existing:
|
|
revocations_collection.update_one(
|
|
{'_id': existing.get('_id')}, {'$set': {'revoke_all': True}}
|
|
)
|
|
else:
|
|
revocations_collection.insert_one(
|
|
{'type': 'revoke_all', 'username': username, 'revoke_all': True}
|
|
)
|
|
except Exception:
|
|
revoked_all_users.add(username)
|
|
return
|
|
if _redis_enabled and _redis_client is not None:
|
|
_redis_client.set(_revoke_all_key(username), '1')
|
|
else:
|
|
revoked_all_users.add(username)
|
|
except Exception:
|
|
revoked_all_users.add(username)
|
|
|
|
|
|
def unrevoke_all_for_user(username: str):
|
|
"""Clear 'revoke all' for a user (durable if Redis is enabled)."""
|
|
_init_redis_if_possible()
|
|
try:
|
|
if (
|
|
database is not None
|
|
and getattr(database, 'memory_only', False)
|
|
and revocations_collection is not None
|
|
):
|
|
try:
|
|
revocations_collection.delete_one({'type': 'revoke_all', 'username': username})
|
|
except Exception:
|
|
revoked_all_users.discard(username)
|
|
return
|
|
if _redis_enabled and _redis_client is not None:
|
|
_redis_client.delete(_revoke_all_key(username))
|
|
else:
|
|
revoked_all_users.discard(username)
|
|
except Exception:
|
|
revoked_all_users.discard(username)
|
|
|
|
|
|
def is_user_revoked(username: str) -> bool:
|
|
"""Return True if user is under 'revoke all' (durable check if Redis enabled)."""
|
|
_init_redis_if_possible()
|
|
try:
|
|
if (
|
|
database is not None
|
|
and getattr(database, 'memory_only', False)
|
|
and revocations_collection is not None
|
|
):
|
|
try:
|
|
doc = revocations_collection.find_one({'type': 'revoke_all', 'username': username})
|
|
return bool(doc and doc.get('revoke_all'))
|
|
except Exception:
|
|
pass
|
|
if _redis_enabled and _redis_client is not None:
|
|
return bool(_redis_client.exists(_revoke_all_key(username)))
|
|
return username in revoked_all_users
|
|
except Exception:
|
|
return username in revoked_all_users
|
|
|
|
|
|
class TimedHeap:
|
|
def __init__(self, purge_after=timedelta(hours=1)):
|
|
self.heap = []
|
|
self.purge_after = purge_after
|
|
|
|
def push(self, item):
|
|
expire_time = datetime.now() + self.purge_after
|
|
heapq.heappush(self.heap, (expire_time, item))
|
|
|
|
def pop(self):
|
|
self.purge()
|
|
if self.heap:
|
|
return heapq.heappop(self.heap)[1]
|
|
raise IndexError('pop from an empty priority queue')
|
|
|
|
def purge(self):
|
|
current_time = datetime.now()
|
|
while self.heap and self.heap[0][0] < current_time:
|
|
heapq.heappop(self.heap)
|
|
|
|
def peek(self):
|
|
self.purge()
|
|
if self.heap:
|
|
return self.heap[0][1]
|
|
return None
|
|
|
|
|
|
def add_revoked_jti(username: str, jti: str, ttl_seconds: int | None = None):
|
|
"""Add a specific JTI to the revocation list.
|
|
|
|
- If Redis is enabled, store key with TTL so it auto-expires.
|
|
- Otherwise push into in-memory TimedHeap (approximate via default purge window when ttl not provided).
|
|
"""
|
|
if not username or not jti:
|
|
return
|
|
_init_redis_if_possible()
|
|
try:
|
|
if (
|
|
database is not None
|
|
and getattr(database, 'memory_only', False)
|
|
and revocations_collection is not None
|
|
):
|
|
try:
|
|
exp = int(time.time()) + (
|
|
max(1, int(ttl_seconds)) if ttl_seconds is not None else 3600
|
|
)
|
|
existing = revocations_collection.find_one(
|
|
{'type': 'jti', 'username': username, 'jti': jti}
|
|
)
|
|
if existing:
|
|
revocations_collection.update_one(
|
|
{'_id': existing.get('_id')}, {'$set': {'expires_at': exp}}
|
|
)
|
|
else:
|
|
revocations_collection.insert_one(
|
|
{'type': 'jti', 'username': username, 'jti': jti, 'expires_at': exp}
|
|
)
|
|
return
|
|
except Exception:
|
|
pass
|
|
if _redis_enabled and _redis_client is not None:
|
|
ttl = max(1, int(ttl_seconds)) if ttl_seconds is not None else 3600
|
|
_redis_client.setex(_revoked_jti_key(username, jti), ttl, '1')
|
|
return
|
|
except Exception:
|
|
pass
|
|
th = jwt_blacklist.get(username)
|
|
if not th:
|
|
th = TimedHeap()
|
|
jwt_blacklist[username] = th
|
|
th.push(jti)
|
|
|
|
|
|
def is_jti_revoked(username: str, jti: str) -> bool:
|
|
"""Check whether a specific JTI is revoked (durable if Redis enabled)."""
|
|
if not username or not jti:
|
|
return False
|
|
_init_redis_if_possible()
|
|
try:
|
|
if (
|
|
database is not None
|
|
and getattr(database, 'memory_only', False)
|
|
and revocations_collection is not None
|
|
):
|
|
try:
|
|
doc = revocations_collection.find_one(
|
|
{'type': 'jti', 'username': username, 'jti': jti}
|
|
)
|
|
if not doc:
|
|
pass
|
|
else:
|
|
exp = int(doc.get('expires_at') or 0)
|
|
now = int(time.time())
|
|
if exp <= now:
|
|
revocations_collection.delete_one({'_id': doc.get('_id')})
|
|
return False
|
|
return True
|
|
except Exception:
|
|
pass
|
|
if _redis_enabled and _redis_client is not None:
|
|
return bool(_redis_client.exists(_revoked_jti_key(username, jti)))
|
|
except Exception:
|
|
pass
|
|
th = jwt_blacklist.get(username)
|
|
if not th:
|
|
return False
|
|
th.purge()
|
|
for _, token_jti in list(th.heap):
|
|
if token_jti == jti:
|
|
return True
|
|
return False
|
|
|
|
|
|
async def purge_expired_tokens():
|
|
"""No-op when Redis-backed; purge DB/in-memory when memory-only."""
|
|
_init_redis_if_possible()
|
|
if _redis_enabled:
|
|
return
|
|
try:
|
|
if (
|
|
database is not None
|
|
and getattr(database, 'memory_only', False)
|
|
and revocations_collection is not None
|
|
):
|
|
now = int(time.time())
|
|
to_delete = []
|
|
for d in revocations_collection.find({'type': 'jti'}):
|
|
try:
|
|
if int(d.get('expires_at') or 0) <= now:
|
|
to_delete.append(d)
|
|
except Exception:
|
|
to_delete.append(d)
|
|
for d in to_delete:
|
|
revocations_collection.delete_one({'_id': d.get('_id')})
|
|
except Exception:
|
|
pass
|
|
for key, timed_heap in list(jwt_blacklist.items()):
|
|
timed_heap.purge()
|
|
if not timed_heap.heap:
|
|
del jwt_blacklist[key]
|