test grpc package resolution and errors

This commit is contained in:
seniorswe
2025-10-05 17:29:25 -04:00
parent be125d23af
commit 402d766a12
4 changed files with 290 additions and 41 deletions
@@ -18,6 +18,7 @@ class CreateApiModel(BaseModel):
api_servers: List[str] = Field(default_factory=list, description='List of backend servers for the API', example=['http://localhost:8080', 'http://localhost:8081'])
api_type: str = Field(None, description="Type of the API. Valid values: 'REST'", example='REST')
api_allowed_retry_count: int = Field(0, description='Number of allowed retries for the API', example=0)
api_grpc_package: Optional[str] = Field(None, description='Optional gRPC Python package to use for this API (e.g., "my.pkg"). When set, overrides request package and default.', example='my.pkg')
api_authorization_field_swap: Optional[str] = Field(None, description='Header to swap for backend authorization header', example='backend-auth-header')
api_allowed_headers: Optional[List[str]] = Field(None, description='Allowed headers for the API', example=['Content-Type', 'Authorization'])
@@ -20,6 +20,7 @@ class UpdateApiModel(BaseModel):
api_authorization_field_swap: Optional[str] = Field(None, description='Header to swap for backend authorization header', example='backend-auth-header')
api_allowed_headers: Optional[List[str]] = Field(None, description='Allowed headers for the API', example=['Content-Type', 'Authorization'])
api_allowed_retry_count: Optional[int] = Field(None, description='Number of allowed retries for the API', example=0)
api_grpc_package: Optional[str] = Field(None, description='Optional gRPC Python package to use for this API (e.g., "my.pkg"). When set, overrides request package and default.', example='my.pkg')
api_credits_enabled: Optional[bool] = Field(False, description='Enable credit-based authentication for the API', example=True)
api_credit_group: Optional[str] = Field(None, description='API credit group for the API credits', example='ai-group-1')
active: Optional[bool] = Field(None, description='Whether the API is active (enabled)')
+31 -41
View File
@@ -728,7 +728,7 @@ class GatewayService:
if 'message' not in body:
logger.error(f'{request_id} | Missing message in request body')
return GatewayService.error_response(request_id, 'GTW011', 'Missing message in request body', status=400)
module_base = f'{api_name}_{api_version}'.replace('-', '_')
# Preserve previously resolved module_base (api_grpc_package > request package > default)
proto_filename = f'{module_base}.proto'
try:
@@ -859,56 +859,46 @@ class GatewayService:
last_exc = None
for attempt in range(attempts):
try:
method_callable = getattr(stub, method_name)
response = await method_callable(request_message)
# Prefer direct unary call via channel for better error mapping
full_method = f'/{module_base}.{service_name}/{method_name}'
req_ser = getattr(request_message, 'SerializeToString', None)
if not callable(req_ser):
req_ser = (lambda _m: b'')
unary = channel.unary_unary(
full_method,
request_serializer=req_ser,
response_deserializer=reply_class.FromString,
)
response = await unary(request_message)
last_exc = None
break
except (AttributeError, grpc.RpcError) as e:
last_exc = e
full_method = f'/{module_base}.{service_name}/{method_name}'
except grpc.RpcError as e2:
last_exc = e2
if attempt < attempts - 1 and getattr(e2, 'code', lambda: None)() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNIMPLEMENTED):
await asyncio.sleep(0.1 * (attempt + 1))
continue
# Try alternative method path without package prefix
try:
unary = channel.unary_unary(
full_method,
request_serializer=request_message.SerializeToString,
alt_method = f'/{service_name}/{method_name}'
req_ser = getattr(request_message, 'SerializeToString', None)
if not callable(req_ser):
req_ser = (lambda _m: b'')
unary2 = channel.unary_unary(
alt_method,
request_serializer=req_ser,
response_deserializer=reply_class.FromString,
)
response = await unary(request_message)
response = await unary2(request_message)
last_exc = None
break
except grpc.RpcError as e2:
last_exc = e2
if attempt < attempts - 1 and getattr(e2, 'code', lambda: None)() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNIMPLEMENTED):
except grpc.RpcError as e3:
last_exc = e3
if attempt < attempts - 1 and getattr(e3, 'code', lambda: None)() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNIMPLEMENTED):
await asyncio.sleep(0.1 * (attempt + 1))
continue
try:
alt_method = f'/{service_name}/{method_name}'
unary2 = channel.unary_unary(
alt_method,
request_serializer=request_message.SerializeToString,
response_deserializer=reply_class.FromString,
)
response = await unary2(request_message)
last_exc = None
else:
# Do not mask channel errors with stub fallback; propagate
break
except grpc.RpcError as e3:
last_exc = e3
if attempt < attempts - 1 and getattr(e3, 'code', lambda: None)() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNIMPLEMENTED):
await asyncio.sleep(0.1 * (attempt + 1))
continue
else:
try:
import functools
def _call_sync(url_s: str, svc_mod, svc_name: str, meth: str, req):
ch = grpc.insecure_channel(url_s)
stub_sync = getattr(svc_mod, f"{svc_name}Stub")(ch)
return getattr(stub_sync, meth)(req)
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, functools.partial(_call_sync, url, service_module, service_name, method_name, request_message))
last_exc = None
break
except Exception as e4:
last_exc = e4
break
if last_exc is not None:
raise last_exc
response_dict = {}
@@ -0,0 +1,257 @@
import pytest
async def _setup_api(client, name, ver, retry=0, api_pkg=None):
payload = {
'api_name': name,
'api_version': ver,
'api_description': f'{name} {ver}',
'api_allowed_roles': ['admin'],
'api_allowed_groups': ['ALL'],
'api_servers': ['grpc://127.0.0.1:50051'],
'api_type': 'REST',
'api_allowed_retry_count': retry,
}
if api_pkg is not None:
payload['api_grpc_package'] = api_pkg
r = await client.post('/platform/api', json=payload)
assert r.status_code in (200, 201)
r2 = await client.post('/platform/endpoint', json={
'api_name': name,
'api_version': ver,
'endpoint_method': 'POST',
'endpoint_uri': '/grpc',
'endpoint_description': 'grpc',
})
assert r2.status_code in (200, 201)
from conftest import subscribe_self
await subscribe_self(client, name, ver)
def _fake_pb2_module(method_name='M'):
class Req:
pass
class Reply:
DESCRIPTOR = type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})()
def __init__(self, ok=True):
self.ok = ok
@staticmethod
def FromString(b):
return Reply(True)
setattr(Req, '__name__', f'{method_name}Request')
setattr(Reply, '__name__', f'{method_name}Reply')
return Req, Reply
def _make_import_module_recorder(record, pb2_map):
def _imp(name):
record.append(name)
if name.endswith('_pb2'):
mod = type('PB2', (), {})
mapping = pb2_map.get(name)
if mapping is None:
# default: provide classes so gateway can proceed
req_cls, rep_cls = _fake_pb2_module('M')
setattr(mod, 'MRequest', req_cls)
setattr(mod, 'MReply', rep_cls)
else:
req_cls, rep_cls = mapping
if req_cls:
setattr(mod, 'MRequest', req_cls)
if rep_cls:
setattr(mod, 'MReply', rep_cls)
return mod
if name.endswith('_pb2_grpc'):
# service module with Stub class
class Stub:
def __init__(self, ch):
self._ch = ch
async def M(self, req):
# Default success path
return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})()
mod = type('SVC', (), {'SvcStub': Stub})
return mod
raise ImportError(name)
return _imp
def _make_fake_grpc_unary(sequence_codes, grpc_mod):
# Build a fake aio channel whose unary_unary returns a coroutine function using sequence codes
counter = {'i': 0}
class AioChan:
async def channel_ready(self):
return True
class Chan(AioChan):
def unary_unary(self, method, request_serializer=None, response_deserializer=None):
async def _call(req):
idx = min(counter['i'], len(sequence_codes) - 1)
code = sequence_codes[idx]
counter['i'] += 1
if code is None:
# success
return type('R', (), {'DESCRIPTOR': type('D', (), {'fields': [type('F', (), {'name': 'ok'})()]})(), 'ok': True})()
# Raise RpcError-like
class E(Exception):
def code(self):
return code
def details(self):
return 'err'
raise E()
return _call
class aio:
@staticmethod
def insecure_channel(url):
return Chan()
fake = type('G', (), {'aio': aio, 'StatusCode': grpc_mod.StatusCode, 'RpcError': Exception})
return fake
@pytest.mark.asyncio
async def test_grpc_uses_api_grpc_package_over_request(monkeypatch, authed_client):
import services.gateway_service as gs
name, ver = 'gpack1', 'v1'
await _setup_api(authed_client, name, ver, api_pkg='api.pkg')
record = []
req_cls, rep_cls = _fake_pb2_module('M')
pb2_map = { 'api.pkg_pb2': (req_cls, rep_cls) }
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
# Skip on-demand proto generation/import checks
monkeypatch.setattr(gs.os.path, 'exists', lambda p: True)
# Fake grpc to always succeed
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
r = await authed_client.post(
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'}
)
assert r.status_code == 200
assert any(n == 'api.pkg_pb2' for n in record)
@pytest.mark.asyncio
async def test_grpc_uses_request_package_when_no_api_package(monkeypatch, authed_client):
import services.gateway_service as gs
name, ver = 'gpack2', 'v1'
await _setup_api(authed_client, name, ver, api_pkg=None)
record = []
req_cls, rep_cls = _fake_pb2_module('M')
pb2_map = { 'req.pkg_pb2': (req_cls, rep_cls) }
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
monkeypatch.setattr(gs.os.path, 'exists', lambda p: True)
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
r = await authed_client.post(
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}, 'package': 'req.pkg'}
)
assert r.status_code == 200
assert any(n == 'req.pkg_pb2' for n in record)
@pytest.mark.asyncio
async def test_grpc_uses_default_package_when_no_overrides(monkeypatch, authed_client):
import services.gateway_service as gs
name, ver = 'gpack3', 'v1'
await _setup_api(authed_client, name, ver, api_pkg=None)
record = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
pb2_map = { default_pkg: (req_cls, rep_cls) }
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
monkeypatch.setattr(gs.os.path, 'exists', lambda p: True)
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
r = await authed_client.post(
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
)
assert r.status_code == 200
assert any(n.endswith(default_pkg) for n in record)
@pytest.mark.asyncio
async def test_grpc_unavailable_then_success_with_retry(monkeypatch, authed_client):
import services.gateway_service as gs
name, ver = 'gunavail', 'v1'
await _setup_api(authed_client, name, ver, retry=1)
record = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
pb2_map = { default_pkg: (req_cls, rep_cls) }
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
# First UNAVAILABLE, then success
fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNAVAILABLE, None], gs.grpc)
monkeypatch.setattr(gs, 'grpc', fake_grpc)
r = await authed_client.post(
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
)
assert r.status_code == 200
@pytest.mark.asyncio
async def test_grpc_unimplemented_then_success_with_retry(monkeypatch, authed_client):
import services.gateway_service as gs
name, ver = 'gunimpl', 'v1'
await _setup_api(authed_client, name, ver, retry=1)
record = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
pb2_map = { default_pkg: (req_cls, rep_cls) }
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNIMPLEMENTED, None], gs.grpc)
monkeypatch.setattr(gs, 'grpc', fake_grpc)
r = await authed_client.post(
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
)
assert r.status_code == 200
@pytest.mark.asyncio
async def test_grpc_not_found_maps_to_500_error(monkeypatch, authed_client):
import services.gateway_service as gs
name, ver = 'gnotfound', 'v1'
await _setup_api(authed_client, name, ver)
# Cause missing method types -> AttributeError -> GTW006 500
record = []
# Provide pb2 without classes to force failure
pb2_map = { f'{name}_{ver}_pb2': (None, None) }
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
monkeypatch.setattr(gs, 'grpc', _make_fake_grpc_unary([None], gs.grpc))
r = await authed_client.post(
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
)
assert r.status_code == 500
body = r.json()
assert body.get('error_code') == 'GTW006'
@pytest.mark.asyncio
async def test_grpc_unknown_maps_to_500_error(monkeypatch, authed_client):
import services.gateway_service as gs
name, ver = 'gunk', 'v1'
await _setup_api(authed_client, name, ver)
record = []
req_cls, rep_cls = _fake_pb2_module('M')
default_pkg = f'{name}_{ver}'.replace('-', '_') + '_pb2'
pb2_map = { default_pkg: (req_cls, rep_cls) }
monkeypatch.setattr(gs.importlib, 'import_module', _make_import_module_recorder(record, pb2_map))
# Force UNKNOWN error (maps to 500)
fake_grpc = _make_fake_grpc_unary([gs.grpc.StatusCode.UNKNOWN], gs.grpc)
monkeypatch.setattr(gs, 'grpc', fake_grpc)
r = await authed_client.post(
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
)
assert r.status_code == 500
@pytest.mark.asyncio
async def test_grpc_proto_missing_returns_404_gtw012(monkeypatch, authed_client):
import services.gateway_service as gs
name, ver = 'gproto404', 'v1'
await _setup_api(authed_client, name, ver)
# Make on-demand proto generation fail by raising on import grpc_tools
def _imp_fail(name):
if name.startswith('grpc_tools'):
raise ImportError('no tools')
raise ImportError(name)
monkeypatch.setattr(gs.importlib, 'import_module', _imp_fail)
r = await authed_client.post(
f'/api/grpc/{name}', headers={'X-API-Version': ver, 'Content-Type': 'application/json'}, json={'method': 'Svc.M', 'message': {}}
)
assert r.status_code == 404
body = r.json()
assert body.get('error_code') == 'GTW012'