mirror of
https://github.com/apidoorman/doorman.git
synced 2026-02-13 04:58:47 -06:00
test rest methods and 405
This commit is contained in:
@@ -191,7 +191,8 @@ async def gateway(request: Request, path: str):
|
||||
endpoints = await api_util.get_api_endpoints(resolved_api.get('api_id'))
|
||||
import re as _re
|
||||
regex_pattern = _re.compile(r'\{[^/]+\}')
|
||||
composite = request.method + endpoint_uri
|
||||
method_to_match = 'GET' if str(request.method).upper() == 'HEAD' else request.method
|
||||
composite = method_to_match + endpoint_uri
|
||||
if not any(_re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in (endpoints or [])):
|
||||
return process_response(ResponseModel(
|
||||
status_code=404,
|
||||
@@ -263,6 +264,10 @@ async def rest_patch(request: Request, path: str):
|
||||
async def rest_delete(request: Request, path: str):
|
||||
return await gateway(request, path)
|
||||
|
||||
@gateway_router.head('/rest/{path:path}', description='REST gateway endpoint (HEAD)', response_model=ResponseModel, operation_id='rest_head')
|
||||
async def rest_head(request: Request, path: str):
|
||||
return await gateway(request, path)
|
||||
|
||||
"""
|
||||
Endpoint
|
||||
|
||||
@@ -297,9 +302,28 @@ async def rest_preflight(request: Request, path: str):
|
||||
name_ver = f'/{parts[0]}/{parts[1]}'
|
||||
api_key = _cache.get_cache('api_id_cache', name_ver)
|
||||
api = await _api_util.get_api(api_key, name_ver)
|
||||
endpoint_uri = '/' + '/'.join(parts[2:]) if len(parts) > 2 else '/'
|
||||
if not api:
|
||||
from fastapi.responses import Response as StarletteResponse
|
||||
return StarletteResponse(status_code=204, headers={'request_id': request_id})
|
||||
# If endpoint is not registered for any method, return 405
|
||||
try:
|
||||
endpoints = await _api_util.get_api_endpoints(api.get('api_id'))
|
||||
import re as _re
|
||||
regex_pattern = _re.compile(r'\{[^/]+\}')
|
||||
# Try matching against any method
|
||||
methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD']
|
||||
exists = False
|
||||
for ep in endpoints or []:
|
||||
pat = regex_pattern.sub(r'([^/]+)', ep)
|
||||
if any(_re.fullmatch(pat, m + endpoint_uri) for m in methods):
|
||||
exists = True
|
||||
break
|
||||
if not exists:
|
||||
from fastapi.responses import Response as StarletteResponse
|
||||
return StarletteResponse(status_code=405, headers={'request_id': request_id})
|
||||
except Exception:
|
||||
pass
|
||||
origin = request.headers.get('origin') or request.headers.get('Origin')
|
||||
req_method = request.headers.get('access-control-request-method') or request.headers.get('Access-Control-Request-Method')
|
||||
req_headers = request.headers.get('access-control-request-headers') or request.headers.get('Access-Control-Request-Headers')
|
||||
|
||||
@@ -165,7 +165,9 @@ class GatewayService:
|
||||
if not endpoints:
|
||||
return GatewayService.error_response(request_id, 'GTW002', 'No endpoints found for the requested API')
|
||||
regex_pattern = re.compile(r'\{[^/]+\}')
|
||||
composite = request.method + '/' + endpoint_uri
|
||||
# Treat HEAD like GET for endpoint registration matching
|
||||
match_method = 'GET' if str(request.method).upper() == 'HEAD' else request.method
|
||||
composite = match_method + '/' + endpoint_uri
|
||||
if not any(re.fullmatch(regex_pattern.sub(r'([^/]+)', ep), composite) for ep in endpoints):
|
||||
logger.error(f'{endpoints} | REST gateway failed with code GTW003')
|
||||
return GatewayService.error_response(request_id, 'GTW003', 'Endpoint does not exist for the requested API')
|
||||
@@ -223,7 +225,8 @@ class GatewayService:
|
||||
pass
|
||||
|
||||
try:
|
||||
endpoint_doc = await api_util.get_endpoint(api, method, '/' + endpoint_uri.lstrip('/')) if api else None
|
||||
lookup_method = 'GET' if str(method).upper() == 'HEAD' else method
|
||||
endpoint_doc = await api_util.get_endpoint(api, lookup_method, '/' + endpoint_uri.lstrip('/')) if api else None
|
||||
endpoint_id = endpoint_doc.get('endpoint_id') if endpoint_doc else None
|
||||
if endpoint_id:
|
||||
if 'JSON' in content_type:
|
||||
@@ -239,6 +242,8 @@ class GatewayService:
|
||||
try:
|
||||
if method == 'GET':
|
||||
http_response = await client.get(url, params=query_params, headers=headers)
|
||||
elif method == 'HEAD':
|
||||
http_response = await client.head(url, params=query_params, headers=headers)
|
||||
elif method in ('POST', 'PUT', 'DELETE', 'PATCH'):
|
||||
cl_header = request.headers.get('content-length') or request.headers.get('Content-Length')
|
||||
try:
|
||||
@@ -269,10 +274,13 @@ class GatewayService:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
if 'application/json' in http_response.headers.get('Content-Type', '').lower():
|
||||
response_content = http_response.json()
|
||||
if str(method).upper() == 'HEAD':
|
||||
response_content = ''
|
||||
else:
|
||||
response_content = http_response.text
|
||||
if 'application/json' in http_response.headers.get('Content-Type', '').lower():
|
||||
response_content = http_response.json()
|
||||
else:
|
||||
response_content = http_response.text
|
||||
backend_end_time = time.time() * 1000
|
||||
if http_response.status_code in [500, 502, 503, 504] and retry > 0:
|
||||
logger.error(f'{request_id} | REST gateway failed retrying')
|
||||
|
||||
115
backend-services/tests/test_rest_methods_and_405.py
Normal file
115
backend-services/tests/test_rest_methods_and_405.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakeHTTPResponse:
|
||||
def __init__(self, status_code=200, json_body=None, text_body=None, headers=None):
|
||||
self.status_code = status_code
|
||||
self._json_body = json_body
|
||||
self.text = text_body if text_body is not None else ('' if json_body is not None else 'OK')
|
||||
base_headers = {'Content-Type': 'application/json'}
|
||||
if headers:
|
||||
base_headers.update(headers)
|
||||
self.headers = base_headers
|
||||
|
||||
def json(self):
|
||||
import json as _json
|
||||
if self._json_body is None:
|
||||
return _json.loads(self.text or '{}')
|
||||
return self._json_body
|
||||
|
||||
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def patch(self, url, json=None, params=None, headers=None, content=None):
|
||||
body = json if json is not None else (content.decode('utf-8') if isinstance(content, (bytes, bytearray)) else content)
|
||||
return _FakeHTTPResponse(200, json_body={'method': 'PATCH', 'url': url, 'body': body, 'headers': headers or {}}, headers={'X-Upstream': 'yes'})
|
||||
|
||||
async def head(self, url, params=None, headers=None):
|
||||
# Simulate a successful HEAD when called
|
||||
return _FakeHTTPResponse(200, json_body=None, headers={'X-Upstream': 'yes'})
|
||||
|
||||
|
||||
async def _setup_api(client, name, ver, endpoint_method='GET', endpoint_uri='/p'):
|
||||
r = await client.post('/platform/api', json={
|
||||
'api_name': name,
|
||||
'api_version': ver,
|
||||
'api_description': f'{name} {ver}',
|
||||
'api_allowed_roles': ['admin'],
|
||||
'api_allowed_groups': ['ALL'],
|
||||
'api_servers': ['http://up.methods'],
|
||||
'api_type': 'REST',
|
||||
'api_allowed_retry_count': 0,
|
||||
})
|
||||
assert r.status_code in (200, 201)
|
||||
r2 = await client.post('/platform/endpoint', json={
|
||||
'api_name': name,
|
||||
'api_version': ver,
|
||||
'endpoint_method': endpoint_method,
|
||||
'endpoint_uri': endpoint_uri,
|
||||
'endpoint_description': endpoint_method.lower(),
|
||||
})
|
||||
assert r2.status_code in (200, 201)
|
||||
# Subscribe admin
|
||||
rme = await client.get('/platform/user/me')
|
||||
username = (rme.json().get('username') if rme.status_code == 200 else 'admin')
|
||||
sr = await client.post('/platform/subscription/subscribe', json={'username': username, 'api_name': name, 'api_version': ver})
|
||||
assert sr.status_code in (200, 201)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_head_supported_when_upstream_allows(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
name, ver = 'headok', 'v1'
|
||||
await _setup_api(authed_client, name, ver, endpoint_method='GET', endpoint_uri='/p')
|
||||
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
|
||||
r = await authed_client.request('HEAD', f'/api/rest/{name}/{ver}/p')
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_patch_supported_when_registered(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
name, ver = 'patchok', 'v1'
|
||||
await _setup_api(authed_client, name, ver, endpoint_method='PATCH', endpoint_uri='/edit')
|
||||
monkeypatch.setattr(gs.httpx, 'AsyncClient', _FakeAsyncClient)
|
||||
r = await authed_client.patch(f'/api/rest/{name}/{ver}/edit', json={'x': 1})
|
||||
assert r.status_code == 200
|
||||
j = r.json().get('response', r.json())
|
||||
assert j.get('method') == 'PATCH'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_options_unregistered_endpoint_returns_405(authed_client):
|
||||
# Create API without registering the specific endpoint
|
||||
name, ver = 'optunreg', 'v1'
|
||||
r = await authed_client.post('/platform/api', json={
|
||||
'api_name': name,
|
||||
'api_version': ver,
|
||||
'api_description': f'{name} {ver}',
|
||||
'api_allowed_roles': ['admin'],
|
||||
'api_allowed_groups': ['ALL'],
|
||||
'api_servers': ['http://up.methods'],
|
||||
'api_type': 'REST',
|
||||
'api_allowed_retry_count': 0,
|
||||
})
|
||||
assert r.status_code in (200, 201)
|
||||
# OPTIONS for unregistered endpoint currently yields 204 due to preflight handler
|
||||
resp = await authed_client.options(f'/api/rest/{name}/{ver}/not-made')
|
||||
assert resp.status_code == 405
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_unsupported_method_returns_405(authed_client):
|
||||
# Register a GET endpoint but call TRACE which is unsupported
|
||||
name, ver = 'unsup', 'v1'
|
||||
await _setup_api(authed_client, name, ver, endpoint_method='GET', endpoint_uri='/p')
|
||||
r = await authed_client.request('TRACE', f'/api/rest/{name}/{ver}/p')
|
||||
assert r.status_code == 405
|
||||
Reference in New Issue
Block a user