Added subscription support. Expanded test cases with new framework

This commit is contained in:
seniorswe
2025-03-24 22:03:59 -04:00
parent 2ade5e517a
commit a808e4ef3f
36 changed files with 393 additions and 268 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
+2
View File
@@ -12,6 +12,8 @@ class ApiModel(BaseModel):
api_name: str = Field(..., min_length=1, max_length=25)
api_version: str = Field(..., min_length=1, max_length=2)
api_description: str = Field(None, min_length=1, max_length=127)
api_allowed_roles: List[str] = Field(default_factory=list)
api_allowed_groups: List[str] = Field(default_factory=list)
api_servers: List[str] = Field(default_factory=list)
api_type: str = None
+16
View File
@@ -0,0 +1,16 @@
"""
The contents of this file are property of pygate.org
Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/pygate for more information
"""
from pydantic import BaseModel, Field
class SubscribeModel(BaseModel):
username: str = Field(..., min_length=3, max_length=50)
api_name: str = Field(..., min_length=3, max_length=50)
api_version: str = Field(..., min_length=1, max_length=5)
class Config:
arbitrary_types_allowed = True
+5
View File
@@ -4,6 +4,7 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/pygate for more information
"""
from datetime import timedelta
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi_jwt_auth import AuthJWT
@@ -65,6 +66,10 @@ class Settings(BaseSettings):
authjwt_cookie_domain: str = domain
authjwt_cookie_path: str = "/"
authjwt_cookie_samesite: str = 'lax'
authjwt_cookie_csrf_protect: bool = False
authjwt_access_token_expires: timedelta = timedelta(minutes=int(os.getenv("ACCESS_TOKEN_EXPIRES_MINUTES", 15)))
authjwt_refresh_token_expires: timedelta = timedelta(days=int(os.getenv("REFRESH_TOKEN_EXPIRES_DAYS", 30)))
class Config:
env_file = ".env"
+3
View File
@@ -0,0 +1,3 @@
[pytest]
markers =
order: mark test to run in a specific order
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+14 -13
View File
@@ -4,11 +4,12 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/pygate for more information
"""
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from services.api_service import ApiService
from utils.auth_util import auth_required
from utils.subscription_util import subscription_required
from utils.whitelist_util import whitelist_check
from utils.role_util import role_required
from models.api_model import ApiModel
@@ -30,12 +31,12 @@ Response:
"message": "API created successfully"
}
"""
@api_router.post("")
@api_router.post("",
dependencies=[
Depends(auth_required)
])
async def create_api(api_data: ApiModel):
try:
auth_required()
whitelist_check()
role_required(("admin", "dev", "platform"))
await ApiService.create_api(api_data)
return JSONResponse(content={'message': 'API created successfully'}, status_code=201)
except ValueError as e:
@@ -54,12 +55,12 @@ Response:
"api_path": "<string>"
}
"""
@api_router.get("/{api_name}/{api_version}")
@api_router.get("/{api_name}/{api_version}",
dependencies=[
Depends(auth_required)
])
async def get_api_by_name_version(api_name: str, api_version: str):
try:
auth_required()
whitelist_check()
role_required(("admin", "dev", "platform"))
api = await ApiService.get_api_by_name_version(api_name, api_version)
if api.get('_id'): del api['_id']
return JSONResponse(content=api, status_code=200)
@@ -79,12 +80,12 @@ Response:
"api_path": "<string>"
}
"""
@api_router.get("/all")
@api_router.get("/all",
dependencies=[
Depends(auth_required)
])
async def get_all_apis(page: int = 1, page_size: int = 10):
try:
auth_required()
whitelist_check()
role_required(("admin", "dev", "platform"))
apis = await ApiService.get_apis(page, page_size)
return JSONResponse(content=apis, status_code=200)
except ValueError as e:
+9 -5
View File
@@ -61,10 +61,12 @@ Response:
{
}
"""
@authorization_router.get("/authorization/status")
async def status(Authorize: AuthJWT = Depends()):
@authorization_router.get("/authorization/status",
dependencies=[
Depends(auth_required)
])
async def status():
try:
auth_required()
return JSONResponse(content={"status": "authorized"}, status_code=200)
except Exception as e:
raise HTTPException(status_code=401, detail="Invalid token")
@@ -81,10 +83,12 @@ Response:
"message": "Your token has been invalidated"
}
"""
@authorization_router.post("/authorization/invalidate")
@authorization_router.post("/authorization/invalidate",
dependencies=[
Depends(auth_required)
])
async def logout(response: Response, Authorize: AuthJWT = Depends()):
try:
auth_required()
jwt_id = Authorize.get_raw_jwt()['jti']
user = Authorize.get_jwt_subject()
Authorize.unset_jwt_cookies(response)
+8 -8
View File
@@ -30,12 +30,12 @@ Response:
"message": "Endpoint created successfully"
}
"""
@endpoint_router.post("")
@endpoint_router.post("",
dependencies=[
Depends(auth_required)
])
async def create_endpoint(endpoint_data: EndpointModel, Authorize: AuthJWT = Depends()):
try:
auth_required()
whitelist_check()
role_required(("admin", "dev", "platform"))
await EndpointService.create_endpoint(endpoint_data)
return JSONResponse(content={'message': 'Endpoint created successfully'}, status_code=201)
except ValueError as e:
@@ -59,12 +59,12 @@ Response:
]
}
"""
@endpoint_router.get("/api/{api_name}/{api_version}")
@endpoint_router.get("/api/{api_name}/{api_version}",
dependencies=[
Depends(auth_required)
])
async def get_endpoints_by_name_version(api_name: str, api_version: str, Authorize: AuthJWT = Depends()):
try:
auth_required()
whitelist_check()
role_required(("admin", "dev", "platform"))
endpoints = await EndpointService.get_endpoints_by_name_version(api_name, api_version)
return JSONResponse(content={"endpoints": endpoints}, status_code=200)
except ValueError as e:
+11 -7
View File
@@ -16,13 +16,17 @@ from models.request_model import RequestModel
gateway_router = APIRouter()
@gateway_router.api_route("/rest/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
async def rest_gateway(path: str, request: Request, Authorize: AuthJWT = Depends()):
@gateway_router.api_route(
"/rest/{path:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
dependencies=[
Depends(auth_required),
Depends(subscription_required)
]
)
async def rest_gateway(path: str, request: Request,
Authorize: AuthJWT = Depends()):
try:
auth_required()
whitelist_check()
subscription_required()
request_model = RequestModel(
method=request.method,
path=path,
@@ -34,4 +38,4 @@ async def rest_gateway(path: str, request: Request, Authorize: AuthJWT = Depends
return await GatewayService.rest_gateway(request_model)
except ValueError as e:
return JSONResponse(content={"error": str(e)}, status_code=400)
return JSONResponse(content={"error": str(e)}, status_code=400)
+13 -13
View File
@@ -4,7 +4,7 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/pygate for more information
"""
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from services.group_service import GroupService
@@ -28,12 +28,12 @@ Response:
"message": "Group created successfully"
}
"""
@group_router.post("")
@group_router.post("",
dependencies=[
Depends(auth_required)
])
async def create_group(api_data: GroupModel):
try:
auth_required()
whitelist_check()
role_required(("admin", "dev", "platform"))
await GroupService.create_group(api_data)
return JSONResponse(content={'message': 'Group created successfully'}, status_code=201)
except ValueError as e:
@@ -56,12 +56,12 @@ Response:
]
}
"""
@group_router.get("/all")
@group_router.get("/all",
dependencies=[
Depends(auth_required)
])
async def get_groups(page: int = 1, page_size: int = 10):
try:
auth_required()
whitelist_check()
role_required(("admin", "dev", "platform"))
groups = await GroupService.get_groups(page, page_size)
return JSONResponse(content=groups, status_code=200)
except ValueError as e:
@@ -83,12 +83,12 @@ Response:
]
}
"""
@group_router.get("/{group_name}")
@group_router.get("/{group_name}",
dependencies=[
Depends(auth_required)
])
async def get_group(group_name: str):
try:
auth_required()
whitelist_check()
role_required(("admin", "dev", "platform"))
group = await GroupService.get_group(group_name)
return JSONResponse(content=group, status_code=200)
except ValueError as e:
+12 -12
View File
@@ -32,12 +32,12 @@ Response:
"message": "Role created successfully"
}
"""
@role_router.post("")
@role_router.post("",
dependencies=[
Depends(auth_required)
])
async def create_role(api_data: RoleModel):
try:
auth_required()
whitelist_check()
role_required(("admin", "dev", "platform"))
await RoleService.create_role(api_data)
return JSONResponse(content={'message': 'Role created successfully'}, status_code=201)
except ValueError as e:
@@ -64,12 +64,12 @@ Response:
]
}
"""
@role_router.get("/all")
@role_router.get("/all",
dependencies=[
Depends(auth_required)
])
async def get_roles(page: int = 1, page_size: int = 10):
try:
auth_required()
whitelist_check()
role_required(("admin", "dev", "platform"))
roles = await RoleService.get_roles(page, page_size)
return JSONResponse(content=roles, status_code=200)
except ValueError as e:
@@ -94,12 +94,12 @@ Response:
}
}
"""
@role_router.get("/{role_name}")
@role_router.get("/{role_name}",
dependencies=[
Depends(auth_required)
])
async def get_role(role_name: str):
try:
auth_required()
whitelist_check()
role_required(("admin", "dev", "platform"))
role = await RoleService.get_role(role_name)
return JSONResponse(content=role, status_code=200)
except ValueError as e:
+21 -20
View File
@@ -12,6 +12,7 @@ from services.subscription_service import SubscriptionService
from utils.auth_util import auth_required
from utils.whitelist_util import whitelist_check
from utils.role_util import role_required
from models.subscribe_model import SubscribeModel
subscription_router = APIRouter()
@@ -28,14 +29,13 @@ Response:
"message": "Successfully subscribed to the API"
}
"""
@subscription_router.post("/subscribe")
@auth_required()
@whitelist_check()
@role_required(("admin", "dev", "platform"))
async def subscribe_api(request: Request):
@subscription_router.post("/subscribe",
dependencies=[
Depends(auth_required)
])
async def subscribe_api(api_data: SubscribeModel):
try:
data = await request.json()
SubscriptionService.subscribe(data)
await SubscriptionService.subscribe(api_data)
return JSONResponse(content={'message': 'Successfully subscribed to the API'}, status_code=200)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -54,14 +54,13 @@ Response:
"message": "Successfully unsubscribed from the API"
}
"""
@subscription_router.post("/unsubscribe")
@auth_required()
@whitelist_check()
@role_required(("admin", "dev", "platform"))
async def unsubscribe_api(request: Request):
@subscription_router.post("/unsubscribe",
dependencies=[
Depends(auth_required)
])
async def unsubscribe_api(api_data: SubscribeModel):
try:
data = await request.json()
SubscriptionService.unsubscribe(data)
await SubscriptionService.unsubscribe(api_data)
return JSONResponse(content={'message': 'Successfully unsubscribed from the API'}, status_code=200)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -77,10 +76,10 @@ Response:
"subscriptions": []
}
"""
@subscription_router.get("/subscriptions/{user_id}")
@auth_required()
@whitelist_check()
@role_required(("admin", "dev", "platform"))
@subscription_router.get("/subscriptions/{user_id}",
dependencies=[
Depends(auth_required)
])
async def subscriptions_for_user_by_id(user_id: str):
try:
subscriptions = SubscriptionService.get_user_subscriptions(user_id)
@@ -98,8 +97,10 @@ Response:
"subscriptions": []
}
"""
@subscription_router.get("/subscriptions")
@auth_required()
@subscription_router.get("/subscriptions",
dependencies=[
Depends(auth_required)
])
async def subscriptions_for_current_user(Authorize: AuthJWT = Depends()):
try:
username = Authorize.get_jwt_subject()
+5 -9
View File
@@ -42,7 +42,7 @@ Response:
@user_router.post("")
async def create_user(user_data: UserModel, Authorize: AuthJWT = Depends()):
try:
auth_required()
role_required(["admin", "dev", "platform"])
new_user = await UserService.create_user(user_data)
return JSONResponse(content={"message": "User created successfully"}, status_code=201)
@@ -67,8 +67,7 @@ Response:
}
"""
@user_router.put("/{user_id}")
@auth_required()
@whitelist_check()
@role_required(["admin", "dev", "platform"])
async def update_user(user_id: str, request: Request):
try:
@@ -91,8 +90,7 @@ Response:
}
"""
@user_router.put("/{user_id}/update-password")
@auth_required()
@whitelist_check()
@role_required(["admin", "dev", "platform"])
async def update_user_password(user_id: str, request: Request):
try:
@@ -123,8 +121,7 @@ Response:
}
"""
@user_router.get("/{username}")
@auth_required()
@whitelist_check()
@role_required(["admin", "dev", "platform"])
async def get_user_by_username(username: str):
try:
@@ -148,8 +145,7 @@ Response:
}
"""
@user_router.get("/email/{email}")
@auth_required()
@whitelist_check()
@role_required(["admin", "dev", "platform"])
async def get_user_by_email(email: str):
try:
Binary file not shown.
+10 -4
View File
@@ -12,6 +12,7 @@ import logging
from utils.database import db
from services.cache import pygate_cache
import uuid
class GatewayService:
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
@@ -25,9 +26,12 @@ class GatewayService:
"""
External gateway.
"""
request_id = uuid.uuid4()
start_time = time.time() * 1000
gateway_end_time = None
backend_start_time = None
response = None
GatewayService.logger.info(f"REST | {request_id} | Resource: {request.path}")
try:
match = re.match(r"([^/]+/v\d+)", request.path)
api_name_version = '/' + match.group(1) if match else ""
@@ -66,12 +70,14 @@ class GatewayService:
return JSONResponse("Endpoint does not exists in backend service", status_code=404)
return JSONResponse(content=response_content, status_code=response.status_code)
except Exception as e:
GatewayService.logger.error(f"Error in rest_gateway: {str(e)}")
GatewayService.logger.error(f"REST | {request_id} | Error in rest_gateway: {str(e)}")
return {"error": str(e)}
finally:
end_time = time.time() * 1000
response.headers['X-Request-Id'] = request_id
if gateway_end_time:
GatewayService.logger.info(f"Gateway Time: {gateway_end_time - start_time}ms")
GatewayService.logger.info(f"REST | {request_id} | Gateway Time: {gateway_end_time - start_time}ms")
if backend_start_time:
GatewayService.logger.info(f"Backend Time: {end_time - backend_start_time}ms")
GatewayService.logger.info(f"Total Time: {end_time - start_time}ms")
GatewayService.logger.info(f"REST | {request_id} | Backend Time: {end_time - backend_start_time}ms")
GatewayService.logger.info(f"REST | {request_id} | Total Time: {end_time - start_time}ms")
GatewayService.logger.info(f"REST | {request_id} | Status Code: {response.status_code}")
+38 -30
View File
@@ -9,8 +9,14 @@ from utils.cache import cache_manager
from services.cache import pygate_cache
from services.api_service import ApiService
import logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
logger = logging.getLogger("pygate.gateway")
class SubscriptionService:
subscriptions_collection = db.subscriptions
api_collection = db.apis
@staticmethod
@cache_manager.cached(ttl=300)
@@ -18,7 +24,10 @@ class SubscriptionService:
"""
Check if an API exists.
"""
return pygate_cache.get_cache('api_cache', f"{api_name}/{api_version}") or ApiService.api_collection.find_one({'api_name': api_name, 'api_version': api_version})
api = pygate_cache.get_cache('api_cache', f"{api_name}/{api_version}") or ApiService.api_collection.find_one({'api_name': api_name, 'api_version': api_version})
if api and '_id' in api:
del api['_id']
return api
@staticmethod
@cache_manager.cached(ttl=300)
@@ -38,44 +47,43 @@ class SubscriptionService:
"""
Subscribe to an API.
"""
username = data.get('username')
api_name = data.get('api_name')
api_version = data.get('api_version')
if not await SubscriptionService.api_exists(api_name, api_version):
if not SubscriptionService.api_exists(data.api_name, data.api_version):
raise ValueError("API does not exist")
user_subscriptions = pygate_cache.get_cache('user_subscription_cache', username) or SubscriptionService.subscriptions_collection.find_one({'username': username})
user_subscriptions = pygate_cache.get_cache('user_subscription_cache', data.username) or SubscriptionService.subscriptions_collection.find_one({'username': data.username})
if user_subscriptions is None:
user_subscriptions = {
'username': username,
'apis': []
'username': data.username,
'apis': [f"{data.api_name}/{data.api_version}"]
}
SubscriptionService.subscriptions_collection.insert_one(user_subscriptions)
elif 'apis' in user_subscriptions and f"{api_name}/{api_version}" in user_subscriptions['apis']:
elif 'apis' in user_subscriptions and f"{data.api_name}/{data.api_version}" in user_subscriptions['apis']:
raise ValueError("User is already subscribed to the API")
updated_subscriptions = SubscriptionService.subscriptions_collection.update_one(
{'username': username},
{'$push': {'apis': f"{api_name}/{api_version}"}}
)
pygate_cache.set_cache('user_subscription_cache', username, updated_subscriptions)
else:
SubscriptionService.subscriptions_collection.update_one(
{'username': data.username},
{'$push': {'apis': f"{data.api_name}/{data.api_version}"}}
)
user_subscriptions = SubscriptionService.subscriptions_collection.find_one({'username': data.username})
if user_subscriptions and '_id' in user_subscriptions:
del user_subscriptions['_id']
pygate_cache.set_cache('user_subscription_cache', data.username, user_subscriptions)
@staticmethod
async def unsubscribe(data):
"""
Unsubscribe from an API.
"""
username = data.get('username')
api_name = data.get('api_name')
api_version = data.get('api_version')
if not SubscriptionService.api_exists(api_name, api_version):
if not await SubscriptionService.api_exists(data.api_name, data.api_version):
raise ValueError("API does not exist")
user_subscriptions = pygate_cache.get_cache('user_subscription_cache', username) or SubscriptionService.subscriptions_collection.find_one({'username': username})
if not user_subscriptions.contains(
f"""{api_name}/{api_version}"""):
raise ValueError("User is already not subscribed to the API")
updated_subscriptions = SubscriptionService.subscriptions_collection.update_one(
{'username': username},
{'$pull': {'apis': f"""{api_name}/{api_version}"""}})
pygate_cache.get_cache('user_subscription_cache', username, updated_subscriptions)
user_subscriptions = pygate_cache.get_cache('user_subscription_cache', data.username) or SubscriptionService.subscriptions_collection.find_one({'username': data.username})
if not user_subscriptions or f"{data.api_name}/{data.api_version}" not in user_subscriptions.get('apis', []):
raise ValueError("User is not subscribed to the API")
user_subscriptions['apis'].remove(f"{data.api_name}/{data.api_version}")
update_result = SubscriptionService.subscriptions_collection.update_one(
{'username': data.username},
{'$set': {'apis': user_subscriptions.get('apis', [])}}
)
user_subscriptions = SubscriptionService.subscriptions_collection.find_one({'username': data.username})
if user_subscriptions and '_id' in user_subscriptions:
del user_subscriptions['_id']
pygate_cache.set_cache('user_subscription_cache', data.username, user_subscriptions)
+177 -107
View File
@@ -1,19 +1,27 @@
import json
import random
import unittest
import time
import requests
import pytest
import asyncio
class TestPygate(unittest.TestCase):
class TestPygate:
base_url = "http://localhost:3002"
token = None
api_name = None
endpoint_path = None
group_name = None
role_name = None
username = None
email = None
password = None
@classmethod
def setUpClass(cls):
@staticmethod
def getAccessCookies():
return {"access_token_cookie": TestPygate.token}
@pytest.fixture(scope="class", autouse=True)
def setup_class(cls):
for _ in range(5):
try:
response = requests.get(f"{cls.base_url}/platform/status")
@@ -27,150 +35,212 @@ class TestPygate(unittest.TestCase):
print("Failed to connect to the server after multiple attempts")
raise RuntimeError("pygate is not running")
def test_01_auth_calls(self):
@pytest.mark.asyncio
@pytest.mark.order(1)
async def test_auth_calls(self):
response = requests.post(f"{self.base_url}/platform/authorization",
json={"email": "admin@pygate.org", "password": "password1"})
self.assertEqual(response.status_code, 200)
json={"email": "admin@pygate.org", "password": "password1"})
assert response.status_code == 200
TestPygate.token = response.json().get('access_token')
self.assertIsNotNone(TestPygate.token)
assert TestPygate.token is not None
response = requests.get(f"{self.base_url}/platform/authorization/status",
headers={"Authorization": f"Bearer {TestPygate.token}"})
self.assertEqual(response.status_code, 200)
def test_02_create_user(self):
if not TestPygate.token:
self.fail("Auth token is missing")
cookies=TestPygate.getAccessCookies())
assert response.status_code == 200
@pytest.mark.asyncio
@pytest.mark.order(2)
async def test_create_user(self):
TestPygate.username = "newuser" + str(time.time())
TestPygate.email = "newuser" + str(time.time()) + "@pygate.org"
TestPygate.password = "newpass"
response = requests.post(f"{self.base_url}/platform/user",
headers={"Authorization": f"Bearer {TestPygate.token}"},
cookies=TestPygate.getAccessCookies(),
json={
"username": "newuser" + str(time.time()),
"email": "newuser" + str(time.time()) + "@pygate.org",
"password": "newpass",
"role": "user"
"username": TestPygate.username,
"email": TestPygate.email,
"password": TestPygate.password,
"role": "admin",
"groups": ["ALL"]
})
self.assertEqual(response.status_code, 201)
assert response.status_code == 201
def test_03_onboard_api(self):
"""Step 3: Onboard an API"""
if not TestPygate.token:
self.fail("Auth token is missing")
@pytest.mark.asyncio
@pytest.mark.order(3)
async def test_create_group(self):
TestPygate.group_name = "testgroup" + str(time.time())
response = requests.post(f"{self.base_url}/platform/group",
cookies=TestPygate.getAccessCookies(),
json={
"group_name": TestPygate.group_name,
"group_description": "Test group"
})
assert response.status_code == 201
@pytest.mark.asyncio
@pytest.mark.order(4)
async def test_get_groups(self):
response = requests.get(f"{self.base_url}/platform/group/all",
cookies=TestPygate.getAccessCookies())
assert response.status_code == 200
@pytest.mark.asyncio
@pytest.mark.order(5)
async def test_get_group(self):
response = requests.get(f"{self.base_url}/platform/group/" + TestPygate.group_name,
cookies=TestPygate.getAccessCookies())
assert response.status_code == 200
@pytest.mark.asyncio
@pytest.mark.order(6)
async def test_create_role(self):
TestPygate.role_name = "testrole" + str(time.time())
response = requests.post(f"{self.base_url}/platform/role",
cookies=TestPygate.getAccessCookies(),
json={
"role_name": TestPygate.role_name,
"role_description": "Test role",
"manage_users": False,
"manage_apis": False,
"manage_endpoints": False,
"manage_groups": False,
"manage_roles": False
})
assert response.status_code == 201
@pytest.mark.asyncio
@pytest.mark.order(7)
async def test_get_roles(self):
response = requests.get(f"{self.base_url}/platform/role/all",
cookies=TestPygate.getAccessCookies())
assert response.status_code == 200
@pytest.mark.asyncio
@pytest.mark.order(8)
async def test_get_role(self):
response = requests.get(f"{self.base_url}/platform/role/" + TestPygate.role_name,
cookies=TestPygate.getAccessCookies())
assert response.status_code == 200
@pytest.mark.asyncio
@pytest.mark.order(9)
async def test_onboard_api(self):
TestPygate.api_name = "test" + "".join(random.sample("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ", 8))
response = requests.post(f"{self.base_url}/platform/api",
headers={"Authorization": f"Bearer {TestPygate.token}"},
cookies=TestPygate.getAccessCookies(),
json={
"api_name": TestPygate.api_name,
"api_version": "v1",
"api_description": "Test API",
"api_servers": ["https://fake-json-api.mock.beeceptor.com/"],
"api_allowed_roles": [TestPygate.role_name],
"api_allowed_groups": [TestPygate.group_name],
"api_type": "REST"
})
self.assertEqual(response.status_code, 201)
assert response.status_code == 201
def test_04_onboard_endpoint(self):
if not TestPygate.token:
self.fail("Auth token is missing")
@pytest.mark.asyncio
@pytest.mark.order(10)
async def test_onboard_endpoint(self):
TestPygate.endpoint_path = "/users"
response = requests.post(f"{self.base_url}/platform/endpoint",
headers={"Authorization": f"Bearer {TestPygate.token}"},
cookies=TestPygate.getAccessCookies(),
json={
"api_name": TestPygate.api_name,
"api_version": "v1",
"endpoint_uri": TestPygate.endpoint_path,
"endpoint_method": "GET"
})
self.assertEqual(response.status_code, 201)
assert response.status_code == 201
def test_05_gateway_call(self):
response = requests.get(f"{self.base_url}/api/rest/" + TestPygate.api_name + "/v1" + TestPygate.endpoint_path.replace("{userId}", "2"))
self.assertEqual(response.status_code, 200)
@pytest.mark.asyncio
@pytest.mark.order(11)
async def test_subscribe(self):
response = requests.post(f"{self.base_url}/platform/subscription/subscribe",
cookies=TestPygate.getAccessCookies(),
json={
"username": TestPygate.username,
"api_name": TestPygate.api_name,
"api_version": "v1"
})
assert response.status_code == 200
def test_06_get_api(self):
@pytest.mark.asyncio
@pytest.mark.order(12)
async def test_re_auth_calls(self):
response = requests.post(f"{self.base_url}/platform/authorization",
json={"email": TestPygate.email, "password": TestPygate.password})
assert response.status_code == 200
TestPygate.token = response.json().get('access_token')
assert TestPygate.token is not None
response = requests.get(f"{self.base_url}/platform/authorization/status",
cookies=TestPygate.getAccessCookies())
assert response.status_code == 200
@pytest.mark.asyncio
@pytest.mark.order(13)
async def test_gateway_call(self):
response = requests.get(f"{self.base_url}/api/rest/" + TestPygate.api_name + "/v1" + TestPygate.endpoint_path.replace("{userId}", "2"),
cookies=TestPygate.getAccessCookies())
assert response.status_code == 200
@pytest.mark.asyncio
@pytest.mark.order(14)
async def test_unsubscribe(self):
response = requests.post(f"{self.base_url}/platform/subscription/unsubscribe",
cookies=TestPygate.getAccessCookies(),
json={
"username": TestPygate.username,
"api_name": TestPygate.api_name,
"api_version": "v1"
})
assert response.status_code == 200
@pytest.mark.asyncio
@pytest.mark.order(15)
async def test_re_gateway_call(self):
response = requests.get(f"{self.base_url}/api/rest/" + TestPygate.api_name + "/v1" + TestPygate.endpoint_path.replace("{userId}", "2"),
cookies=TestPygate.getAccessCookies())
assert response.status_code == 403
@pytest.mark.asyncio
@pytest.mark.order(16)
async def test_get_api(self):
response = requests.get(f"{self.base_url}/platform/api/" + TestPygate.api_name + "/v1",
headers={"Authorization": f"Bearer {TestPygate.token}"})
self.assertEqual(response.status_code, 200)
cookies=TestPygate.getAccessCookies())
assert response.status_code == 200
def test_07_get_all_apis(self):
@pytest.mark.asyncio
@pytest.mark.order(17)
async def test_get_all_apis(self):
response = requests.get(f"{self.base_url}/platform/api/all",
headers={"Authorization": f"Bearer {TestPygate.token}"},
cookies=TestPygate.getAccessCookies(),
params={"page": 1, "page_size": 10})
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
def test_08_api_endpoints(self):
@pytest.mark.asyncio
@pytest.mark.order(18)
async def test_api_endpoints(self):
response = requests.get(f"{self.base_url}/platform/endpoint/api/" + TestPygate.api_name + "/v1",
headers={"Authorization": f"Bearer {TestPygate.token}"})
self.assertEqual(response.status_code, 200)
cookies=TestPygate.getAccessCookies())
assert response.status_code == 200
def test_09_create_group(self):
TestPygate.group_name = "testgroup" + str(time.time())
response = requests.post(f"{self.base_url}/platform/group",
headers={"Authorization": f"Bearer {TestPygate.token}"},
@pytest.mark.asyncio
@pytest.mark.order(19)
async def test_re_subscribe(self):
response = requests.post(f"{self.base_url}/platform/subscription/subscribe",
cookies=TestPygate.getAccessCookies(),
json={
"group_name": TestPygate.group_name,
"group_description": "Test group"
"username": TestPygate.username,
"api_name": TestPygate.api_name,
"api_version": "v1"
})
self.assertEqual(response.status_code, 201)
def test_10_get_groups(self):
response = requests.get(f"{self.base_url}/platform/group/all",
headers={"Authorization": f"Bearer {TestPygate.token}"})
self.assertEqual(response.status_code, 200)
def test_11_get_group(self):
response = requests.get(f"{self.base_url}/platform/group/" + TestPygate.group_name,
headers={"Authorization": f"Bearer {TestPygate.token}"})
self.assertEqual(response.status_code, 200)
def test_12_create_role(self):
TestPygate.role_name = "testrole" + str(time.time())
response = requests.post(f"{self.base_url}/platform/role",
headers={"Authorization": f"Bearer {TestPygate.token}"},
json={
"role_name": TestPygate.role_name,
"role_description": "Test role",
"manage_users": False,
"manage_apis": False,
"manage_endpoints": False,
"manage_groups": False,
"manage_roles": False
})
self.assertEqual(response.status_code, 201)
def test_13_get_roles(self):
response = requests.get(f"{self.base_url}/platform/role/all",
headers={"Authorization": f"Bearer {TestPygate.token}"})
self.assertEqual(response.status_code, 200)
def test_14_get_role(self):
response = requests.get(f"{self.base_url}/platform/role/" + TestPygate.role_name,
headers={"Authorization": f"Bearer {TestPygate.token}"})
self.assertEqual(response.status_code, 200)
def suite():
test_suite = unittest.TestSuite()
test_suite.addTest(TestPygate("test_01_auth_calls"))
test_suite.addTest(TestPygate("test_02_create_user"))
test_suite.addTest(TestPygate("test_03_onboard_api"))
test_suite.addTest(TestPygate("test_04_onboard_endpoint"))
test_suite.addTest(TestPygate("test_05_gateway_call"))
test_suite.addTest(TestPygate("test_06_get_api"))
test_suite.addTest(TestPygate("test_07_get_all_apis"))
test_suite.addTest(TestPygate("test_08_api_endpoints"))
test_suite.addTest(TestPygate("test_09_create_group"))
test_suite.addTest(TestPygate("test_10_get_groups"))
test_suite.addTest(TestPygate("test_11_get_group"))
test_suite.addTest(TestPygate("test_12_create_role"))
test_suite.addTest(TestPygate("test_13_get_roles"))
test_suite.addTest(TestPygate("test_14_get_role"))
return test_suite
assert response.status_code == 200
if __name__ == '__main__':
runner = unittest.TextTestRunner()
runner.run(suite())
pytest.main()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+2 -2
View File
@@ -1,13 +1,13 @@
import asyncio
from datetime import datetime, timedelta
import heapq
import asyncio
jwt_blacklist = {}
class TimedHeap:
def __init__(self):
def __init__(self, purge_after=timedelta(hours=1)):
self.heap = []
self.purge_after = purge_after
async def push(self, item):
expire_time = datetime.now() + self.purge_after
+17 -24
View File
@@ -4,34 +4,27 @@ Review the Apache License 2.0 for valid authorization of use
See https://github.com/pypeople-dev/pygate for more information
"""
from functools import wraps
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from utils.auth_blacklist import jwt_blacklist
def auth_required():
def decorator(func):
@wraps(func)
async def decorated_function(*args, Authorize: AuthJWT = Depends(), **kwargs):
try:
Authorize.jwt_required()
import logging
jwt_subject = Authorize.get_jwt_subject()
jwt_id = Authorize.get_raw_jwt()['jti']
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
logger = logging.getLogger("pygate.gateway")
if jwt_subject in jwt_blacklist and jwt_id in jwt_blacklist[jwt_subject]:
raise HTTPException(
status_code=401,
detail="Token has been revoked"
)
return await func(*args, Authorize=Authorize, **kwargs)
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=401,
detail="Could not validate credentials"
)
return decorated_function
return decorator
async def auth_required(Authorize: AuthJWT = Depends()):
try:
Authorize.jwt_required()
jwt_subject = Authorize.get_jwt_subject()
jwt_id = Authorize.get_raw_jwt()['jti']
if jwt_subject in jwt_blacklist and jwt_id in jwt_blacklist[jwt_subject]:
raise HTTPException(
status_code=401,
detail="Token has been revoked"
)
except Exception as e:
logger.error("Unauthorized access")
raise HTTPException(status_code=401, detail="Unauthorized")
return Authorize
+30 -14
View File
@@ -7,20 +7,36 @@ See https://github.com/pypeople-dev/pygate for more information
from functools import wraps
from fastapi import HTTPException, Depends, Request
from fastapi_jwt_auth import AuthJWT
from fastapi_jwt_auth.exceptions import MissingTokenError
from services.cache import pygate_cache
from services.subscription_service import SubscriptionService
def subscription_required():
def decorator(f):
@wraps(f)
async def decorated_function(*args, request: Request, Authorize: AuthJWT = Depends(), **kwargs):
Authorize.jwt_required('cookies')
username = Authorize.get_jwt_subject()
subscriptions = await pygate_cache.get_cache('user_subscription_cache', username) or SubscriptionService.subscriptions_collection.find_one({'username': username})
path = kwargs.get('path', '')
if not subscriptions or not subscriptions.get('apis') or path not in subscriptions.get('apis'):
raise HTTPException(status_code=403, detail="You are not subscribed to this resource")
return await f(*args, **kwargs)
return decorated_function
return decorator
import logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
logger = logging.getLogger("pygate.gateway")
def subscription_required(request: Request, Authorize: AuthJWT = Depends()):
try:
Authorize.jwt_required()
username = Authorize.get_jwt_subject()
full_path = request.url.path
prefix = "/api/rest/"
if full_path.startswith(prefix):
path = full_path[len(prefix):]
else:
path = full_path
api_and_version = '/'.join(path.split('/')[:2])
user_subscriptions = pygate_cache.get_cache('user_subscription_cache', username) or SubscriptionService.subscriptions_collection.find_one({'username': username})
subscriptions = user_subscriptions.get('apis') if user_subscriptions and 'apis' in user_subscriptions else None
if not subscriptions or api_and_version not in subscriptions:
logger.info(f"User {username} attempted access to {api_and_version}")
raise HTTPException(status_code=403, detail="You are not subscribed to this resource")
except MissingTokenError:
raise HTTPException(status_code=401, detail="Missing token")
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
return Authorize