mirror of
https://github.com/apidoorman/doorman.git
synced 2026-01-28 04:38:28 -06:00
355 lines
15 KiB
Python
355 lines
15 KiB
Python
"""
|
|
The contents of this file are property of Doorman Dev, LLC
|
|
Review the Apache License 2.0 for valid authorization of use
|
|
See https://github.com/apidoorman/doorman for more information
|
|
"""
|
|
|
|
import re
|
|
import uuid
|
|
from collections.abc import Callable
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from fastapi import HTTPException
|
|
|
|
try:
|
|
from defusedxml import ElementTree as ET
|
|
|
|
_DEFUSED = True
|
|
except Exception:
|
|
import xml.etree.ElementTree as ET
|
|
|
|
_DEFUSED = False
|
|
import grpc
|
|
from graphql import GraphQLError, parse
|
|
|
|
from models.field_validation_model import FieldValidation
|
|
from models.validation_schema_model import ValidationSchema
|
|
from utils.async_db import db_find_one
|
|
from utils.database_async import endpoint_validation_collection
|
|
from utils.doorman_cache_util import doorman_cache
|
|
|
|
|
|
class ValidationError(Exception):
|
|
def __init__(self, message: str, field_path: str):
|
|
self.message = message
|
|
self.field_path = field_path
|
|
super().__init__(self.message)
|
|
|
|
|
|
class ValidationUtil:
|
|
def __init__(self):
|
|
self.type_validators = {
|
|
'string': self._validate_string,
|
|
'number': self._validate_number,
|
|
'boolean': self._validate_boolean,
|
|
'array': self._validate_array,
|
|
'object': self._validate_object,
|
|
}
|
|
self.format_validators = {
|
|
'email': self._validate_email,
|
|
'url': self._validate_url,
|
|
'date': self._validate_date,
|
|
'datetime': self._validate_datetime,
|
|
'uuid': self._validate_uuid,
|
|
}
|
|
self.custom_validators: dict[str, Callable] = {}
|
|
# When defusedxml is unavailable, apply a basic pre-parse guard against DOCTYPE/ENTITY.
|
|
|
|
def _reject_unsafe_xml(self, xml_text: str) -> None:
|
|
"""
|
|
Reject XML with DOCTYPE/ENTITY declarations to prevent XXE attacks.
|
|
This is a fallback when defusedxml is not available.
|
|
When defusedxml is available (_DEFUSED=True), it handles XXE protection automatically.
|
|
"""
|
|
if _DEFUSED:
|
|
return
|
|
lowered = xml_text.lower()
|
|
if '<!doctype' in lowered or '<!entity' in lowered:
|
|
raise HTTPException(status_code=400, detail='XML DTD/entities are not allowed')
|
|
|
|
def register_custom_validator(
|
|
self, name: str, validator: Callable[[Any, FieldValidation], None]
|
|
) -> None:
|
|
self.custom_validators[name] = validator
|
|
|
|
async def get_validation_schema(self, endpoint_id: str) -> ValidationSchema | None:
|
|
"""Return the ValidationSchema for an endpoint_id if configured.
|
|
|
|
Looks up the in-memory cache first, then falls back to the DB collection.
|
|
Accepts both shapes:
|
|
- { 'validation_schema': {<paths>: FieldValidation} }
|
|
- {<paths>: FieldValidation}
|
|
"""
|
|
validation_doc = doorman_cache.get_cache('endpoint_validation_cache', endpoint_id)
|
|
if not validation_doc:
|
|
validation_doc = await db_find_one(
|
|
endpoint_validation_collection, {'endpoint_id': endpoint_id}
|
|
)
|
|
if validation_doc:
|
|
try:
|
|
vdoc = dict(validation_doc)
|
|
vdoc.pop('_id', None)
|
|
doorman_cache.set_cache('endpoint_validation_cache', endpoint_id, vdoc)
|
|
validation_doc = vdoc
|
|
except Exception:
|
|
pass
|
|
if not validation_doc:
|
|
return None
|
|
if not bool(validation_doc.get('validation_enabled')):
|
|
return None
|
|
raw = validation_doc.get('validation_schema')
|
|
if not raw:
|
|
return None
|
|
mapping = (
|
|
raw.get('validation_schema')
|
|
if isinstance(raw, dict) and 'validation_schema' in raw
|
|
else raw
|
|
)
|
|
if not isinstance(mapping, dict):
|
|
return None
|
|
schema = ValidationSchema(validation_schema=mapping)
|
|
self._validate_schema_paths(schema.validation_schema)
|
|
return schema
|
|
|
|
def _validate_schema_paths(
|
|
self, schema: dict[str, FieldValidation], parent_path: str = ''
|
|
) -> None:
|
|
for field_path, validation in schema.items():
|
|
full_path = f'{parent_path}.{field_path}' if parent_path else field_path
|
|
if not self._is_valid_field_path(full_path):
|
|
raise ValidationError(f'Invalid field path: {full_path}', full_path)
|
|
if validation.nested_schema:
|
|
self._validate_schema_paths(validation.nested_schema, full_path)
|
|
|
|
def _is_valid_field_path(self, path: str) -> bool:
|
|
parts = path.split('.')
|
|
for part in parts:
|
|
if '[' in part:
|
|
field, index = part.split('[')
|
|
if not field and not index.rstrip(']').isdigit():
|
|
return False
|
|
if not part or part.startswith('.') or part.endswith('.'):
|
|
return False
|
|
return True
|
|
|
|
def _validate_string(self, value: Any, validation: FieldValidation, path: str) -> None:
|
|
if not isinstance(value, str):
|
|
raise ValidationError(f'Expected string, got {type(value).__name__}', path)
|
|
if validation.min is not None and len(value) < validation.min:
|
|
raise ValidationError(f'String length must be at least {validation.min}', path)
|
|
if validation.max is not None and len(value) > validation.max:
|
|
raise ValidationError(f'String length must be at most {validation.max}', path)
|
|
if validation.pattern and not re.match(validation.pattern, value):
|
|
raise ValidationError(f'String does not match pattern {validation.pattern}', path)
|
|
if validation.format and validation.format in self.format_validators:
|
|
self.format_validators[validation.format](value, validation, path)
|
|
|
|
def _validate_number(self, value: Any, validation: FieldValidation, path: str) -> None:
|
|
if not isinstance(value, (int, float)):
|
|
raise ValidationError(f'Expected number, got {type(value).__name__}', path)
|
|
if validation.min is not None and value < validation.min:
|
|
raise ValidationError(f'Value must be at least {validation.min}', path)
|
|
if validation.max is not None and value > validation.max:
|
|
raise ValidationError(f'Value must be at most {validation.max}', path)
|
|
|
|
def _validate_boolean(self, value: Any, validation: FieldValidation, path: str) -> None:
|
|
if not isinstance(value, bool):
|
|
raise ValidationError(f'Expected boolean, got {type(value).__name__}', path)
|
|
|
|
def _validate_array(self, value: Any, validation: FieldValidation, path: str) -> None:
|
|
if not isinstance(value, list):
|
|
raise ValidationError(f'Expected array, got {type(value).__name__}', path)
|
|
if validation.min is not None and len(value) < validation.min:
|
|
raise ValidationError(f'Array must have at least {validation.min} items', path)
|
|
if validation.max is not None and len(value) > validation.max:
|
|
raise ValidationError(f'Array must have at most {validation.max} items', path)
|
|
if validation.array_items:
|
|
for i, item in enumerate(value):
|
|
self._validate_value(item, validation.array_items, f'{path}[{i}]')
|
|
|
|
def _validate_object(self, value: Any, validation: FieldValidation, path: str) -> None:
|
|
if not isinstance(value, dict):
|
|
raise ValidationError(f'Expected object, got {type(value).__name__}', path)
|
|
if validation.nested_schema:
|
|
for field_path, field_validation in validation.nested_schema.items():
|
|
if field_validation.required and field_path not in value:
|
|
raise ValidationError(f'Required field {field_path} is missing', path)
|
|
if field_path in value:
|
|
self._validate_value(
|
|
value[field_path], field_validation, f'{path}.{field_path}'
|
|
)
|
|
|
|
def _validate_value(self, value: Any, validation: FieldValidation, field_path: str) -> None:
|
|
if validation.required and value is None:
|
|
raise ValidationError('Field is required', field_path)
|
|
if value is None:
|
|
return
|
|
if validation.type in self.type_validators:
|
|
self.type_validators[validation.type](value, validation, field_path)
|
|
if validation.enum and value not in validation.enum:
|
|
raise ValidationError(f'Value must be one of {validation.enum}', field_path)
|
|
if validation.custom_validator and validation.custom_validator in self.custom_validators:
|
|
try:
|
|
self.custom_validators[validation.custom_validator](value, validation)
|
|
except ValidationError as e:
|
|
raise ValidationError(e.message, field_path)
|
|
|
|
def _validate_email(self, value: str, validation: FieldValidation, path: str) -> None:
|
|
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
|
if not re.match(email_pattern, value):
|
|
raise ValidationError('Invalid email format', path)
|
|
|
|
def _validate_url(self, value: str, validation: FieldValidation, path: str) -> None:
|
|
url_pattern = (
|
|
r'^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}'
|
|
r'\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)$'
|
|
)
|
|
if not re.match(url_pattern, value):
|
|
raise ValidationError('Invalid URL format', path)
|
|
|
|
def _validate_date(self, value: str, validation: FieldValidation, path: str) -> None:
|
|
try:
|
|
datetime.strptime(value, '%Y-%m-%d')
|
|
except ValueError as e:
|
|
raise ValidationError('Invalid date format (YYYY-MM-DD)', path) from e
|
|
|
|
def _validate_datetime(self, value: str, validation: FieldValidation, path: str) -> None:
|
|
try:
|
|
datetime.fromisoformat(value.replace('Z', '+00:00'))
|
|
except ValueError as e:
|
|
raise ValidationError('Invalid datetime format (ISO 8601)', path) from e
|
|
|
|
def _validate_uuid(self, value: str, validation: FieldValidation, path: str) -> None:
|
|
try:
|
|
uuid.UUID(value)
|
|
except ValueError as e:
|
|
raise ValidationError('Invalid UUID format', path) from e
|
|
|
|
async def validate_rest_request(self, endpoint_id: str, request_data: dict[str, Any]) -> None:
|
|
schema = await self.get_validation_schema(endpoint_id)
|
|
if not schema:
|
|
return
|
|
for field_path, validation in schema.validation_schema.items():
|
|
try:
|
|
value = self._get_nested_value(request_data, field_path)
|
|
self._validate_value(value, validation, field_path)
|
|
except ValidationError as e:
|
|
import logging
|
|
|
|
logging.getLogger('doorman.gateway').error(
|
|
f'Validation failed for {field_path}: {e}'
|
|
)
|
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
|
|
|
async def validate_soap_request(self, endpoint_id: str, soap_envelope: str) -> None:
|
|
schema = await self.get_validation_schema(endpoint_id)
|
|
if not schema:
|
|
return
|
|
try:
|
|
self._reject_unsafe_xml(soap_envelope)
|
|
root = ET.fromstring(soap_envelope)
|
|
body = root.find('.//{http://schemas.xmlsoap.org/soap/envelope/}Body')
|
|
if body is None:
|
|
raise ValidationError('SOAP Body not found', 'Body')
|
|
request_data = self._xml_to_dict(body[0])
|
|
for field_path, validation in schema.validation_schema.items():
|
|
try:
|
|
value = self._get_nested_value(request_data, field_path)
|
|
self._validate_value(value, validation, field_path)
|
|
except ValidationError as e:
|
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
|
except ET.ParseError as e:
|
|
raise HTTPException(status_code=400, detail='Invalid SOAP envelope') from e
|
|
|
|
async def validate_grpc_request(self, endpoint_id: str, request: Any) -> None:
|
|
schema = await self.get_validation_schema(endpoint_id)
|
|
if not schema:
|
|
return
|
|
request_data = request if isinstance(request, dict) else self._protobuf_to_dict(request)
|
|
for field_path, validation in schema.validation_schema.items():
|
|
try:
|
|
value = self._get_nested_value(request_data, field_path)
|
|
self._validate_value(value, validation, field_path)
|
|
except ValidationError as e:
|
|
raise grpc.RpcError(grpc.StatusCode.INVALID_ARGUMENT, str(e)) from e
|
|
|
|
async def validate_graphql_request(
|
|
self, endpoint_id: str, query: str, variables: dict[str, Any]
|
|
) -> None:
|
|
schema = await self.get_validation_schema(endpoint_id)
|
|
if not schema:
|
|
return
|
|
try:
|
|
parse(query)
|
|
operation_name = self._extract_operation_name(query)
|
|
if operation_name:
|
|
for field_path, validation in schema.validation_schema.items():
|
|
if field_path.startswith(operation_name):
|
|
try:
|
|
value = self._get_nested_value(
|
|
variables, field_path[len(operation_name) + 1 :]
|
|
)
|
|
self._validate_value(value, validation, field_path)
|
|
except ValidationError as e:
|
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
|
except GraphQLError as e:
|
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
|
|
|
def _extract_operation_name(self, query: str) -> str | None:
|
|
match = re.search(r'(?:query|mutation)\s+(\w+)', query)
|
|
return match.group(1) if match else None
|
|
|
|
def _get_nested_value(self, data: dict[str, Any], field_path: str) -> Any:
|
|
parts = field_path.split('.')
|
|
current = data
|
|
for part in parts:
|
|
if '[' in part:
|
|
field, index = part.split('[')
|
|
index = int(index.rstrip(']'))
|
|
if field:
|
|
current = current.get(field, [])
|
|
if not isinstance(current, list) or index >= len(current):
|
|
return None
|
|
current = current[index]
|
|
else:
|
|
if not isinstance(current, dict):
|
|
return None
|
|
current = current.get(part)
|
|
if current is None:
|
|
return None
|
|
return current
|
|
|
|
def _strip_ns(self, tag: str) -> str:
|
|
if '}' in tag:
|
|
return tag.split('}', 1)[1]
|
|
return tag
|
|
|
|
def _xml_to_dict(self, element: Any) -> dict[str, Any]:
|
|
result = {}
|
|
for child in element:
|
|
key = self._strip_ns(child.tag)
|
|
if len(child) > 0:
|
|
result[key] = self._xml_to_dict(child)
|
|
else:
|
|
result[key] = child.text
|
|
return result
|
|
|
|
def _protobuf_to_dict(self, message: Any) -> dict[str, Any]:
|
|
result = {}
|
|
for field in message.DESCRIPTOR.fields:
|
|
value = getattr(message, field.name)
|
|
if field.type == field.TYPE_MESSAGE:
|
|
if field.label == field.LABEL_REPEATED:
|
|
result[field.name] = [self._protobuf_to_dict(item) for item in value]
|
|
else:
|
|
result[field.name] = self._protobuf_to_dict(value)
|
|
else:
|
|
result[field.name] = value
|
|
return result
|
|
|
|
|
|
validation_util = ValidationUtil()
|