mirror of
https://github.com/apidoorman/doorman.git
synced 2026-04-25 10:08:41 -05:00
test grpc package resolution and errors
This commit is contained in:
@@ -18,6 +18,7 @@ class CreateApiModel(BaseModel):
|
||||
api_servers: List[str] = Field(default_factory=list, description='List of backend servers for the API', example=['http://localhost:8080', 'http://localhost:8081'])
|
||||
api_type: str = Field(None, description="Type of the API. Valid values: 'REST'", example='REST')
|
||||
api_allowed_retry_count: int = Field(0, description='Number of allowed retries for the API', example=0)
|
||||
api_grpc_package: Optional[str] = Field(None, description='Optional gRPC Python package to use for this API (e.g., "my.pkg"). When set, overrides request package and default.', example='my.pkg')
|
||||
|
||||
api_authorization_field_swap: Optional[str] = Field(None, description='Header to swap for backend authorization header', example='backend-auth-header')
|
||||
api_allowed_headers: Optional[List[str]] = Field(None, description='Allowed headers for the API', example=['Content-Type', 'Authorization'])
|
||||
|
||||
@@ -20,6 +20,7 @@ class UpdateApiModel(BaseModel):
|
||||
api_authorization_field_swap: Optional[str] = Field(None, description='Header to swap for backend authorization header', example='backend-auth-header')
|
||||
api_allowed_headers: Optional[List[str]] = Field(None, description='Allowed headers for the API', example=['Content-Type', 'Authorization'])
|
||||
api_allowed_retry_count: Optional[int] = Field(None, description='Number of allowed retries for the API', example=0)
|
||||
api_grpc_package: Optional[str] = Field(None, description='Optional gRPC Python package to use for this API (e.g., "my.pkg"). When set, overrides request package and default.', example='my.pkg')
|
||||
api_credits_enabled: Optional[bool] = Field(False, description='Enable credit-based authentication for the API', example=True)
|
||||
api_credit_group: Optional[str] = Field(None, description='API credit group for the API credits', example='ai-group-1')
|
||||
active: Optional[bool] = Field(None, description='Whether the API is active (enabled)')
|
||||
|
||||
@@ -728,7 +728,7 @@ class GatewayService:
|
||||
if 'message' not in body:
|
||||
logger.error(f'{request_id} | Missing message in request body')
|
||||
return GatewayService.error_response(request_id, 'GTW011', 'Missing message in request body', status=400)
|
||||
module_base = f'{api_name}_{api_version}'.replace('-', '_')
|
||||
# Preserve previously resolved module_base (api_grpc_package > request package > default)
|
||||
proto_filename = f'{module_base}.proto'
|
||||
|
||||
try:
|
||||
@@ -859,56 +859,46 @@ class GatewayService:
|
||||
last_exc = None
|
||||
for attempt in range(attempts):
|
||||
try:
|
||||
method_callable = getattr(stub, method_name)
|
||||
response = await method_callable(request_message)
|
||||
# Prefer direct unary call via channel for better error mapping
|
||||
full_method = f'/{module_base}.{service_name}/{method_name}'
|
||||
req_ser = getattr(request_message, 'SerializeToString', None)
|
||||
if not callable(req_ser):
|
||||
req_ser = (lambda _m: b'')
|
||||
unary = channel.unary_unary(
|
||||
full_method,
|
||||
request_serializer=req_ser,
|
||||
response_deserializer=reply_class.FromString,
|
||||
)
|
||||
response = await unary(request_message)
|
||||
last_exc = None
|
||||
break
|
||||
except (AttributeError, grpc.RpcError) as e:
|
||||
last_exc = e
|
||||
full_method = f'/{module_base}.{service_name}/{method_name}'
|
||||
except grpc.RpcError as e2:
|
||||
last_exc = e2
|
||||
if attempt < attempts - 1 and getattr(e2, 'code', lambda: None)() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNIMPLEMENTED):
|
||||
await asyncio.sleep(0.1 * (attempt + 1))
|
||||
continue
|
||||
# Try alternative method path without package prefix
|
||||
try:
|
||||
unary = channel.unary_unary(
|
||||
full_method,
|
||||
request_serializer=request_message.SerializeToString,
|
||||
alt_method = f'/{service_name}/{method_name}'
|
||||
req_ser = getattr(request_message, 'SerializeToString', None)
|
||||
if not callable(req_ser):
|
||||
req_ser = (lambda _m: b'')
|
||||
unary2 = channel.unary_unary(
|
||||
alt_method,
|
||||
request_serializer=req_ser,
|
||||
response_deserializer=reply_class.FromString,
|
||||
)
|
||||
response = await unary(request_message)
|
||||
response = await unary2(request_message)
|
||||
last_exc = None
|
||||
break
|
||||
except grpc.RpcError as e2:
|
||||
last_exc = e2
|
||||
if attempt < attempts - 1 and getattr(e2, 'code', lambda: None)() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNIMPLEMENTED):
|
||||
except grpc.RpcError as e3:
|
||||
last_exc = e3
|
||||
if attempt < attempts - 1 and getattr(e3, 'code', lambda: None)() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNIMPLEMENTED):
|
||||
await asyncio.sleep(0.1 * (attempt + 1))
|
||||
continue
|
||||
try:
|
||||
alt_method = f'/{service_name}/{method_name}'
|
||||
unary2 = channel.unary_unary(
|
||||
alt_method,
|
||||
request_serializer=request_message.SerializeToString,
|
||||
response_deserializer=reply_class.FromString,
|
||||
)
|
||||
response = await unary2(request_message)
|
||||
last_exc = None
|
||||
else:
|
||||
# Do not mask channel errors with stub fallback; propagate
|
||||
break
|
||||
except grpc.RpcError as e3:
|
||||
last_exc = e3
|
||||
if attempt < attempts - 1 and getattr(e3, 'code', lambda: None)() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNIMPLEMENTED):
|
||||
await asyncio.sleep(0.1 * (attempt + 1))
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
import functools
|
||||
def _call_sync(url_s: str, svc_mod, svc_name: str, meth: str, req):
|
||||
ch = grpc.insecure_channel(url_s)
|
||||
stub_sync = getattr(svc_mod, f"{svc_name}Stub")(ch)
|
||||
return getattr(stub_sync, meth)(req)
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(None, functools.partial(_call_sync, url, service_module, service_name, method_name, request_message))
|
||||
last_exc = None
|
||||
break
|
||||
except Exception as e4:
|
||||
last_exc = e4
|
||||
break
|
||||
if last_exc is not None:
|
||||
raise last_exc
|
||||
response_dict = {}
|
||||
|
||||
@@ -0,0 +1,257 @@
|
||||
import pytest
|
||||
|
||||
|
||||
async def _setup_api(client, name, ver, retry=0, api_pkg=None):
|
||||
payload = {
|
||||
'api_name': name,
|
||||
'api_version': ver,
|
||||
'api_description': f'{name} {ver}',
|
||||
'api_allowed_roles': ['admin'],
|
||||
'api_allowed_groups': ['ALL'],
|
||||
'api_servers': ['grpc://127.0.0.1:50051'],
|
||||
'api_type': 'REST',
|
||||
'api_allowed_retry_count': retry,
|
||||
}
|
||||
if api_pkg is not None:
|
||||
payload['api_grpc_package'] = api_pkg
|
||||
r = await client.post('/platform/api', json=payload)
|
||||
assert r.status_code in (200, 201)
|
||||
r2 = await client.post('/platform/endpoint', json={
|
||||
'api_name': name,
|
||||
'api_version': ver,
|
||||
'endpoint_method': 'POST',
|
||||
'endpoint_uri': '/grpc',
|
||||
'endpoint_description': 'grpc',
|
||||
})
|
||||
assert r2.status_code in (200, 201)
|
||||
from conftest import subscribe_self
|
||||
await subscribe_self(client, name, ver)
|
||||
|
||||
|
||||
def _fake_pb2_module(method_name='M'):
|
||||
class Req:
|
||||
pass
|
||||
class Reply:
|
||||
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
|
||||
def __init__(self, ok=True):
|
||||
self.ok = ok
|
||||
@staticmethod
|
||||
def FromString(b):
|
||||
return Reply(True)
|
||||
setattr(Req, '__name__', f'{method_name}Request')
|
||||
setattr(Reply, '__name__', f'{method_name}Reply')
|
||||
return Req, Reply
|
||||
|
||||
|
||||
def _make_import_module_recorder(record, pb2_map):
|
||||
def _imp(name):
|
||||
record.append(name)
|
||||
if name.endswith('_pb2'):
|
||||
mod = type('PB2', (), {})
|
||||
mapping = pb2_map.get(name)
|
||||
if mapping is None:
|
||||
# default: provide classes so gateway can proceed
|
||||
req_cls, rep_cls = _fake_pb2_module('M')
|
||||
setattr(mod, 'MRequest', req_cls)
|
||||
setattr(mod, 'MReply', rep_cls)
|
||||
else:
|
||||
req_cls, rep_cls = mapping
|
||||
if req_cls:
|
||||
setattr(mod, 'MRequest', req_cls)
|
||||
if rep_cls:
|
||||
setattr(mod, 'MReply', rep_cls)
|
||||
return mod
|
||||
if name.endswith('_pb2_grpc'):
|
||||
# service module with Stub class
|
||||
class Stub:
|
||||
def __init__(self, ch):
|
||||
self._ch = ch
|
||||
async def M(self, req):
|
||||
# Default success path
|
||||
return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})()
|
||||
mod = type('SVC', (), {'SvcStub': Stub})
|
||||
return mod
|
||||
raise ImportError(name)
|
||||
return _imp
|
||||
|
||||
|
||||
def _make_fake_grpc_unary(sequence_codes, grpc_mod):
|
||||
# Build a fake aio channel whose unary_unary returns a coroutine function using sequence codes
|
||||
counter = {'i': 0}
|
||||
class AioChan:
|
||||
async def channel_ready(self):
|
||||
return True
|
||||
class Chan(AioChan):
|
||||
def unary_unary(self, method, request_serializer=None, response_deserializer=None):
|
||||
async def _call(req):
|
||||
idx = min(counter['i'], len(sequence_codes) - 1)
|
||||
code = sequence_codes[idx]
|
||||
counter['i'] += 1
|
||||
if code is None:
|
||||
# success
|
||||
return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})()
|
||||
# Raise RpcError-like
|
||||
class E(Exception):
|
||||
def code(self):
|
||||
return code
|
||||
def details(self):
|
||||
return 'err'
|
||||
raise E()
|
||||
return _call
|
||||
class aio:
|
||||
@staticmethod
|
||||
def insecure_channel(url):
|
||||
return Chan()
|
||||
fake = type('G', (), {'aio': aio, 'StatusCode': grpc_mod.StatusCode, 'RpcError': Exception})
|
||||
return fake
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grpc_uses_api_grpc_package_over_request(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
name, ver = 'gpack1', 'v1'
|
||||
await _setup_api(authed_client, name, ver, api_pkg='api.pkg')
|
||||
record = []
|
||||
req_cls, rep_cls = _fake_pb2_module('M')
|
||||
pb2_map = { 'api.pkg_pb2': (req_cls, rep_cls) }
|
||||
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
|
||||
# Skip on-demand proto generation/import checks
|
||||
monkeypatch.setattr(gs.os.path, 'exists', lambda p: True)
|
||||
# Fake grpc to always succeed
|
||||
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
|
||||
r = await authed_client.post(
|
||||
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'}
|
||||
)
|
||||
assert r.status_code == 200
|
||||
assert any(n == 'api.pkg_pb2' for n in record)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grpc_uses_request_package_when_no_api_package(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
name, ver = 'gpack2', 'v1'
|
||||
await _setup_api(authed_client, name, ver, api_pkg=None)
|
||||
record = []
|
||||
req_cls, rep_cls = _fake_pb2_module('M')
|
||||
pb2_map = { 'req.pkg_pb2': (req_cls, rep_cls) }
|
||||
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
|
||||
monkeypatch.setattr(gs.os.path, 'exists', lambda p: True)
|
||||
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
|
||||
r = await authed_client.post(
|
||||
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'}
|
||||
)
|
||||
assert r.status_code == 200
|
||||
assert any(n == 'req.pkg_pb2' for n in record)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grpc_uses_default_package_when_no_overrides(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
name, ver = 'gpack3', 'v1'
|
||||
await _setup_api(authed_client, name, ver, api_pkg=None)
|
||||
record = []
|
||||
req_cls, rep_cls = _fake_pb2_module('M')
|
||||
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
|
||||
pb2_map = { default_pkg: (req_cls, rep_cls) }
|
||||
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
|
||||
monkeypatch.setattr(gs.os.path, 'exists', lambda p: True)
|
||||
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
|
||||
r = await authed_client.post(
|
||||
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
|
||||
)
|
||||
assert r.status_code == 200
|
||||
assert any(n.endswith(default_pkg) for n in record)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grpc_unavailable_then_success_with_retry(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
name, ver = 'gunavail', 'v1'
|
||||
await _setup_api(authed_client, name, ver, retry=1)
|
||||
record = []
|
||||
req_cls, rep_cls = _fake_pb2_module('M')
|
||||
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
|
||||
pb2_map = { default_pkg: (req_cls, rep_cls) }
|
||||
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
|
||||
# First UNAVAILABLE, then success
|
||||
fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNAVAILABLE, None], gs.grpc)
|
||||
monkeypatch.setattr(gs, 'grpc', fake_grpc)
|
||||
r = await authed_client.post(
|
||||
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grpc_unimplemented_then_success_with_retry(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
name, ver = 'gunimpl', 'v1'
|
||||
await _setup_api(authed_client, name, ver, retry=1)
|
||||
record = []
|
||||
req_cls, rep_cls = _fake_pb2_module('M')
|
||||
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
|
||||
pb2_map = { default_pkg: (req_cls, rep_cls) }
|
||||
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
|
||||
fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNIMPLEMENTED, None], gs.grpc)
|
||||
monkeypatch.setattr(gs, 'grpc', fake_grpc)
|
||||
r = await authed_client.post(
|
||||
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grpc_not_found_maps_to_500_error(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
name, ver = 'gnotfound', 'v1'
|
||||
await _setup_api(authed_client, name, ver)
|
||||
# Cause missing method types -> AttributeError -> GTW006 500
|
||||
record = []
|
||||
# Provide pb2 without classes to force failure
|
||||
pb2_map = { f'{name}_{ver}_pb2': (None, None) }
|
||||
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
|
||||
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
|
||||
r = await authed_client.post(
|
||||
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
|
||||
)
|
||||
assert r.status_code == 500
|
||||
body = r.json()
|
||||
assert body.get('error_code') == 'GTW006'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grpc_unknown_maps_to_500_error(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
name, ver = 'gunk', 'v1'
|
||||
await _setup_api(authed_client, name, ver)
|
||||
record = []
|
||||
req_cls, rep_cls = _fake_pb2_module('M')
|
||||
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
|
||||
pb2_map = { default_pkg: (req_cls, rep_cls) }
|
||||
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
|
||||
# Force UNKNOWN error (maps to 500)
|
||||
fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNKNOWN], gs.grpc)
|
||||
monkeypatch.setattr(gs, 'grpc', fake_grpc)
|
||||
r = await authed_client.post(
|
||||
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
|
||||
)
|
||||
assert r.status_code == 500
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grpc_proto_missing_returns_404_gtw012(monkeypatch, authed_client):
|
||||
import services.gateway_service as gs
|
||||
name, ver = 'gproto404', 'v1'
|
||||
await _setup_api(authed_client, name, ver)
|
||||
# Make on-demand proto generation fail by raising on import grpc_tools
|
||||
def _imp_fail(name):
|
||||
if name.startswith('grpc_tools'):
|
||||
raise ImportError('no tools')
|
||||
raise ImportError(name)
|
||||
monkeypatch.setattr(gs.importlib, 'import_module', _imp_fail)
|
||||
r = await authed_client.post(
|
||||
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
|
||||
)
|
||||
assert r.status_code == 404
|
||||
body = r.json()
|
||||
assert body.get('error_code') == 'GTW012'
|
||||
Reference in New Issue
Block a user