Revert: Cancellation token Python changes (#3061)

* revert: cancellation token changes

* fix: changelog

* fix: add note on yank
This commit is contained in:
matt
2026-02-19 12:41:43 -05:00
committed by GitHub
parent 1361047d11
commit 0cce1cfc04
20 changed files with 74 additions and 1516 deletions
+1 -20
View File
@@ -1,7 +1,7 @@
import asyncio
import time
from hatchet_sdk import CancelledError, Context, EmptyModel, Hatchet
from hatchet_sdk import Context, EmptyModel, Hatchet
hatchet = Hatchet(debug=True)
@@ -40,25 +40,6 @@ def check_flag(input: EmptyModel, ctx: Context) -> dict[str, str]:
# > Handling cancelled error
@cancellation_workflow.task()
async def my_task(input: EmptyModel, ctx: Context) -> dict[str, str]:
try:
await asyncio.sleep(10)
except CancelledError as e:
# Handle parent cancellation - i.e. perform cleanup, then re-raise
print(f"Parent Task cancelled: {e.reason}")
# Always re-raise CancelledError so Hatchet can properly handle the cancellation
raise
except Exception as e:
# This will NOT catch CancelledError
print(f"Other error: {e}")
raise
return {"error": "Task should have been cancelled"}
def main() -> None:
worker = hatchet.worker("cancellation-worker", workflows=[cancellation_workflow])
worker.start()
+1 -1
View File
@@ -10,7 +10,7 @@ def simple(input: EmptyModel, ctx: Context) -> dict[str, str]:
@hatchet.durable_task()
async def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]:
def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]:
return {"result": "Hello, world!"}
+2 -25
View File
@@ -22,38 +22,15 @@ When a task is canceled, Hatchet sends a cancellation signal to the task. The ta
/>
### CancelledError Exception
When a sync task is cancelled while waiting for a child workflow or during a cancellation-aware operation, a `CancelledError` exception is raised.
<Callout type="warning">
**Important:** `CancelledError` inherits from `BaseException`, not
`Exception`. This means it will **not** be caught by bare `except Exception:`
handlers. This is intentional and mirrors the behavior of Python's
`asyncio.CancelledError`.
</Callout>
<Snippet src={snippets.python.cancellation.worker.handling_cancelled_error} />
### Cancellation Reasons
The `CancelledError` includes a `reason` attribute that indicates why the cancellation occurred:
| Reason | Description |
| --------------------------------------- | --------------------------------------------------------------------- |
| `CancellationReason.USER_REQUESTED` | The user explicitly requested cancellation via `ctx.cancel()` |
| `CancellationReason.WORKFLOW_CANCELLED` | The workflow run was cancelled (e.g., via API or concurrency control) |
| `CancellationReason.PARENT_CANCELLED` | The parent workflow was cancelled while waiting for a child |
| `CancellationReason.TIMEOUT` | The operation timed out |
| `CancellationReason.UNKNOWN` | Unknown or unspecified reason |
</Tabs.Tab>
<Tabs.Tab title="Typescript">
<Snippet
src={snippets.typescript.cancellations.workflow.declaring_a_task}
/>
<Snippet
src={snippets.typescript.cancellations.workflow.abort_signal}
/>
</Tabs.Tab>
+8 -5
View File
@@ -5,12 +5,19 @@ All notable changes to Hatchet's Python SDK will be documented in this changelog
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [1.25.2] - 2026-02-19
### Fixed
- Reverts cancellation changes in 1.25.0 that introduced a regression
## [1.25.1] - 2026-02-17
### Fixed
- Fixes internal registration of durable slots
## [1.25.0] - 2026-02-17
## [1.25.0] - 2026-02-17 **YANKED ON 2/19/26**
### Added
@@ -47,28 +54,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Adds type-hinted `Task.output_validator` and `Task.output_validator_type` properties to support easier type-safety and match the patterns on `BaseWorkflow/Standalone`.
- Adds parameterized unit tests documenting current retry behavior of the Python SDKs tenacity retry predicate for REST and gRPC errors.
## [1.23.2] - 2026-02-11
### Changed
- Improves error handling for REST transport-level failures by raising typed exceptions for timeouts, connection, TLS, and protocol errors while preserving existing diagnostics.
## [1.23.1] - 2026-02-10
### Changed
- Fixes a bug introduced in v1.21.0 where the `BaseWorkflow.input_validator` class property became incorrectly typed. Now separate properties are available for the type adapter and the underlying type.
## [1.23.0] - 2026-02-05
### Internal Only
- Updated gRPC/REST contract field names to snake_case for consistency across SDKs.
## [1.22.16] - 2026-02-05
### Changed
+1 -21
View File
@@ -1,7 +1,7 @@
import asyncio
import time
from hatchet_sdk import CancelledError, Context, EmptyModel, Hatchet
from hatchet_sdk import Context, EmptyModel, Hatchet
hatchet = Hatchet(debug=True)
@@ -42,26 +42,6 @@ def check_flag(input: EmptyModel, ctx: Context) -> dict[str, str]:
# !!
# > Handling cancelled error
@cancellation_workflow.task()
async def my_task(input: EmptyModel, ctx: Context) -> dict[str, str]:
try:
await asyncio.sleep(10)
except CancelledError as e:
# Handle parent cancellation - i.e. perform cleanup, then re-raise
print(f"Parent Task cancelled: {e.reason}")
# Always re-raise CancelledError so Hatchet can properly handle the cancellation
raise
except Exception as e:
# This will NOT catch CancelledError
print(f"Other error: {e}")
raise
return {"error": "Task should have been cancelled"}
# !!
def main() -> None:
worker = hatchet.worker("cancellation-worker", workflows=[cancellation_workflow])
worker.start()
+1 -1
View File
@@ -10,7 +10,7 @@ def simple(input: EmptyModel, ctx: Context) -> dict[str, str]:
@hatchet.durable_task()
async def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]:
def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]:
return {"result": "Hello, world!"}
-6
View File
@@ -1,4 +1,3 @@
from hatchet_sdk.cancellation import CancellationToken
from hatchet_sdk.clients.admin import (
RunStatus,
ScheduleTriggerWorkflowOptions,
@@ -156,8 +155,6 @@ from hatchet_sdk.contracts.workflows_pb2 import (
WorkerLabelComparator,
)
from hatchet_sdk.exceptions import (
CancellationReason,
CancelledError,
DedupeViolationError,
FailedTaskRunExceptionGroup,
NonRetryableException,
@@ -197,9 +194,6 @@ __all__ = [
"CELEvaluationResult",
"CELFailure",
"CELSuccess",
"CancellationReason",
"CancellationToken",
"CancelledError",
"ClientConfig",
"ClientTLSConfig",
"ConcurrencyExpression",
-197
View File
@@ -1,197 +0,0 @@
"""Cancellation token for coordinating cancellation across async and sync operations."""
from __future__ import annotations
import asyncio
import threading
from collections.abc import Callable
from typing import TYPE_CHECKING
from hatchet_sdk.exceptions import CancellationReason
from hatchet_sdk.logger import logger
if TYPE_CHECKING:
pass
class CancellationToken:
"""
A token that can be used to signal cancellation across async and sync operations.
The token provides both asyncio and threading event primitives, allowing it to work
seamlessly in both async and sync code paths. Child workflow run IDs can be registered
with the token so they can be cancelled when the parent is cancelled.
Example:
```python
token = CancellationToken()
# In async code
await token.aio_wait() # Blocks until cancelled
# In sync code
token.wait(timeout=1.0) # Returns True if cancelled within timeout
# Check if cancelled
if token.is_cancelled:
raise CancelledError("Operation was cancelled")
# Trigger cancellation
token.cancel()
```
"""
def __init__(self) -> None:
self._cancelled = False
self._reason: CancellationReason | None = None
self._async_event: asyncio.Event | None = None
self._sync_event = threading.Event()
self._child_run_ids: list[str] = []
self._callbacks: list[Callable[[], None]] = []
self._lock = threading.Lock()
def _get_async_event(self) -> asyncio.Event:
"""Lazily create the asyncio event to avoid requiring an event loop at init time."""
if self._async_event is None:
self._async_event = asyncio.Event()
# If already cancelled, set the event
if self._cancelled:
self._async_event.set()
return self._async_event
def cancel(
self, reason: CancellationReason = CancellationReason.TOKEN_CANCELLED
) -> None:
"""
Trigger cancellation.
This will:
- Set the cancelled flag and reason
- Signal both async and sync events
- Invoke all registered callbacks
Args:
reason: The reason for cancellation.
"""
with self._lock:
if self._cancelled:
logger.debug(
f"CancellationToken: cancel() called but already cancelled, "
f"reason={self._reason.value if self._reason else 'none'}"
)
return
logger.debug(
f"CancellationToken: cancel() called, reason={reason.value}, "
f"{len(self._child_run_ids)} children registered"
)
self._cancelled = True
self._reason = reason
# Signal both event types
if self._async_event is not None:
self._async_event.set()
self._sync_event.set()
# Snapshot callbacks under the lock, invoke outside to avoid deadlocks
callbacks = list(self._callbacks)
for callback in callbacks:
try:
logger.debug(f"CancellationToken: invoking callback {callback}")
callback()
except Exception as e: # noqa: PERF203
logger.warning(f"CancellationToken: callback raised exception: {e}")
logger.debug(f"CancellationToken: cancel() complete, reason={reason.value}")
@property
def is_cancelled(self) -> bool:
"""Check if cancellation has been triggered."""
return self._cancelled
@property
def reason(self) -> CancellationReason | None:
"""Get the reason for cancellation, or None if not cancelled."""
return self._reason
async def aio_wait(self) -> None:
"""
Await until cancelled (for use in asyncio).
This will block until cancel() is called.
"""
await self._get_async_event().wait()
logger.debug(
f"CancellationToken: async wait completed (cancelled), "
f"reason={self._reason.value if self._reason else 'none'}"
)
def wait(self, timeout: float | None = None) -> bool:
"""
Block until cancelled (for use in sync code).
Args:
timeout: Maximum time to wait in seconds. None means wait forever.
Returns:
True if the token was cancelled (event was set), False if timeout expired.
"""
result = self._sync_event.wait(timeout)
if result:
logger.debug(
f"CancellationToken: sync wait interrupted by cancellation, "
f"reason={self._reason.value if self._reason else 'none'}"
)
return result
def register_child(self, run_id: str) -> None:
"""
Register a child workflow run ID with this token.
When the parent is cancelled, these child run IDs can be used to cancel
the child workflows as well.
Args:
run_id: The workflow run ID of the child workflow.
"""
with self._lock:
logger.debug(f"CancellationToken: registering child workflow {run_id}")
self._child_run_ids.append(run_id)
@property
def child_run_ids(self) -> list[str]:
"""The registered child workflow run IDs."""
return self._child_run_ids
def add_callback(self, callback: Callable[[], None]) -> None:
"""
Register a callback to be invoked when cancellation is triggered.
If the token is already cancelled, the callback will be invoked immediately.
Args:
callback: A callable that takes no arguments.
"""
with self._lock:
if self._cancelled:
invoke_now = True
else:
invoke_now = False
self._callbacks.append(callback)
if invoke_now:
logger.debug(
f"CancellationToken: invoking callback immediately (already cancelled): {callback}"
)
try:
callback()
except Exception as e:
logger.warning(f"CancellationToken: callback raised exception: {e}")
def __repr__(self) -> str:
return (
f"CancellationToken(cancelled={self._cancelled}, "
f"children={len(self._child_run_ids)}, callbacks={len(self._callbacks)})"
)
@@ -1,9 +1,7 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Generic, Literal, TypeVar
from typing import Generic, Literal, TypeVar
import grpc
import grpc.aio
@@ -16,10 +14,6 @@ from hatchet_sdk.clients.event_ts import (
from hatchet_sdk.config import ClientConfig
from hatchet_sdk.logger import logger
from hatchet_sdk.metadata import get_metadata
from hatchet_sdk.utils.cancellation import race_against_token
if TYPE_CHECKING:
from hatchet_sdk.cancellation import CancellationToken
DEFAULT_LISTENER_RETRY_INTERVAL = 3 # seconds
DEFAULT_LISTENER_RETRY_COUNT = 5
@@ -42,7 +36,7 @@ class Subscription(Generic[T]):
self.id = id
self.queue: asyncio.Queue[T | SentinelValue] = asyncio.Queue()
async def __aiter__(self) -> Subscription[T]:
async def __aiter__(self) -> "Subscription[T]":
return self
async def __anext__(self) -> T | SentinelValue:
@@ -205,17 +199,7 @@ class PooledListener(Generic[R, T, L], ABC):
del self.from_subscriptions[subscription_id]
del self.events[subscription_id]
async def subscribe(
self, id: str, cancellation_token: CancellationToken | None = None
) -> T:
"""
Subscribe to events for the given ID.
:param id: The ID to subscribe to (e.g., workflow run ID).
:param cancellation_token: Optional cancellation token to abort the subscription wait.
:return: The event received for this ID.
:raises asyncio.CancelledError: If the cancellation token is triggered or if externally cancelled.
"""
async def subscribe(self, id: str) -> T:
subscription_id: int | None = None
try:
@@ -237,17 +221,8 @@ class PooledListener(Generic[R, T, L], ABC):
if not self.listener_task or self.listener_task.done():
self.listener_task = asyncio.create_task(self._init_producer())
logger.debug(
f"PooledListener.subscribe: waiting for event on id={id}, "
f"subscription_id={subscription_id}, token={cancellation_token is not None}"
)
if cancellation_token:
result_task = asyncio.create_task(self.events[subscription_id].get())
return await race_against_token(result_task, cancellation_token)
return await self.events[subscription_id].get()
except asyncio.CancelledError:
logger.debug(f"PooledListener.subscribe: externally cancelled for id={id}")
raise
finally:
if subscription_id:
+1 -32
View File
@@ -52,7 +52,7 @@ class HealthcheckConfig(BaseSettings):
if isinstance(value, timedelta):
return value
if isinstance(value, (int, float)):
if isinstance(value, int | float):
return timedelta(seconds=float(value))
v = value.strip()
@@ -135,37 +135,6 @@ class ClientConfig(BaseSettings):
force_shutdown_on_shutdown_signal: bool = False
tenacity: TenacityConfig = TenacityConfig()
# Cancellation configuration
cancellation_grace_period: timedelta = Field(
default=timedelta(milliseconds=1000),
description="The maximum time to wait for a task to complete after cancellation is triggered before force-cancelling. Value is interpreted as seconds when provided as int/float.",
)
cancellation_warning_threshold: timedelta = Field(
default=timedelta(milliseconds=300),
description="If a task has not completed cancellation within this duration, a warning will be logged. Value is interpreted as seconds when provided as int/float.",
)
@field_validator(
"cancellation_grace_period", "cancellation_warning_threshold", mode="before"
)
@classmethod
def validate_cancellation_timedelta(
cls, value: timedelta | int | float | str
) -> timedelta:
"""Convert int/float/string to timedelta, interpreting as seconds."""
if isinstance(value, timedelta):
return value
if isinstance(value, (int, float)):
return timedelta(seconds=float(value))
v = value.strip()
# Allow a small convenience suffix, but keep "seconds" as the contract.
if v.endswith("s"):
v = v[:-1].strip()
return timedelta(seconds=float(v))
@model_validator(mode="after")
def validate_token_and_tenant(self) -> "ClientConfig":
if not self.token:
+13 -78
View File
@@ -4,7 +4,6 @@ from datetime import timedelta
from typing import TYPE_CHECKING, Any, cast
from warnings import warn
from hatchet_sdk.cancellation import CancellationToken
from hatchet_sdk.clients.admin import AdminClient
from hatchet_sdk.clients.dispatcher.dispatcher import ( # type: ignore[attr-defined]
Action,
@@ -22,10 +21,9 @@ from hatchet_sdk.conditions import (
flatten_conditions,
)
from hatchet_sdk.context.worker_context import WorkerContext
from hatchet_sdk.exceptions import CancellationReason, TaskRunError
from hatchet_sdk.exceptions import TaskRunError
from hatchet_sdk.features.runs import RunsClient
from hatchet_sdk.logger import logger
from hatchet_sdk.utils.cancellation import await_with_cancellation
from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr
from hatchet_sdk.utils.typing import JSONSerializableMapping, LogLevel
from hatchet_sdk.worker.runner.utils.capture_logs import AsyncLogSender, LogRecord
@@ -58,7 +56,7 @@ class Context:
self.action = action
self.step_run_id = action.step_run_id
self.cancellation_token = CancellationToken()
self.exit_flag = False
self.dispatcher_client = dispatcher_client
self.admin_client = admin_client
self.event_client = event_client
@@ -76,31 +74,6 @@ class Context:
self._workflow_name = workflow_name
self._task_name = task_name
@property
def exit_flag(self) -> bool:
"""
Check if the cancellation flag has been set.
This property is maintained for backwards compatibility.
Use `cancellation_token.is_cancelled` for new code.
:return: True if the task has been cancelled, False otherwise.
"""
return self.cancellation_token.is_cancelled
@exit_flag.setter
def exit_flag(self, value: bool) -> None:
"""
Set the cancellation flag.
This setter is maintained for backwards compatibility.
Setting to True will trigger the cancellation token.
:param value: True to trigger cancellation, False is a no-op.
"""
if value:
self.cancellation_token.cancel(CancellationReason.USER_REQUESTED)
def _increment_stream_index(self) -> int:
index = self.stream_index
self.stream_index += 1
@@ -196,25 +169,8 @@ class Context:
"""
return self.action.workflow_run_id
def _set_cancellation_flag(
self, reason: CancellationReason = CancellationReason.WORKFLOW_CANCELLED
) -> None:
"""
Internal method to trigger cancellation.
This triggers the cancellation token, which will:
- Signal all waiters (async and sync)
- Set the exit_flag property to True
- Allow child workflow cancellation
Args:
reason: The reason for cancellation.
"""
logger.debug(
f"Context: setting cancellation flag for step_run_id={self.step_run_id}, "
f"reason={reason.value}"
)
self.cancellation_token.cancel(reason)
def _set_cancellation_flag(self) -> None:
self.exit_flag = True
def cancel(self) -> None:
"""
@@ -222,11 +178,9 @@ class Context:
:return: None
"""
logger.debug(
f"Context: cancel() called for task_run_external_id={self.step_run_id}"
)
logger.debug("cancelling step...")
self.runs_client.cancel(self.step_run_id)
self._set_cancellation_flag(CancellationReason.USER_REQUESTED)
self._set_cancellation_flag()
async def aio_cancel(self) -> None:
"""
@@ -234,11 +188,9 @@ class Context:
:return: None
"""
logger.debug(
f"Context: aio_cancel() called for task_run_external_id={self.step_run_id}"
)
logger.debug("cancelling step...")
await self.runs_client.aio_cancel(self.step_run_id)
self._set_cancellation_flag(CancellationReason.USER_REQUESTED)
self._set_cancellation_flag()
def done(self) -> bool:
"""
@@ -530,11 +482,8 @@ class DurableContext(Context):
"""
Durably wait for either a sleep or an event.
This method respects the context's cancellation token. If the task is cancelled
while waiting, an asyncio.CancelledError will be raised.
:param signal_key: The key to use for the durable event. This is used to identify the event in the Hatchet API.
:param \\*conditions: The conditions to wait for. Can be a SleepCondition or UserEventCondition.
:param *conditions: The conditions to wait for. Can be a SleepCondition or UserEventCondition.
:return: A dictionary containing the results of the wait.
:raises ValueError: If the durable event listener is not available.
@@ -544,10 +493,6 @@ class DurableContext(Context):
task_id = self.step_run_id
logger.debug(
f"DurableContext.aio_wait_for: waiting for signal_key={signal_key}, task_id={task_id}"
)
request = RegisterDurableEventRequest(
task_id=task_id,
signal_key=signal_key,
@@ -557,29 +502,19 @@ class DurableContext(Context):
self.durable_event_listener.register_durable_event(request)
# Use await_with_cancellation to respect the cancellation token
return await await_with_cancellation(
self.durable_event_listener.result(task_id, signal_key),
self.cancellation_token,
return await self.durable_event_listener.result(
task_id,
signal_key,
)
async def aio_sleep_for(self, duration: Duration) -> dict[str, Any]:
"""
Lightweight wrapper for durable sleep. Allows for shorthand usage of `ctx.aio_wait_for` when specifying a sleep condition.
This method respects the context's cancellation token. If the task is cancelled
while sleeping, an asyncio.CancelledError will be raised.
For more complicated conditions, use `ctx.aio_wait_for` directly.
:param duration: The duration to sleep for.
:return: A dictionary containing the results of the wait.
"""
wait_index = self._increment_wait_index()
logger.debug(
f"DurableContext.aio_sleep_for: sleeping for {duration}, wait_index={wait_index}"
)
wait_index = self._increment_wait_index()
return await self.aio_wait_for(
f"sleep:{timedelta_to_expr(duration)}-{wait_index}",
-52
View File
@@ -1,6 +1,5 @@
import json
import traceback
from enum import Enum
from typing import cast
@@ -171,54 +170,3 @@ class IllegalTaskOutputError(Exception):
class LifespanSetupError(Exception):
pass
class CancellationReason(Enum):
"""Reason for cancellation of an operation."""
USER_REQUESTED = "user_requested"
"""The user explicitly requested cancellation."""
TIMEOUT = "timeout"
"""The operation timed out."""
PARENT_CANCELLED = "parent_cancelled"
"""The parent workflow or task was cancelled."""
WORKFLOW_CANCELLED = "workflow_cancelled"
"""The workflow run was cancelled."""
TOKEN_CANCELLED = "token_cancelled"
"""The cancellation token was cancelled."""
class CancelledError(BaseException):
"""
Raised when an operation is cancelled via CancellationToken.
This exception inherits from BaseException (not Exception) so that it
won't be caught by bare `except Exception:` handlers. This mirrors the
behavior of asyncio.CancelledError in Python 3.8+.
To catch this exception, use:
- `except CancelledError:` (recommended)
- `except BaseException:` (catches all exceptions)
This exception is used for sync code paths. For async code paths,
asyncio.CancelledError is used instead.
:param message: Optional message describing the cancellation.
:param reason: Optional enum indicating the reason for cancellation.
"""
def __init__(
self,
message: str = "Operation cancelled",
reason: CancellationReason | None = None,
) -> None:
self.reason = reason
super().__init__(message)
@property
def message(self) -> str:
return str(self.args[0]) if self.args else "Operation cancelled"
@@ -1,17 +1,11 @@
from __future__ import annotations
import asyncio
import threading
from collections import Counter
from contextvars import ContextVar
from typing import TYPE_CHECKING
from hatchet_sdk.runnables.action import ActionKey
from hatchet_sdk.utils.typing import JSONSerializableMapping
if TYPE_CHECKING:
from hatchet_sdk.cancellation import CancellationToken
ctx_workflow_run_id: ContextVar[str | None] = ContextVar(
"ctx_workflow_run_id", default=None
)
@@ -26,9 +20,6 @@ ctx_additional_metadata: ContextVar[JSONSerializableMapping | None] = ContextVar
ctx_task_retry_count: ContextVar[int | None] = ContextVar(
"ctx_task_retry_count", default=0
)
ctx_cancellation_token: ContextVar[CancellationToken | None] = ContextVar(
"ctx_cancellation_token", default=None
)
workflow_spawn_indices = Counter[ActionKey]()
spawn_index_lock = asyncio.Lock()
+19 -222
View File
@@ -1,5 +1,3 @@
from __future__ import annotations
import asyncio
import json
from collections.abc import Callable
@@ -39,11 +37,8 @@ from hatchet_sdk.contracts.v1.workflows_pb2 import (
)
from hatchet_sdk.contracts.v1.workflows_pb2 import StickyStrategy as StickyStrategyProto
from hatchet_sdk.contracts.workflows_pb2 import WorkflowVersion
from hatchet_sdk.exceptions import CancellationReason, CancelledError
from hatchet_sdk.labels import DesiredWorkerLabel
from hatchet_sdk.logger import logger
from hatchet_sdk.rate_limit import RateLimit
from hatchet_sdk.runnables.contextvars import ctx_cancellation_token
from hatchet_sdk.runnables.task import Task
from hatchet_sdk.runnables.types import (
ConcurrencyExpression,
@@ -57,7 +52,6 @@ from hatchet_sdk.runnables.types import (
normalize_validator,
)
from hatchet_sdk.serde import HATCHET_PYDANTIC_SENTINEL
from hatchet_sdk.utils.cancellation import await_with_cancellation
from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto
from hatchet_sdk.utils.timedelta_to_expression import Duration
from hatchet_sdk.utils.typing import CoroutineLike, JSONSerializableMapping
@@ -65,7 +59,6 @@ from hatchet_sdk.workflow_run import WorkflowRunRef
if TYPE_CHECKING:
from hatchet_sdk import Hatchet
from hatchet_sdk.cancellation import CancellationToken
T = TypeVar("T")
@@ -95,7 +88,7 @@ class ComputedTaskParameters(BaseModel):
task_defaults: TaskDefaults
@model_validator(mode="after")
def validate_params(self) -> ComputedTaskParameters:
def validate_params(self) -> "ComputedTaskParameters":
self.execution_timeout = fall_back_to_default(
value=self.execution_timeout,
param_default=timedelta(seconds=60),
@@ -143,7 +136,7 @@ class TypedTriggerWorkflowRunConfig(BaseModel, Generic[TWorkflowInput]):
class BaseWorkflow(Generic[TWorkflowInput]):
def __init__(self, config: WorkflowConfig, client: Hatchet) -> None:
def __init__(self, config: WorkflowConfig, client: "Hatchet") -> None:
self.config = config
self._default_tasks: list[Task[TWorkflowInput, Any]] = []
self._durable_tasks: list[Task[TWorkflowInput, Any]] = []
@@ -632,38 +625,6 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
and can be arranged into complex dependency patterns.
"""
def _resolve_check_cancellation_token(self) -> CancellationToken | None:
cancellation_token = ctx_cancellation_token.get()
if cancellation_token and cancellation_token.is_cancelled:
raise CancelledError(
"Operation cancelled by cancellation token",
reason=CancellationReason.TOKEN_CANCELLED,
)
return cancellation_token
def _register_child_with_token(
self,
cancellation_token: CancellationToken | None,
workflow_run_id: str,
) -> None:
if not cancellation_token:
return
cancellation_token.register_child(workflow_run_id)
def _register_children_with_token(
self,
cancellation_token: CancellationToken | None,
refs: list[WorkflowRunRef],
) -> None:
if not cancellation_token:
return
for ref in refs:
cancellation_token.register_child(ref.workflow_run_id)
def run_no_wait(
self,
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
@@ -673,34 +634,17 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
Synchronously trigger a workflow run without waiting for it to complete.
This method is useful for starting a workflow run and immediately returning a reference to the run without blocking while the workflow runs.
If a cancellation token is available via context, the child workflow will be registered
with the token.
:param input: The input data for the workflow.
:param options: Additional options for workflow execution.
:returns: A `WorkflowRunRef` object representing the reference to the workflow run.
"""
cancellation_token = self._resolve_check_cancellation_token()
logger.debug(
f"Workflow.run_no_wait: triggering {self.config.name}, "
f"token={cancellation_token is not None}"
)
ref = self.client._client.admin.run_workflow(
return self.client._client.admin.run_workflow(
workflow_name=self.config.name,
input=self._serialize_input(input),
options=self._create_options_with_combined_additional_meta(options),
)
self._register_child_with_token(
cancellation_token,
ref.workflow_run_id,
)
return ref
def run(
self,
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
@@ -710,19 +654,12 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
Run the workflow synchronously and wait for it to complete.
This method triggers a workflow run, blocks until completion, and returns the final result.
If a cancellation token is available via context, the wait can be interrupted.
:param input: The input data for the workflow, must match the workflow's input type.
:param options: Additional options for workflow execution like metadata and parent workflow ID.
:returns: The result of the workflow execution as a dictionary.
"""
cancellation_token = self._resolve_check_cancellation_token()
logger.debug(
f"Workflow.run: triggering {self.config.name}, "
f"token={cancellation_token is not None}"
)
ref = self.client._client.admin.run_workflow(
workflow_name=self.config.name,
@@ -730,14 +667,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
options=self._create_options_with_combined_additional_meta(options),
)
self._register_child_with_token(
cancellation_token,
ref.workflow_run_id,
)
logger.debug(f"Workflow.run: awaiting result for {ref.workflow_run_id}")
return ref.result(cancellation_token=cancellation_token)
return ref.result()
async def aio_run_no_wait(
self,
@@ -748,34 +678,18 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
Asynchronously trigger a workflow run without waiting for it to complete.
This method is useful for starting a workflow run and immediately returning a reference to the run without blocking while the workflow runs.
If a cancellation token is available via context, the child workflow will be registered
with the token.
:param input: The input data for the workflow.
:param options: Additional options for workflow execution.
:returns: A `WorkflowRunRef` object representing the reference to the workflow run.
"""
cancellation_token = self._resolve_check_cancellation_token()
logger.debug(
f"Workflow.aio_run_no_wait: triggering {self.config.name}, "
f"token={cancellation_token is not None}"
)
ref = await self.client._client.admin.aio_run_workflow(
return await self.client._client.admin.aio_run_workflow(
workflow_name=self.config.name,
input=self._serialize_input(input),
options=self._create_options_with_combined_additional_meta(options),
)
self._register_child_with_token(
cancellation_token,
ref.workflow_run_id,
)
return ref
async def aio_run(
self,
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
@@ -785,47 +699,25 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
Run the workflow asynchronously and wait for it to complete.
This method triggers a workflow run, awaits until completion, and returns the final result.
If a cancellation token is available via context, the wait can be interrupted.
:param input: The input data for the workflow, must match the workflow's input type.
:param options: Additional options for workflow execution like metadata and parent workflow ID.
:returns: The result of the workflow execution as a dictionary.
"""
cancellation_token = self._resolve_check_cancellation_token()
logger.debug(
f"Workflow.aio_run: triggering {self.config.name}, "
f"token={cancellation_token is not None}"
)
ref = await self.client._client.admin.aio_run_workflow(
workflow_name=self.config.name,
input=self._serialize_input(input),
options=self._create_options_with_combined_additional_meta(options),
)
self._register_child_with_token(
cancellation_token,
ref.workflow_run_id,
)
logger.debug(f"Workflow.aio_run: awaiting result for {ref.workflow_run_id}")
return await await_with_cancellation(
ref.aio_result(),
cancellation_token,
)
return await ref.aio_result()
def _get_result(
self,
ref: WorkflowRunRef,
return_exceptions: bool,
self, ref: WorkflowRunRef, return_exceptions: bool
) -> dict[str, Any] | BaseException:
try:
return ref.result(
cancellation_token=self._resolve_check_cancellation_token()
)
return ref.result()
except Exception as e:
if return_exceptions:
return e
@@ -854,52 +746,15 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
Run a workflow in bulk and wait for all runs to complete.
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
If a cancellation token is available via context, all child workflows will be registered
with the token and the wait can be interrupted.
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
:returns: A list of results for each workflow run.
:raises CancelledError: If the cancellation token is triggered (and return_exceptions is False).
:raises Exception: If a workflow run fails (and return_exceptions is False).
"""
cancellation_token = self._resolve_check_cancellation_token()
refs = self.client._client.admin.run_workflows(
workflows=workflows,
)
self._register_children_with_token(
cancellation_token,
refs,
)
# Pass cancellation_token through to each result() call
# The cancellation check happens INSIDE result()'s polling loop
results: list[dict[str, Any] | BaseException] = []
for ref in refs:
try:
results.append(ref.result(cancellation_token=cancellation_token))
except CancelledError: # noqa: PERF203
logger.debug(
f"Workflow.run_many: cancellation detected, stopping wait, "
f"reason={CancellationReason.PARENT_CANCELLED.value}"
)
if return_exceptions:
results.append(
CancelledError(
"Operation cancelled by cancellation token",
reason=CancellationReason.PARENT_CANCELLED,
)
)
break
raise
except Exception as e:
if return_exceptions:
results.append(e)
else:
raise
return results
return [self._get_result(ref, return_exceptions) for ref in refs]
@overload
async def aio_run_many(
@@ -924,34 +779,16 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
Run a workflow in bulk and wait for all runs to complete.
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
If a cancellation token is available via context, all child workflows will be registered
with the token and the wait can be interrupted.
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
:returns: A list of results for each workflow run.
"""
cancellation_token = self._resolve_check_cancellation_token()
logger.debug(
f"Workflow.aio_run_many: triggering {len(workflows)} workflows, "
f"token={cancellation_token is not None}"
)
refs = await self.client._client.admin.aio_run_workflows(
workflows=workflows,
)
self._register_children_with_token(
cancellation_token,
refs,
)
return await await_with_cancellation(
asyncio.gather(
*[ref.aio_result() for ref in refs], return_exceptions=return_exceptions
),
cancellation_token,
return await asyncio.gather(
*[ref.aio_result() for ref in refs], return_exceptions=return_exceptions
)
def run_many_no_wait(
@@ -963,30 +800,13 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
This method triggers multiple workflow runs and immediately returns a list of references to the runs without blocking while the workflows run.
If a cancellation token is available via context, all child workflows will be registered
with the token.
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
:returns: A list of `WorkflowRunRef` objects, each representing a reference to a workflow run.
"""
cancellation_token = self._resolve_check_cancellation_token()
logger.debug(
f"Workflow.run_many_no_wait: triggering {len(workflows)} workflows, "
f"token={cancellation_token is not None}"
)
refs = self.client._client.admin.run_workflows(
return self.client._client.admin.run_workflows(
workflows=workflows,
)
self._register_children_with_token(
cancellation_token,
refs,
)
return refs
async def aio_run_many_no_wait(
self,
workflows: list[WorkflowRunTriggerConfig],
@@ -996,31 +816,14 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
This method triggers multiple workflow runs and immediately returns a list of references to the runs without blocking while the workflows run.
If a cancellation token is available via context, all child workflows will be registered
with the token.
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
:returns: A list of `WorkflowRunRef` objects, each representing a reference to a workflow run.
"""
cancellation_token = self._resolve_check_cancellation_token()
logger.debug(
f"Workflow.aio_run_many_no_wait: triggering {len(workflows)} workflows, "
f"token={cancellation_token is not None}"
)
refs = await self.client._client.admin.aio_run_workflows(
return await self.client._client.admin.aio_run_workflows(
workflows=workflows,
)
self._register_children_with_token(
cancellation_token,
refs,
)
return refs
def _parse_task_name(
self,
name: str | None,
@@ -1365,7 +1168,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
return inner
def add_task(self, task: Standalone[TWorkflowInput, Any]) -> None:
def add_task(self, task: "Standalone[TWorkflowInput, Any]") -> None:
"""
Add a task to a workflow. Intended to be used with a previously existing task (a Standalone),
such as one created with `@hatchet.task()`, which has been converted to a `Task` object using `to_task`.
@@ -1404,7 +1207,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
class TaskRunRef(Generic[TWorkflowInput, R]):
def __init__(
self,
standalone: Standalone[TWorkflowInput, R],
standalone: "Standalone[TWorkflowInput, R]",
workflow_run_ref: WorkflowRunRef,
):
self._s = standalone
@@ -1563,9 +1366,7 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
) -> list[R]: ...
def run_many(
self,
workflows: list[WorkflowRunTriggerConfig],
return_exceptions: bool = False,
self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
) -> list[R] | list[R | BaseException]:
"""
Run a workflow in bulk and wait for all runs to complete.
@@ -1599,9 +1400,7 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
) -> list[R]: ...
async def aio_run_many(
self,
workflows: list[WorkflowRunTriggerConfig],
return_exceptions: bool = False,
self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
) -> list[R] | list[R | BaseException]:
"""
Run a workflow in bulk and wait for all runs to complete.
@@ -1621,8 +1420,7 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
]
def run_many_no_wait(
self,
workflows: list[WorkflowRunTriggerConfig],
self, workflows: list[WorkflowRunTriggerConfig]
) -> list[TaskRunRef[TWorkflowInput, R]]:
"""
Run a workflow in bulk without waiting for all runs to complete.
@@ -1637,8 +1435,7 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
async def aio_run_many_no_wait(
self,
workflows: list[WorkflowRunTriggerConfig],
self, workflows: list[WorkflowRunTriggerConfig]
) -> list[TaskRunRef[TWorkflowInput, R]]:
"""
Run a workflow in bulk without waiting for all runs to complete.
@@ -1,149 +0,0 @@
"""Utilities for cancellation-aware operations."""
from __future__ import annotations
import asyncio
import contextlib
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, TypeVar
from hatchet_sdk.logger import logger
if TYPE_CHECKING:
from hatchet_sdk.cancellation import CancellationToken
T = TypeVar("T")
async def _invoke_cancel_callback(
cancel_callback: Callable[[], Awaitable[None]] | None,
) -> None:
"""Invoke a cancel callback."""
if not cancel_callback:
return
await cancel_callback()
async def race_against_token(
main_task: asyncio.Task[T],
token: CancellationToken,
) -> T:
"""
Race an asyncio task against a cancellation token.
Waits for either the task to complete or the token to be cancelled. Cleans up
whichever side loses the race.
Args:
main_task: The asyncio task to race.
token: The cancellation token to race against.
Returns:
The result of the main task if it completes first.
Raises:
asyncio.CancelledError: If the token fires before the task completes.
"""
cancel_task = asyncio.create_task(token.aio_wait())
try:
done, pending = await asyncio.wait(
[main_task, cancel_task],
return_when=asyncio.FIRST_COMPLETED,
)
# Cancel pending tasks
for task in pending:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
if cancel_task in done:
raise asyncio.CancelledError("Operation cancelled by cancellation token")
return main_task.result()
except asyncio.CancelledError:
# Ensure both tasks are cleaned up on any cancellation (external or token)
main_task.cancel()
cancel_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await main_task
with contextlib.suppress(asyncio.CancelledError):
await cancel_task
raise
async def await_with_cancellation(
coro: Awaitable[T],
token: CancellationToken | None,
cancel_callback: Callable[[], Awaitable[None]] | None = None,
) -> T:
"""
Await an awaitable with cancellation support.
This function races the given awaitable against a cancellation token. If the
token is cancelled before the awaitable completes, the awaitable is cancelled
and an asyncio.CancelledError is raised.
Args:
coro: The awaitable to await (coroutine, Future, or asyncio.Task).
token: The cancellation token to check. If None, the coroutine is awaited directly.
cancel_callback: An optional async callback to invoke when cancellation occurs
(e.g., to cancel child workflows).
Returns:
The result of the coroutine.
Raises:
asyncio.CancelledError: If the token is cancelled before the coroutine completes.
Example:
```python
async def cleanup() -> None:
print("cleaning up...")
async def long_running_task():
await asyncio.sleep(10)
return "done"
token = CancellationToken()
# This will raise asyncio.CancelledError if token.cancel() is called
result = await await_with_cancellation(
long_running_task(),
token,
cancel_callback=cleanup,
)
```
"""
if token is None:
logger.debug("await_with_cancellation: no token provided, awaiting directly")
return await coro
logger.debug("await_with_cancellation: starting with cancellation token")
# Check if already cancelled
if token.is_cancelled:
logger.debug("await_with_cancellation: token already cancelled")
if cancel_callback:
logger.debug("await_with_cancellation: invoking cancel callback")
await _invoke_cancel_callback(cancel_callback)
raise asyncio.CancelledError("Operation cancelled by cancellation token")
main_task = asyncio.ensure_future(coro)
try:
result = await race_against_token(main_task, token)
logger.debug("await_with_cancellation: completed successfully")
return result
except asyncio.CancelledError:
logger.debug("await_with_cancellation: cancelled")
if cancel_callback:
logger.debug("await_with_cancellation: invoking cancel callback")
with contextlib.suppress(asyncio.CancelledError):
await asyncio.shield(_invoke_cancel_callback(cancel_callback))
raise
+13 -91
View File
@@ -2,7 +2,6 @@ import asyncio
import ctypes
import functools
import json
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, is_dataclass
@@ -30,7 +29,6 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
STEP_EVENT_TYPE_STARTED,
)
from hatchet_sdk.exceptions import (
CancellationReason,
IllegalTaskOutputError,
NonRetryableException,
TaskRunError,
@@ -41,7 +39,6 @@ from hatchet_sdk.runnables.action import Action, ActionKey, ActionType
from hatchet_sdk.runnables.contextvars import (
ctx_action_key,
ctx_additional_metadata,
ctx_cancellation_token,
ctx_step_run_id,
ctx_task_retry_count,
ctx_worker_id,
@@ -63,7 +60,6 @@ from hatchet_sdk.worker.runner.utils.capture_logs import (
ContextVarToCopyDict,
ContextVarToCopyInt,
ContextVarToCopyStr,
ContextVarToCopyToken,
copy_context_vars,
)
@@ -255,7 +251,6 @@ class Runner:
ctx_action_key.set(action.key)
ctx_additional_metadata.set(action.additional_metadata)
ctx_task_retry_count.set(action.retry_count)
ctx_cancellation_token.set(ctx.cancellation_token)
async with task._unpack_dependencies_with_cleanup(ctx) as dependencies:
try:
@@ -303,12 +298,6 @@ class Runner:
value=action.retry_count,
)
),
ContextVarToCopy(
var=ContextVarToCopyToken(
name="ctx_cancellation_token",
value=ctx.cancellation_token,
)
),
],
self.thread_action_func,
ctx,
@@ -491,95 +480,28 @@ class Runner:
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
async def handle_cancel_action(self, action: Action) -> None:
key = action.key
start_time = time.monotonic()
logger.info(
f"Cancellation: received cancel action for {action.action_id}, "
f"reason={CancellationReason.WORKFLOW_CANCELLED.value}"
)
try:
# Trigger the cancellation token to signal the context to stop
# call cancel to signal the context to stop
if key in self.contexts:
ctx = self.contexts[key]
child_count = len(ctx.cancellation_token.child_run_ids)
logger.debug(
f"Cancellation: triggering token for {action.action_id}, "
f"reason={CancellationReason.WORKFLOW_CANCELLED.value}, "
f"{child_count} children registered"
)
ctx._set_cancellation_flag(CancellationReason.WORKFLOW_CANCELLED)
self.contexts[key]._set_cancellation_flag()
self.cancellations[key] = True
# Note: Child workflows are not cancelled here - they run independently
# and are managed by Hatchet's normal cancellation mechanisms
else:
logger.debug(f"Cancellation: no context found for {action.action_id}")
# Wait with supervision (using timedelta configs)
grace_period = self.config.cancellation_grace_period.total_seconds()
warning_threshold = (
self.config.cancellation_warning_threshold.total_seconds()
)
grace_period_ms = round(grace_period * 1000)
warning_threshold_ms = round(warning_threshold * 1000)
await asyncio.sleep(1)
# Wait until warning threshold
await asyncio.sleep(warning_threshold)
elapsed = time.monotonic() - start_time
elapsed_ms = round(elapsed * 1000)
if key in self.tasks:
self.tasks[key].cancel()
# Check if the task has not yet exited despite the cancellation signal.
task_still_running = key in self.tasks and not self.tasks[key].done()
# check if thread is still running, if so, print a warning
if key in self.threads:
thread = self.threads[key]
if self.config.enable_force_kill_sync_threads:
self.force_kill_thread(thread)
await asyncio.sleep(1)
if task_still_running:
logger.warning(
f"Cancellation: task {action.action_id} has not cancelled after "
f"{elapsed_ms}ms (warning threshold {warning_threshold_ms}ms). "
f"Consider checking for blocking operations. "
f"See https://docs.hatchet.run/home/cancellation"
f"thread {self.threads[key].ident} with key {key} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running."
)
remaining = grace_period - elapsed
if remaining > 0:
await asyncio.sleep(remaining)
if key in self.tasks and not self.tasks[key].done():
logger.debug(
f"Cancellation: force-cancelling task {action.action_id} "
f"after grace period ({grace_period_ms}ms)"
)
self.tasks[key].cancel()
if key in self.threads:
thread = self.threads[key]
if self.config.enable_force_kill_sync_threads:
logger.debug(
f"Cancellation: force-killing thread for {action.action_id}"
)
self.force_kill_thread(thread)
await asyncio.sleep(1)
if thread.is_alive():
logger.warning(
f"Cancellation: thread {thread.ident} with key {key} is still running "
f"after cancellation. This could cause the thread pool to get blocked "
f"and prevent new tasks from running."
)
total_elapsed = time.monotonic() - start_time
total_elapsed_ms = round(total_elapsed * 1000)
if total_elapsed > grace_period:
logger.warning(
f"Cancellation: cancellation of {action.action_id} took {total_elapsed_ms}ms "
f"(exceeded grace period of {grace_period_ms}ms)"
)
else:
logger.debug(
f"Cancellation: task {action.action_id} eventually completed in {total_elapsed_ms}ms"
)
else:
logger.info(f"Cancellation: task {action.action_id} completed")
finally:
self.cleanup_run_id(key)
@@ -1,5 +1,3 @@
from __future__ import annotations
import asyncio
import functools
import logging
@@ -7,15 +5,13 @@ from collections.abc import Awaitable, Callable
from io import StringIO
from typing import Literal, ParamSpec, TypeVar
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field
from hatchet_sdk.cancellation import CancellationToken
from hatchet_sdk.clients.events import EventClient
from hatchet_sdk.logger import logger
from hatchet_sdk.runnables.contextvars import (
ctx_action_key,
ctx_additional_metadata,
ctx_cancellation_token,
ctx_step_run_id,
ctx_task_retry_count,
ctx_worker_id,
@@ -52,22 +48,10 @@ class ContextVarToCopyDict(BaseModel):
value: JSONSerializableMapping | None
class ContextVarToCopyToken(BaseModel):
"""Special type for copying CancellationToken to threads."""
model_config = ConfigDict(arbitrary_types_allowed=True)
name: Literal["ctx_cancellation_token"]
value: CancellationToken | None
class ContextVarToCopy(BaseModel):
var: (
ContextVarToCopyStr
| ContextVarToCopyDict
| ContextVarToCopyInt
| ContextVarToCopyToken
) = Field(discriminator="name")
var: ContextVarToCopyStr | ContextVarToCopyDict | ContextVarToCopyInt = Field(
discriminator="name"
)
def copy_context_vars(
@@ -89,8 +73,6 @@ def copy_context_vars(
ctx_worker_id.set(var.var.value)
elif var.var.name == "ctx_additional_metadata":
ctx_additional_metadata.set(var.var.value or {})
elif var.var.name == "ctx_cancellation_token":
ctx_cancellation_token.set(var.var.value)
else:
raise ValueError(f"Unknown context variable name: {var.var.name}")
+6 -95
View File
@@ -1,5 +1,3 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING, Any
@@ -8,17 +6,9 @@ from hatchet_sdk.clients.listeners.run_event_listener import (
RunEventListenerClient,
)
from hatchet_sdk.clients.listeners.workflow_listener import PooledWorkflowRunListener
from hatchet_sdk.exceptions import (
CancellationReason,
CancelledError,
FailedTaskRunExceptionGroup,
TaskRunError,
)
from hatchet_sdk.logger import logger
from hatchet_sdk.utils.cancellation import await_with_cancellation
from hatchet_sdk.exceptions import FailedTaskRunExceptionGroup, TaskRunError
if TYPE_CHECKING:
from hatchet_sdk.cancellation import CancellationToken
from hatchet_sdk.clients.admin import AdminClient
@@ -28,7 +18,7 @@ class WorkflowRunRef:
workflow_run_id: str,
workflow_run_listener: PooledWorkflowRunListener,
workflow_run_event_listener: RunEventListenerClient,
admin_client: AdminClient,
admin_client: "AdminClient",
):
self.workflow_run_id = workflow_run_id
self.workflow_run_listener = workflow_run_listener
@@ -41,25 +31,7 @@ class WorkflowRunRef:
def stream(self) -> RunEventListener:
return self.workflow_run_event_listener.stream(self.workflow_run_id)
async def aio_result(
self, cancellation_token: CancellationToken | None = None
) -> dict[str, Any]:
"""
Asynchronously wait for the workflow run to complete and return the result.
:param cancellation_token: Optional cancellation token to abort the wait.
:return: A dictionary mapping task names to their outputs.
"""
logger.debug(
f"WorkflowRunRef.aio_result: waiting for {self.workflow_run_id}, "
f"token={cancellation_token is not None}"
)
if cancellation_token:
return await await_with_cancellation(
self.workflow_run_listener.aio_result(self.workflow_run_id),
cancellation_token,
)
async def aio_result(self) -> dict[str, Any]:
return await self.workflow_run_listener.aio_result(self.workflow_run_id)
def _safely_get_action_name(self, action_id: str | None) -> str | None:
@@ -71,42 +43,12 @@ class WorkflowRunRef:
except IndexError:
return None
def result(
self, cancellation_token: CancellationToken | None = None
) -> dict[str, Any]:
"""
Synchronously wait for the workflow run to complete and return the result.
This method polls the API for the workflow run status. If a cancellation token
is provided, the polling will be interrupted when cancellation is triggered.
:param cancellation_token: Optional cancellation token to abort the wait.
:return: A dictionary mapping task names to their outputs.
:raises CancelledError: If the cancellation token is triggered.
:raises FailedTaskRunExceptionGroup: If the workflow run fails.
:raises ValueError: If the workflow run is not found.
"""
def result(self) -> dict[str, Any]:
from hatchet_sdk.clients.admin import RunStatus
logger.debug(
f"WorkflowRunRef.result: waiting for {self.workflow_run_id}, "
f"token={cancellation_token is not None}"
)
retries = 0
while True:
# Check cancellation at start of each iteration
if cancellation_token and cancellation_token.is_cancelled:
logger.debug(
f"WorkflowRunRef.result: cancellation detected for {self.workflow_run_id}, "
f"reason={CancellationReason.PARENT_CANCELLED.value}"
)
raise CancelledError(
"Operation cancelled by cancellation token",
reason=CancellationReason.PARENT_CANCELLED,
)
try:
details = self.admin_client.get_details(self.workflow_run_id)
except Exception as e:
@@ -117,42 +59,14 @@ class WorkflowRunRef:
f"Workflow run {self.workflow_run_id} not found"
) from e
# Use interruptible sleep via token.wait()
if cancellation_token:
if cancellation_token.wait(timeout=1.0):
logger.debug(
f"WorkflowRunRef.result: cancellation during retry sleep for {self.workflow_run_id}, "
f"reason={CancellationReason.PARENT_CANCELLED.value}"
)
raise CancelledError(
"Operation cancelled by cancellation token",
reason=CancellationReason.PARENT_CANCELLED,
) from None
else:
time.sleep(1)
time.sleep(1)
continue
logger.debug(
f"WorkflowRunRef.result: {self.workflow_run_id} status={details.status}"
)
if (
details.status in [RunStatus.QUEUED, RunStatus.RUNNING]
or details.done is False
):
# Use interruptible sleep via token.wait()
if cancellation_token:
if cancellation_token.wait(timeout=1.0):
logger.debug(
f"WorkflowRunRef.result: cancellation during poll sleep for {self.workflow_run_id}, "
f"reason={CancellationReason.PARENT_CANCELLED.value}"
)
raise CancelledError(
"Operation cancelled by cancellation token",
reason=CancellationReason.PARENT_CANCELLED,
)
else:
time.sleep(1)
time.sleep(1)
continue
if details.status == RunStatus.FAILED:
@@ -166,9 +80,6 @@ class WorkflowRunRef:
)
if details.status == RunStatus.COMPLETED:
logger.debug(
f"WorkflowRunRef.result: {self.workflow_run_id} completed successfully"
)
return {
readable_id: run.output
for readable_id, run in details.task_runs.items()
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "hatchet-sdk"
version = "1.25.1"
version = "1.25.2"
description = "This is the official Python SDK for Hatchet, a distributed, fault-tolerant task queue. The SDK allows you to easily integrate Hatchet's task scheduling and workflow orchestration capabilities into your Python applications."
authors = [
"Alexander Belanger <alexander@hatchet.run>",
-461
View File
@@ -1,461 +0,0 @@
"""Unit tests for CancellationToken and cancellation utilities."""
import asyncio
import threading
import time
import pytest
from hatchet_sdk.cancellation import CancellationToken
from hatchet_sdk.exceptions import CancellationReason, CancelledError
from hatchet_sdk.runnables.contextvars import ctx_cancellation_token
from hatchet_sdk.utils.cancellation import await_with_cancellation
# CancellationToken
def test_initial_state() -> None:
"""Token should start in non-cancelled state."""
token = CancellationToken()
assert token.is_cancelled is False
def test_cancel_sets_flag() -> None:
"""cancel() should set is_cancelled to True."""
token = CancellationToken()
token.cancel()
assert token.is_cancelled is True
def test_cancel_sets_reason() -> None:
"""cancel() should set the reason."""
token = CancellationToken()
token.cancel(CancellationReason.USER_REQUESTED)
assert token.reason == CancellationReason.USER_REQUESTED
def test_reason_is_none_before_cancel() -> None:
"""reason should be None before cancellation."""
token = CancellationToken()
assert token.reason is None
def test_cancel_idempotent() -> None:
"""Multiple calls to cancel() should be safe."""
token = CancellationToken()
token.cancel()
token.cancel() # Should not raise
assert token.is_cancelled is True
def test_cancel_idempotent_preserves_reason() -> None:
"""Multiple calls to cancel() should preserve the original reason."""
token = CancellationToken()
token.cancel(CancellationReason.USER_REQUESTED)
token.cancel(CancellationReason.TIMEOUT) # Second call should be ignored
assert token.reason == CancellationReason.USER_REQUESTED
def test_sync_wait_returns_true_when_cancelled() -> None:
"""wait() should return True immediately if already cancelled."""
token = CancellationToken()
token.cancel()
result = token.wait(timeout=0.1)
assert result is True
def test_sync_wait_timeout_returns_false() -> None:
"""wait() should return False when timeout expires without cancellation."""
token = CancellationToken()
start = time.monotonic()
result = token.wait(timeout=0.1)
elapsed = time.monotonic() - start
assert result is False
assert elapsed >= 0.1
def test_sync_wait_interrupted_by_cancel() -> None:
"""wait() should return True when cancelled during wait."""
token = CancellationToken()
def cancel_after_delay() -> None:
time.sleep(0.1)
token.cancel()
thread = threading.Thread(target=cancel_after_delay)
thread.start()
start = time.monotonic()
result = token.wait(timeout=1.0)
elapsed = time.monotonic() - start
thread.join()
assert result is True
assert elapsed < 0.5 # Should be much faster than timeout
@pytest.mark.asyncio
async def test_aio_wait_returns_when_cancelled() -> None:
"""aio_wait() should return when cancelled."""
token = CancellationToken()
async def cancel_after_delay() -> None:
await asyncio.sleep(0.1)
token.cancel()
asyncio.create_task(cancel_after_delay())
start = time.monotonic()
await token.aio_wait()
elapsed = time.monotonic() - start
assert elapsed < 0.5 # Should be fast
def test_register_child() -> None:
"""register_child() should add run IDs to the list."""
token = CancellationToken()
token.register_child("run-1")
token.register_child("run-2")
assert token.child_run_ids == ["run-1", "run-2"]
def test_callback_invoked_on_cancel() -> None:
"""Callbacks should be invoked when cancel() is called."""
token = CancellationToken()
called = []
def callback() -> None:
called.append(True)
token.add_callback(callback)
token.cancel()
assert called == [True]
def test_callback_invoked_immediately_if_already_cancelled() -> None:
"""Callbacks added after cancellation should be invoked immediately."""
token = CancellationToken()
token.cancel()
called = []
def callback() -> None:
called.append(True)
token.add_callback(callback)
assert called == [True]
def test_multiple_callbacks() -> None:
"""Multiple callbacks should all be invoked."""
token = CancellationToken()
results: list[int] = []
token.add_callback(lambda: results.append(1))
token.add_callback(lambda: results.append(2))
token.add_callback(lambda: results.append(3))
token.cancel()
assert results == [1, 2, 3]
def test_repr() -> None:
"""__repr__ should provide useful debugging info."""
token = CancellationToken()
token.register_child("run-1")
repr_str = repr(token)
assert "cancelled=False" in repr_str
assert "children=1" in repr_str
# await_with_cancellation
@pytest.mark.asyncio
async def test_no_token_awaits_directly() -> None:
"""Without a token, coroutine should be awaited directly."""
async def simple_coro() -> str:
return "result"
result = await await_with_cancellation(simple_coro(), None)
assert result == "result"
@pytest.mark.asyncio
async def test_token_not_cancelled_returns_result() -> None:
"""With a non-cancelled token, should return coroutine result."""
token = CancellationToken()
async def simple_coro() -> str:
await asyncio.sleep(0.01)
return "result"
result = await await_with_cancellation(simple_coro(), token)
assert result == "result"
@pytest.mark.asyncio
async def test_already_cancelled_raises_immediately() -> None:
"""With an already-cancelled token, should raise immediately."""
token = CancellationToken()
token.cancel()
async def simple_coro() -> str:
await asyncio.sleep(10) # Would block if actually awaited
return "result"
with pytest.raises(asyncio.CancelledError):
await await_with_cancellation(simple_coro(), token)
@pytest.mark.asyncio
async def test_cancellation_during_await_raises() -> None:
"""Should raise CancelledError when token is cancelled during await."""
token = CancellationToken()
async def slow_coro() -> str:
await asyncio.sleep(10)
return "result"
async def cancel_after_delay() -> None:
await asyncio.sleep(0.1)
token.cancel()
asyncio.create_task(cancel_after_delay())
start = time.monotonic()
with pytest.raises(asyncio.CancelledError):
await await_with_cancellation(slow_coro(), token)
elapsed = time.monotonic() - start
assert elapsed < 0.5 # Should be cancelled quickly
@pytest.mark.asyncio
async def test_cancel_callback_invoked() -> None:
"""Cancel callback should be invoked on cancellation."""
token = CancellationToken()
callback_called = []
async def cancel_callback() -> None:
callback_called.append(True)
async def slow_coro() -> str:
await asyncio.sleep(10)
return "result"
async def cancel_after_delay() -> None:
await asyncio.sleep(0.1)
token.cancel()
asyncio.create_task(cancel_after_delay())
with pytest.raises(asyncio.CancelledError):
await await_with_cancellation(
slow_coro(), token, cancel_callback=cancel_callback
)
assert callback_called == [True]
@pytest.mark.asyncio
async def test_sync_cancel_callback_invoked() -> None:
"""Cancel callback should be invoked on cancellation."""
token = CancellationToken()
callback_called = []
async def cancel_callback() -> None:
callback_called.append(True)
async def slow_coro() -> str:
await asyncio.sleep(10)
return "result"
async def cancel_after_delay() -> None:
await asyncio.sleep(0.1)
token.cancel()
asyncio.create_task(cancel_after_delay())
with pytest.raises(asyncio.CancelledError):
await await_with_cancellation(
slow_coro(), token, cancel_callback=cancel_callback
)
assert callback_called == [True]
@pytest.mark.asyncio
async def test_cancel_callback_invoked_on_external_task_cancel() -> None:
"""Cancel callback should be invoked if the awaiting task is cancelled externally."""
token = CancellationToken()
callback_called = asyncio.Event()
async def cancel_callback() -> None:
callback_called.set()
async def slow_coro() -> str:
await asyncio.sleep(10)
return "result"
task = asyncio.create_task(
await_with_cancellation(slow_coro(), token, cancel_callback=cancel_callback)
)
await asyncio.sleep(0.1)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
await asyncio.wait_for(callback_called.wait(), timeout=1.0)
@pytest.mark.asyncio
async def test_cancel_callback_not_invoked_on_success() -> None:
"""Cancel callback should NOT be invoked when coroutine completes normally."""
token = CancellationToken()
callback_called = []
async def cancel_callback() -> None:
callback_called.append(True)
async def fast_coro() -> str:
await asyncio.sleep(0.01)
return "result"
result = await await_with_cancellation(
fast_coro(), token, cancel_callback=cancel_callback
)
assert result == "result"
assert callback_called == []
# CancellationReason
def test_all_reasons_exist() -> None:
"""All expected cancellation reasons should exist."""
assert CancellationReason.USER_REQUESTED.value == "user_requested"
assert CancellationReason.TIMEOUT.value == "timeout"
assert CancellationReason.PARENT_CANCELLED.value == "parent_cancelled"
assert CancellationReason.WORKFLOW_CANCELLED.value == "workflow_cancelled"
assert CancellationReason.TOKEN_CANCELLED.value == "token_cancelled"
def test_reasons_are_strings() -> None:
"""Cancellation reason values should be strings."""
for reason in CancellationReason:
assert isinstance(reason.value, str)
# CancelledError
def test_cancelled_error_is_base_exception() -> None:
"""CancelledError should be a BaseException (not Exception)."""
err = CancelledError("test message")
assert isinstance(err, BaseException)
assert not isinstance(err, Exception) # Should NOT be caught by except Exception
assert str(err) == "test message"
def test_cancelled_error_not_caught_by_except_exception() -> None:
"""CancelledError should NOT be caught by except Exception."""
caught_by_exception = False
caught_by_cancelled_error = False
try:
raise CancelledError("test")
except Exception:
caught_by_exception = True
except CancelledError:
caught_by_cancelled_error = True
assert not caught_by_exception
assert caught_by_cancelled_error
def test_cancelled_error_with_reason() -> None:
"""CancelledError should accept and store a reason."""
err = CancelledError("test message", reason=CancellationReason.TIMEOUT)
assert err.reason == CancellationReason.TIMEOUT
def test_cancelled_error_reason_defaults_to_none() -> None:
"""CancelledError reason should default to None."""
err = CancelledError("test message")
assert err.reason is None
def test_cancelled_error_message_property() -> None:
"""CancelledError should have a message property."""
err = CancelledError("test message")
assert err.message == "test message"
def test_cancelled_error_default_message() -> None:
"""CancelledError should have a default message."""
err = CancelledError()
assert err.message == "Operation cancelled"
def test_can_be_raised_and_caught() -> None:
"""CancelledError should be raisable and catchable."""
with pytest.raises(CancelledError) as exc_info:
raise CancelledError("Operation cancelled")
assert "Operation cancelled" in str(exc_info.value)
def test_can_be_raised_with_reason() -> None:
"""CancelledError should be raisable with a reason."""
with pytest.raises(CancelledError) as exc_info:
raise CancelledError(
"Parent was cancelled", reason=CancellationReason.PARENT_CANCELLED
)
assert exc_info.value.reason == CancellationReason.PARENT_CANCELLED
# Context var propagation
def test_context_var_default_is_none() -> None:
"""ctx_cancellation_token should default to None."""
assert ctx_cancellation_token.get() is None
def test_context_var_can_be_set_and_retrieved() -> None:
"""ctx_cancellation_token should be settable and retrievable."""
token = CancellationToken()
ctx_cancellation_token.set(token)
try:
assert ctx_cancellation_token.get() is token
finally:
ctx_cancellation_token.set(None)
@pytest.mark.asyncio
async def test_context_var_propagates_in_async() -> None:
"""ctx_cancellation_token should propagate in async context."""
token = CancellationToken()
ctx_cancellation_token.set(token)
async def check_token() -> CancellationToken | None:
return ctx_cancellation_token.get()
try:
retrieved = await check_token()
assert retrieved is token
finally:
ctx_cancellation_token.set(None)