Files
hatchet/sdks/python/apply_patches.py
T
Trevor Wilson cb68300652 test(rest): add typed REST transport exceptions + diagnostics tests (#2968)
* test(rest): add typed REST transport exceptions + diagnostics tests

Adds typed REST transport exceptions for urllib3/network failures and unit tests
verifying translation, diagnostics, and backward compatibility. Updates
apply_patches.py to patch generated REST client code and exceptions.

Refs: #2443

* chore(tests): remove redundant docstrings and comments

* test(rest): convert transport exception tests to module-level functions

* chore(python-sdk): bump version to 1.23.2 and update changelog
2026-02-11 21:31:33 -05:00

388 lines
13 KiB
Python

import re
from collections.abc import Callable
from copy import deepcopy
from pathlib import Path
def prepend_import(content: str, import_statement: str) -> str:
if import_statement in content:
return content
future_import_pattern = r"^from __future__ import [^\n]+\n"
future_imports = re.findall(future_import_pattern, content, re.MULTILINE)
content = re.sub(future_import_pattern, "", content, flags=re.MULTILINE)
match = re.search(r"^import\s+|^from\s+", content, re.MULTILINE)
insert_position = match.start() if match else 0
future_block = "".join(future_imports)
return (
content[:insert_position]
+ future_block
+ import_statement
+ "\n"
+ content[insert_position:]
)
def apply_patch(content: str, pattern: str, replacement: str) -> str:
return re.sub(pattern, replacement, content)
def atomically_patch_file(
file_path: str, patch_funcs: list[Callable[[str], str]]
) -> None:
path = Path(file_path)
original = path.read_text()
modified = deepcopy(original)
try:
for func in patch_funcs:
modified = func(modified)
except Exception as e:
print(f"Error patching {file_path}: {e}")
return
if modified != original:
path.write_text(modified)
print(f"Patched {file_path}")
else:
print(f"No changes made to {file_path}")
def patch_contract_import_paths(content: str) -> str:
return apply_patch(content, r"\bfrom v1\b", "from hatchet_sdk.contracts.v1")
def patch_grpc_dispatcher_import(content: str) -> str:
return apply_patch(
content,
r"\bimport dispatcher_pb2 as dispatcher__pb2\b",
"from hatchet_sdk.contracts import dispatcher_pb2 as dispatcher__pb2",
)
def patch_grpc_events_import(content: str) -> str:
return apply_patch(
content,
r"\bimport events_pb2 as events__pb2\b",
"from hatchet_sdk.contracts import events_pb2 as events__pb2",
)
def patch_grpc_workflows_import(content: str) -> str:
return apply_patch(
content,
r"\bimport workflows_pb2 as workflows__pb2\b",
"from hatchet_sdk.contracts import workflows_pb2 as workflows__pb2",
)
def patch_grpc_init_signature(content: str) -> str:
return apply_patch(
content,
r"def __init__\(self, channel\):",
"def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:",
)
def patch_rest_transport_exceptions(content: str) -> str:
"""Insert typed REST transport exception classes into exceptions.py.
Adds exception classes above render_path function, idempotently.
"""
# Check if already patched
if "class RestTransportError" in content:
return content
new_exceptions = '''\
class RestTransportError(ApiException):
"""Base exception for REST transport-level errors (network, timeout, TLS)."""
pass
class RestTimeoutError(RestTransportError):
"""Raised when a REST request times out (connect or read timeout)."""
pass
class RestConnectionError(RestTransportError):
"""Raised when a REST request fails to establish a connection."""
pass
class RestTLSError(RestTransportError):
"""Raised when a REST request fails due to SSL/TLS errors."""
pass
class RestProtocolError(RestTransportError):
"""Raised when a REST request fails due to protocol-level errors."""
pass
'''
# Insert before render_path function (match any arguments)
pattern = r"(\ndef render_path\([^)]*\):)"
replacement = new_exceptions + r"\1"
return re.sub(pattern, replacement, content)
def patch_rest_imports(content: str) -> str:
"""Update rest.py imports to include typed transport exceptions.
Handles both single-line and parenthesized import formats. Idempotent.
"""
# The exceptions we need to ensure are imported
required_exceptions = [
"RestConnectionError",
"RestProtocolError",
"RestTimeoutError",
"RestTLSError",
]
# Idempotency check: if RestTLSError is already imported from this module, do nothing.
if re.search(
r"(?m)^from\s+hatchet_sdk\.clients\.rest\.exceptions\s+import[^\n]*\bRestTLSError\b",
content,
):
return content
# Parenthesized import block includes RestTLSError
if re.search(
r"^from\s+hatchet_sdk\.clients\.rest\.exceptions\s+import\s*\(\s*.*?\bRestTLSError\b.*?\)\s*$",
content,
flags=re.MULTILINE | re.DOTALL,
):
return content
# The target import statement we want with trailing newline to preserve spacing
new_import = (
"from hatchet_sdk.clients.rest.exceptions import (\n"
" ApiException,\n"
" ApiValueError,\n"
" RestConnectionError,\n"
" RestProtocolError,\n"
" RestTimeoutError,\n"
" RestTLSError,\n"
")\n"
)
# Single line import
# Matches: from hatchet_sdk.clients.rest.exceptions import ApiException, ApiValueError
single_line_pattern = (
r"^from\s+hatchet_sdk\.clients\.rest\.exceptions\s+import\s+"
r"ApiException\s*,\s*ApiValueError\s*$"
)
modified = re.sub(single_line_pattern, new_import, content, flags=re.MULTILINE)
if modified != content:
return modified
# More flexible parenthesized import which matches any order, with or without trailing comma
# This handles cases where ApiException and ApiValueError might be in different orders
flexible_paren_pattern = (
r"^from\s+hatchet_sdk\.clients\.rest\.exceptions\s+import\s*\("
r"[^)]*?" # Non-greedy match of contents (only ApiException/ApiValueError expected)
r"\)"
)
# Only apply if the block contains just ApiException and/or ApiValueError (no Rest* yet)
match = re.search(flexible_paren_pattern, content, flags=re.MULTILINE | re.DOTALL)
if match:
block = match.group(0)
# Verify it only has ApiException/ApiValueError, not our new exceptions
if not any(exc in block for exc in required_exceptions):
if "ApiException" in block or "ApiValueError" in block:
modified = (
content[: match.start()] + new_import + content[match.end() :]
)
return modified
return content
def patch_rest_error_diagnostics(content: str) -> str:
"""Patch rest.py exception handlers to use typed exceptions.
Replaces the generic ApiException handler with typed exception handlers.
Handler ordering is critical: NewConnectionError must be caught before
ConnectTimeoutError because it inherits from ConnectTimeoutError in urllib3.
"""
# This pattern matches either the original SSLError only handler or
# the previously patched multi-exception handler raising ApiException
pattern_original = (
r"(?ms)^([ \t]*)except urllib3\.exceptions\.SSLError as e:\s*\n"
r"^\1[ \t]*msg = \"\\n\"\.join\(\[type\(e\)\.__name__, str\(e\)\]\)\s*\n"
r"^\1[ \t]*raise ApiException\(status=0, reason=msg\)\s*\n"
)
pattern_expanded = (
r"(?ms)^([ \t]*)except \(\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.SSLError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ConnectTimeoutError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ReadTimeoutError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.MaxRetryError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.NewConnectionError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ProtocolError,\s*\n"
r"^\1\) as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise ApiException\(status=0, reason=msg\)\s*\n"
)
# Check if already using typed exceptions
if "raise RestTLSError" in content:
return content
# Build typed replacement with proper handler ordering
# NewConnectionError inherits from ConnectTimeoutError, so must be caught first
replacement = (
r"\1except urllib3.exceptions.SSLError as e:\n"
r'\1 msg = "\\n".join(\n'
r"\1 [\n"
r"\1 type(e).__name__,\n"
r"\1 str(e),\n"
r'\1 f"method={method}",\n'
r'\1 f"url={url}",\n'
r'\1 f"timeout={_request_timeout}",\n'
r"\1 ]\n"
r"\1 )\n"
r"\1 raise RestTLSError(status=0, reason=msg) from e\n"
r"\1except (\n"
r"\1 urllib3.exceptions.MaxRetryError,\n"
r"\1 urllib3.exceptions.NewConnectionError,\n"
r"\1) as e:\n"
r"\1 # NewConnectionError inherits from ConnectTimeoutError, so must be caught first\n"
r'\1 msg = "\\n".join(\n'
r"\1 [\n"
r"\1 type(e).__name__,\n"
r"\1 str(e),\n"
r'\1 f"method={method}",\n'
r'\1 f"url={url}",\n'
r'\1 f"timeout={_request_timeout}",\n'
r"\1 ]\n"
r"\1 )\n"
r"\1 raise RestConnectionError(status=0, reason=msg) from e\n"
r"\1except (\n"
r"\1 urllib3.exceptions.ConnectTimeoutError,\n"
r"\1 urllib3.exceptions.ReadTimeoutError,\n"
r"\1) as e:\n"
r'\1 msg = "\\n".join(\n'
r"\1 [\n"
r"\1 type(e).__name__,\n"
r"\1 str(e),\n"
r'\1 f"method={method}",\n'
r'\1 f"url={url}",\n'
r'\1 f"timeout={_request_timeout}",\n'
r"\1 ]\n"
r"\1 )\n"
r"\1 raise RestTimeoutError(status=0, reason=msg) from e\n"
r"\1except urllib3.exceptions.ProtocolError as e:\n"
r'\1 msg = "\\n".join(\n'
r"\1 [\n"
r"\1 type(e).__name__,\n"
r"\1 str(e),\n"
r'\1 f"method={method}",\n'
r'\1 f"url={url}",\n'
r'\1 f"timeout={_request_timeout}",\n'
r"\1 ]\n"
r"\1 )\n"
r"\1 raise RestProtocolError(status=0, reason=msg) from e\n"
)
# Try expanded pattern first. Relevant if previously patched with ApiException
modified = re.sub(pattern_expanded, replacement, content)
if modified != content:
return modified
# Otherwise try original pattern
return re.sub(pattern_original, replacement, content)
def apply_patches_to_matching_files(
root: str, glob: str, patch_funcs: list[Callable[[str], str]]
) -> None:
for file_path in Path(root).rglob(glob):
atomically_patch_file(str(file_path), patch_funcs)
def patch_api_client_datetime_format_on_post(content: str) -> str:
content = prepend_import(content, "from hatchet_sdk.logger import logger")
pattern = r"([ \t]*)elif isinstance\(obj, \(datetime\.datetime, datetime\.date\)\):\s*\n\1[ \t]*return obj\.isoformat\(\)"
replacement = (
r"\1## IMPORTANT: Checking `datetime` must come before `date` since `datetime` is a subclass of `date`\n"
r"\1elif isinstance(obj, datetime.datetime):\n"
r"\1 if not obj.tzinfo:\n"
r"\1 current_tz = (datetime.datetime.now(datetime.timezone(datetime.timedelta(0))).astimezone().tzinfo or datetime.timezone.utc)\n"
r'\1 logger.warning(f"timezone-naive datetime found. assuming {current_tz}.")\n'
r"\1 obj = obj.replace(tzinfo=current_tz)\n\n"
r"\1 return obj.isoformat()\n"
r"\1elif isinstance(obj, datetime.date):\n"
r"\1 return obj.isoformat()"
)
return apply_patch(content, pattern, replacement)
def patch_workflow_run_metrics_counts_return_type(content: str) -> str:
content = prepend_import(
content,
"from hatchet_sdk.clients.rest.models.workflow_runs_metrics_counts import WorkflowRunsMetricsCounts",
)
pattern = r"([ \t]*)counts: Optional\[Dict\[str, Any\]\] = None"
replacement = r"\1counts: Optional[WorkflowRunsMetricsCounts] = None"
return apply_patch(content, pattern, replacement)
if __name__ == "__main__":
atomically_patch_file(
"hatchet_sdk/clients/rest/api_client.py",
[patch_api_client_datetime_format_on_post],
)
atomically_patch_file(
"hatchet_sdk/clients/rest/models/workflow_runs_metrics.py",
[patch_workflow_run_metrics_counts_return_type],
)
atomically_patch_file(
"hatchet_sdk/clients/rest/exceptions.py",
[patch_rest_transport_exceptions],
)
atomically_patch_file(
"hatchet_sdk/clients/rest/rest.py",
[patch_rest_imports, patch_rest_error_diagnostics],
)
grpc_patches: list[Callable[[str], str]] = [
patch_contract_import_paths,
patch_grpc_dispatcher_import,
patch_grpc_events_import,
patch_grpc_workflows_import,
patch_grpc_init_signature,
]
pb2_patches: list[Callable[[str], str]] = [
patch_contract_import_paths,
]
apply_patches_to_matching_files("hatchet_sdk/contracts", "*_grpc.py", grpc_patches)
apply_patches_to_matching_files("hatchet_sdk/contracts", "*_pb2.py", pb2_patches)
apply_patches_to_matching_files("hatchet_sdk/contracts", "*_pb2.pyi", pb2_patches)