mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-04-23 02:34:48 -05:00
Revert: Cancellation token Python changes (#3061)
* revert: cancellation token changes * fix: changelog * fix: add note on yank
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from hatchet_sdk import CancelledError, Context, EmptyModel, Hatchet
|
||||
from hatchet_sdk import Context, EmptyModel, Hatchet
|
||||
|
||||
hatchet = Hatchet(debug=True)
|
||||
|
||||
@@ -40,25 +40,6 @@ def check_flag(input: EmptyModel, ctx: Context) -> dict[str, str]:
|
||||
|
||||
|
||||
|
||||
# > Handling cancelled error
|
||||
@cancellation_workflow.task()
|
||||
async def my_task(input: EmptyModel, ctx: Context) -> dict[str, str]:
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except CancelledError as e:
|
||||
# Handle parent cancellation - i.e. perform cleanup, then re-raise
|
||||
print(f"Parent Task cancelled: {e.reason}")
|
||||
# Always re-raise CancelledError so Hatchet can properly handle the cancellation
|
||||
raise
|
||||
except Exception as e:
|
||||
# This will NOT catch CancelledError
|
||||
print(f"Other error: {e}")
|
||||
raise
|
||||
return {"error": "Task should have been cancelled"}
|
||||
|
||||
|
||||
|
||||
|
||||
def main() -> None:
|
||||
worker = hatchet.worker("cancellation-worker", workflows=[cancellation_workflow])
|
||||
worker.start()
|
||||
|
||||
@@ -10,7 +10,7 @@ def simple(input: EmptyModel, ctx: Context) -> dict[str, str]:
|
||||
|
||||
|
||||
@hatchet.durable_task()
|
||||
async def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]:
|
||||
def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]:
|
||||
return {"result": "Hello, world!"}
|
||||
|
||||
|
||||
|
||||
@@ -22,38 +22,15 @@ When a task is canceled, Hatchet sends a cancellation signal to the task. The ta
|
||||
|
||||
/>
|
||||
|
||||
### CancelledError Exception
|
||||
|
||||
When a sync task is cancelled while waiting for a child workflow or during a cancellation-aware operation, a `CancelledError` exception is raised.
|
||||
|
||||
<Callout type="warning">
|
||||
**Important:** `CancelledError` inherits from `BaseException`, not
|
||||
`Exception`. This means it will **not** be caught by bare `except Exception:`
|
||||
handlers. This is intentional and mirrors the behavior of Python's
|
||||
`asyncio.CancelledError`.
|
||||
</Callout>
|
||||
|
||||
<Snippet src={snippets.python.cancellation.worker.handling_cancelled_error} />
|
||||
|
||||
### Cancellation Reasons
|
||||
|
||||
The `CancelledError` includes a `reason` attribute that indicates why the cancellation occurred:
|
||||
|
||||
| Reason | Description |
|
||||
| --------------------------------------- | --------------------------------------------------------------------- |
|
||||
| `CancellationReason.USER_REQUESTED` | The user explicitly requested cancellation via `ctx.cancel()` |
|
||||
| `CancellationReason.WORKFLOW_CANCELLED` | The workflow run was cancelled (e.g., via API or concurrency control) |
|
||||
| `CancellationReason.PARENT_CANCELLED` | The parent workflow was cancelled while waiting for a child |
|
||||
| `CancellationReason.TIMEOUT` | The operation timed out |
|
||||
| `CancellationReason.UNKNOWN` | Unknown or unspecified reason |
|
||||
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab title="Typescript">
|
||||
<Snippet
|
||||
src={snippets.typescript.cancellations.workflow.declaring_a_task}
|
||||
|
||||
/>
|
||||
<Snippet
|
||||
src={snippets.typescript.cancellations.workflow.abort_signal}
|
||||
|
||||
/>
|
||||
|
||||
</Tabs.Tab>
|
||||
|
||||
@@ -5,12 +5,19 @@ All notable changes to Hatchet's Python SDK will be documented in this changelog
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [1.25.2] - 2026-02-19
|
||||
|
||||
### Fixed
|
||||
|
||||
- Reverts cancellation changes in 1.25.0 that introduced a regression
|
||||
|
||||
## [1.25.1] - 2026-02-17
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixes internal registration of durable slots
|
||||
|
||||
## [1.25.0] - 2026-02-17
|
||||
## [1.25.0] - 2026-02-17 **YANKED ON 2/19/26**
|
||||
|
||||
### Added
|
||||
|
||||
@@ -47,28 +54,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Adds type-hinted `Task.output_validator` and `Task.output_validator_type` properties to support easier type-safety and match the patterns on `BaseWorkflow/Standalone`.
|
||||
- Adds parameterized unit tests documenting current retry behavior of the Python SDK’s tenacity retry predicate for REST and gRPC errors.
|
||||
|
||||
|
||||
## [1.23.2] - 2026-02-11
|
||||
|
||||
### Changed
|
||||
|
||||
- Improves error handling for REST transport-level failures by raising typed exceptions for timeouts, connection, TLS, and protocol errors while preserving existing diagnostics.
|
||||
|
||||
|
||||
## [1.23.1] - 2026-02-10
|
||||
|
||||
### Changed
|
||||
|
||||
- Fixes a bug introduced in v1.21.0 where the `BaseWorkflow.input_validator` class property became incorrectly typed. Now separate properties are available for the type adapter and the underlying type.
|
||||
|
||||
|
||||
## [1.23.0] - 2026-02-05
|
||||
|
||||
### Internal Only
|
||||
|
||||
- Updated gRPC/REST contract field names to snake_case for consistency across SDKs.
|
||||
|
||||
|
||||
## [1.22.16] - 2026-02-05
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from hatchet_sdk import CancelledError, Context, EmptyModel, Hatchet
|
||||
from hatchet_sdk import Context, EmptyModel, Hatchet
|
||||
|
||||
hatchet = Hatchet(debug=True)
|
||||
|
||||
@@ -42,26 +42,6 @@ def check_flag(input: EmptyModel, ctx: Context) -> dict[str, str]:
|
||||
# !!
|
||||
|
||||
|
||||
# > Handling cancelled error
|
||||
@cancellation_workflow.task()
|
||||
async def my_task(input: EmptyModel, ctx: Context) -> dict[str, str]:
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except CancelledError as e:
|
||||
# Handle parent cancellation - i.e. perform cleanup, then re-raise
|
||||
print(f"Parent Task cancelled: {e.reason}")
|
||||
# Always re-raise CancelledError so Hatchet can properly handle the cancellation
|
||||
raise
|
||||
except Exception as e:
|
||||
# This will NOT catch CancelledError
|
||||
print(f"Other error: {e}")
|
||||
raise
|
||||
return {"error": "Task should have been cancelled"}
|
||||
|
||||
|
||||
# !!
|
||||
|
||||
|
||||
def main() -> None:
|
||||
worker = hatchet.worker("cancellation-worker", workflows=[cancellation_workflow])
|
||||
worker.start()
|
||||
|
||||
@@ -10,7 +10,7 @@ def simple(input: EmptyModel, ctx: Context) -> dict[str, str]:
|
||||
|
||||
|
||||
@hatchet.durable_task()
|
||||
async def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]:
|
||||
def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]:
|
||||
return {"result": "Hello, world!"}
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from hatchet_sdk.cancellation import CancellationToken
|
||||
from hatchet_sdk.clients.admin import (
|
||||
RunStatus,
|
||||
ScheduleTriggerWorkflowOptions,
|
||||
@@ -156,8 +155,6 @@ from hatchet_sdk.contracts.workflows_pb2 import (
|
||||
WorkerLabelComparator,
|
||||
)
|
||||
from hatchet_sdk.exceptions import (
|
||||
CancellationReason,
|
||||
CancelledError,
|
||||
DedupeViolationError,
|
||||
FailedTaskRunExceptionGroup,
|
||||
NonRetryableException,
|
||||
@@ -197,9 +194,6 @@ __all__ = [
|
||||
"CELEvaluationResult",
|
||||
"CELFailure",
|
||||
"CELSuccess",
|
||||
"CancellationReason",
|
||||
"CancellationToken",
|
||||
"CancelledError",
|
||||
"ClientConfig",
|
||||
"ClientTLSConfig",
|
||||
"ConcurrencyExpression",
|
||||
|
||||
@@ -1,197 +0,0 @@
|
||||
"""Cancellation token for coordinating cancellation across async and sync operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from hatchet_sdk.exceptions import CancellationReason
|
||||
from hatchet_sdk.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class CancellationToken:
|
||||
"""
|
||||
A token that can be used to signal cancellation across async and sync operations.
|
||||
|
||||
The token provides both asyncio and threading event primitives, allowing it to work
|
||||
seamlessly in both async and sync code paths. Child workflow run IDs can be registered
|
||||
with the token so they can be cancelled when the parent is cancelled.
|
||||
|
||||
Example:
|
||||
```python
|
||||
token = CancellationToken()
|
||||
|
||||
# In async code
|
||||
await token.aio_wait() # Blocks until cancelled
|
||||
|
||||
# In sync code
|
||||
token.wait(timeout=1.0) # Returns True if cancelled within timeout
|
||||
|
||||
# Check if cancelled
|
||||
if token.is_cancelled:
|
||||
raise CancelledError("Operation was cancelled")
|
||||
|
||||
# Trigger cancellation
|
||||
token.cancel()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._cancelled = False
|
||||
self._reason: CancellationReason | None = None
|
||||
self._async_event: asyncio.Event | None = None
|
||||
self._sync_event = threading.Event()
|
||||
self._child_run_ids: list[str] = []
|
||||
self._callbacks: list[Callable[[], None]] = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _get_async_event(self) -> asyncio.Event:
|
||||
"""Lazily create the asyncio event to avoid requiring an event loop at init time."""
|
||||
if self._async_event is None:
|
||||
self._async_event = asyncio.Event()
|
||||
# If already cancelled, set the event
|
||||
if self._cancelled:
|
||||
self._async_event.set()
|
||||
return self._async_event
|
||||
|
||||
def cancel(
|
||||
self, reason: CancellationReason = CancellationReason.TOKEN_CANCELLED
|
||||
) -> None:
|
||||
"""
|
||||
Trigger cancellation.
|
||||
|
||||
This will:
|
||||
- Set the cancelled flag and reason
|
||||
- Signal both async and sync events
|
||||
- Invoke all registered callbacks
|
||||
|
||||
Args:
|
||||
reason: The reason for cancellation.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._cancelled:
|
||||
logger.debug(
|
||||
f"CancellationToken: cancel() called but already cancelled, "
|
||||
f"reason={self._reason.value if self._reason else 'none'}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"CancellationToken: cancel() called, reason={reason.value}, "
|
||||
f"{len(self._child_run_ids)} children registered"
|
||||
)
|
||||
|
||||
self._cancelled = True
|
||||
self._reason = reason
|
||||
|
||||
# Signal both event types
|
||||
if self._async_event is not None:
|
||||
self._async_event.set()
|
||||
self._sync_event.set()
|
||||
|
||||
# Snapshot callbacks under the lock, invoke outside to avoid deadlocks
|
||||
callbacks = list(self._callbacks)
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
logger.debug(f"CancellationToken: invoking callback {callback}")
|
||||
callback()
|
||||
except Exception as e: # noqa: PERF203
|
||||
logger.warning(f"CancellationToken: callback raised exception: {e}")
|
||||
|
||||
logger.debug(f"CancellationToken: cancel() complete, reason={reason.value}")
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been triggered."""
|
||||
return self._cancelled
|
||||
|
||||
@property
|
||||
def reason(self) -> CancellationReason | None:
|
||||
"""Get the reason for cancellation, or None if not cancelled."""
|
||||
return self._reason
|
||||
|
||||
async def aio_wait(self) -> None:
|
||||
"""
|
||||
Await until cancelled (for use in asyncio).
|
||||
|
||||
This will block until cancel() is called.
|
||||
"""
|
||||
await self._get_async_event().wait()
|
||||
logger.debug(
|
||||
f"CancellationToken: async wait completed (cancelled), "
|
||||
f"reason={self._reason.value if self._reason else 'none'}"
|
||||
)
|
||||
|
||||
def wait(self, timeout: float | None = None) -> bool:
|
||||
"""
|
||||
Block until cancelled (for use in sync code).
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||
|
||||
Returns:
|
||||
True if the token was cancelled (event was set), False if timeout expired.
|
||||
"""
|
||||
result = self._sync_event.wait(timeout)
|
||||
if result:
|
||||
logger.debug(
|
||||
f"CancellationToken: sync wait interrupted by cancellation, "
|
||||
f"reason={self._reason.value if self._reason else 'none'}"
|
||||
)
|
||||
return result
|
||||
|
||||
def register_child(self, run_id: str) -> None:
|
||||
"""
|
||||
Register a child workflow run ID with this token.
|
||||
|
||||
When the parent is cancelled, these child run IDs can be used to cancel
|
||||
the child workflows as well.
|
||||
|
||||
Args:
|
||||
run_id: The workflow run ID of the child workflow.
|
||||
"""
|
||||
with self._lock:
|
||||
logger.debug(f"CancellationToken: registering child workflow {run_id}")
|
||||
self._child_run_ids.append(run_id)
|
||||
|
||||
@property
|
||||
def child_run_ids(self) -> list[str]:
|
||||
"""The registered child workflow run IDs."""
|
||||
return self._child_run_ids
|
||||
|
||||
def add_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""
|
||||
Register a callback to be invoked when cancellation is triggered.
|
||||
|
||||
If the token is already cancelled, the callback will be invoked immediately.
|
||||
|
||||
Args:
|
||||
callback: A callable that takes no arguments.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._cancelled:
|
||||
invoke_now = True
|
||||
else:
|
||||
invoke_now = False
|
||||
self._callbacks.append(callback)
|
||||
|
||||
if invoke_now:
|
||||
logger.debug(
|
||||
f"CancellationToken: invoking callback immediately (already cancelled): {callback}"
|
||||
)
|
||||
try:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.warning(f"CancellationToken: callback raised exception: {e}")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"CancellationToken(cancelled={self._cancelled}, "
|
||||
f"children={len(self._child_run_ids)}, callbacks={len(self._callbacks)})"
|
||||
)
|
||||
@@ -1,9 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Generic, Literal, TypeVar
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
import grpc
|
||||
import grpc.aio
|
||||
@@ -16,10 +14,6 @@ from hatchet_sdk.clients.event_ts import (
|
||||
from hatchet_sdk.config import ClientConfig
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.metadata import get_metadata
|
||||
from hatchet_sdk.utils.cancellation import race_against_token
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from hatchet_sdk.cancellation import CancellationToken
|
||||
|
||||
DEFAULT_LISTENER_RETRY_INTERVAL = 3 # seconds
|
||||
DEFAULT_LISTENER_RETRY_COUNT = 5
|
||||
@@ -42,7 +36,7 @@ class Subscription(Generic[T]):
|
||||
self.id = id
|
||||
self.queue: asyncio.Queue[T | SentinelValue] = asyncio.Queue()
|
||||
|
||||
async def __aiter__(self) -> Subscription[T]:
|
||||
async def __aiter__(self) -> "Subscription[T]":
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> T | SentinelValue:
|
||||
@@ -205,17 +199,7 @@ class PooledListener(Generic[R, T, L], ABC):
|
||||
del self.from_subscriptions[subscription_id]
|
||||
del self.events[subscription_id]
|
||||
|
||||
async def subscribe(
|
||||
self, id: str, cancellation_token: CancellationToken | None = None
|
||||
) -> T:
|
||||
"""
|
||||
Subscribe to events for the given ID.
|
||||
|
||||
:param id: The ID to subscribe to (e.g., workflow run ID).
|
||||
:param cancellation_token: Optional cancellation token to abort the subscription wait.
|
||||
:return: The event received for this ID.
|
||||
:raises asyncio.CancelledError: If the cancellation token is triggered or if externally cancelled.
|
||||
"""
|
||||
async def subscribe(self, id: str) -> T:
|
||||
subscription_id: int | None = None
|
||||
|
||||
try:
|
||||
@@ -237,17 +221,8 @@ class PooledListener(Generic[R, T, L], ABC):
|
||||
if not self.listener_task or self.listener_task.done():
|
||||
self.listener_task = asyncio.create_task(self._init_producer())
|
||||
|
||||
logger.debug(
|
||||
f"PooledListener.subscribe: waiting for event on id={id}, "
|
||||
f"subscription_id={subscription_id}, token={cancellation_token is not None}"
|
||||
)
|
||||
|
||||
if cancellation_token:
|
||||
result_task = asyncio.create_task(self.events[subscription_id].get())
|
||||
return await race_against_token(result_task, cancellation_token)
|
||||
return await self.events[subscription_id].get()
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"PooledListener.subscribe: externally cancelled for id={id}")
|
||||
raise
|
||||
finally:
|
||||
if subscription_id:
|
||||
|
||||
@@ -52,7 +52,7 @@ class HealthcheckConfig(BaseSettings):
|
||||
if isinstance(value, timedelta):
|
||||
return value
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
if isinstance(value, int | float):
|
||||
return timedelta(seconds=float(value))
|
||||
|
||||
v = value.strip()
|
||||
@@ -135,37 +135,6 @@ class ClientConfig(BaseSettings):
|
||||
force_shutdown_on_shutdown_signal: bool = False
|
||||
tenacity: TenacityConfig = TenacityConfig()
|
||||
|
||||
# Cancellation configuration
|
||||
cancellation_grace_period: timedelta = Field(
|
||||
default=timedelta(milliseconds=1000),
|
||||
description="The maximum time to wait for a task to complete after cancellation is triggered before force-cancelling. Value is interpreted as seconds when provided as int/float.",
|
||||
)
|
||||
cancellation_warning_threshold: timedelta = Field(
|
||||
default=timedelta(milliseconds=300),
|
||||
description="If a task has not completed cancellation within this duration, a warning will be logged. Value is interpreted as seconds when provided as int/float.",
|
||||
)
|
||||
|
||||
@field_validator(
|
||||
"cancellation_grace_period", "cancellation_warning_threshold", mode="before"
|
||||
)
|
||||
@classmethod
|
||||
def validate_cancellation_timedelta(
|
||||
cls, value: timedelta | int | float | str
|
||||
) -> timedelta:
|
||||
"""Convert int/float/string to timedelta, interpreting as seconds."""
|
||||
if isinstance(value, timedelta):
|
||||
return value
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
return timedelta(seconds=float(value))
|
||||
|
||||
v = value.strip()
|
||||
# Allow a small convenience suffix, but keep "seconds" as the contract.
|
||||
if v.endswith("s"):
|
||||
v = v[:-1].strip()
|
||||
|
||||
return timedelta(seconds=float(v))
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_token_and_tenant(self) -> "ClientConfig":
|
||||
if not self.token:
|
||||
|
||||
@@ -4,7 +4,6 @@ from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from warnings import warn
|
||||
|
||||
from hatchet_sdk.cancellation import CancellationToken
|
||||
from hatchet_sdk.clients.admin import AdminClient
|
||||
from hatchet_sdk.clients.dispatcher.dispatcher import ( # type: ignore[attr-defined]
|
||||
Action,
|
||||
@@ -22,10 +21,9 @@ from hatchet_sdk.conditions import (
|
||||
flatten_conditions,
|
||||
)
|
||||
from hatchet_sdk.context.worker_context import WorkerContext
|
||||
from hatchet_sdk.exceptions import CancellationReason, TaskRunError
|
||||
from hatchet_sdk.exceptions import TaskRunError
|
||||
from hatchet_sdk.features.runs import RunsClient
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.utils.cancellation import await_with_cancellation
|
||||
from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping, LogLevel
|
||||
from hatchet_sdk.worker.runner.utils.capture_logs import AsyncLogSender, LogRecord
|
||||
@@ -58,7 +56,7 @@ class Context:
|
||||
self.action = action
|
||||
|
||||
self.step_run_id = action.step_run_id
|
||||
self.cancellation_token = CancellationToken()
|
||||
self.exit_flag = False
|
||||
self.dispatcher_client = dispatcher_client
|
||||
self.admin_client = admin_client
|
||||
self.event_client = event_client
|
||||
@@ -76,31 +74,6 @@ class Context:
|
||||
self._workflow_name = workflow_name
|
||||
self._task_name = task_name
|
||||
|
||||
@property
|
||||
def exit_flag(self) -> bool:
|
||||
"""
|
||||
Check if the cancellation flag has been set.
|
||||
|
||||
This property is maintained for backwards compatibility.
|
||||
Use `cancellation_token.is_cancelled` for new code.
|
||||
|
||||
:return: True if the task has been cancelled, False otherwise.
|
||||
"""
|
||||
return self.cancellation_token.is_cancelled
|
||||
|
||||
@exit_flag.setter
|
||||
def exit_flag(self, value: bool) -> None:
|
||||
"""
|
||||
Set the cancellation flag.
|
||||
|
||||
This setter is maintained for backwards compatibility.
|
||||
Setting to True will trigger the cancellation token.
|
||||
|
||||
:param value: True to trigger cancellation, False is a no-op.
|
||||
"""
|
||||
if value:
|
||||
self.cancellation_token.cancel(CancellationReason.USER_REQUESTED)
|
||||
|
||||
def _increment_stream_index(self) -> int:
|
||||
index = self.stream_index
|
||||
self.stream_index += 1
|
||||
@@ -196,25 +169,8 @@ class Context:
|
||||
"""
|
||||
return self.action.workflow_run_id
|
||||
|
||||
def _set_cancellation_flag(
|
||||
self, reason: CancellationReason = CancellationReason.WORKFLOW_CANCELLED
|
||||
) -> None:
|
||||
"""
|
||||
Internal method to trigger cancellation.
|
||||
|
||||
This triggers the cancellation token, which will:
|
||||
- Signal all waiters (async and sync)
|
||||
- Set the exit_flag property to True
|
||||
- Allow child workflow cancellation
|
||||
|
||||
Args:
|
||||
reason: The reason for cancellation.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Context: setting cancellation flag for step_run_id={self.step_run_id}, "
|
||||
f"reason={reason.value}"
|
||||
)
|
||||
self.cancellation_token.cancel(reason)
|
||||
def _set_cancellation_flag(self) -> None:
|
||||
self.exit_flag = True
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""
|
||||
@@ -222,11 +178,9 @@ class Context:
|
||||
|
||||
:return: None
|
||||
"""
|
||||
logger.debug(
|
||||
f"Context: cancel() called for task_run_external_id={self.step_run_id}"
|
||||
)
|
||||
logger.debug("cancelling step...")
|
||||
self.runs_client.cancel(self.step_run_id)
|
||||
self._set_cancellation_flag(CancellationReason.USER_REQUESTED)
|
||||
self._set_cancellation_flag()
|
||||
|
||||
async def aio_cancel(self) -> None:
|
||||
"""
|
||||
@@ -234,11 +188,9 @@ class Context:
|
||||
|
||||
:return: None
|
||||
"""
|
||||
logger.debug(
|
||||
f"Context: aio_cancel() called for task_run_external_id={self.step_run_id}"
|
||||
)
|
||||
logger.debug("cancelling step...")
|
||||
await self.runs_client.aio_cancel(self.step_run_id)
|
||||
self._set_cancellation_flag(CancellationReason.USER_REQUESTED)
|
||||
self._set_cancellation_flag()
|
||||
|
||||
def done(self) -> bool:
|
||||
"""
|
||||
@@ -530,11 +482,8 @@ class DurableContext(Context):
|
||||
"""
|
||||
Durably wait for either a sleep or an event.
|
||||
|
||||
This method respects the context's cancellation token. If the task is cancelled
|
||||
while waiting, an asyncio.CancelledError will be raised.
|
||||
|
||||
:param signal_key: The key to use for the durable event. This is used to identify the event in the Hatchet API.
|
||||
:param \\*conditions: The conditions to wait for. Can be a SleepCondition or UserEventCondition.
|
||||
:param *conditions: The conditions to wait for. Can be a SleepCondition or UserEventCondition.
|
||||
|
||||
:return: A dictionary containing the results of the wait.
|
||||
:raises ValueError: If the durable event listener is not available.
|
||||
@@ -544,10 +493,6 @@ class DurableContext(Context):
|
||||
|
||||
task_id = self.step_run_id
|
||||
|
||||
logger.debug(
|
||||
f"DurableContext.aio_wait_for: waiting for signal_key={signal_key}, task_id={task_id}"
|
||||
)
|
||||
|
||||
request = RegisterDurableEventRequest(
|
||||
task_id=task_id,
|
||||
signal_key=signal_key,
|
||||
@@ -557,29 +502,19 @@ class DurableContext(Context):
|
||||
|
||||
self.durable_event_listener.register_durable_event(request)
|
||||
|
||||
# Use await_with_cancellation to respect the cancellation token
|
||||
return await await_with_cancellation(
|
||||
self.durable_event_listener.result(task_id, signal_key),
|
||||
self.cancellation_token,
|
||||
return await self.durable_event_listener.result(
|
||||
task_id,
|
||||
signal_key,
|
||||
)
|
||||
|
||||
async def aio_sleep_for(self, duration: Duration) -> dict[str, Any]:
|
||||
"""
|
||||
Lightweight wrapper for durable sleep. Allows for shorthand usage of `ctx.aio_wait_for` when specifying a sleep condition.
|
||||
|
||||
This method respects the context's cancellation token. If the task is cancelled
|
||||
while sleeping, an asyncio.CancelledError will be raised.
|
||||
|
||||
For more complicated conditions, use `ctx.aio_wait_for` directly.
|
||||
|
||||
:param duration: The duration to sleep for.
|
||||
:return: A dictionary containing the results of the wait.
|
||||
"""
|
||||
wait_index = self._increment_wait_index()
|
||||
|
||||
logger.debug(
|
||||
f"DurableContext.aio_sleep_for: sleeping for {duration}, wait_index={wait_index}"
|
||||
)
|
||||
wait_index = self._increment_wait_index()
|
||||
|
||||
return await self.aio_wait_for(
|
||||
f"sleep:{timedelta_to_expr(duration)}-{wait_index}",
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
|
||||
@@ -171,54 +170,3 @@ class IllegalTaskOutputError(Exception):
|
||||
|
||||
class LifespanSetupError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CancellationReason(Enum):
|
||||
"""Reason for cancellation of an operation."""
|
||||
|
||||
USER_REQUESTED = "user_requested"
|
||||
"""The user explicitly requested cancellation."""
|
||||
|
||||
TIMEOUT = "timeout"
|
||||
"""The operation timed out."""
|
||||
|
||||
PARENT_CANCELLED = "parent_cancelled"
|
||||
"""The parent workflow or task was cancelled."""
|
||||
|
||||
WORKFLOW_CANCELLED = "workflow_cancelled"
|
||||
"""The workflow run was cancelled."""
|
||||
|
||||
TOKEN_CANCELLED = "token_cancelled"
|
||||
"""The cancellation token was cancelled."""
|
||||
|
||||
|
||||
class CancelledError(BaseException):
|
||||
"""
|
||||
Raised when an operation is cancelled via CancellationToken.
|
||||
|
||||
This exception inherits from BaseException (not Exception) so that it
|
||||
won't be caught by bare `except Exception:` handlers. This mirrors the
|
||||
behavior of asyncio.CancelledError in Python 3.8+.
|
||||
|
||||
To catch this exception, use:
|
||||
- `except CancelledError:` (recommended)
|
||||
- `except BaseException:` (catches all exceptions)
|
||||
|
||||
This exception is used for sync code paths. For async code paths,
|
||||
asyncio.CancelledError is used instead.
|
||||
|
||||
:param message: Optional message describing the cancellation.
|
||||
:param reason: Optional enum indicating the reason for cancellation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Operation cancelled",
|
||||
reason: CancellationReason | None = None,
|
||||
) -> None:
|
||||
self.reason = reason
|
||||
super().__init__(message)
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return str(self.args[0]) if self.args else "Operation cancelled"
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from collections import Counter
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from hatchet_sdk.runnables.action import ActionKey
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from hatchet_sdk.cancellation import CancellationToken
|
||||
|
||||
ctx_workflow_run_id: ContextVar[str | None] = ContextVar(
|
||||
"ctx_workflow_run_id", default=None
|
||||
)
|
||||
@@ -26,9 +20,6 @@ ctx_additional_metadata: ContextVar[JSONSerializableMapping | None] = ContextVar
|
||||
ctx_task_retry_count: ContextVar[int | None] = ContextVar(
|
||||
"ctx_task_retry_count", default=0
|
||||
)
|
||||
ctx_cancellation_token: ContextVar[CancellationToken | None] = ContextVar(
|
||||
"ctx_cancellation_token", default=None
|
||||
)
|
||||
|
||||
workflow_spawn_indices = Counter[ActionKey]()
|
||||
spawn_index_lock = asyncio.Lock()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
@@ -39,11 +37,8 @@ from hatchet_sdk.contracts.v1.workflows_pb2 import (
|
||||
)
|
||||
from hatchet_sdk.contracts.v1.workflows_pb2 import StickyStrategy as StickyStrategyProto
|
||||
from hatchet_sdk.contracts.workflows_pb2 import WorkflowVersion
|
||||
from hatchet_sdk.exceptions import CancellationReason, CancelledError
|
||||
from hatchet_sdk.labels import DesiredWorkerLabel
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.rate_limit import RateLimit
|
||||
from hatchet_sdk.runnables.contextvars import ctx_cancellation_token
|
||||
from hatchet_sdk.runnables.task import Task
|
||||
from hatchet_sdk.runnables.types import (
|
||||
ConcurrencyExpression,
|
||||
@@ -57,7 +52,6 @@ from hatchet_sdk.runnables.types import (
|
||||
normalize_validator,
|
||||
)
|
||||
from hatchet_sdk.serde import HATCHET_PYDANTIC_SENTINEL
|
||||
from hatchet_sdk.utils.cancellation import await_with_cancellation
|
||||
from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto
|
||||
from hatchet_sdk.utils.timedelta_to_expression import Duration
|
||||
from hatchet_sdk.utils.typing import CoroutineLike, JSONSerializableMapping
|
||||
@@ -65,7 +59,6 @@ from hatchet_sdk.workflow_run import WorkflowRunRef
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from hatchet_sdk import Hatchet
|
||||
from hatchet_sdk.cancellation import CancellationToken
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -95,7 +88,7 @@ class ComputedTaskParameters(BaseModel):
|
||||
task_defaults: TaskDefaults
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_params(self) -> ComputedTaskParameters:
|
||||
def validate_params(self) -> "ComputedTaskParameters":
|
||||
self.execution_timeout = fall_back_to_default(
|
||||
value=self.execution_timeout,
|
||||
param_default=timedelta(seconds=60),
|
||||
@@ -143,7 +136,7 @@ class TypedTriggerWorkflowRunConfig(BaseModel, Generic[TWorkflowInput]):
|
||||
|
||||
|
||||
class BaseWorkflow(Generic[TWorkflowInput]):
|
||||
def __init__(self, config: WorkflowConfig, client: Hatchet) -> None:
|
||||
def __init__(self, config: WorkflowConfig, client: "Hatchet") -> None:
|
||||
self.config = config
|
||||
self._default_tasks: list[Task[TWorkflowInput, Any]] = []
|
||||
self._durable_tasks: list[Task[TWorkflowInput, Any]] = []
|
||||
@@ -632,38 +625,6 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
||||
and can be arranged into complex dependency patterns.
|
||||
"""
|
||||
|
||||
def _resolve_check_cancellation_token(self) -> CancellationToken | None:
|
||||
cancellation_token = ctx_cancellation_token.get()
|
||||
|
||||
if cancellation_token and cancellation_token.is_cancelled:
|
||||
raise CancelledError(
|
||||
"Operation cancelled by cancellation token",
|
||||
reason=CancellationReason.TOKEN_CANCELLED,
|
||||
)
|
||||
|
||||
return cancellation_token
|
||||
|
||||
def _register_child_with_token(
|
||||
self,
|
||||
cancellation_token: CancellationToken | None,
|
||||
workflow_run_id: str,
|
||||
) -> None:
|
||||
if not cancellation_token:
|
||||
return
|
||||
|
||||
cancellation_token.register_child(workflow_run_id)
|
||||
|
||||
def _register_children_with_token(
|
||||
self,
|
||||
cancellation_token: CancellationToken | None,
|
||||
refs: list[WorkflowRunRef],
|
||||
) -> None:
|
||||
if not cancellation_token:
|
||||
return
|
||||
|
||||
for ref in refs:
|
||||
cancellation_token.register_child(ref.workflow_run_id)
|
||||
|
||||
def run_no_wait(
|
||||
self,
|
||||
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
||||
@@ -673,34 +634,17 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
||||
Synchronously trigger a workflow run without waiting for it to complete.
|
||||
This method is useful for starting a workflow run and immediately returning a reference to the run without blocking while the workflow runs.
|
||||
|
||||
If a cancellation token is available via context, the child workflow will be registered
|
||||
with the token.
|
||||
|
||||
:param input: The input data for the workflow.
|
||||
:param options: Additional options for workflow execution.
|
||||
|
||||
:returns: A `WorkflowRunRef` object representing the reference to the workflow run.
|
||||
"""
|
||||
cancellation_token = self._resolve_check_cancellation_token()
|
||||
|
||||
logger.debug(
|
||||
f"Workflow.run_no_wait: triggering {self.config.name}, "
|
||||
f"token={cancellation_token is not None}"
|
||||
)
|
||||
|
||||
ref = self.client._client.admin.run_workflow(
|
||||
return self.client._client.admin.run_workflow(
|
||||
workflow_name=self.config.name,
|
||||
input=self._serialize_input(input),
|
||||
options=self._create_options_with_combined_additional_meta(options),
|
||||
)
|
||||
|
||||
self._register_child_with_token(
|
||||
cancellation_token,
|
||||
ref.workflow_run_id,
|
||||
)
|
||||
|
||||
return ref
|
||||
|
||||
def run(
|
||||
self,
|
||||
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
||||
@@ -710,19 +654,12 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
||||
Run the workflow synchronously and wait for it to complete.
|
||||
|
||||
This method triggers a workflow run, blocks until completion, and returns the final result.
|
||||
If a cancellation token is available via context, the wait can be interrupted.
|
||||
|
||||
:param input: The input data for the workflow, must match the workflow's input type.
|
||||
:param options: Additional options for workflow execution like metadata and parent workflow ID.
|
||||
|
||||
:returns: The result of the workflow execution as a dictionary.
|
||||
"""
|
||||
cancellation_token = self._resolve_check_cancellation_token()
|
||||
|
||||
logger.debug(
|
||||
f"Workflow.run: triggering {self.config.name}, "
|
||||
f"token={cancellation_token is not None}"
|
||||
)
|
||||
|
||||
ref = self.client._client.admin.run_workflow(
|
||||
workflow_name=self.config.name,
|
||||
@@ -730,14 +667,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
||||
options=self._create_options_with_combined_additional_meta(options),
|
||||
)
|
||||
|
||||
self._register_child_with_token(
|
||||
cancellation_token,
|
||||
ref.workflow_run_id,
|
||||
)
|
||||
|
||||
logger.debug(f"Workflow.run: awaiting result for {ref.workflow_run_id}")
|
||||
|
||||
return ref.result(cancellation_token=cancellation_token)
|
||||
return ref.result()
|
||||
|
||||
async def aio_run_no_wait(
|
||||
self,
|
||||
@@ -748,34 +678,18 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
||||
Asynchronously trigger a workflow run without waiting for it to complete.
|
||||
This method is useful for starting a workflow run and immediately returning a reference to the run without blocking while the workflow runs.
|
||||
|
||||
If a cancellation token is available via context, the child workflow will be registered
|
||||
with the token.
|
||||
|
||||
:param input: The input data for the workflow.
|
||||
:param options: Additional options for workflow execution.
|
||||
|
||||
:returns: A `WorkflowRunRef` object representing the reference to the workflow run.
|
||||
"""
|
||||
cancellation_token = self._resolve_check_cancellation_token()
|
||||
|
||||
logger.debug(
|
||||
f"Workflow.aio_run_no_wait: triggering {self.config.name}, "
|
||||
f"token={cancellation_token is not None}"
|
||||
)
|
||||
|
||||
ref = await self.client._client.admin.aio_run_workflow(
|
||||
return await self.client._client.admin.aio_run_workflow(
|
||||
workflow_name=self.config.name,
|
||||
input=self._serialize_input(input),
|
||||
options=self._create_options_with_combined_additional_meta(options),
|
||||
)
|
||||
|
||||
self._register_child_with_token(
|
||||
cancellation_token,
|
||||
ref.workflow_run_id,
|
||||
)
|
||||
|
||||
return ref
|
||||
|
||||
async def aio_run(
|
||||
self,
|
||||
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
||||
@@ -785,47 +699,25 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
||||
Run the workflow asynchronously and wait for it to complete.
|
||||
|
||||
This method triggers a workflow run, awaits until completion, and returns the final result.
|
||||
If a cancellation token is available via context, the wait can be interrupted.
|
||||
|
||||
:param input: The input data for the workflow, must match the workflow's input type.
|
||||
:param options: Additional options for workflow execution like metadata and parent workflow ID.
|
||||
|
||||
:returns: The result of the workflow execution as a dictionary.
|
||||
"""
|
||||
cancellation_token = self._resolve_check_cancellation_token()
|
||||
|
||||
logger.debug(
|
||||
f"Workflow.aio_run: triggering {self.config.name}, "
|
||||
f"token={cancellation_token is not None}"
|
||||
)
|
||||
|
||||
ref = await self.client._client.admin.aio_run_workflow(
|
||||
workflow_name=self.config.name,
|
||||
input=self._serialize_input(input),
|
||||
options=self._create_options_with_combined_additional_meta(options),
|
||||
)
|
||||
|
||||
self._register_child_with_token(
|
||||
cancellation_token,
|
||||
ref.workflow_run_id,
|
||||
)
|
||||
|
||||
logger.debug(f"Workflow.aio_run: awaiting result for {ref.workflow_run_id}")
|
||||
|
||||
return await await_with_cancellation(
|
||||
ref.aio_result(),
|
||||
cancellation_token,
|
||||
)
|
||||
return await ref.aio_result()
|
||||
|
||||
def _get_result(
|
||||
self,
|
||||
ref: WorkflowRunRef,
|
||||
return_exceptions: bool,
|
||||
self, ref: WorkflowRunRef, return_exceptions: bool
|
||||
) -> dict[str, Any] | BaseException:
|
||||
try:
|
||||
return ref.result(
|
||||
cancellation_token=self._resolve_check_cancellation_token()
|
||||
)
|
||||
return ref.result()
|
||||
except Exception as e:
|
||||
if return_exceptions:
|
||||
return e
|
||||
@@ -854,52 +746,15 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
||||
Run a workflow in bulk and wait for all runs to complete.
|
||||
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
||||
|
||||
If a cancellation token is available via context, all child workflows will be registered
|
||||
with the token and the wait can be interrupted.
|
||||
|
||||
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
||||
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
||||
:returns: A list of results for each workflow run.
|
||||
:raises CancelledError: If the cancellation token is triggered (and return_exceptions is False).
|
||||
:raises Exception: If a workflow run fails (and return_exceptions is False).
|
||||
"""
|
||||
cancellation_token = self._resolve_check_cancellation_token()
|
||||
|
||||
refs = self.client._client.admin.run_workflows(
|
||||
workflows=workflows,
|
||||
)
|
||||
|
||||
self._register_children_with_token(
|
||||
cancellation_token,
|
||||
refs,
|
||||
)
|
||||
|
||||
# Pass cancellation_token through to each result() call
|
||||
# The cancellation check happens INSIDE result()'s polling loop
|
||||
results: list[dict[str, Any] | BaseException] = []
|
||||
for ref in refs:
|
||||
try:
|
||||
results.append(ref.result(cancellation_token=cancellation_token))
|
||||
except CancelledError: # noqa: PERF203
|
||||
logger.debug(
|
||||
f"Workflow.run_many: cancellation detected, stopping wait, "
|
||||
f"reason={CancellationReason.PARENT_CANCELLED.value}"
|
||||
)
|
||||
if return_exceptions:
|
||||
results.append(
|
||||
CancelledError(
|
||||
"Operation cancelled by cancellation token",
|
||||
reason=CancellationReason.PARENT_CANCELLED,
|
||||
)
|
||||
)
|
||||
break
|
||||
raise
|
||||
except Exception as e:
|
||||
if return_exceptions:
|
||||
results.append(e)
|
||||
else:
|
||||
raise
|
||||
return results
|
||||
return [self._get_result(ref, return_exceptions) for ref in refs]
|
||||
|
||||
@overload
|
||||
async def aio_run_many(
|
||||
@@ -924,34 +779,16 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
||||
Run a workflow in bulk and wait for all runs to complete.
|
||||
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
||||
|
||||
If a cancellation token is available via context, all child workflows will be registered
|
||||
with the token and the wait can be interrupted.
|
||||
|
||||
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
||||
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
||||
:returns: A list of results for each workflow run.
|
||||
"""
|
||||
cancellation_token = self._resolve_check_cancellation_token()
|
||||
|
||||
logger.debug(
|
||||
f"Workflow.aio_run_many: triggering {len(workflows)} workflows, "
|
||||
f"token={cancellation_token is not None}"
|
||||
)
|
||||
|
||||
refs = await self.client._client.admin.aio_run_workflows(
|
||||
workflows=workflows,
|
||||
)
|
||||
|
||||
self._register_children_with_token(
|
||||
cancellation_token,
|
||||
refs,
|
||||
)
|
||||
|
||||
return await await_with_cancellation(
|
||||
asyncio.gather(
|
||||
*[ref.aio_result() for ref in refs], return_exceptions=return_exceptions
|
||||
),
|
||||
cancellation_token,
|
||||
return await asyncio.gather(
|
||||
*[ref.aio_result() for ref in refs], return_exceptions=return_exceptions
|
||||
)
|
||||
|
||||
def run_many_no_wait(
|
||||
@@ -963,30 +800,13 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
||||
|
||||
This method triggers multiple workflow runs and immediately returns a list of references to the runs without blocking while the workflows run.
|
||||
|
||||
If a cancellation token is available via context, all child workflows will be registered
|
||||
with the token.
|
||||
|
||||
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
||||
:returns: A list of `WorkflowRunRef` objects, each representing a reference to a workflow run.
|
||||
"""
|
||||
cancellation_token = self._resolve_check_cancellation_token()
|
||||
|
||||
logger.debug(
|
||||
f"Workflow.run_many_no_wait: triggering {len(workflows)} workflows, "
|
||||
f"token={cancellation_token is not None}"
|
||||
)
|
||||
|
||||
refs = self.client._client.admin.run_workflows(
|
||||
return self.client._client.admin.run_workflows(
|
||||
workflows=workflows,
|
||||
)
|
||||
|
||||
self._register_children_with_token(
|
||||
cancellation_token,
|
||||
refs,
|
||||
)
|
||||
|
||||
return refs
|
||||
|
||||
async def aio_run_many_no_wait(
|
||||
self,
|
||||
workflows: list[WorkflowRunTriggerConfig],
|
||||
@@ -996,31 +816,14 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
||||
|
||||
This method triggers multiple workflow runs and immediately returns a list of references to the runs without blocking while the workflows run.
|
||||
|
||||
If a cancellation token is available via context, all child workflows will be registered
|
||||
with the token.
|
||||
|
||||
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
||||
|
||||
:returns: A list of `WorkflowRunRef` objects, each representing a reference to a workflow run.
|
||||
"""
|
||||
cancellation_token = self._resolve_check_cancellation_token()
|
||||
|
||||
logger.debug(
|
||||
f"Workflow.aio_run_many_no_wait: triggering {len(workflows)} workflows, "
|
||||
f"token={cancellation_token is not None}"
|
||||
)
|
||||
|
||||
refs = await self.client._client.admin.aio_run_workflows(
|
||||
return await self.client._client.admin.aio_run_workflows(
|
||||
workflows=workflows,
|
||||
)
|
||||
|
||||
self._register_children_with_token(
|
||||
cancellation_token,
|
||||
refs,
|
||||
)
|
||||
|
||||
return refs
|
||||
|
||||
def _parse_task_name(
|
||||
self,
|
||||
name: str | None,
|
||||
@@ -1365,7 +1168,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
||||
|
||||
return inner
|
||||
|
||||
def add_task(self, task: Standalone[TWorkflowInput, Any]) -> None:
|
||||
def add_task(self, task: "Standalone[TWorkflowInput, Any]") -> None:
|
||||
"""
|
||||
Add a task to a workflow. Intended to be used with a previously existing task (a Standalone),
|
||||
such as one created with `@hatchet.task()`, which has been converted to a `Task` object using `to_task`.
|
||||
@@ -1404,7 +1207,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
||||
class TaskRunRef(Generic[TWorkflowInput, R]):
|
||||
def __init__(
|
||||
self,
|
||||
standalone: Standalone[TWorkflowInput, R],
|
||||
standalone: "Standalone[TWorkflowInput, R]",
|
||||
workflow_run_ref: WorkflowRunRef,
|
||||
):
|
||||
self._s = standalone
|
||||
@@ -1563,9 +1366,7 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
||||
) -> list[R]: ...
|
||||
|
||||
def run_many(
|
||||
self,
|
||||
workflows: list[WorkflowRunTriggerConfig],
|
||||
return_exceptions: bool = False,
|
||||
self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
|
||||
) -> list[R] | list[R | BaseException]:
|
||||
"""
|
||||
Run a workflow in bulk and wait for all runs to complete.
|
||||
@@ -1599,9 +1400,7 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
||||
) -> list[R]: ...
|
||||
|
||||
async def aio_run_many(
|
||||
self,
|
||||
workflows: list[WorkflowRunTriggerConfig],
|
||||
return_exceptions: bool = False,
|
||||
self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
|
||||
) -> list[R] | list[R | BaseException]:
|
||||
"""
|
||||
Run a workflow in bulk and wait for all runs to complete.
|
||||
@@ -1621,8 +1420,7 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
||||
]
|
||||
|
||||
def run_many_no_wait(
|
||||
self,
|
||||
workflows: list[WorkflowRunTriggerConfig],
|
||||
self, workflows: list[WorkflowRunTriggerConfig]
|
||||
) -> list[TaskRunRef[TWorkflowInput, R]]:
|
||||
"""
|
||||
Run a workflow in bulk without waiting for all runs to complete.
|
||||
@@ -1637,8 +1435,7 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
||||
return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
|
||||
|
||||
async def aio_run_many_no_wait(
|
||||
self,
|
||||
workflows: list[WorkflowRunTriggerConfig],
|
||||
self, workflows: list[WorkflowRunTriggerConfig]
|
||||
) -> list[TaskRunRef[TWorkflowInput, R]]:
|
||||
"""
|
||||
Run a workflow in bulk without waiting for all runs to complete.
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
"""Utilities for cancellation-aware operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from hatchet_sdk.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from hatchet_sdk.cancellation import CancellationToken
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def _invoke_cancel_callback(
|
||||
cancel_callback: Callable[[], Awaitable[None]] | None,
|
||||
) -> None:
|
||||
"""Invoke a cancel callback."""
|
||||
if not cancel_callback:
|
||||
return
|
||||
|
||||
await cancel_callback()
|
||||
|
||||
|
||||
async def race_against_token(
|
||||
main_task: asyncio.Task[T],
|
||||
token: CancellationToken,
|
||||
) -> T:
|
||||
"""
|
||||
Race an asyncio task against a cancellation token.
|
||||
|
||||
Waits for either the task to complete or the token to be cancelled. Cleans up
|
||||
whichever side loses the race.
|
||||
|
||||
Args:
|
||||
main_task: The asyncio task to race.
|
||||
token: The cancellation token to race against.
|
||||
|
||||
Returns:
|
||||
The result of the main task if it completes first.
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: If the token fires before the task completes.
|
||||
"""
|
||||
cancel_task = asyncio.create_task(token.aio_wait())
|
||||
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
[main_task, cancel_task],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
# Cancel pending tasks
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
if cancel_task in done:
|
||||
raise asyncio.CancelledError("Operation cancelled by cancellation token")
|
||||
|
||||
return main_task.result()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Ensure both tasks are cleaned up on any cancellation (external or token)
|
||||
main_task.cancel()
|
||||
cancel_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await main_task
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await cancel_task
|
||||
raise
|
||||
|
||||
|
||||
async def await_with_cancellation(
|
||||
coro: Awaitable[T],
|
||||
token: CancellationToken | None,
|
||||
cancel_callback: Callable[[], Awaitable[None]] | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Await an awaitable with cancellation support.
|
||||
|
||||
This function races the given awaitable against a cancellation token. If the
|
||||
token is cancelled before the awaitable completes, the awaitable is cancelled
|
||||
and an asyncio.CancelledError is raised.
|
||||
|
||||
Args:
|
||||
coro: The awaitable to await (coroutine, Future, or asyncio.Task).
|
||||
token: The cancellation token to check. If None, the coroutine is awaited directly.
|
||||
cancel_callback: An optional async callback to invoke when cancellation occurs
|
||||
(e.g., to cancel child workflows).
|
||||
|
||||
Returns:
|
||||
The result of the coroutine.
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: If the token is cancelled before the coroutine completes.
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def cleanup() -> None:
|
||||
print("cleaning up...")
|
||||
|
||||
async def long_running_task():
|
||||
await asyncio.sleep(10)
|
||||
return "done"
|
||||
|
||||
token = CancellationToken()
|
||||
|
||||
# This will raise asyncio.CancelledError if token.cancel() is called
|
||||
result = await await_with_cancellation(
|
||||
long_running_task(),
|
||||
token,
|
||||
cancel_callback=cleanup,
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
if token is None:
|
||||
logger.debug("await_with_cancellation: no token provided, awaiting directly")
|
||||
return await coro
|
||||
|
||||
logger.debug("await_with_cancellation: starting with cancellation token")
|
||||
|
||||
# Check if already cancelled
|
||||
if token.is_cancelled:
|
||||
logger.debug("await_with_cancellation: token already cancelled")
|
||||
if cancel_callback:
|
||||
logger.debug("await_with_cancellation: invoking cancel callback")
|
||||
await _invoke_cancel_callback(cancel_callback)
|
||||
raise asyncio.CancelledError("Operation cancelled by cancellation token")
|
||||
|
||||
main_task = asyncio.ensure_future(coro)
|
||||
|
||||
try:
|
||||
result = await race_against_token(main_task, token)
|
||||
logger.debug("await_with_cancellation: completed successfully")
|
||||
return result
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("await_with_cancellation: cancelled")
|
||||
if cancel_callback:
|
||||
logger.debug("await_with_cancellation: invoking cancel callback")
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await asyncio.shield(_invoke_cancel_callback(cancel_callback))
|
||||
raise
|
||||
@@ -2,7 +2,6 @@ import asyncio
|
||||
import ctypes
|
||||
import functools
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict, is_dataclass
|
||||
@@ -30,7 +29,6 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
|
||||
STEP_EVENT_TYPE_STARTED,
|
||||
)
|
||||
from hatchet_sdk.exceptions import (
|
||||
CancellationReason,
|
||||
IllegalTaskOutputError,
|
||||
NonRetryableException,
|
||||
TaskRunError,
|
||||
@@ -41,7 +39,6 @@ from hatchet_sdk.runnables.action import Action, ActionKey, ActionType
|
||||
from hatchet_sdk.runnables.contextvars import (
|
||||
ctx_action_key,
|
||||
ctx_additional_metadata,
|
||||
ctx_cancellation_token,
|
||||
ctx_step_run_id,
|
||||
ctx_task_retry_count,
|
||||
ctx_worker_id,
|
||||
@@ -63,7 +60,6 @@ from hatchet_sdk.worker.runner.utils.capture_logs import (
|
||||
ContextVarToCopyDict,
|
||||
ContextVarToCopyInt,
|
||||
ContextVarToCopyStr,
|
||||
ContextVarToCopyToken,
|
||||
copy_context_vars,
|
||||
)
|
||||
|
||||
@@ -255,7 +251,6 @@ class Runner:
|
||||
ctx_action_key.set(action.key)
|
||||
ctx_additional_metadata.set(action.additional_metadata)
|
||||
ctx_task_retry_count.set(action.retry_count)
|
||||
ctx_cancellation_token.set(ctx.cancellation_token)
|
||||
|
||||
async with task._unpack_dependencies_with_cleanup(ctx) as dependencies:
|
||||
try:
|
||||
@@ -303,12 +298,6 @@ class Runner:
|
||||
value=action.retry_count,
|
||||
)
|
||||
),
|
||||
ContextVarToCopy(
|
||||
var=ContextVarToCopyToken(
|
||||
name="ctx_cancellation_token",
|
||||
value=ctx.cancellation_token,
|
||||
)
|
||||
),
|
||||
],
|
||||
self.thread_action_func,
|
||||
ctx,
|
||||
@@ -491,95 +480,28 @@ class Runner:
|
||||
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
|
||||
async def handle_cancel_action(self, action: Action) -> None:
|
||||
key = action.key
|
||||
start_time = time.monotonic()
|
||||
|
||||
logger.info(
|
||||
f"Cancellation: received cancel action for {action.action_id}, "
|
||||
f"reason={CancellationReason.WORKFLOW_CANCELLED.value}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Trigger the cancellation token to signal the context to stop
|
||||
# call cancel to signal the context to stop
|
||||
if key in self.contexts:
|
||||
ctx = self.contexts[key]
|
||||
child_count = len(ctx.cancellation_token.child_run_ids)
|
||||
logger.debug(
|
||||
f"Cancellation: triggering token for {action.action_id}, "
|
||||
f"reason={CancellationReason.WORKFLOW_CANCELLED.value}, "
|
||||
f"{child_count} children registered"
|
||||
)
|
||||
ctx._set_cancellation_flag(CancellationReason.WORKFLOW_CANCELLED)
|
||||
self.contexts[key]._set_cancellation_flag()
|
||||
self.cancellations[key] = True
|
||||
# Note: Child workflows are not cancelled here - they run independently
|
||||
# and are managed by Hatchet's normal cancellation mechanisms
|
||||
else:
|
||||
logger.debug(f"Cancellation: no context found for {action.action_id}")
|
||||
|
||||
# Wait with supervision (using timedelta configs)
|
||||
grace_period = self.config.cancellation_grace_period.total_seconds()
|
||||
warning_threshold = (
|
||||
self.config.cancellation_warning_threshold.total_seconds()
|
||||
)
|
||||
grace_period_ms = round(grace_period * 1000)
|
||||
warning_threshold_ms = round(warning_threshold * 1000)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Wait until warning threshold
|
||||
await asyncio.sleep(warning_threshold)
|
||||
elapsed = time.monotonic() - start_time
|
||||
elapsed_ms = round(elapsed * 1000)
|
||||
if key in self.tasks:
|
||||
self.tasks[key].cancel()
|
||||
|
||||
# Check if the task has not yet exited despite the cancellation signal.
|
||||
task_still_running = key in self.tasks and not self.tasks[key].done()
|
||||
# check if thread is still running, if so, print a warning
|
||||
if key in self.threads:
|
||||
thread = self.threads[key]
|
||||
|
||||
if self.config.enable_force_kill_sync_threads:
|
||||
self.force_kill_thread(thread)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if task_still_running:
|
||||
logger.warning(
|
||||
f"Cancellation: task {action.action_id} has not cancelled after "
|
||||
f"{elapsed_ms}ms (warning threshold {warning_threshold_ms}ms). "
|
||||
f"Consider checking for blocking operations. "
|
||||
f"See https://docs.hatchet.run/home/cancellation"
|
||||
f"thread {self.threads[key].ident} with key {key} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running."
|
||||
)
|
||||
|
||||
remaining = grace_period - elapsed
|
||||
if remaining > 0:
|
||||
await asyncio.sleep(remaining)
|
||||
|
||||
if key in self.tasks and not self.tasks[key].done():
|
||||
logger.debug(
|
||||
f"Cancellation: force-cancelling task {action.action_id} "
|
||||
f"after grace period ({grace_period_ms}ms)"
|
||||
)
|
||||
self.tasks[key].cancel()
|
||||
|
||||
if key in self.threads:
|
||||
thread = self.threads[key]
|
||||
|
||||
if self.config.enable_force_kill_sync_threads:
|
||||
logger.debug(
|
||||
f"Cancellation: force-killing thread for {action.action_id}"
|
||||
)
|
||||
self.force_kill_thread(thread)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if thread.is_alive():
|
||||
logger.warning(
|
||||
f"Cancellation: thread {thread.ident} with key {key} is still running "
|
||||
f"after cancellation. This could cause the thread pool to get blocked "
|
||||
f"and prevent new tasks from running."
|
||||
)
|
||||
|
||||
total_elapsed = time.monotonic() - start_time
|
||||
total_elapsed_ms = round(total_elapsed * 1000)
|
||||
if total_elapsed > grace_period:
|
||||
logger.warning(
|
||||
f"Cancellation: cancellation of {action.action_id} took {total_elapsed_ms}ms "
|
||||
f"(exceeded grace period of {grace_period_ms}ms)"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Cancellation: task {action.action_id} eventually completed in {total_elapsed_ms}ms"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Cancellation: task {action.action_id} completed")
|
||||
finally:
|
||||
self.cleanup_run_id(key)
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
@@ -7,15 +5,13 @@ from collections.abc import Awaitable, Callable
|
||||
from io import StringIO
|
||||
from typing import Literal, ParamSpec, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from hatchet_sdk.cancellation import CancellationToken
|
||||
from hatchet_sdk.clients.events import EventClient
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.runnables.contextvars import (
|
||||
ctx_action_key,
|
||||
ctx_additional_metadata,
|
||||
ctx_cancellation_token,
|
||||
ctx_step_run_id,
|
||||
ctx_task_retry_count,
|
||||
ctx_worker_id,
|
||||
@@ -52,22 +48,10 @@ class ContextVarToCopyDict(BaseModel):
|
||||
value: JSONSerializableMapping | None
|
||||
|
||||
|
||||
class ContextVarToCopyToken(BaseModel):
|
||||
"""Special type for copying CancellationToken to threads."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
name: Literal["ctx_cancellation_token"]
|
||||
value: CancellationToken | None
|
||||
|
||||
|
||||
class ContextVarToCopy(BaseModel):
|
||||
var: (
|
||||
ContextVarToCopyStr
|
||||
| ContextVarToCopyDict
|
||||
| ContextVarToCopyInt
|
||||
| ContextVarToCopyToken
|
||||
) = Field(discriminator="name")
|
||||
var: ContextVarToCopyStr | ContextVarToCopyDict | ContextVarToCopyInt = Field(
|
||||
discriminator="name"
|
||||
)
|
||||
|
||||
|
||||
def copy_context_vars(
|
||||
@@ -89,8 +73,6 @@ def copy_context_vars(
|
||||
ctx_worker_id.set(var.var.value)
|
||||
elif var.var.name == "ctx_additional_metadata":
|
||||
ctx_additional_metadata.set(var.var.value or {})
|
||||
elif var.var.name == "ctx_cancellation_token":
|
||||
ctx_cancellation_token.set(var.var.value)
|
||||
else:
|
||||
raise ValueError(f"Unknown context variable name: {var.var.name}")
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -8,17 +6,9 @@ from hatchet_sdk.clients.listeners.run_event_listener import (
|
||||
RunEventListenerClient,
|
||||
)
|
||||
from hatchet_sdk.clients.listeners.workflow_listener import PooledWorkflowRunListener
|
||||
from hatchet_sdk.exceptions import (
|
||||
CancellationReason,
|
||||
CancelledError,
|
||||
FailedTaskRunExceptionGroup,
|
||||
TaskRunError,
|
||||
)
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.utils.cancellation import await_with_cancellation
|
||||
from hatchet_sdk.exceptions import FailedTaskRunExceptionGroup, TaskRunError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from hatchet_sdk.cancellation import CancellationToken
|
||||
from hatchet_sdk.clients.admin import AdminClient
|
||||
|
||||
|
||||
@@ -28,7 +18,7 @@ class WorkflowRunRef:
|
||||
workflow_run_id: str,
|
||||
workflow_run_listener: PooledWorkflowRunListener,
|
||||
workflow_run_event_listener: RunEventListenerClient,
|
||||
admin_client: AdminClient,
|
||||
admin_client: "AdminClient",
|
||||
):
|
||||
self.workflow_run_id = workflow_run_id
|
||||
self.workflow_run_listener = workflow_run_listener
|
||||
@@ -41,25 +31,7 @@ class WorkflowRunRef:
|
||||
def stream(self) -> RunEventListener:
|
||||
return self.workflow_run_event_listener.stream(self.workflow_run_id)
|
||||
|
||||
async def aio_result(
|
||||
self, cancellation_token: CancellationToken | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Asynchronously wait for the workflow run to complete and return the result.
|
||||
|
||||
:param cancellation_token: Optional cancellation token to abort the wait.
|
||||
:return: A dictionary mapping task names to their outputs.
|
||||
"""
|
||||
logger.debug(
|
||||
f"WorkflowRunRef.aio_result: waiting for {self.workflow_run_id}, "
|
||||
f"token={cancellation_token is not None}"
|
||||
)
|
||||
|
||||
if cancellation_token:
|
||||
return await await_with_cancellation(
|
||||
self.workflow_run_listener.aio_result(self.workflow_run_id),
|
||||
cancellation_token,
|
||||
)
|
||||
async def aio_result(self) -> dict[str, Any]:
|
||||
return await self.workflow_run_listener.aio_result(self.workflow_run_id)
|
||||
|
||||
def _safely_get_action_name(self, action_id: str | None) -> str | None:
|
||||
@@ -71,42 +43,12 @@ class WorkflowRunRef:
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
def result(
|
||||
self, cancellation_token: CancellationToken | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Synchronously wait for the workflow run to complete and return the result.
|
||||
|
||||
This method polls the API for the workflow run status. If a cancellation token
|
||||
is provided, the polling will be interrupted when cancellation is triggered.
|
||||
|
||||
:param cancellation_token: Optional cancellation token to abort the wait.
|
||||
:return: A dictionary mapping task names to their outputs.
|
||||
:raises CancelledError: If the cancellation token is triggered.
|
||||
:raises FailedTaskRunExceptionGroup: If the workflow run fails.
|
||||
:raises ValueError: If the workflow run is not found.
|
||||
"""
|
||||
def result(self) -> dict[str, Any]:
|
||||
from hatchet_sdk.clients.admin import RunStatus
|
||||
|
||||
logger.debug(
|
||||
f"WorkflowRunRef.result: waiting for {self.workflow_run_id}, "
|
||||
f"token={cancellation_token is not None}"
|
||||
)
|
||||
|
||||
retries = 0
|
||||
|
||||
while True:
|
||||
# Check cancellation at start of each iteration
|
||||
if cancellation_token and cancellation_token.is_cancelled:
|
||||
logger.debug(
|
||||
f"WorkflowRunRef.result: cancellation detected for {self.workflow_run_id}, "
|
||||
f"reason={CancellationReason.PARENT_CANCELLED.value}"
|
||||
)
|
||||
raise CancelledError(
|
||||
"Operation cancelled by cancellation token",
|
||||
reason=CancellationReason.PARENT_CANCELLED,
|
||||
)
|
||||
|
||||
try:
|
||||
details = self.admin_client.get_details(self.workflow_run_id)
|
||||
except Exception as e:
|
||||
@@ -117,42 +59,14 @@ class WorkflowRunRef:
|
||||
f"Workflow run {self.workflow_run_id} not found"
|
||||
) from e
|
||||
|
||||
# Use interruptible sleep via token.wait()
|
||||
if cancellation_token:
|
||||
if cancellation_token.wait(timeout=1.0):
|
||||
logger.debug(
|
||||
f"WorkflowRunRef.result: cancellation during retry sleep for {self.workflow_run_id}, "
|
||||
f"reason={CancellationReason.PARENT_CANCELLED.value}"
|
||||
)
|
||||
raise CancelledError(
|
||||
"Operation cancelled by cancellation token",
|
||||
reason=CancellationReason.PARENT_CANCELLED,
|
||||
) from None
|
||||
else:
|
||||
time.sleep(1)
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
f"WorkflowRunRef.result: {self.workflow_run_id} status={details.status}"
|
||||
)
|
||||
|
||||
if (
|
||||
details.status in [RunStatus.QUEUED, RunStatus.RUNNING]
|
||||
or details.done is False
|
||||
):
|
||||
# Use interruptible sleep via token.wait()
|
||||
if cancellation_token:
|
||||
if cancellation_token.wait(timeout=1.0):
|
||||
logger.debug(
|
||||
f"WorkflowRunRef.result: cancellation during poll sleep for {self.workflow_run_id}, "
|
||||
f"reason={CancellationReason.PARENT_CANCELLED.value}"
|
||||
)
|
||||
raise CancelledError(
|
||||
"Operation cancelled by cancellation token",
|
||||
reason=CancellationReason.PARENT_CANCELLED,
|
||||
)
|
||||
else:
|
||||
time.sleep(1)
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
if details.status == RunStatus.FAILED:
|
||||
@@ -166,9 +80,6 @@ class WorkflowRunRef:
|
||||
)
|
||||
|
||||
if details.status == RunStatus.COMPLETED:
|
||||
logger.debug(
|
||||
f"WorkflowRunRef.result: {self.workflow_run_id} completed successfully"
|
||||
)
|
||||
return {
|
||||
readable_id: run.output
|
||||
for readable_id, run in details.task_runs.items()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "hatchet-sdk"
|
||||
version = "1.25.1"
|
||||
version = "1.25.2"
|
||||
description = "This is the official Python SDK for Hatchet, a distributed, fault-tolerant task queue. The SDK allows you to easily integrate Hatchet's task scheduling and workflow orchestration capabilities into your Python applications."
|
||||
authors = [
|
||||
"Alexander Belanger <alexander@hatchet.run>",
|
||||
|
||||
@@ -1,461 +0,0 @@
|
||||
"""Unit tests for CancellationToken and cancellation utilities."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from hatchet_sdk.cancellation import CancellationToken
|
||||
from hatchet_sdk.exceptions import CancellationReason, CancelledError
|
||||
from hatchet_sdk.runnables.contextvars import ctx_cancellation_token
|
||||
from hatchet_sdk.utils.cancellation import await_with_cancellation
|
||||
|
||||
# CancellationToken
|
||||
|
||||
|
||||
def test_initial_state() -> None:
|
||||
"""Token should start in non-cancelled state."""
|
||||
token = CancellationToken()
|
||||
assert token.is_cancelled is False
|
||||
|
||||
|
||||
def test_cancel_sets_flag() -> None:
|
||||
"""cancel() should set is_cancelled to True."""
|
||||
token = CancellationToken()
|
||||
token.cancel()
|
||||
assert token.is_cancelled is True
|
||||
|
||||
|
||||
def test_cancel_sets_reason() -> None:
|
||||
"""cancel() should set the reason."""
|
||||
token = CancellationToken()
|
||||
token.cancel(CancellationReason.USER_REQUESTED)
|
||||
assert token.reason == CancellationReason.USER_REQUESTED
|
||||
|
||||
|
||||
def test_reason_is_none_before_cancel() -> None:
|
||||
"""reason should be None before cancellation."""
|
||||
token = CancellationToken()
|
||||
assert token.reason is None
|
||||
|
||||
|
||||
def test_cancel_idempotent() -> None:
|
||||
"""Multiple calls to cancel() should be safe."""
|
||||
token = CancellationToken()
|
||||
token.cancel()
|
||||
token.cancel() # Should not raise
|
||||
assert token.is_cancelled is True
|
||||
|
||||
|
||||
def test_cancel_idempotent_preserves_reason() -> None:
|
||||
"""Multiple calls to cancel() should preserve the original reason."""
|
||||
token = CancellationToken()
|
||||
token.cancel(CancellationReason.USER_REQUESTED)
|
||||
token.cancel(CancellationReason.TIMEOUT) # Second call should be ignored
|
||||
assert token.reason == CancellationReason.USER_REQUESTED
|
||||
|
||||
|
||||
def test_sync_wait_returns_true_when_cancelled() -> None:
|
||||
"""wait() should return True immediately if already cancelled."""
|
||||
token = CancellationToken()
|
||||
token.cancel()
|
||||
result = token.wait(timeout=0.1)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_sync_wait_timeout_returns_false() -> None:
|
||||
"""wait() should return False when timeout expires without cancellation."""
|
||||
token = CancellationToken()
|
||||
start = time.monotonic()
|
||||
result = token.wait(timeout=0.1)
|
||||
elapsed = time.monotonic() - start
|
||||
assert result is False
|
||||
assert elapsed >= 0.1
|
||||
|
||||
|
||||
def test_sync_wait_interrupted_by_cancel() -> None:
|
||||
"""wait() should return True when cancelled during wait."""
|
||||
token = CancellationToken()
|
||||
|
||||
def cancel_after_delay() -> None:
|
||||
time.sleep(0.1)
|
||||
token.cancel()
|
||||
|
||||
thread = threading.Thread(target=cancel_after_delay)
|
||||
thread.start()
|
||||
|
||||
start = time.monotonic()
|
||||
result = token.wait(timeout=1.0)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
thread.join()
|
||||
|
||||
assert result is True
|
||||
assert elapsed < 0.5 # Should be much faster than timeout
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aio_wait_returns_when_cancelled() -> None:
|
||||
"""aio_wait() should return when cancelled."""
|
||||
token = CancellationToken()
|
||||
|
||||
async def cancel_after_delay() -> None:
|
||||
await asyncio.sleep(0.1)
|
||||
token.cancel()
|
||||
|
||||
asyncio.create_task(cancel_after_delay())
|
||||
|
||||
start = time.monotonic()
|
||||
await token.aio_wait()
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert elapsed < 0.5 # Should be fast
|
||||
|
||||
|
||||
def test_register_child() -> None:
|
||||
"""register_child() should add run IDs to the list."""
|
||||
token = CancellationToken()
|
||||
token.register_child("run-1")
|
||||
token.register_child("run-2")
|
||||
|
||||
assert token.child_run_ids == ["run-1", "run-2"]
|
||||
|
||||
|
||||
def test_callback_invoked_on_cancel() -> None:
|
||||
"""Callbacks should be invoked when cancel() is called."""
|
||||
token = CancellationToken()
|
||||
called = []
|
||||
|
||||
def callback() -> None:
|
||||
called.append(True)
|
||||
|
||||
token.add_callback(callback)
|
||||
token.cancel()
|
||||
|
||||
assert called == [True]
|
||||
|
||||
|
||||
def test_callback_invoked_immediately_if_already_cancelled() -> None:
|
||||
"""Callbacks added after cancellation should be invoked immediately."""
|
||||
token = CancellationToken()
|
||||
token.cancel()
|
||||
|
||||
called = []
|
||||
|
||||
def callback() -> None:
|
||||
called.append(True)
|
||||
|
||||
token.add_callback(callback)
|
||||
|
||||
assert called == [True]
|
||||
|
||||
|
||||
def test_multiple_callbacks() -> None:
|
||||
"""Multiple callbacks should all be invoked."""
|
||||
token = CancellationToken()
|
||||
results: list[int] = []
|
||||
|
||||
token.add_callback(lambda: results.append(1))
|
||||
token.add_callback(lambda: results.append(2))
|
||||
token.add_callback(lambda: results.append(3))
|
||||
|
||||
token.cancel()
|
||||
|
||||
assert results == [1, 2, 3]
|
||||
|
||||
|
||||
def test_repr() -> None:
|
||||
"""__repr__ should provide useful debugging info."""
|
||||
token = CancellationToken()
|
||||
token.register_child("run-1")
|
||||
|
||||
repr_str = repr(token)
|
||||
assert "cancelled=False" in repr_str
|
||||
assert "children=1" in repr_str
|
||||
|
||||
|
||||
# await_with_cancellation
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_token_awaits_directly() -> None:
|
||||
"""Without a token, coroutine should be awaited directly."""
|
||||
|
||||
async def simple_coro() -> str:
|
||||
return "result"
|
||||
|
||||
result = await await_with_cancellation(simple_coro(), None)
|
||||
assert result == "result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_not_cancelled_returns_result() -> None:
|
||||
"""With a non-cancelled token, should return coroutine result."""
|
||||
token = CancellationToken()
|
||||
|
||||
async def simple_coro() -> str:
|
||||
await asyncio.sleep(0.01)
|
||||
return "result"
|
||||
|
||||
result = await await_with_cancellation(simple_coro(), token)
|
||||
assert result == "result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_cancelled_raises_immediately() -> None:
|
||||
"""With an already-cancelled token, should raise immediately."""
|
||||
token = CancellationToken()
|
||||
token.cancel()
|
||||
|
||||
async def simple_coro() -> str:
|
||||
await asyncio.sleep(10) # Would block if actually awaited
|
||||
return "result"
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await await_with_cancellation(simple_coro(), token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_during_await_raises() -> None:
|
||||
"""Should raise CancelledError when token is cancelled during await."""
|
||||
token = CancellationToken()
|
||||
|
||||
async def slow_coro() -> str:
|
||||
await asyncio.sleep(10)
|
||||
return "result"
|
||||
|
||||
async def cancel_after_delay() -> None:
|
||||
await asyncio.sleep(0.1)
|
||||
token.cancel()
|
||||
|
||||
asyncio.create_task(cancel_after_delay())
|
||||
|
||||
start = time.monotonic()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await await_with_cancellation(slow_coro(), token)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert elapsed < 0.5 # Should be cancelled quickly
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_callback_invoked() -> None:
|
||||
"""Cancel callback should be invoked on cancellation."""
|
||||
token = CancellationToken()
|
||||
callback_called = []
|
||||
|
||||
async def cancel_callback() -> None:
|
||||
callback_called.append(True)
|
||||
|
||||
async def slow_coro() -> str:
|
||||
await asyncio.sleep(10)
|
||||
return "result"
|
||||
|
||||
async def cancel_after_delay() -> None:
|
||||
await asyncio.sleep(0.1)
|
||||
token.cancel()
|
||||
|
||||
asyncio.create_task(cancel_after_delay())
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await await_with_cancellation(
|
||||
slow_coro(), token, cancel_callback=cancel_callback
|
||||
)
|
||||
|
||||
assert callback_called == [True]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_cancel_callback_invoked() -> None:
|
||||
"""Cancel callback should be invoked on cancellation."""
|
||||
token = CancellationToken()
|
||||
callback_called = []
|
||||
|
||||
async def cancel_callback() -> None:
|
||||
callback_called.append(True)
|
||||
|
||||
async def slow_coro() -> str:
|
||||
await asyncio.sleep(10)
|
||||
return "result"
|
||||
|
||||
async def cancel_after_delay() -> None:
|
||||
await asyncio.sleep(0.1)
|
||||
token.cancel()
|
||||
|
||||
asyncio.create_task(cancel_after_delay())
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await await_with_cancellation(
|
||||
slow_coro(), token, cancel_callback=cancel_callback
|
||||
)
|
||||
|
||||
assert callback_called == [True]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_callback_invoked_on_external_task_cancel() -> None:
|
||||
"""Cancel callback should be invoked if the awaiting task is cancelled externally."""
|
||||
token = CancellationToken()
|
||||
callback_called = asyncio.Event()
|
||||
|
||||
async def cancel_callback() -> None:
|
||||
callback_called.set()
|
||||
|
||||
async def slow_coro() -> str:
|
||||
await asyncio.sleep(10)
|
||||
return "result"
|
||||
|
||||
task = asyncio.create_task(
|
||||
await_with_cancellation(slow_coro(), token, cancel_callback=cancel_callback)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
task.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
await asyncio.wait_for(callback_called.wait(), timeout=1.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_callback_not_invoked_on_success() -> None:
|
||||
"""Cancel callback should NOT be invoked when coroutine completes normally."""
|
||||
token = CancellationToken()
|
||||
callback_called = []
|
||||
|
||||
async def cancel_callback() -> None:
|
||||
callback_called.append(True)
|
||||
|
||||
async def fast_coro() -> str:
|
||||
await asyncio.sleep(0.01)
|
||||
return "result"
|
||||
|
||||
result = await await_with_cancellation(
|
||||
fast_coro(), token, cancel_callback=cancel_callback
|
||||
)
|
||||
|
||||
assert result == "result"
|
||||
assert callback_called == []
|
||||
|
||||
|
||||
# CancellationReason
|
||||
|
||||
|
||||
def test_all_reasons_exist() -> None:
|
||||
"""All expected cancellation reasons should exist."""
|
||||
assert CancellationReason.USER_REQUESTED.value == "user_requested"
|
||||
assert CancellationReason.TIMEOUT.value == "timeout"
|
||||
assert CancellationReason.PARENT_CANCELLED.value == "parent_cancelled"
|
||||
assert CancellationReason.WORKFLOW_CANCELLED.value == "workflow_cancelled"
|
||||
assert CancellationReason.TOKEN_CANCELLED.value == "token_cancelled"
|
||||
|
||||
|
||||
def test_reasons_are_strings() -> None:
|
||||
"""Cancellation reason values should be strings."""
|
||||
for reason in CancellationReason:
|
||||
assert isinstance(reason.value, str)
|
||||
|
||||
|
||||
# CancelledError
|
||||
|
||||
|
||||
def test_cancelled_error_is_base_exception() -> None:
|
||||
"""CancelledError should be a BaseException (not Exception)."""
|
||||
err = CancelledError("test message")
|
||||
assert isinstance(err, BaseException)
|
||||
assert not isinstance(err, Exception) # Should NOT be caught by except Exception
|
||||
assert str(err) == "test message"
|
||||
|
||||
|
||||
def test_cancelled_error_not_caught_by_except_exception() -> None:
|
||||
"""CancelledError should NOT be caught by except Exception."""
|
||||
caught_by_exception = False
|
||||
caught_by_cancelled_error = False
|
||||
|
||||
try:
|
||||
raise CancelledError("test")
|
||||
except Exception:
|
||||
caught_by_exception = True
|
||||
except CancelledError:
|
||||
caught_by_cancelled_error = True
|
||||
|
||||
assert not caught_by_exception
|
||||
assert caught_by_cancelled_error
|
||||
|
||||
|
||||
def test_cancelled_error_with_reason() -> None:
|
||||
"""CancelledError should accept and store a reason."""
|
||||
err = CancelledError("test message", reason=CancellationReason.TIMEOUT)
|
||||
assert err.reason == CancellationReason.TIMEOUT
|
||||
|
||||
|
||||
def test_cancelled_error_reason_defaults_to_none() -> None:
|
||||
"""CancelledError reason should default to None."""
|
||||
err = CancelledError("test message")
|
||||
assert err.reason is None
|
||||
|
||||
|
||||
def test_cancelled_error_message_property() -> None:
|
||||
"""CancelledError should have a message property."""
|
||||
err = CancelledError("test message")
|
||||
assert err.message == "test message"
|
||||
|
||||
|
||||
def test_cancelled_error_default_message() -> None:
|
||||
"""CancelledError should have a default message."""
|
||||
err = CancelledError()
|
||||
assert err.message == "Operation cancelled"
|
||||
|
||||
|
||||
def test_can_be_raised_and_caught() -> None:
|
||||
"""CancelledError should be raisable and catchable."""
|
||||
with pytest.raises(CancelledError) as exc_info:
|
||||
raise CancelledError("Operation cancelled")
|
||||
|
||||
assert "Operation cancelled" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_can_be_raised_with_reason() -> None:
|
||||
"""CancelledError should be raisable with a reason."""
|
||||
with pytest.raises(CancelledError) as exc_info:
|
||||
raise CancelledError(
|
||||
"Parent was cancelled", reason=CancellationReason.PARENT_CANCELLED
|
||||
)
|
||||
|
||||
assert exc_info.value.reason == CancellationReason.PARENT_CANCELLED
|
||||
|
||||
|
||||
# Context var propagation
|
||||
|
||||
|
||||
def test_context_var_default_is_none() -> None:
|
||||
"""ctx_cancellation_token should default to None."""
|
||||
assert ctx_cancellation_token.get() is None
|
||||
|
||||
|
||||
def test_context_var_can_be_set_and_retrieved() -> None:
|
||||
"""ctx_cancellation_token should be settable and retrievable."""
|
||||
token = CancellationToken()
|
||||
ctx_cancellation_token.set(token)
|
||||
try:
|
||||
assert ctx_cancellation_token.get() is token
|
||||
finally:
|
||||
ctx_cancellation_token.set(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_var_propagates_in_async() -> None:
|
||||
"""ctx_cancellation_token should propagate in async context."""
|
||||
token = CancellationToken()
|
||||
ctx_cancellation_token.set(token)
|
||||
|
||||
async def check_token() -> CancellationToken | None:
|
||||
return ctx_cancellation_token.get()
|
||||
|
||||
try:
|
||||
retrieved = await check_token()
|
||||
assert retrieved is token
|
||||
finally:
|
||||
ctx_cancellation_token.set(None)
|
||||
Reference in New Issue
Block a user