mirror of
https://github.com/apidoorman/doorman.git
synced 2026-05-12 11:58:25 -05:00
Added subscription support. Expanded test cases with new framework
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -12,6 +12,8 @@ class ApiModel(BaseModel):
|
||||
api_name: str = Field(..., min_length=1, max_length=25)
|
||||
api_version: str = Field(..., min_length=1, max_length=2)
|
||||
api_description: str = Field(None, min_length=1, max_length=127)
|
||||
api_allowed_roles: List[str] = Field(default_factory=list)
|
||||
api_allowed_groups: List[str] = Field(default_factory=list)
|
||||
api_servers: List[str] = Field(default_factory=list)
|
||||
api_type: str = None
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
The contents of this file are property of pygate.org
|
||||
Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/pygate for more information
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class SubscribeModel(BaseModel):
|
||||
|
||||
username: str = Field(..., min_length=3, max_length=50)
|
||||
api_name: str = Field(..., min_length=3, max_length=50)
|
||||
api_version: str = Field(..., min_length=1, max_length=5)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -4,6 +4,7 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/pygate for more information
|
||||
"""
|
||||
|
||||
from datetime import timedelta
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_jwt_auth import AuthJWT
|
||||
@@ -65,6 +66,10 @@ class Settings(BaseSettings):
|
||||
authjwt_cookie_domain: str = domain
|
||||
authjwt_cookie_path: str = "/"
|
||||
authjwt_cookie_samesite: str = 'lax'
|
||||
authjwt_cookie_csrf_protect: bool = False
|
||||
|
||||
authjwt_access_token_expires: timedelta = timedelta(minutes=int(os.getenv("ACCESS_TOKEN_EXPIRES_MINUTES", 15)))
|
||||
authjwt_refresh_token_expires: timedelta = timedelta(days=int(os.getenv("REFRESH_TOKEN_EXPIRES_DAYS", 30)))
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
[pytest]
|
||||
markers =
|
||||
order: mark test to run in a specific order
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+14
-13
@@ -4,11 +4,12 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/pygate for more information
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from services.api_service import ApiService
|
||||
from utils.auth_util import auth_required
|
||||
from utils.subscription_util import subscription_required
|
||||
from utils.whitelist_util import whitelist_check
|
||||
from utils.role_util import role_required
|
||||
from models.api_model import ApiModel
|
||||
@@ -30,12 +31,12 @@ Response:
|
||||
"message": "API created successfully"
|
||||
}
|
||||
"""
|
||||
@api_router.post("")
|
||||
@api_router.post("",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def create_api(api_data: ApiModel):
|
||||
try:
|
||||
auth_required()
|
||||
whitelist_check()
|
||||
role_required(("admin", "dev", "platform"))
|
||||
await ApiService.create_api(api_data)
|
||||
return JSONResponse(content={'message': 'API created successfully'}, status_code=201)
|
||||
except ValueError as e:
|
||||
@@ -54,12 +55,12 @@ Response:
|
||||
"api_path": "<string>"
|
||||
}
|
||||
"""
|
||||
@api_router.get("/{api_name}/{api_version}")
|
||||
@api_router.get("/{api_name}/{api_version}",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def get_api_by_name_version(api_name: str, api_version: str):
|
||||
try:
|
||||
auth_required()
|
||||
whitelist_check()
|
||||
role_required(("admin", "dev", "platform"))
|
||||
api = await ApiService.get_api_by_name_version(api_name, api_version)
|
||||
if api.get('_id'): del api['_id']
|
||||
return JSONResponse(content=api, status_code=200)
|
||||
@@ -79,12 +80,12 @@ Response:
|
||||
"api_path": "<string>"
|
||||
}
|
||||
"""
|
||||
@api_router.get("/all")
|
||||
@api_router.get("/all",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def get_all_apis(page: int = 1, page_size: int = 10):
|
||||
try:
|
||||
auth_required()
|
||||
whitelist_check()
|
||||
role_required(("admin", "dev", "platform"))
|
||||
apis = await ApiService.get_apis(page, page_size)
|
||||
return JSONResponse(content=apis, status_code=200)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -61,10 +61,12 @@ Response:
|
||||
{
|
||||
}
|
||||
"""
|
||||
@authorization_router.get("/authorization/status")
|
||||
async def status(Authorize: AuthJWT = Depends()):
|
||||
@authorization_router.get("/authorization/status",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def status():
|
||||
try:
|
||||
auth_required()
|
||||
return JSONResponse(content={"status": "authorized"}, status_code=200)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
@@ -81,10 +83,12 @@ Response:
|
||||
"message": "Your token has been invalidated"
|
||||
}
|
||||
"""
|
||||
@authorization_router.post("/authorization/invalidate")
|
||||
@authorization_router.post("/authorization/invalidate",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def logout(response: Response, Authorize: AuthJWT = Depends()):
|
||||
try:
|
||||
auth_required()
|
||||
jwt_id = Authorize.get_raw_jwt()['jti']
|
||||
user = Authorize.get_jwt_subject()
|
||||
Authorize.unset_jwt_cookies(response)
|
||||
|
||||
@@ -30,12 +30,12 @@ Response:
|
||||
"message": "Endpoint created successfully"
|
||||
}
|
||||
"""
|
||||
@endpoint_router.post("")
|
||||
@endpoint_router.post("",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def create_endpoint(endpoint_data: EndpointModel, Authorize: AuthJWT = Depends()):
|
||||
try:
|
||||
auth_required()
|
||||
whitelist_check()
|
||||
role_required(("admin", "dev", "platform"))
|
||||
await EndpointService.create_endpoint(endpoint_data)
|
||||
return JSONResponse(content={'message': 'Endpoint created successfully'}, status_code=201)
|
||||
except ValueError as e:
|
||||
@@ -59,12 +59,12 @@ Response:
|
||||
]
|
||||
}
|
||||
"""
|
||||
@endpoint_router.get("/api/{api_name}/{api_version}")
|
||||
@endpoint_router.get("/api/{api_name}/{api_version}",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def get_endpoints_by_name_version(api_name: str, api_version: str, Authorize: AuthJWT = Depends()):
|
||||
try:
|
||||
auth_required()
|
||||
whitelist_check()
|
||||
role_required(("admin", "dev", "platform"))
|
||||
endpoints = await EndpointService.get_endpoints_by_name_version(api_name, api_version)
|
||||
return JSONResponse(content={"endpoints": endpoints}, status_code=200)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -16,13 +16,17 @@ from models.request_model import RequestModel
|
||||
|
||||
gateway_router = APIRouter()
|
||||
|
||||
@gateway_router.api_route("/rest/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
|
||||
async def rest_gateway(path: str, request: Request, Authorize: AuthJWT = Depends()):
|
||||
@gateway_router.api_route(
|
||||
"/rest/{path:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
||||
dependencies=[
|
||||
Depends(auth_required),
|
||||
Depends(subscription_required)
|
||||
]
|
||||
)
|
||||
async def rest_gateway(path: str, request: Request,
|
||||
Authorize: AuthJWT = Depends()):
|
||||
try:
|
||||
auth_required()
|
||||
whitelist_check()
|
||||
subscription_required()
|
||||
|
||||
request_model = RequestModel(
|
||||
method=request.method,
|
||||
path=path,
|
||||
@@ -34,4 +38,4 @@ async def rest_gateway(path: str, request: Request, Authorize: AuthJWT = Depends
|
||||
|
||||
return await GatewayService.rest_gateway(request_model)
|
||||
except ValueError as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=400)
|
||||
return JSONResponse(content={"error": str(e)}, status_code=400)
|
||||
|
||||
+13
-13
@@ -4,7 +4,7 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/pygate for more information
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from services.group_service import GroupService
|
||||
@@ -28,12 +28,12 @@ Response:
|
||||
"message": "Group created successfully"
|
||||
}
|
||||
"""
|
||||
@group_router.post("")
|
||||
@group_router.post("",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def create_group(api_data: GroupModel):
|
||||
try:
|
||||
auth_required()
|
||||
whitelist_check()
|
||||
role_required(("admin", "dev", "platform"))
|
||||
await GroupService.create_group(api_data)
|
||||
return JSONResponse(content={'message': 'Group created successfully'}, status_code=201)
|
||||
except ValueError as e:
|
||||
@@ -56,12 +56,12 @@ Response:
|
||||
]
|
||||
}
|
||||
"""
|
||||
@group_router.get("/all")
|
||||
@group_router.get("/all",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def get_groups(page: int = 1, page_size: int = 10):
|
||||
try:
|
||||
auth_required()
|
||||
whitelist_check()
|
||||
role_required(("admin", "dev", "platform"))
|
||||
groups = await GroupService.get_groups(page, page_size)
|
||||
return JSONResponse(content=groups, status_code=200)
|
||||
except ValueError as e:
|
||||
@@ -83,12 +83,12 @@ Response:
|
||||
]
|
||||
}
|
||||
"""
|
||||
@group_router.get("/{group_name}")
|
||||
@group_router.get("/{group_name}",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def get_group(group_name: str):
|
||||
try:
|
||||
auth_required()
|
||||
whitelist_check()
|
||||
role_required(("admin", "dev", "platform"))
|
||||
group = await GroupService.get_group(group_name)
|
||||
return JSONResponse(content=group, status_code=200)
|
||||
except ValueError as e:
|
||||
|
||||
+12
-12
@@ -32,12 +32,12 @@ Response:
|
||||
"message": "Role created successfully"
|
||||
}
|
||||
"""
|
||||
@role_router.post("")
|
||||
@role_router.post("",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def create_role(api_data: RoleModel):
|
||||
try:
|
||||
auth_required()
|
||||
whitelist_check()
|
||||
role_required(("admin", "dev", "platform"))
|
||||
await RoleService.create_role(api_data)
|
||||
return JSONResponse(content={'message': 'Role created successfully'}, status_code=201)
|
||||
except ValueError as e:
|
||||
@@ -64,12 +64,12 @@ Response:
|
||||
]
|
||||
}
|
||||
"""
|
||||
@role_router.get("/all")
|
||||
@role_router.get("/all",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def get_roles(page: int = 1, page_size: int = 10):
|
||||
try:
|
||||
auth_required()
|
||||
whitelist_check()
|
||||
role_required(("admin", "dev", "platform"))
|
||||
roles = await RoleService.get_roles(page, page_size)
|
||||
return JSONResponse(content=roles, status_code=200)
|
||||
except ValueError as e:
|
||||
@@ -94,12 +94,12 @@ Response:
|
||||
}
|
||||
}
|
||||
"""
|
||||
@role_router.get("/{role_name}")
|
||||
@role_router.get("/{role_name}",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def get_role(role_name: str):
|
||||
try:
|
||||
auth_required()
|
||||
whitelist_check()
|
||||
role_required(("admin", "dev", "platform"))
|
||||
role = await RoleService.get_role(role_name)
|
||||
return JSONResponse(content=role, status_code=200)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -12,6 +12,7 @@ from services.subscription_service import SubscriptionService
|
||||
from utils.auth_util import auth_required
|
||||
from utils.whitelist_util import whitelist_check
|
||||
from utils.role_util import role_required
|
||||
from models.subscribe_model import SubscribeModel
|
||||
|
||||
subscription_router = APIRouter()
|
||||
|
||||
@@ -28,14 +29,13 @@ Response:
|
||||
"message": "Successfully subscribed to the API"
|
||||
}
|
||||
"""
|
||||
@subscription_router.post("/subscribe")
|
||||
@auth_required()
|
||||
@whitelist_check()
|
||||
@role_required(("admin", "dev", "platform"))
|
||||
async def subscribe_api(request: Request):
|
||||
@subscription_router.post("/subscribe",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def subscribe_api(api_data: SubscribeModel):
|
||||
try:
|
||||
data = await request.json()
|
||||
SubscriptionService.subscribe(data)
|
||||
await SubscriptionService.subscribe(api_data)
|
||||
return JSONResponse(content={'message': 'Successfully subscribed to the API'}, status_code=200)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -54,14 +54,13 @@ Response:
|
||||
"message": "Successfully unsubscribed from the API"
|
||||
}
|
||||
"""
|
||||
@subscription_router.post("/unsubscribe")
|
||||
@auth_required()
|
||||
@whitelist_check()
|
||||
@role_required(("admin", "dev", "platform"))
|
||||
async def unsubscribe_api(request: Request):
|
||||
@subscription_router.post("/unsubscribe",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def unsubscribe_api(api_data: SubscribeModel):
|
||||
try:
|
||||
data = await request.json()
|
||||
SubscriptionService.unsubscribe(data)
|
||||
await SubscriptionService.unsubscribe(api_data)
|
||||
return JSONResponse(content={'message': 'Successfully unsubscribed from the API'}, status_code=200)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -77,10 +76,10 @@ Response:
|
||||
"subscriptions": []
|
||||
}
|
||||
"""
|
||||
@subscription_router.get("/subscriptions/{user_id}")
|
||||
@auth_required()
|
||||
@whitelist_check()
|
||||
@role_required(("admin", "dev", "platform"))
|
||||
@subscription_router.get("/subscriptions/{user_id}",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def subscriptions_for_user_by_id(user_id: str):
|
||||
try:
|
||||
subscriptions = SubscriptionService.get_user_subscriptions(user_id)
|
||||
@@ -98,8 +97,10 @@ Response:
|
||||
"subscriptions": []
|
||||
}
|
||||
"""
|
||||
@subscription_router.get("/subscriptions")
|
||||
@auth_required()
|
||||
@subscription_router.get("/subscriptions",
|
||||
dependencies=[
|
||||
Depends(auth_required)
|
||||
])
|
||||
async def subscriptions_for_current_user(Authorize: AuthJWT = Depends()):
|
||||
try:
|
||||
username = Authorize.get_jwt_subject()
|
||||
|
||||
@@ -42,7 +42,7 @@ Response:
|
||||
@user_router.post("")
|
||||
async def create_user(user_data: UserModel, Authorize: AuthJWT = Depends()):
|
||||
try:
|
||||
auth_required()
|
||||
|
||||
role_required(["admin", "dev", "platform"])
|
||||
new_user = await UserService.create_user(user_data)
|
||||
return JSONResponse(content={"message": "User created successfully"}, status_code=201)
|
||||
@@ -67,8 +67,7 @@ Response:
|
||||
}
|
||||
"""
|
||||
@user_router.put("/{user_id}")
|
||||
@auth_required()
|
||||
@whitelist_check()
|
||||
|
||||
@role_required(["admin", "dev", "platform"])
|
||||
async def update_user(user_id: str, request: Request):
|
||||
try:
|
||||
@@ -91,8 +90,7 @@ Response:
|
||||
}
|
||||
"""
|
||||
@user_router.put("/{user_id}/update-password")
|
||||
@auth_required()
|
||||
@whitelist_check()
|
||||
|
||||
@role_required(["admin", "dev", "platform"])
|
||||
async def update_user_password(user_id: str, request: Request):
|
||||
try:
|
||||
@@ -123,8 +121,7 @@ Response:
|
||||
}
|
||||
"""
|
||||
@user_router.get("/{username}")
|
||||
@auth_required()
|
||||
@whitelist_check()
|
||||
|
||||
@role_required(["admin", "dev", "platform"])
|
||||
async def get_user_by_username(username: str):
|
||||
try:
|
||||
@@ -148,8 +145,7 @@ Response:
|
||||
}
|
||||
"""
|
||||
@user_router.get("/email/{email}")
|
||||
@auth_required()
|
||||
@whitelist_check()
|
||||
|
||||
@role_required(["admin", "dev", "platform"])
|
||||
async def get_user_by_email(email: str):
|
||||
try:
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -12,6 +12,7 @@ import logging
|
||||
|
||||
from utils.database import db
|
||||
from services.cache import pygate_cache
|
||||
import uuid
|
||||
|
||||
class GatewayService:
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
|
||||
@@ -25,9 +26,12 @@ class GatewayService:
|
||||
"""
|
||||
External gateway.
|
||||
"""
|
||||
request_id = uuid.uuid4()
|
||||
start_time = time.time() * 1000
|
||||
gateway_end_time = None
|
||||
backend_start_time = None
|
||||
response = None
|
||||
GatewayService.logger.info(f"REST | {request_id} | Resource: {request.path}")
|
||||
try:
|
||||
match = re.match(r"([^/]+/v\d+)", request.path)
|
||||
api_name_version = '/' + match.group(1) if match else ""
|
||||
@@ -66,12 +70,14 @@ class GatewayService:
|
||||
return JSONResponse("Endpoint does not exists in backend service", status_code=404)
|
||||
return JSONResponse(content=response_content, status_code=response.status_code)
|
||||
except Exception as e:
|
||||
GatewayService.logger.error(f"Error in rest_gateway: {str(e)}")
|
||||
GatewayService.logger.error(f"REST | {request_id} | Error in rest_gateway: {str(e)}")
|
||||
return {"error": str(e)}
|
||||
finally:
|
||||
end_time = time.time() * 1000
|
||||
response.headers['X-Request-Id'] = request_id
|
||||
if gateway_end_time:
|
||||
GatewayService.logger.info(f"Gateway Time: {gateway_end_time - start_time}ms")
|
||||
GatewayService.logger.info(f"REST | {request_id} | Gateway Time: {gateway_end_time - start_time}ms")
|
||||
if backend_start_time:
|
||||
GatewayService.logger.info(f"Backend Time: {end_time - backend_start_time}ms")
|
||||
GatewayService.logger.info(f"Total Time: {end_time - start_time}ms")
|
||||
GatewayService.logger.info(f"REST | {request_id} | Backend Time: {end_time - backend_start_time}ms")
|
||||
GatewayService.logger.info(f"REST | {request_id} | Total Time: {end_time - start_time}ms")
|
||||
GatewayService.logger.info(f"REST | {request_id} | Status Code: {response.status_code}")
|
||||
@@ -9,8 +9,14 @@ from utils.cache import cache_manager
|
||||
from services.cache import pygate_cache
|
||||
from services.api_service import ApiService
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
|
||||
logger = logging.getLogger("pygate.gateway")
|
||||
|
||||
class SubscriptionService:
|
||||
subscriptions_collection = db.subscriptions
|
||||
api_collection = db.apis
|
||||
|
||||
@staticmethod
|
||||
@cache_manager.cached(ttl=300)
|
||||
@@ -18,7 +24,10 @@ class SubscriptionService:
|
||||
"""
|
||||
Check if an API exists.
|
||||
"""
|
||||
return pygate_cache.get_cache('api_cache', f"{api_name}/{api_version}") or ApiService.api_collection.find_one({'api_name': api_name, 'api_version': api_version})
|
||||
api = pygate_cache.get_cache('api_cache', f"{api_name}/{api_version}") or ApiService.api_collection.find_one({'api_name': api_name, 'api_version': api_version})
|
||||
if api and '_id' in api:
|
||||
del api['_id']
|
||||
return api
|
||||
|
||||
@staticmethod
|
||||
@cache_manager.cached(ttl=300)
|
||||
@@ -38,44 +47,43 @@ class SubscriptionService:
|
||||
"""
|
||||
Subscribe to an API.
|
||||
"""
|
||||
username = data.get('username')
|
||||
api_name = data.get('api_name')
|
||||
api_version = data.get('api_version')
|
||||
|
||||
if not await SubscriptionService.api_exists(api_name, api_version):
|
||||
if not SubscriptionService.api_exists(data.api_name, data.api_version):
|
||||
raise ValueError("API does not exist")
|
||||
|
||||
user_subscriptions = pygate_cache.get_cache('user_subscription_cache', username) or SubscriptionService.subscriptions_collection.find_one({'username': username})
|
||||
user_subscriptions = pygate_cache.get_cache('user_subscription_cache', data.username) or SubscriptionService.subscriptions_collection.find_one({'username': data.username})
|
||||
if user_subscriptions is None:
|
||||
user_subscriptions = {
|
||||
'username': username,
|
||||
'apis': []
|
||||
'username': data.username,
|
||||
'apis': [f"{data.api_name}/{data.api_version}"]
|
||||
}
|
||||
SubscriptionService.subscriptions_collection.insert_one(user_subscriptions)
|
||||
elif 'apis' in user_subscriptions and f"{api_name}/{api_version}" in user_subscriptions['apis']:
|
||||
elif 'apis' in user_subscriptions and f"{data.api_name}/{data.api_version}" in user_subscriptions['apis']:
|
||||
raise ValueError("User is already subscribed to the API")
|
||||
|
||||
updated_subscriptions = SubscriptionService.subscriptions_collection.update_one(
|
||||
{'username': username},
|
||||
{'$push': {'apis': f"{api_name}/{api_version}"}}
|
||||
)
|
||||
pygate_cache.set_cache('user_subscription_cache', username, updated_subscriptions)
|
||||
|
||||
else:
|
||||
SubscriptionService.subscriptions_collection.update_one(
|
||||
{'username': data.username},
|
||||
{'$push': {'apis': f"{data.api_name}/{data.api_version}"}}
|
||||
)
|
||||
user_subscriptions = SubscriptionService.subscriptions_collection.find_one({'username': data.username})
|
||||
if user_subscriptions and '_id' in user_subscriptions:
|
||||
del user_subscriptions['_id']
|
||||
pygate_cache.set_cache('user_subscription_cache', data.username, user_subscriptions)
|
||||
|
||||
@staticmethod
|
||||
async def unsubscribe(data):
|
||||
"""
|
||||
Unsubscribe from an API.
|
||||
"""
|
||||
username = data.get('username')
|
||||
api_name = data.get('api_name')
|
||||
api_version = data.get('api_version')
|
||||
if not SubscriptionService.api_exists(api_name, api_version):
|
||||
if not await SubscriptionService.api_exists(data.api_name, data.api_version):
|
||||
raise ValueError("API does not exist")
|
||||
user_subscriptions = pygate_cache.get_cache('user_subscription_cache', username) or SubscriptionService.subscriptions_collection.find_one({'username': username})
|
||||
if not user_subscriptions.contains(
|
||||
f"""{api_name}/{api_version}"""):
|
||||
raise ValueError("User is already not subscribed to the API")
|
||||
updated_subscriptions = SubscriptionService.subscriptions_collection.update_one(
|
||||
{'username': username},
|
||||
{'$pull': {'apis': f"""{api_name}/{api_version}"""}})
|
||||
pygate_cache.get_cache('user_subscription_cache', username, updated_subscriptions)
|
||||
user_subscriptions = pygate_cache.get_cache('user_subscription_cache', data.username) or SubscriptionService.subscriptions_collection.find_one({'username': data.username})
|
||||
if not user_subscriptions or f"{data.api_name}/{data.api_version}" not in user_subscriptions.get('apis', []):
|
||||
raise ValueError("User is not subscribed to the API")
|
||||
user_subscriptions['apis'].remove(f"{data.api_name}/{data.api_version}")
|
||||
update_result = SubscriptionService.subscriptions_collection.update_one(
|
||||
{'username': data.username},
|
||||
{'$set': {'apis': user_subscriptions.get('apis', [])}}
|
||||
)
|
||||
user_subscriptions = SubscriptionService.subscriptions_collection.find_one({'username': data.username})
|
||||
if user_subscriptions and '_id' in user_subscriptions:
|
||||
del user_subscriptions['_id']
|
||||
pygate_cache.set_cache('user_subscription_cache', data.username, user_subscriptions)
|
||||
Binary file not shown.
+177
-107
@@ -1,19 +1,27 @@
|
||||
import json
|
||||
import random
|
||||
import unittest
|
||||
import time
|
||||
import requests
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
class TestPygate(unittest.TestCase):
|
||||
class TestPygate:
|
||||
base_url = "http://localhost:3002"
|
||||
token = None
|
||||
api_name = None
|
||||
endpoint_path = None
|
||||
group_name = None
|
||||
role_name = None
|
||||
username = None
|
||||
email = None
|
||||
password = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@staticmethod
|
||||
def getAccessCookies():
|
||||
return {"access_token_cookie": TestPygate.token}
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def setup_class(cls):
|
||||
for _ in range(5):
|
||||
try:
|
||||
response = requests.get(f"{cls.base_url}/platform/status")
|
||||
@@ -27,150 +35,212 @@ class TestPygate(unittest.TestCase):
|
||||
print("Failed to connect to the server after multiple attempts")
|
||||
raise RuntimeError("pygate is not running")
|
||||
|
||||
def test_01_auth_calls(self):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(1)
|
||||
async def test_auth_calls(self):
|
||||
response = requests.post(f"{self.base_url}/platform/authorization",
|
||||
json={"email": "admin@pygate.org", "password": "password1"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
json={"email": "admin@pygate.org", "password": "password1"})
|
||||
assert response.status_code == 200
|
||||
|
||||
TestPygate.token = response.json().get('access_token')
|
||||
self.assertIsNotNone(TestPygate.token)
|
||||
assert TestPygate.token is not None
|
||||
|
||||
response = requests.get(f"{self.base_url}/platform/authorization/status",
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_02_create_user(self):
|
||||
if not TestPygate.token:
|
||||
self.fail("Auth token is missing")
|
||||
cookies=TestPygate.getAccessCookies())
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(2)
|
||||
async def test_create_user(self):
|
||||
TestPygate.username = "newuser" + str(time.time())
|
||||
TestPygate.email = "newuser" + str(time.time()) + "@pygate.org"
|
||||
TestPygate.password = "newpass"
|
||||
response = requests.post(f"{self.base_url}/platform/user",
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"},
|
||||
cookies=TestPygate.getAccessCookies(),
|
||||
json={
|
||||
"username": "newuser" + str(time.time()),
|
||||
"email": "newuser" + str(time.time()) + "@pygate.org",
|
||||
"password": "newpass",
|
||||
"role": "user"
|
||||
"username": TestPygate.username,
|
||||
"email": TestPygate.email,
|
||||
"password": TestPygate.password,
|
||||
"role": "admin",
|
||||
"groups": ["ALL"]
|
||||
})
|
||||
self.assertEqual(response.status_code, 201)
|
||||
assert response.status_code == 201
|
||||
|
||||
def test_03_onboard_api(self):
|
||||
"""Step 3: Onboard an API"""
|
||||
if not TestPygate.token:
|
||||
self.fail("Auth token is missing")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(3)
|
||||
async def test_create_group(self):
|
||||
TestPygate.group_name = "testgroup" + str(time.time())
|
||||
response = requests.post(f"{self.base_url}/platform/group",
|
||||
cookies=TestPygate.getAccessCookies(),
|
||||
json={
|
||||
"group_name": TestPygate.group_name,
|
||||
"group_description": "Test group"
|
||||
})
|
||||
assert response.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(4)
|
||||
async def test_get_groups(self):
|
||||
response = requests.get(f"{self.base_url}/platform/group/all",
|
||||
cookies=TestPygate.getAccessCookies())
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(5)
|
||||
async def test_get_group(self):
|
||||
response = requests.get(f"{self.base_url}/platform/group/" + TestPygate.group_name,
|
||||
cookies=TestPygate.getAccessCookies())
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(6)
|
||||
async def test_create_role(self):
|
||||
TestPygate.role_name = "testrole" + str(time.time())
|
||||
response = requests.post(f"{self.base_url}/platform/role",
|
||||
cookies=TestPygate.getAccessCookies(),
|
||||
json={
|
||||
"role_name": TestPygate.role_name,
|
||||
"role_description": "Test role",
|
||||
"manage_users": False,
|
||||
"manage_apis": False,
|
||||
"manage_endpoints": False,
|
||||
"manage_groups": False,
|
||||
"manage_roles": False
|
||||
})
|
||||
assert response.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(7)
|
||||
async def test_get_roles(self):
|
||||
response = requests.get(f"{self.base_url}/platform/role/all",
|
||||
cookies=TestPygate.getAccessCookies())
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(8)
|
||||
async def test_get_role(self):
|
||||
response = requests.get(f"{self.base_url}/platform/role/" + TestPygate.role_name,
|
||||
cookies=TestPygate.getAccessCookies())
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(9)
|
||||
async def test_onboard_api(self):
|
||||
TestPygate.api_name = "test" + "".join(random.sample("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ", 8))
|
||||
|
||||
response = requests.post(f"{self.base_url}/platform/api",
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"},
|
||||
cookies=TestPygate.getAccessCookies(),
|
||||
json={
|
||||
"api_name": TestPygate.api_name,
|
||||
"api_version": "v1",
|
||||
"api_description": "Test API",
|
||||
"api_servers": ["https://fake-json-api.mock.beeceptor.com/"],
|
||||
"api_allowed_roles": [TestPygate.role_name],
|
||||
"api_allowed_groups": [TestPygate.group_name],
|
||||
"api_type": "REST"
|
||||
})
|
||||
self.assertEqual(response.status_code, 201)
|
||||
assert response.status_code == 201
|
||||
|
||||
def test_04_onboard_endpoint(self):
|
||||
if not TestPygate.token:
|
||||
self.fail("Auth token is missing")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(10)
|
||||
async def test_onboard_endpoint(self):
|
||||
|
||||
TestPygate.endpoint_path = "/users"
|
||||
|
||||
response = requests.post(f"{self.base_url}/platform/endpoint",
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"},
|
||||
cookies=TestPygate.getAccessCookies(),
|
||||
json={
|
||||
"api_name": TestPygate.api_name,
|
||||
"api_version": "v1",
|
||||
"endpoint_uri": TestPygate.endpoint_path,
|
||||
"endpoint_method": "GET"
|
||||
})
|
||||
self.assertEqual(response.status_code, 201)
|
||||
assert response.status_code == 201
|
||||
|
||||
def test_05_gateway_call(self):
|
||||
response = requests.get(f"{self.base_url}/api/rest/" + TestPygate.api_name + "/v1" + TestPygate.endpoint_path.replace("{userId}", "2"))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(11)
|
||||
async def test_subscribe(self):
|
||||
response = requests.post(f"{self.base_url}/platform/subscription/subscribe",
|
||||
cookies=TestPygate.getAccessCookies(),
|
||||
json={
|
||||
"username": TestPygate.username,
|
||||
"api_name": TestPygate.api_name,
|
||||
"api_version": "v1"
|
||||
})
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_06_get_api(self):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(12)
|
||||
async def test_re_auth_calls(self):
|
||||
response = requests.post(f"{self.base_url}/platform/authorization",
|
||||
json={"email": TestPygate.email, "password": TestPygate.password})
|
||||
assert response.status_code == 200
|
||||
|
||||
TestPygate.token = response.json().get('access_token')
|
||||
assert TestPygate.token is not None
|
||||
|
||||
response = requests.get(f"{self.base_url}/platform/authorization/status",
|
||||
cookies=TestPygate.getAccessCookies())
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(13)
|
||||
async def test_gateway_call(self):
|
||||
response = requests.get(f"{self.base_url}/api/rest/" + TestPygate.api_name + "/v1" + TestPygate.endpoint_path.replace("{userId}", "2"),
|
||||
cookies=TestPygate.getAccessCookies())
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(14)
|
||||
async def test_unsubscribe(self):
|
||||
response = requests.post(f"{self.base_url}/platform/subscription/unsubscribe",
|
||||
cookies=TestPygate.getAccessCookies(),
|
||||
json={
|
||||
"username": TestPygate.username,
|
||||
"api_name": TestPygate.api_name,
|
||||
"api_version": "v1"
|
||||
})
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(15)
|
||||
async def test_re_gateway_call(self):
|
||||
response = requests.get(f"{self.base_url}/api/rest/" + TestPygate.api_name + "/v1" + TestPygate.endpoint_path.replace("{userId}", "2"),
|
||||
cookies=TestPygate.getAccessCookies())
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(16)
|
||||
async def test_get_api(self):
|
||||
response = requests.get(f"{self.base_url}/platform/api/" + TestPygate.api_name + "/v1",
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
cookies=TestPygate.getAccessCookies())
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_07_get_all_apis(self):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(17)
|
||||
async def test_get_all_apis(self):
|
||||
response = requests.get(f"{self.base_url}/platform/api/all",
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"},
|
||||
cookies=TestPygate.getAccessCookies(),
|
||||
params={"page": 1, "page_size": 10})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_08_api_endpoints(self):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(18)
|
||||
async def test_api_endpoints(self):
|
||||
response = requests.get(f"{self.base_url}/platform/endpoint/api/" + TestPygate.api_name + "/v1",
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
cookies=TestPygate.getAccessCookies())
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_09_create_group(self):
|
||||
TestPygate.group_name = "testgroup" + str(time.time())
|
||||
response = requests.post(f"{self.base_url}/platform/group",
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"},
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.order(19)
|
||||
async def test_re_subscribe(self):
|
||||
response = requests.post(f"{self.base_url}/platform/subscription/subscribe",
|
||||
cookies=TestPygate.getAccessCookies(),
|
||||
json={
|
||||
"group_name": TestPygate.group_name,
|
||||
"group_description": "Test group"
|
||||
"username": TestPygate.username,
|
||||
"api_name": TestPygate.api_name,
|
||||
"api_version": "v1"
|
||||
})
|
||||
|
||||
self.assertEqual(response.status_code, 201)
|
||||
|
||||
def test_10_get_groups(self):
|
||||
response = requests.get(f"{self.base_url}/platform/group/all",
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_11_get_group(self):
|
||||
response = requests.get(f"{self.base_url}/platform/group/" + TestPygate.group_name,
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_12_create_role(self):
|
||||
TestPygate.role_name = "testrole" + str(time.time())
|
||||
response = requests.post(f"{self.base_url}/platform/role",
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"},
|
||||
json={
|
||||
"role_name": TestPygate.role_name,
|
||||
"role_description": "Test role",
|
||||
"manage_users": False,
|
||||
"manage_apis": False,
|
||||
"manage_endpoints": False,
|
||||
"manage_groups": False,
|
||||
"manage_roles": False
|
||||
})
|
||||
self.assertEqual(response.status_code, 201)
|
||||
|
||||
def test_13_get_roles(self):
|
||||
response = requests.get(f"{self.base_url}/platform/role/all",
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_14_get_role(self):
|
||||
response = requests.get(f"{self.base_url}/platform/role/" + TestPygate.role_name,
|
||||
headers={"Authorization": f"Bearer {TestPygate.token}"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def suite():
|
||||
test_suite = unittest.TestSuite()
|
||||
test_suite.addTest(TestPygate("test_01_auth_calls"))
|
||||
test_suite.addTest(TestPygate("test_02_create_user"))
|
||||
test_suite.addTest(TestPygate("test_03_onboard_api"))
|
||||
test_suite.addTest(TestPygate("test_04_onboard_endpoint"))
|
||||
test_suite.addTest(TestPygate("test_05_gateway_call"))
|
||||
test_suite.addTest(TestPygate("test_06_get_api"))
|
||||
test_suite.addTest(TestPygate("test_07_get_all_apis"))
|
||||
test_suite.addTest(TestPygate("test_08_api_endpoints"))
|
||||
test_suite.addTest(TestPygate("test_09_create_group"))
|
||||
test_suite.addTest(TestPygate("test_10_get_groups"))
|
||||
test_suite.addTest(TestPygate("test_11_get_group"))
|
||||
test_suite.addTest(TestPygate("test_12_create_role"))
|
||||
test_suite.addTest(TestPygate("test_13_get_roles"))
|
||||
test_suite.addTest(TestPygate("test_14_get_role"))
|
||||
return test_suite
|
||||
assert response.status_code == 200
|
||||
|
||||
if __name__ == '__main__':
|
||||
runner = unittest.TextTestRunner()
|
||||
runner.run(suite())
|
||||
pytest.main()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,13 +1,13 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import heapq
|
||||
import asyncio
|
||||
|
||||
jwt_blacklist = {}
|
||||
|
||||
class TimedHeap:
|
||||
def __init__(self):
|
||||
def __init__(self, purge_after=timedelta(hours=1)):
|
||||
self.heap = []
|
||||
self.purge_after = purge_after
|
||||
|
||||
async def push(self, item):
|
||||
expire_time = datetime.now() + self.purge_after
|
||||
|
||||
+17
-24
@@ -4,34 +4,27 @@ Review the Apache License 2.0 for valid authorization of use
|
||||
See https://github.com/pypeople-dev/pygate for more information
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from fastapi import HTTPException, Depends
|
||||
from fastapi_jwt_auth import AuthJWT
|
||||
|
||||
from utils.auth_blacklist import jwt_blacklist
|
||||
|
||||
def auth_required():
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def decorated_function(*args, Authorize: AuthJWT = Depends(), **kwargs):
|
||||
try:
|
||||
Authorize.jwt_required()
|
||||
import logging
|
||||
|
||||
jwt_subject = Authorize.get_jwt_subject()
|
||||
jwt_id = Authorize.get_raw_jwt()['jti']
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
|
||||
logger = logging.getLogger("pygate.gateway")
|
||||
|
||||
if jwt_subject in jwt_blacklist and jwt_id in jwt_blacklist[jwt_subject]:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Token has been revoked"
|
||||
)
|
||||
return await func(*args, Authorize=Authorize, **kwargs)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Could not validate credentials"
|
||||
)
|
||||
return decorated_function
|
||||
return decorator
|
||||
async def auth_required(Authorize: AuthJWT = Depends()):
|
||||
try:
|
||||
Authorize.jwt_required()
|
||||
jwt_subject = Authorize.get_jwt_subject()
|
||||
jwt_id = Authorize.get_raw_jwt()['jti']
|
||||
if jwt_subject in jwt_blacklist and jwt_id in jwt_blacklist[jwt_subject]:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Token has been revoked"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Unauthorized access")
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
return Authorize
|
||||
+30
-14
@@ -7,20 +7,36 @@ See https://github.com/pypeople-dev/pygate for more information
|
||||
from functools import wraps
|
||||
from fastapi import HTTPException, Depends, Request
|
||||
from fastapi_jwt_auth import AuthJWT
|
||||
from fastapi_jwt_auth.exceptions import MissingTokenError
|
||||
from services.cache import pygate_cache
|
||||
|
||||
from services.subscription_service import SubscriptionService
|
||||
|
||||
def subscription_required():
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
async def decorated_function(*args, request: Request, Authorize: AuthJWT = Depends(), **kwargs):
|
||||
Authorize.jwt_required('cookies')
|
||||
username = Authorize.get_jwt_subject()
|
||||
subscriptions = await pygate_cache.get_cache('user_subscription_cache', username) or SubscriptionService.subscriptions_collection.find_one({'username': username})
|
||||
path = kwargs.get('path', '')
|
||||
if not subscriptions or not subscriptions.get('apis') or path not in subscriptions.get('apis'):
|
||||
raise HTTPException(status_code=403, detail="You are not subscribed to this resource")
|
||||
return await f(*args, **kwargs)
|
||||
return decorated_function
|
||||
return decorator
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
|
||||
logger = logging.getLogger("pygate.gateway")
|
||||
|
||||
def subscription_required(request: Request, Authorize: AuthJWT = Depends()):
|
||||
try:
|
||||
Authorize.jwt_required()
|
||||
username = Authorize.get_jwt_subject()
|
||||
full_path = request.url.path
|
||||
prefix = "/api/rest/"
|
||||
if full_path.startswith(prefix):
|
||||
path = full_path[len(prefix):]
|
||||
else:
|
||||
path = full_path
|
||||
api_and_version = '/'.join(path.split('/')[:2])
|
||||
user_subscriptions = pygate_cache.get_cache('user_subscription_cache', username) or SubscriptionService.subscriptions_collection.find_one({'username': username})
|
||||
subscriptions = user_subscriptions.get('apis') if user_subscriptions and 'apis' in user_subscriptions else None
|
||||
if not subscriptions or api_and_version not in subscriptions:
|
||||
logger.info(f"User {username} attempted access to {api_and_version}")
|
||||
raise HTTPException(status_code=403, detail="You are not subscribed to this resource")
|
||||
except MissingTokenError:
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
return Authorize
|
||||
Reference in New Issue
Block a user