feat(sdk): tenacity rest transport retry config (#3003)

* feat(python-sdk): add opt-in retries for REST transport errors (GET/DELETE)

* feat(python-sdk): make REST transport retries configurable via TenacityConfig

* docs(python-sdk): clarify transport retry methods exclude mutating verbs by default

* refactor(sdk): avoid parsing HTTP method from REST transport error message

- add http_method field to RestTransportError
- populate http_method when translating urllib3 transport exceptions
- use http_method for transport retry gating in tenacity utils
- update unit tests to cover the new structured method propagation

* fix(sdk): move REST transport http_method changes into apply_patches

* chore(python-sdk): bump version to 1.26.2 and update changelog

* refactor(python-sdk): type retry_transport_methods as HTTPMethod enum

* refactor(python-sdk): type retry_transport_methods as HTTPMethod enum

* fix(sdk): type rest transport http_method as HTTPMethod

- Update apply_patches to emit HTTPMethod typed http_method in generated REST transport exceptions
- Normalize method values via method.upper() when constructing HTTPMethod
- Simplify tenacity transport retry check to compare enums directly
- Update transport retry tests to use HTTPMethod enums
This commit is contained in:
Trevor Wilson
2026-02-26 14:18:45 -07:00
committed by GitHub
parent deee6e213c
commit 0a9e0dab40
9 changed files with 488 additions and 19 deletions
+10
View File
@@ -5,6 +5,16 @@ All notable changes to Hatchet's Python SDK will be documented in this changelog
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [1.26.2] - 2026-02-26
### Added
- Adds `retry_transport_errors` and `retry_transport_methods` to `TenacityConfig` to optionally retry REST transport-level failures for configured HTTP methods (default: `GET`, `DELETE`). Default behavior is unchanged.
### Changed
- Uses a structured `http_method` on `RestTransportError` for determining retry eligibility.
## [1.26.1] - 2026-02-25
### Added
+206 -13
View File
@@ -135,17 +135,88 @@ def patch_rest_transport_exceptions(content: str) -> str:
"""Insert typed REST transport exception classes into exceptions.py.
Adds exception classes above render_path function, idempotently.
The RestTransportError class includes an http_method attribute for
tenacity retry logic to use without parsing the reason string.
"""
# Check if already patched
if "class RestTransportError" in content:
# Check if already patched with HTTPMethod enum support
if "Optional[HTTPMethod]" in content and "self.http_method" in content:
return content
content = prepend_import(content, "from hatchet_sdk.config import HTTPMethod")
if "class RestTransportError" in content:
# Pattern for simple classes with no __init__
pattern_simple = (
r"\nclass RestTransportError\(ApiException\):\s*"
r'"""Base exception for REST transport-level errors \(network, timeout, TLS\)\."""\s*'
r"pass\s*"
r"\nclass RestTimeoutError\(RestTransportError\):\s*"
r'"""Raised when a REST request times out \(connect or read timeout\)\."""\s*'
r"pass\s*"
r"\nclass RestConnectionError\(RestTransportError\):\s*"
r'"""Raised when a REST request fails to establish a connection\."""\s*'
r"pass\s*"
r"\nclass RestTLSError\(RestTransportError\):\s*"
r'"""Raised when a REST request fails due to SSL/TLS errors\."""\s*'
r"pass\s*"
r"\nclass RestProtocolError\(RestTransportError\):\s*"
r'"""Raised when a REST request fails due to protocol-level errors\."""\s*'
r"pass\s*"
)
content = re.sub(pattern_simple, "\n", content)
# Pattern for classes with __init__ using Optional[str] http_method
pattern_with_str_http_method = (
r"\nclass RestTransportError\(ApiException\):\s*"
r'"""Base exception for REST transport-level errors \(network, timeout, TLS\)\."""\s*'
r"def __init__\(\s*"
r"self,\s*"
r"status=None,\s*"
r"reason=None,\s*"
r"http_resp=None,\s*"
r"\*,\s*"
r"body: Optional\[str\] = None,\s*"
r"data: Optional\[Any\] = None,\s*"
r"http_method: Optional\[str\] = None,\s*"
r"\) -> None:\s*"
r"super\(\).__init__\(\s*"
r"status=status, reason=reason, http_resp=http_resp, body=body, data=data\s*"
r"\)\s*"
r"self\.http_method: Optional\[str\] = http_method\s*"
r"\nclass RestTimeoutError\(RestTransportError\):\s*"
r'"""Raised when a REST request times out \(connect or read timeout\)\."""\s*'
r"pass\s*"
r"\nclass RestConnectionError\(RestTransportError\):\s*"
r'"""Raised when a REST request fails to establish a connection\."""\s*'
r"pass\s*"
r"\nclass RestTLSError\(RestTransportError\):\s*"
r'"""Raised when a REST request fails due to SSL/TLS errors\."""\s*'
r"pass\s*"
r"\nclass RestProtocolError\(RestTransportError\):\s*"
r'"""Raised when a REST request fails due to protocol-level errors\."""\s*'
r"pass\s*"
)
content = re.sub(pattern_with_str_http_method, "\n", content)
new_exceptions = '''\
class RestTransportError(ApiException):
"""Base exception for REST transport-level errors (network, timeout, TLS)."""
pass
def __init__(
self,
status=None,
reason=None,
http_resp=None,
*,
body: Optional[str] = None,
data: Optional[Any] = None,
http_method: Optional[HTTPMethod] = None,
) -> None:
super().__init__(
status=status, reason=reason, http_resp=http_resp, body=body, data=data
)
self.http_method: Optional[HTTPMethod] = http_method
class RestTimeoutError(RestTransportError):
@@ -171,7 +242,6 @@ class RestProtocolError(RestTransportError):
pass
'''
# Insert before render_path function (match any arguments)
@@ -256,11 +326,12 @@ def patch_rest_imports(content: str) -> str:
def patch_rest_error_diagnostics(content: str) -> str:
"""Patch rest.py exception handlers to use typed exceptions.
"""Patch rest.py exception handlers to use typed exceptions with http_method.
Replaces the generic ApiException handler with typed exception handlers.
Handler ordering is critical: NewConnectionError must be caught before
ConnectTimeoutError because it inherits from ConnectTimeoutError in urllib3.
Each typed exception includes http_method=HTTPMethod(method) for retry logic.
"""
# This pattern matches either the original SSLError only handler or
# the previously patched multi-exception handler raising ApiException
@@ -291,11 +362,123 @@ def patch_rest_error_diagnostics(content: str) -> str:
r"^\1[ \t]*raise ApiException\(status=0, reason=msg\)\s*\n"
)
# Check if already using typed exceptions
if "raise RestTLSError" in content:
# Check if already using typed exceptions with HTTPMethod enum
if "http_method=HTTPMethod(method.upper())" in content:
return content
# Build typed replacement with proper handler ordering
content = prepend_import(content, "from hatchet_sdk.config import HTTPMethod")
# Pattern for typed exceptions without http_method
pattern_typed_no_http_method = (
r"(?ms)^([ \t]*)except urllib3\.exceptions\.SSLError as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestTLSError\(status=0, reason=msg\) from e\s*\n"
r"^\1except \(\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.MaxRetryError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.NewConnectionError,\s*\n"
r"^\1\) as e:\s*\n"
r"^\1[ \t]*# NewConnectionError inherits from ConnectTimeoutError, so must be caught first\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestConnectionError\(status=0, reason=msg\) from e\s*\n"
r"^\1except \(\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ConnectTimeoutError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ReadTimeoutError,\s*\n"
r"^\1\) as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestTimeoutError\(status=0, reason=msg\) from e\s*\n"
r"^\1except urllib3\.exceptions\.ProtocolError as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestProtocolError\(status=0, reason=msg\) from e\s*\n"
)
# Pattern for typed exceptions with string http_method (change to HTTPMethod enum)
pattern_typed_with_string_http_method = (
r"(?ms)^([ \t]*)except urllib3\.exceptions\.SSLError as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestTLSError\(status=0, reason=msg, http_method=method\) from e\s*\n"
r"^\1except \(\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.MaxRetryError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.NewConnectionError,\s*\n"
r"^\1\) as e:\s*\n"
r"^\1[ \t]*# NewConnectionError inherits from ConnectTimeoutError, so must be caught first\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestConnectionError\(status=0, reason=msg, http_method=method\) from e\s*\n"
r"^\1except \(\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ConnectTimeoutError,\s*\n"
r"^\1[ \t]*urllib3\.exceptions\.ReadTimeoutError,\s*\n"
r"^\1\) as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestTimeoutError\(status=0, reason=msg, http_method=method\) from e\s*\n"
r"^\1except urllib3\.exceptions\.ProtocolError as e:\s*\n"
r'^\1[ \t]*msg = "\\n"\.join\(\s*\n'
r"^\1[ \t]*\[\s*\n"
r"^\1[ \t]*type\(e\)\.__name__,\s*\n"
r"^\1[ \t]*str\(e\),\s*\n"
r'^\1[ \t]*f"method=\{method\}",\s*\n'
r'^\1[ \t]*f"url=\{url\}",\s*\n'
r'^\1[ \t]*f"timeout=\{_request_timeout\}",\s*\n'
r"^\1[ \t]*\]\s*\n"
r"^\1[ \t]*\)\s*\n"
r"^\1[ \t]*raise RestProtocolError\(status=0, reason=msg, http_method=method\) from e\s*\n"
)
# Build typed replacement with proper handler ordering and http_method
# NewConnectionError inherits from ConnectTimeoutError, so must be caught first
replacement = (
r"\1except urllib3.exceptions.SSLError as e:\n"
@@ -308,7 +491,7 @@ def patch_rest_error_diagnostics(content: str) -> str:
r'\1 f"timeout={_request_timeout}",\n'
r"\1 ]\n"
r"\1 )\n"
r"\1 raise RestTLSError(status=0, reason=msg) from e\n"
r"\1 raise RestTLSError(status=0, reason=msg, http_method=HTTPMethod(method.upper())) from e\n"
r"\1except (\n"
r"\1 urllib3.exceptions.MaxRetryError,\n"
r"\1 urllib3.exceptions.NewConnectionError,\n"
@@ -323,7 +506,7 @@ def patch_rest_error_diagnostics(content: str) -> str:
r'\1 f"timeout={_request_timeout}",\n'
r"\1 ]\n"
r"\1 )\n"
r"\1 raise RestConnectionError(status=0, reason=msg) from e\n"
r"\1 raise RestConnectionError(status=0, reason=msg, http_method=HTTPMethod(method.upper())) from e\n"
r"\1except (\n"
r"\1 urllib3.exceptions.ConnectTimeoutError,\n"
r"\1 urllib3.exceptions.ReadTimeoutError,\n"
@@ -337,7 +520,7 @@ def patch_rest_error_diagnostics(content: str) -> str:
r'\1 f"timeout={_request_timeout}",\n'
r"\1 ]\n"
r"\1 )\n"
r"\1 raise RestTimeoutError(status=0, reason=msg) from e\n"
r"\1 raise RestTimeoutError(status=0, reason=msg, http_method=HTTPMethod(method.upper())) from e\n"
r"\1except urllib3.exceptions.ProtocolError as e:\n"
r'\1 msg = "\\n".join(\n'
r"\1 [\n"
@@ -348,10 +531,20 @@ def patch_rest_error_diagnostics(content: str) -> str:
r'\1 f"timeout={_request_timeout}",\n'
r"\1 ]\n"
r"\1 )\n"
r"\1 raise RestProtocolError(status=0, reason=msg) from e\n"
r"\1 raise RestProtocolError(status=0, reason=msg, http_method=HTTPMethod(method.upper())) from e\n"
)
# Try expanded pattern first. Relevant if previously patched with ApiException
# Try pattern for typed exceptions with string http_method (change to HTTPMethod enum)
modified = re.sub(pattern_typed_with_string_http_method, replacement, content)
if modified != content:
return modified
# Try pattern for typed exceptions without http_method
modified = re.sub(pattern_typed_no_http_method, replacement, content)
if modified != content:
return modified
# Try expanded pattern. Relevant if previously patched with ApiException
modified = re.sub(pattern_expanded, replacement, content)
if modified != content:
return modified
@@ -11,6 +11,7 @@ Generated by OpenAPI Generator (https://openapi-generator.tech)
Do not edit the class manually.
""" # noqa: E501
from hatchet_sdk.config import HTTPMethod
from typing import Any, Optional
from typing_extensions import Self
@@ -221,7 +222,20 @@ class TooManyRequestsException(ApiException):
class RestTransportError(ApiException):
"""Base exception for REST transport-level errors (network, timeout, TLS)."""
pass
def __init__(
self,
status=None,
reason=None,
http_resp=None,
*,
body: Optional[str] = None,
data: Optional[Any] = None,
http_method: Optional[HTTPMethod] = None,
) -> None:
super().__init__(
status=status, reason=reason, http_resp=http_resp, body=body, data=data
)
self.http_method: Optional[HTTPMethod] = http_method
class RestTimeoutError(RestTransportError):
+13 -4
View File
@@ -11,6 +11,7 @@ Generated by OpenAPI Generator (https://openapi-generator.tech)
Do not edit the class manually.
""" # noqa: E501
from hatchet_sdk.config import HTTPMethod
import io
import json
import re
@@ -256,7 +257,9 @@ class RESTClientObject:
f"timeout={_request_timeout}",
]
)
raise RestTLSError(status=0, reason=msg) from e
raise RestTLSError(
status=0, reason=msg, http_method=HTTPMethod(method.upper())
) from e
except (
urllib3.exceptions.MaxRetryError,
urllib3.exceptions.NewConnectionError,
@@ -271,7 +274,9 @@ class RESTClientObject:
f"timeout={_request_timeout}",
]
)
raise RestConnectionError(status=0, reason=msg) from e
raise RestConnectionError(
status=0, reason=msg, http_method=HTTPMethod(method.upper())
) from e
except (
urllib3.exceptions.ConnectTimeoutError,
urllib3.exceptions.ReadTimeoutError,
@@ -285,7 +290,9 @@ class RESTClientObject:
f"timeout={_request_timeout}",
]
)
raise RestTimeoutError(status=0, reason=msg) from e
raise RestTimeoutError(
status=0, reason=msg, http_method=HTTPMethod(method.upper())
) from e
except urllib3.exceptions.ProtocolError as e:
msg = "\n".join(
[
@@ -296,5 +303,7 @@ class RESTClientObject:
f"timeout={_request_timeout}",
]
)
raise RestProtocolError(status=0, reason=msg) from e
raise RestProtocolError(
status=0, reason=msg, http_method=HTTPMethod(method.upper())
) from e
return RESTResponse(r)
@@ -8,6 +8,7 @@ import tenacity
from hatchet_sdk.clients.rest.exceptions import (
NotFoundException,
RestTransportError,
ServiceException,
TooManyRequestsException,
)
@@ -54,6 +55,7 @@ def tenacity_should_retry(
if isinstance(ex, TooManyRequestsException):
return bool(config and config.retry_429)
# gRPC errors: retry most, except specific permanent failure codes
if isinstance(ex, grpc.aio.AioRpcError | grpc.RpcError):
return ex.code() not in [
grpc.StatusCode.UNIMPLEMENTED,
@@ -64,4 +66,12 @@ def tenacity_should_retry(
grpc.StatusCode.PERMISSION_DENIED,
]
# REST transport errors: opt-in retry for configured HTTP methods
if isinstance(ex, RestTransportError):
if config is not None and config.retry_transport_errors:
method = ex.http_method
if method is not None:
return method in config.retry_transport_methods
return False
return False
+19
View File
@@ -1,5 +1,6 @@
import json
from datetime import timedelta
from enum import Enum
from logging import Logger, getLogger
from typing import overload
@@ -87,6 +88,16 @@ class OpenTelemetryConfig(BaseSettings):
include_task_name_in_start_step_run_span_name: bool = False
class HTTPMethod(str, Enum):
GET = "GET"
DELETE = "DELETE"
POST = "POST"
PUT = "PUT"
PATCH = "PATCH"
HEAD = "HEAD"
OPTIONS = "OPTIONS"
class TenacityConfig(BaseSettings):
model_config = create_settings_config(
env_prefix="HATCHET_CLIENT_TENACITY_",
@@ -98,6 +109,14 @@ class TenacityConfig(BaseSettings):
default=False,
description="Enable retries for HTTP 429 Too Many Requests responses. Default: off.",
)
retry_transport_errors: bool = Field(
default=False,
description="Enable retries for REST transport errors (timeout, connection, TLS). Default: off.",
)
retry_transport_methods: list[HTTPMethod] = Field(
default_factory=lambda: [HTTPMethod.GET, HTTPMethod.DELETE],
description="HTTP methods to retry on transport errors when retry_transport_errors is enabled; excludes POST/PUT/PATCH by default due to idempotency concerns.",
)
DEFAULT_HOST_PORT = "localhost:7070"
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "hatchet-sdk"
version = "1.26.1"
version = "1.26.2"
description = "This is the official Python SDK for Hatchet, a distributed, fault-tolerant task queue. The SDK allows you to easily integrate Hatchet's task scheduling and workflow orchestration capabilities into your Python applications."
authors = [
"Alexander Belanger <alexander@hatchet.run>",
@@ -384,3 +384,84 @@ def test_diagnostics__reason_handles_none_timeout(
reason = exc_info.value.reason
assert "timeout=None" in reason
# --- http_method attribute tests ---
def test_http_method__tls_error_has_http_method(
rest_client: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
def mock_request(*args: Any, **kwargs: Any) -> NoReturn:
raise urllib3.exceptions.SSLError("SSL failed")
monkeypatch.setattr(rest_client.pool_manager, "request", mock_request)
with pytest.raises(RestTLSError) as exc_info:
rest_client.request(
method="POST",
url="https://example.com/api",
headers={},
)
assert exc_info.value.http_method == "POST"
def test_http_method__timeout_error_has_http_method(
rest_client: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
def mock_request(*args: Any, **kwargs: Any) -> NoReturn:
raise urllib3.exceptions.ConnectTimeoutError(None, "url", "timeout")
monkeypatch.setattr(rest_client.pool_manager, "request", mock_request)
with pytest.raises(RestTimeoutError) as exc_info:
rest_client.request(
method="GET",
url="http://localhost/test",
headers={},
)
assert exc_info.value.http_method == "GET"
def test_http_method__connection_error_has_http_method(
rest_client: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
def mock_request(*args: Any, **kwargs: Any) -> NoReturn:
raise urllib3.exceptions.NewConnectionError(
cast(Any, None), "connection failed"
)
monkeypatch.setattr(rest_client.pool_manager, "request", mock_request)
with pytest.raises(RestConnectionError) as exc_info:
rest_client.request(
method="DELETE",
url="http://localhost/test",
headers={},
)
assert exc_info.value.http_method == "DELETE"
def test_http_method__protocol_error_has_http_method(
rest_client: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
def mock_request(*args: Any, **kwargs: Any) -> NoReturn:
raise urllib3.exceptions.ProtocolError("protocol error")
monkeypatch.setattr(rest_client.pool_manager, "request", mock_request)
with pytest.raises(RestProtocolError) as exc_info:
rest_client.request(
method="PUT",
url="http://localhost/test",
headers={},
)
assert exc_info.value.http_method == "PUT"
@@ -0,0 +1,133 @@
"""Unit tests for tenacity transport error retry behavior.
Tests verify:
1. Default behavior: RestTransportError is NOT retried (even for GET)
2. Opt-in behavior: RestTransportError retried for configured methods only
3. Existing HTTP error retry behavior unchanged
4. HTTP method is read from exception's http_method attribute
"""
import pytest
from hatchet_sdk.clients.rest.exceptions import (
NotFoundException,
RestTimeoutError,
RestTransportError,
ServiceException,
)
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_should_retry
from hatchet_sdk.config import TenacityConfig, HTTPMethod
# --- Default behavior tests (transport errors NOT retried) ---
@pytest.mark.parametrize(
"exc_class",
[RestTransportError, RestTimeoutError],
ids=["base-class", "subclass"],
)
def test_default__transport_errors_not_retried(exc_class: type) -> None:
"""By default, RestTransportError and subclasses should not be retried."""
exc = exc_class(status=0, reason="timeout", http_method=HTTPMethod.GET)
config = TenacityConfig()
assert tenacity_should_retry(exc, config) is False
# --- Opt-in behavior tests (transport errors retried for allowed methods) ---
@pytest.mark.parametrize(
"method", [HTTPMethod.GET, HTTPMethod.DELETE], ids=["get", "delete"]
)
def test_optin__idempotent_methods_retried(method: HTTPMethod) -> None:
"""When enabled, GET and DELETE requests with transport errors should be retried."""
exc = RestTimeoutError(status=0, reason="timeout", http_method=method)
config = TenacityConfig(retry_transport_errors=True)
assert tenacity_should_retry(exc, config) is True
@pytest.mark.parametrize(
"method",
[HTTPMethod.POST, HTTPMethod.PUT, HTTPMethod.PATCH],
ids=["post", "put", "patch"],
)
def test_optin__non_idempotent_methods_not_retried(method: HTTPMethod) -> None:
"""Non-idempotent requests should not be retried even when transport retry is enabled."""
exc = RestTimeoutError(status=0, reason="timeout", http_method=method)
config = TenacityConfig(retry_transport_errors=True)
assert tenacity_should_retry(exc, config) is False
def test_optin__custom_methods_list() -> None:
"""Custom retry_transport_methods should be honored."""
exc = RestTimeoutError(status=0, reason="timeout", http_method=HTTPMethod.POST)
config = TenacityConfig(
retry_transport_errors=True,
retry_transport_methods=[HTTPMethod.POST],
)
assert tenacity_should_retry(exc, config) is True
def test_optin__custom_methods_excludes_default() -> None:
"""Custom retry_transport_methods can exclude default methods like GET."""
exc = RestTimeoutError(status=0, reason="timeout", http_method=HTTPMethod.GET)
config = TenacityConfig(
retry_transport_errors=True,
retry_transport_methods=[HTTPMethod.DELETE],
)
assert tenacity_should_retry(exc, config) is False
# --- Regression tests: existing HTTP error retry behavior unchanged ---
@pytest.mark.parametrize(
("exc", "desc"),
[
(ServiceException(status=500, reason="Internal Server Error"), "5xx"),
(NotFoundException(status=404, reason="Not Found"), "404"),
],
ids=["service-exception", "not-found"],
)
def test_regression__http_errors_still_retried(exc: Exception, desc: str) -> None:
"""ServiceException (5xx) and NotFoundException (404) should still be retried."""
config = TenacityConfig()
assert tenacity_should_retry(exc, config) is True
def test_regression__backward_compat_no_config() -> None:
"""ServiceException should be retried even without config (backward compat)."""
exc = ServiceException(status=500, reason="Internal Server Error")
assert tenacity_should_retry(exc) is True
# --- Edge cases for retry behavior ---
def test_edge__no_http_method_not_retried() -> None:
"""If http_method is None, should not retry even with retry_transport_errors=True."""
exc = RestTimeoutError(status=0, reason="timeout", http_method=None)
config = TenacityConfig(retry_transport_errors=True)
assert tenacity_should_retry(exc, config) is False
def test_edge__enum_method_matching() -> None:
"""Method matching uses HTTPMethod enum values directly."""
exc = RestTimeoutError(status=0, reason="timeout", http_method=HTTPMethod.GET)
config = TenacityConfig(retry_transport_errors=True)
assert tenacity_should_retry(exc, config) is True
# --- Config defaults tests ---
def test_config__default_retry_transport_errors_is_false() -> None:
"""retry_transport_errors should default to False."""
config = TenacityConfig()
assert config.retry_transport_errors is False
def test_config__default_retry_transport_methods() -> None:
"""retry_transport_methods should default to GET and DELETE."""
config = TenacityConfig()
assert set(config.retry_transport_methods) == {HTTPMethod.GET, HTTPMethod.DELETE}