Files
hatchet/sdks/python/apply_patches.py
matt 43e7466d0f Feat: Durable Execution Revamp (#2954)
* Feat: Durable Execution Revamp

---------

Co-authored-by: Gabe Ruttner <gabriel.ruttner@gmail.com>
2026-03-16 11:24:29 -04:00

670 lines
26 KiB
Python

import re
from collections.abc import Callable
from copy import deepcopy
from pathlib import Path
def prepend_import(content: str, import_statement: str) -> str:
if import_statement in content:
return content
future_import_pattern = r"^from __future__ import [^\n]+\n"
future_imports = re.findall(future_import_pattern, content, re.MULTILINE)
content = re.sub(future_import_pattern, "", content, flags=re.MULTILINE)
match = re.search(r"^import\s+|^from\s+", content, re.MULTILINE)
insert_position = match.start() if match else 0
future_block = "".join(future_imports)
return (
content[:insert_position]
+ future_block
+ import_statement
+ "\n"
+ content[insert_position:]
)
def apply_patch(content: str, pattern: str, replacement: str) -> str:
return re.sub(pattern, replacement, content)
def atomically_patch_file(
file_path: str, patch_funcs: list[Callable[[str], str]]
) -> None:
path = Path(file_path)
original = path.read_text()
modified = deepcopy(original)
try:
for func in patch_funcs:
modified = func(modified)
except Exception as e:
print(f"Error patching {file_path}: {e}")
return
if modified != original:
path.write_text(modified)
print(f"Patched {file_path}")
else:
print(f"No changes made to {file_path}")
def patch_contract_import_paths(content: str) -> str:
content = apply_patch(content, r"\bfrom v1\b", "from hatchet_sdk.contracts.v1")
content = apply_patch(
content, r"\bfrom workflows\b", "from hatchet_sdk.contracts.workflows"
)
return content
def patch_grpc_dispatcher_import(content: str) -> str:
return apply_patch(
content,
r"\bimport dispatcher_pb2 as dispatcher__pb2\b",
"from hatchet_sdk.contracts import dispatcher_pb2 as dispatcher__pb2",
)
def patch_grpc_events_import(content: str) -> str:
return apply_patch(
content,
r"\bimport events_pb2 as events__pb2\b",
"from hatchet_sdk.contracts import events_pb2 as events__pb2",
)
def patch_grpc_workflows_import(content: str) -> str:
return apply_patch(
content,
r"\bimport workflows_pb2 as workflows__pb2\b",
"from hatchet_sdk.contracts import workflows_pb2 as workflows__pb2",
)
def patch_grpc_init_signature(content: str) -> str:
return apply_patch(
content,
r"def __init__\(self, channel\):",
"def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:",
)
def patch_rest_429_exception(content: str) -> str:
"""Add TooManyRequestsException with 429 mapping to generated exceptions.py."""
# Insert class definition once, before RestTransportError.
if "class TooManyRequestsException" not in content:
new_class = (
"class TooManyRequestsException(ApiException):\n"
' """Exception for HTTP 429 Too Many Requests."""\n'
" pass\n"
)
pattern = r"(?m)^class RestTransportError\b"
content, n = re.subn(
pattern,
new_class + "\n\nclass RestTransportError",
content,
count=1,
)
if n != 1:
raise ValueError(
"patch_rest_429_exception: expected 'class RestTransportError' anchor not found"
)
# insert mapping once, before the 5xx check.
if "http_resp.status == 429" not in content:
pattern = r"(?m)^(?P<indent>[ \t]*)if 500 <= http_resp\.status <= 599:"
def _insert_429(m: re.Match[str]) -> str:
indent = m.group("indent")
return (
f"{indent}if http_resp.status == 429:\n"
f"{indent} raise TooManyRequestsException(http_resp=http_resp, body=body, data=data)\n"
f"\n"
f"{indent}if 500 <= http_resp.status <= 599:"
)
content, n = re.subn(pattern, _insert_429, content, count=1)
if n != 1:
raise ValueError(
"patch_rest_429_exception: expected 5xx mapping anchor not found"
)
return content
def patch_rest_transport_exceptions(content: str) -> str:
"""Insert typed REST transport exception classes into exceptions.py.
Adds exception classes above render_path function, idempotently.
The RestTransportError class includes an http_method attribute for
tenacity retry logic to use without parsing the reason string.
"""
# Check if already patched with HTTPMethod enum support
if "Optional[HTTPMethod]" in content and "self.http_method" in content:
return content
content = prepend_import(content, "from hatchet_sdk.config import HTTPMethod")
if "class RestTransportError" in content:
# Pattern for simple classes with no __init__
pattern_simple = (
r"\nclass RestTransportError\(ApiException\):\s*"
r'"""Base exception for REST transport-level errors \(network, timeout, TLS\)\."""\s*'
r"pass\s*"
r"\nclass RestTimeoutError\(RestTransportError\):\s*"
r'"""Raised when a REST request times out \(connect or read timeout\)\."""\s*'
r"pass\s*"
r"\nclass RestConnectionError\(RestTransportError\):\s*"
r'"""Raised when a REST request fails to establish a connection\."""\s*'
r"pass\s*"
r"\nclass RestTLSError\(RestTransportError\):\s*"
r'"""Raised when a REST request fails due to SSL/TLS errors\."""\s*'
r"pass\s*"
r"\nclass RestProtocolError\(RestTransportError\):\s*"
r'"""Raised when a REST request fails due to protocol-level errors\."""\s*'
r"pass\s*"
)
content = re.sub(pattern_simple, "\n", content)
# Pattern for classes with __init__ using Optional[str] http_method
pattern_with_str_http_method = (
r"\nclass RestTransportError\(ApiException\):\s*"
r'"""Base exception for REST transport-level errors \(network, timeout, TLS\)\."""\s*'
r"def __init__\(\s*"
r"self,\s*"
r"status=None,\s*"
r"reason=None,\s*"
r"http_resp=None,\s*"
r"\*,\s*"
r"body: Optional\[str\] = None,\s*"
r"data: Optional\[Any\] = None,\s*"
r"http_method: Optional\[str\] = None,\s*"
r"\) -> None:\s*"
r"super\(\).__init__\(\s*"
r"status=status, reason=reason, http_resp=http_resp, body=body, data=data\s*"
r"\)\s*"
r"self\.http_method: Optional\[str\] = http_method\s*"
r"\nclass RestTimeoutError\(RestTransportError\):\s*"
r'"""Raised when a REST request times out \(connect or read timeout\)\."""\s*'
r"pass\s*"
r"\nclass RestConnectionError\(RestTransportError\):\s*"
r'"""Raised when a REST request fails to establish a connection\."""\s*'
r"pass\s*"
r"\nclass RestTLSError\(RestTransportError\):\s*"
r'"""Raised when a REST request fails due to SSL/TLS errors\."""\s*'
r"pass\s*"
r"\nclass RestProtocolError\(RestTransportError\):\s*"
r'"""Raised when a REST request fails due to protocol-level errors\."""\s*'
r"pass\s*"
)
content = re.sub(pattern_with_str_http_method, "\n", content)
new_exceptions = '''\
class RestTransportError(ApiException):
"""Base exception for REST transport-level errors (network, timeout, TLS)."""
def __init__(
self,
status=None,
reason=None,
http_resp=None,
*,
body: Optional[str] = None,
data: Optional[Any] = None,
http_method: Optional[HTTPMethod] = None,
) -> None:
super().__init__(
status=status, reason=reason, http_resp=http_resp, body=body, data=data
)
self.http_method: Optional[HTTPMethod] = http_method
class RestTimeoutError(RestTransportError):
"""Raised when a REST request times out (connect or read timeout)."""
pass
class RestConnectionError(RestTransportError):
"""Raised when a REST request fails to establish a connection."""
pass
class RestTLSError(RestTransportError):
"""Raised when a REST request fails due to SSL/TLS errors."""
pass
class RestProtocolError(RestTransportError):
"""Raised when a REST request fails due to protocol-level errors."""
pass
'''
# Insert before render_path function (match any arguments)
pattern = r"(\ndef render_path\([^)]*\):)"
replacement = new_exceptions + r"\1"
return re.sub(pattern, replacement, content)
def patch_rest_imports(content: str) -> str:
"""Update rest.py imports to include typed transport exceptions.
Handles both single-line and parenthesized import formats. Idempotent.
"""
# The exceptions we need to ensure are imported
required_exceptions = [
"RestConnectionError",
"RestProtocolError",
"RestTimeoutError",
"RestTLSError",
]
# Idempotency check: if RestTLSError is already imported from this module, do nothing.
if re.search(
r"(?m)^from\s+hatchet_sdk\.clients\.rest\.exceptions\s+import[^\n]*\bRestTLSError\b",
content,
):
return content
# Parenthesized import block includes RestTLSError
if re.search(
r"^from\s+hatchet_sdk\.clients\.rest\.exceptions\s+import\s*\(\s*.*?\bRestTLSError\b.*?\)\s*$",
content,
flags=re.MULTILINE | re.DOTALL,
):
return content
# The target import statement we want with trailing newline to preserve spacing
new_import = (
"from hatchet_sdk.clients.rest.exceptions import (\n"
" ApiException,\n"
" ApiValueError,\n"
" RestConnectionError,\n"
" RestProtocolError,\n"
" RestTimeoutError,\n"
" RestTLSError,\n"
")\n"
)
# Single line import
# Matches: from hatchet_sdk.clients.rest.exceptions import ApiException, ApiValueError
single_line_pattern = (
r"^from\s+hatchet_sdk\.clients\.rest\.exceptions\s+import\s+"
r"ApiException\s*,\s*ApiValueError\s*$"
)
modified = re.sub(single_line_pattern, new_import, content, flags=re.MULTILINE)
if modified != content:
return modified
# More flexible parenthesized import which matches any order, with or without trailing comma
# This handles cases where ApiException and ApiValueError might be in different orders
flexible_paren_pattern = (
r"^from\s+hatchet_sdk\.clients\.rest\.exceptions\s+import\s*\("
r"[^)]*?" # Non-greedy match of contents (only ApiException/ApiValueError expected)
r"\)"
)
# Only apply if the block contains just ApiException and/or ApiValueError (no Rest* yet)
match = re.search(flexible_paren_pattern, content, flags=re.MULTILINE | re.DOTALL)
if match:
block = match.group(0)
# Verify it only has ApiException/ApiValueError, not our new exceptions
if not any(exc in block for exc in required_exceptions):
if "ApiException" in block or "ApiValueError" in block:
modified = (
content[: match.start()] + new_import + content[match.end() :]
)
return modified
return content
def patch_rest_error_diagnostics(content: str) -> str:
"""Patch rest.py exception handlers to use typed exceptions with http_method.
Replaces the generic ApiException handler with typed exception handlers.
Handler ordering is critical: NewConnectionError must be caught before
ConnectTimeoutError because it inherits from ConnectTimeoutError in urllib3.
Each typed exception includes http_method=HTTPMethod(method) for retry logic.
"""
# This pattern matches either the original SSLError only handler or
# the previously patched multi-exception handler raising ApiException
pattern_original = (
r"(?ms)^([ \t]*)except urllib3\.exceptions\.SSLError as e:\s*\n"
r"^\1[ \t]*msg = \"\\n\"\.join\(\[type\(e\)\.__name__, str\(e\)\]\)\s*\n"
r"^\1[ \t]*raise ApiException\(status=0, reason=msg\)\s*\n"
)
pattern_expanded = (
r"(?ms)^([ \t]*)except \(\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.SSLError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ConnectTimeoutError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ReadTimeoutError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.MaxRetryError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.NewConnectionError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ProtocolError,\s*\n"
r"^\1\) as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise ApiException\(status=0, reason=msg\)\s*\n"
)
# Check if already using typed exceptions with HTTPMethod enum
if "http_method=HTTPMethod(method.upper())" in content:
return content
content = prepend_import(content, "from hatchet_sdk.config import HTTPMethod")
# Pattern for typed exceptions without http_method
pattern_typed_no_http_method = (
r"(?ms)^([ \t]*)except urllib3\.exceptions\.SSLError as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestTLSError\(status=0, reason=msg\) from e\s*\n"
r"^\1except \(\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.MaxRetryError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.NewConnectionError,\s*\n"
r"^\1\) as e:\s*\n"
r"^\1[ \t]*# NewConnectionError inherits from ConnectTimeoutError, so must be caught first\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestConnectionError\(status=0, reason=msg\) from e\s*\n"
r"^\1except \(\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ConnectTimeoutError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ReadTimeoutError,\s*\n"
r"^\1\) as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestTimeoutError\(status=0, reason=msg\) from e\s*\n"
r"^\1except urllib3\.exceptions\.ProtocolError as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestProtocolError\(status=0, reason=msg\) from e\s*\n"
)
# Pattern for typed exceptions with string http_method (change to HTTPMethod enum)
pattern_typed_with_string_http_method = (
r"(?ms)^([ \t]*)except urllib3\.exceptions\.SSLError as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestTLSError\(status=0, reason=msg, http_method=method\) from e\s*\n"
r"^\1except \(\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.MaxRetryError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.NewConnectionError,\s*\n"
r"^\1\) as e:\s*\n"
r"^\1[ \t]*# NewConnectionError inherits from ConnectTimeoutError, so must be caught first\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestConnectionError\(status=0, reason=msg, http_method=method\) from e\s*\n"
r"^\1except \(\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ConnectTimeoutError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ReadTimeoutError,\s*\n"
r"^\1\) as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestTimeoutError\(status=0, reason=msg, http_method=method\) from e\s*\n"
r"^\1except urllib3\.exceptions\.ProtocolError as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestProtocolError\(status=0, reason=msg, http_method=method\) from e\s*\n"
)
# Build typed replacement with proper handler ordering and http_method
# NewConnectionError inherits from ConnectTimeoutError, so must be caught first
replacement = (
r"\1except urllib3.exceptions.SSLError as e:\n"
r'\1 msg = "\\n".join(\n'
r"\1 [\n"
r"\1 type(e).__name__,\n"
r"\1 str(e),\n"
r'\1 f"method={method}",\n'
r'\1 f"url={url}",\n'
r'\1 f"timeout={_request_timeout}",\n'
r"\1 ]\n"
r"\1 )\n"
r"\1 raise RestTLSError(status=0, reason=msg, http_method=HTTPMethod(method.upper())) from e\n"
r"\1except (\n"
r"\1 urllib3.exceptions.MaxRetryError,\n"
r"\1 urllib3.exceptions.NewConnectionError,\n"
r"\1) as e:\n"
r"\1 # NewConnectionError inherits from ConnectTimeoutError, so must be caught first\n"
r'\1 msg = "\\n".join(\n'
r"\1 [\n"
r"\1 type(e).__name__,\n"
r"\1 str(e),\n"
r'\1 f"method={method}",\n'
r'\1 f"url={url}",\n'
r'\1 f"timeout={_request_timeout}",\n'
r"\1 ]\n"
r"\1 )\n"
r"\1 raise RestConnectionError(status=0, reason=msg, http_method=HTTPMethod(method.upper())) from e\n"
r"\1except (\n"
r"\1 urllib3.exceptions.ConnectTimeoutError,\n"
r"\1 urllib3.exceptions.ReadTimeoutError,\n"
r"\1) as e:\n"
r'\1 msg = "\\n".join(\n'
r"\1 [\n"
r"\1 type(e).__name__,\n"
r"\1 str(e),\n"
r'\1 f"method={method}",\n'
r'\1 f"url={url}",\n'
r'\1 f"timeout={_request_timeout}",\n'
r"\1 ]\n"
r"\1 )\n"
r"\1 raise RestTimeoutError(status=0, reason=msg, http_method=HTTPMethod(method.upper())) from e\n"
r"\1except urllib3.exceptions.ProtocolError as e:\n"
r'\1 msg = "\\n".join(\n'
r"\1 [\n"
r"\1 type(e).__name__,\n"
r"\1 str(e),\n"
r'\1 f"method={method}",\n'
r'\1 f"url={url}",\n'
r'\1 f"timeout={_request_timeout}",\n'
r"\1 ]\n"
r"\1 )\n"
r"\1 raise RestProtocolError(status=0, reason=msg, http_method=HTTPMethod(method.upper())) from e\n"
)
# Try pattern for typed exceptions with string http_method (change to HTTPMethod enum)
modified = re.sub(pattern_typed_with_string_http_method, replacement, content)
if modified != content:
return modified
# Try pattern for typed exceptions without http_method
modified = re.sub(pattern_typed_no_http_method, replacement, content)
if modified != content:
return modified
# Try expanded pattern. Relevant if previously patched with ApiException
modified = re.sub(pattern_expanded, replacement, content)
if modified != content:
return modified
# Otherwise try original pattern
return re.sub(pattern_original, replacement, content)
def apply_patches_to_matching_files(
root: str, glob: str, patch_funcs: list[Callable[[str], str]]
) -> None:
for file_path in Path(root).rglob(glob):
atomically_patch_file(str(file_path), patch_funcs)
def patch_api_client_datetime_format_on_post(content: str) -> str:
content = prepend_import(content, "from hatchet_sdk.logger import logger")
pattern = r"([ \t]*)elif isinstance\(obj, \(datetime\.datetime, datetime\.date\)\):\s*\n\1[ \t]*return obj\.isoformat\(\)"
replacement = (
r"\1## IMPORTANT: Checking `datetime` must come before `date` since `datetime` is a subclass of `date`\n"
r"\1elif isinstance(obj, datetime.datetime):\n"
r"\1 if not obj.tzinfo:\n"
r"\1 current_tz = (datetime.datetime.now(datetime.timezone(datetime.timedelta(0))).astimezone().tzinfo or datetime.timezone.utc)\n"
r'\1 logger.warning(f"timezone-naive datetime found. assuming {current_tz}.")\n'
r"\1 obj = obj.replace(tzinfo=current_tz)\n\n"
r"\1 return obj.isoformat()\n"
r"\1elif isinstance(obj, datetime.date):\n"
r"\1 return obj.isoformat()"
)
return apply_patch(content, pattern, replacement)
def patch_workflow_run_metrics_counts_return_type(content: str) -> str:
content = prepend_import(
content,
"from hatchet_sdk.clients.rest.models.workflow_runs_metrics_counts import WorkflowRunsMetricsCounts",
)
pattern = r"([ \t]*)counts: Optional\[Dict\[str, Any\]\] = None"
replacement = r"\1counts: Optional[WorkflowRunsMetricsCounts] = None"
return apply_patch(content, pattern, replacement)
def patch_workflows_pb2_reexport(content: str) -> str:
"""Re-export types that moved to v1/shared/trigger.proto for backwards compatibility."""
reexport = (
"\n# Re-export for backwards compatibility\n"
"from hatchet_sdk.contracts.v1.shared.trigger_pb2 import "
"TriggerWorkflowRequest as TriggerWorkflowRequest # noqa: F401\n"
"from hatchet_sdk.contracts.v1.shared.trigger_pb2 import "
"DesiredWorkerLabels as DesiredWorkerLabels # noqa: F401\n"
"from hatchet_sdk.contracts.v1.shared.trigger_pb2 import "
"WorkerLabelComparator as WorkerLabelComparator # noqa: F401\n"
)
if "TriggerWorkflowRequest as TriggerWorkflowRequest" in content:
return content
return content + reexport
def patch_workflows_pyi_reexport(content: str) -> str:
"""Add type stubs for re-exported types."""
reexport = (
"\n# Re-export for backwards compatibility\n"
"from hatchet_sdk.contracts.v1.shared.trigger_pb2 import "
"TriggerWorkflowRequest as TriggerWorkflowRequest\n"
"from hatchet_sdk.contracts.v1.shared.trigger_pb2 import "
"DesiredWorkerLabels as DesiredWorkerLabels\n"
"from hatchet_sdk.contracts.v1.shared.trigger_pb2 import "
"WorkerLabelComparator as WorkerLabelComparator\n"
)
if "TriggerWorkflowRequest as TriggerWorkflowRequest" in content:
return content
return content + reexport
if __name__ == "__main__":
atomically_patch_file(
"hatchet_sdk/clients/rest/api_client.py",
[patch_api_client_datetime_format_on_post],
)
atomically_patch_file(
"hatchet_sdk/clients/rest/models/workflow_runs_metrics.py",
[patch_workflow_run_metrics_counts_return_type],
)
atomically_patch_file(
"hatchet_sdk/clients/rest/exceptions.py",
[patch_rest_transport_exceptions, patch_rest_429_exception],
)
atomically_patch_file(
"hatchet_sdk/clients/rest/rest.py",
[patch_rest_imports, patch_rest_error_diagnostics],
)
grpc_patches: list[Callable[[str], str]] = [
patch_contract_import_paths,
patch_grpc_dispatcher_import,
patch_grpc_events_import,
patch_grpc_workflows_import,
patch_grpc_init_signature,
]
pb2_patches: list[Callable[[str], str]] = [
patch_contract_import_paths,
]
apply_patches_to_matching_files("hatchet_sdk/contracts", "*_grpc.py", grpc_patches)
apply_patches_to_matching_files("hatchet_sdk/contracts", "*_pb2.py", pb2_patches)
apply_patches_to_matching_files("hatchet_sdk/contracts", "*_pb2.pyi", pb2_patches)
atomically_patch_file(
"hatchet_sdk/contracts/workflows_pb2.py",
[patch_workflows_pb2_reexport],
)
atomically_patch_file(
"hatchet_sdk/contracts/workflows_pb2.pyi",
[patch_workflows_pyi_reexport],
)