mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-05-03 07:59:31 -05:00
3cc4dbe38f
* chore: publish new version * chore: dummy change for CI to run
1510 lines
59 KiB
Python
1510 lines
59 KiB
Python
import asyncio
|
|
from collections.abc import Callable
|
|
from datetime import datetime, timedelta
|
|
from functools import cached_property
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Concatenate,
|
|
Generic,
|
|
Literal,
|
|
ParamSpec,
|
|
TypeVar,
|
|
cast,
|
|
get_type_hints,
|
|
overload,
|
|
)
|
|
|
|
from google.protobuf import timestamp_pb2
|
|
from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter, model_validator
|
|
|
|
from hatchet_sdk.clients.admin import (
|
|
ScheduleTriggerWorkflowOptions,
|
|
TriggerWorkflowOptions,
|
|
WorkflowRunTriggerConfig,
|
|
)
|
|
from hatchet_sdk.clients.listeners.run_event_listener import RunEventListener
|
|
from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows
|
|
from hatchet_sdk.clients.rest.models.v1_filter import V1Filter
|
|
from hatchet_sdk.clients.rest.models.v1_task_status import V1TaskStatus
|
|
from hatchet_sdk.clients.rest.models.v1_task_summary import V1TaskSummary
|
|
from hatchet_sdk.conditions import Condition, OrGroup
|
|
from hatchet_sdk.context.context import Context, DurableContext
|
|
from hatchet_sdk.contracts.v1.workflows_pb2 import (
|
|
CreateWorkflowVersionRequest,
|
|
DesiredWorkerLabels,
|
|
)
|
|
from hatchet_sdk.contracts.v1.workflows_pb2 import StickyStrategy as StickyStrategyProto
|
|
from hatchet_sdk.contracts.workflows_pb2 import WorkflowVersion
|
|
from hatchet_sdk.labels import DesiredWorkerLabel
|
|
from hatchet_sdk.rate_limit import RateLimit
|
|
from hatchet_sdk.runnables.task import Task
|
|
from hatchet_sdk.runnables.types import (
|
|
ConcurrencyExpression,
|
|
EmptyModel,
|
|
R,
|
|
StepType,
|
|
TaskDefaults,
|
|
TaskPayloadForInternalUse,
|
|
TWorkflowInput,
|
|
WorkflowConfig,
|
|
normalize_validator,
|
|
)
|
|
from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto
|
|
from hatchet_sdk.utils.timedelta_to_expression import Duration
|
|
from hatchet_sdk.utils.typing import CoroutineLike, JSONSerializableMapping
|
|
from hatchet_sdk.workflow_run import WorkflowRunRef
|
|
|
|
if TYPE_CHECKING:
|
|
from hatchet_sdk import Hatchet
|
|
|
|
|
|
T = TypeVar("T")
|
|
P = ParamSpec("P")
|
|
|
|
|
|
def fall_back_to_default(value: T, param_default: T, fallback_value: T | None) -> T:
|
|
## If the value is not the param default, it's set
|
|
if value != param_default:
|
|
return value
|
|
|
|
## Otherwise, it's unset, so return the fallback value if it's set
|
|
if fallback_value is not None:
|
|
return fallback_value
|
|
|
|
## Otherwise return the param value
|
|
return value
|
|
|
|
|
|
class ComputedTaskParameters(BaseModel):
|
|
schedule_timeout: Duration
|
|
execution_timeout: Duration
|
|
retries: int
|
|
backoff_factor: float | None
|
|
backoff_max_seconds: int | None
|
|
|
|
task_defaults: TaskDefaults
|
|
|
|
@model_validator(mode="after")
|
|
def validate_params(self) -> "ComputedTaskParameters":
|
|
self.execution_timeout = fall_back_to_default(
|
|
value=self.execution_timeout,
|
|
param_default=timedelta(seconds=60),
|
|
fallback_value=self.task_defaults.execution_timeout,
|
|
)
|
|
self.schedule_timeout = fall_back_to_default(
|
|
value=self.schedule_timeout,
|
|
param_default=timedelta(minutes=5),
|
|
fallback_value=self.task_defaults.schedule_timeout,
|
|
)
|
|
self.backoff_factor = fall_back_to_default(
|
|
value=self.backoff_factor,
|
|
param_default=None,
|
|
fallback_value=self.task_defaults.backoff_factor,
|
|
)
|
|
self.backoff_max_seconds = fall_back_to_default(
|
|
value=self.backoff_max_seconds,
|
|
param_default=None,
|
|
fallback_value=self.task_defaults.backoff_max_seconds,
|
|
)
|
|
self.retries = fall_back_to_default(
|
|
value=self.retries,
|
|
param_default=0,
|
|
fallback_value=self.task_defaults.retries,
|
|
)
|
|
|
|
return self
|
|
|
|
|
|
def transform_desired_worker_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels:
|
|
value = d.value
|
|
return DesiredWorkerLabels(
|
|
strValue=value if not isinstance(value, int) else None,
|
|
intValue=value if isinstance(value, int) else None,
|
|
required=d.required,
|
|
weight=d.weight,
|
|
comparator=d.comparator, # type: ignore[arg-type]
|
|
)
|
|
|
|
|
|
class TypedTriggerWorkflowRunConfig(BaseModel, Generic[TWorkflowInput]):
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
input: SkipValidation[TWorkflowInput]
|
|
options: TriggerWorkflowOptions
|
|
|
|
|
|
class BaseWorkflow(Generic[TWorkflowInput]):
|
|
def __init__(self, config: WorkflowConfig, client: "Hatchet") -> None:
|
|
self.config = config
|
|
self._default_tasks: list[Task[TWorkflowInput, Any]] = []
|
|
self._durable_tasks: list[Task[TWorkflowInput, Any]] = []
|
|
self._on_failure_task: Task[TWorkflowInput, Any] | None = None
|
|
self._on_success_task: Task[TWorkflowInput, Any] | None = None
|
|
self.client = client
|
|
|
|
@property
|
|
def service_name(self) -> str:
|
|
return self.client.config.apply_namespace(self.config.name.lower())
|
|
|
|
def _create_action_name(self, step: Task[TWorkflowInput, Any]) -> str:
|
|
return self.service_name + ":" + step.name
|
|
|
|
def _is_leaf_task(self, task: Task[TWorkflowInput, Any]) -> bool:
|
|
return not any(task in t.parents for t in self.tasks if task != t)
|
|
|
|
def to_proto(self) -> CreateWorkflowVersionRequest:
|
|
namespace = self.client.config.namespace
|
|
service_name = self.service_name
|
|
|
|
name = self.name
|
|
event_triggers = [
|
|
self.client.config.apply_namespace(event, namespace)
|
|
for event in self.config.on_events
|
|
]
|
|
|
|
if self._on_success_task:
|
|
self._on_success_task.parents = [
|
|
task
|
|
for task in self.tasks
|
|
if task.type == StepType.DEFAULT and self._is_leaf_task(task)
|
|
]
|
|
|
|
on_success_task = (
|
|
t.to_proto(service_name) if (t := self._on_success_task) else None
|
|
)
|
|
|
|
tasks = [
|
|
task.to_proto(service_name)
|
|
for task in self.tasks
|
|
if task.type == StepType.DEFAULT
|
|
]
|
|
|
|
if on_success_task:
|
|
tasks += [on_success_task]
|
|
|
|
on_failure_task = (
|
|
t.to_proto(service_name) if (t := self._on_failure_task) else None
|
|
)
|
|
|
|
if isinstance(self.config.concurrency, list):
|
|
_concurrency_arr = [c.to_proto() for c in self.config.concurrency]
|
|
_concurrency = None
|
|
elif isinstance(self.config.concurrency, ConcurrencyExpression):
|
|
_concurrency_arr = []
|
|
_concurrency = self.config.concurrency.to_proto()
|
|
elif isinstance(self.config.concurrency, int):
|
|
_concurrency_arr = []
|
|
_concurrency = ConcurrencyExpression.from_int(
|
|
self.config.concurrency
|
|
).to_proto()
|
|
else:
|
|
_concurrency = None
|
|
_concurrency_arr = []
|
|
|
|
return CreateWorkflowVersionRequest(
|
|
name=name,
|
|
description=self.config.description,
|
|
version=self.config.version,
|
|
event_triggers=event_triggers,
|
|
cron_triggers=self.config.on_crons,
|
|
tasks=tasks,
|
|
## TODO: Fix this
|
|
cron_input=None,
|
|
on_failure_task=on_failure_task,
|
|
sticky=convert_python_enum_to_proto(
|
|
self.config.sticky, StickyStrategyProto
|
|
), # type: ignore[arg-type]
|
|
concurrency=_concurrency,
|
|
concurrency_arr=_concurrency_arr,
|
|
default_priority=self.config.default_priority,
|
|
default_filters=[f.to_proto() for f in self.config.default_filters],
|
|
)
|
|
|
|
def _get_workflow_input(self, ctx: Context) -> TWorkflowInput:
|
|
return cast(
|
|
TWorkflowInput,
|
|
self.config.input_validator.validate_python(ctx.workflow_input),
|
|
)
|
|
|
|
@property
|
|
def input_validator(self) -> type[TWorkflowInput]:
|
|
return cast(type[TWorkflowInput], self.config.input_validator)
|
|
|
|
@property
|
|
def tasks(self) -> list[Task[TWorkflowInput, Any]]:
|
|
tasks = self._default_tasks + self._durable_tasks
|
|
|
|
if self._on_failure_task:
|
|
tasks += [self._on_failure_task]
|
|
|
|
if self._on_success_task:
|
|
tasks += [self._on_success_task]
|
|
|
|
return tasks
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""
|
|
The (namespaced) name of the workflow.
|
|
"""
|
|
return self.client.config.namespace + self.config.name
|
|
|
|
def create_bulk_run_item(
|
|
self,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
key: str | None = None,
|
|
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
) -> WorkflowRunTriggerConfig:
|
|
"""
|
|
Create a bulk run item for the workflow. This is intended to be used in conjunction with the various `run_many` methods.
|
|
|
|
:param input: The input data for the workflow.
|
|
:param key: The key for the workflow run. This is used to identify the run in the bulk operation and for deduplication.
|
|
:param options: Additional options for the workflow run.
|
|
|
|
:returns: A `WorkflowRunTriggerConfig` object that can be used to trigger the workflow run, which you then pass into the `run_many` methods.
|
|
"""
|
|
return WorkflowRunTriggerConfig(
|
|
workflow_name=self.config.name,
|
|
input=self._serialize_input(input),
|
|
options=options,
|
|
key=key,
|
|
)
|
|
|
|
def _serialize_input(self, input: TWorkflowInput | None) -> JSONSerializableMapping:
|
|
if not input:
|
|
return {}
|
|
|
|
return cast(
|
|
JSONSerializableMapping,
|
|
self.config.input_validator.dump_python(input, mode="json"), # type: ignore[arg-type]
|
|
)
|
|
|
|
@cached_property
|
|
def id(self) -> str:
|
|
"""
|
|
Get the ID of the workflow.
|
|
|
|
:raises ValueError: If no workflow ID is found for the workflow name.
|
|
:returns: The ID of the workflow.
|
|
"""
|
|
workflows = self.client.workflows.list(workflow_name=self.name)
|
|
|
|
if not workflows.rows:
|
|
raise ValueError(f"No id found for {self.name}")
|
|
|
|
for workflow in workflows.rows:
|
|
if workflow.name == self.name:
|
|
return workflow.metadata.id
|
|
|
|
raise ValueError(f"No id found for {self.name}")
|
|
|
|
def list_runs(
|
|
self,
|
|
since: datetime | None = None,
|
|
until: datetime | None = None,
|
|
limit: int = 100,
|
|
offset: int | None = None,
|
|
statuses: list[V1TaskStatus] | None = None,
|
|
additional_metadata: dict[str, str] | None = None,
|
|
worker_id: str | None = None,
|
|
parent_task_external_id: str | None = None,
|
|
only_tasks: bool = False,
|
|
triggering_event_external_id: str | None = None,
|
|
) -> list[V1TaskSummary]:
|
|
"""
|
|
List runs of the workflow.
|
|
|
|
:param since: The start time for the runs to be listed.
|
|
:param until: The end time for the runs to be listed.
|
|
:param limit: The maximum number of runs to be listed.
|
|
:param offset: The offset for pagination.
|
|
:param statuses: The statuses of the runs to be listed.
|
|
:param additional_metadata: Additional metadata for filtering the runs.
|
|
:param worker_id: The ID of the worker that ran the tasks.
|
|
:param parent_task_external_id: The external ID of the parent task.
|
|
:param only_tasks: Whether to list only task runs.
|
|
:param triggering_event_external_id: The event id that triggered the task run.
|
|
|
|
:returns: A list of `V1TaskSummary` objects representing the runs of the workflow.
|
|
"""
|
|
return self.client.runs.list_with_pagination(
|
|
workflow_ids=[self.id],
|
|
since=since,
|
|
only_tasks=only_tasks,
|
|
offset=offset,
|
|
limit=limit,
|
|
statuses=statuses,
|
|
until=until,
|
|
additional_metadata=additional_metadata,
|
|
worker_id=worker_id,
|
|
parent_task_external_id=parent_task_external_id,
|
|
triggering_event_external_id=triggering_event_external_id,
|
|
)
|
|
|
|
async def aio_list_runs(
|
|
self,
|
|
since: datetime | None = None,
|
|
until: datetime | None = None,
|
|
limit: int = 100,
|
|
offset: int | None = None,
|
|
statuses: list[V1TaskStatus] | None = None,
|
|
additional_metadata: dict[str, str] | None = None,
|
|
worker_id: str | None = None,
|
|
parent_task_external_id: str | None = None,
|
|
only_tasks: bool = False,
|
|
triggering_event_external_id: str | None = None,
|
|
) -> list[V1TaskSummary]:
|
|
"""
|
|
List runs of the workflow.
|
|
|
|
:param since: The start time for the runs to be listed.
|
|
:param until: The end time for the runs to be listed.
|
|
:param limit: The maximum number of runs to be listed.
|
|
:param offset: The offset for pagination.
|
|
:param statuses: The statuses of the runs to be listed.
|
|
:param additional_metadata: Additional metadata for filtering the runs.
|
|
:param worker_id: The ID of the worker that ran the tasks.
|
|
:param parent_task_external_id: The external ID of the parent task.
|
|
:param only_tasks: Whether to list only task runs.
|
|
:param triggering_event_external_id: The event id that triggered the task run.
|
|
|
|
:returns: A list of `V1TaskSummary` objects representing the runs of the workflow.
|
|
"""
|
|
return await self.client.runs.aio_list_with_pagination(
|
|
workflow_ids=[self.id],
|
|
since=since,
|
|
only_tasks=only_tasks,
|
|
offset=offset,
|
|
limit=limit,
|
|
statuses=statuses,
|
|
until=until,
|
|
additional_metadata=additional_metadata,
|
|
worker_id=worker_id,
|
|
parent_task_external_id=parent_task_external_id,
|
|
triggering_event_external_id=triggering_event_external_id,
|
|
)
|
|
|
|
def create_filter(
|
|
self,
|
|
expression: str,
|
|
scope: str,
|
|
payload: JSONSerializableMapping | None = None,
|
|
) -> V1Filter:
|
|
"""
|
|
Create a new filter.
|
|
|
|
:param expression: The expression to evaluate for the filter.
|
|
:param scope: The scope for the filter.
|
|
:param payload: The payload to send with the filter.
|
|
|
|
:return: The created filter.
|
|
"""
|
|
return self.client.filters.create(
|
|
workflow_id=self.id,
|
|
expression=expression,
|
|
scope=scope,
|
|
payload=payload,
|
|
)
|
|
|
|
async def aio_create_filter(
|
|
self,
|
|
expression: str,
|
|
scope: str,
|
|
payload: JSONSerializableMapping | None = None,
|
|
) -> V1Filter:
|
|
"""
|
|
Create a new filter.
|
|
|
|
:param expression: The expression to evaluate for the filter.
|
|
:param scope: The scope for the filter.
|
|
:param payload: The payload to send with the filter.
|
|
|
|
:return: The created filter.
|
|
"""
|
|
return await self.client.filters.aio_create(
|
|
workflow_id=self.id,
|
|
expression=expression,
|
|
scope=scope,
|
|
payload=payload,
|
|
)
|
|
|
|
def schedule(
|
|
self,
|
|
run_at: datetime,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
|
|
) -> WorkflowVersion:
|
|
"""
|
|
Schedule a workflow to run at a specific time.
|
|
|
|
:param run_at: The time at which to schedule the workflow.
|
|
:param input: The input data for the workflow.
|
|
:param options: Additional options for workflow execution.
|
|
:returns: A `WorkflowVersion` object representing the scheduled workflow.
|
|
"""
|
|
return self.client._client.admin.schedule_workflow(
|
|
name=self.config.name,
|
|
schedules=cast(list[datetime | timestamp_pb2.Timestamp], [run_at]),
|
|
input=self._serialize_input(input),
|
|
options=options,
|
|
)
|
|
|
|
async def aio_schedule(
|
|
self,
|
|
run_at: datetime,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
|
|
) -> WorkflowVersion:
|
|
"""
|
|
Schedule a workflow to run at a specific time.
|
|
|
|
:param run_at: The time at which to schedule the workflow.
|
|
:param input: The input data for the workflow.
|
|
:param options: Additional options for workflow execution.
|
|
:returns: A `WorkflowVersion` object representing the scheduled workflow.
|
|
"""
|
|
return await self.client._client.admin.aio_schedule_workflow(
|
|
name=self.config.name,
|
|
schedules=cast(list[datetime | timestamp_pb2.Timestamp], [run_at]),
|
|
input=self._serialize_input(input),
|
|
options=options,
|
|
)
|
|
|
|
def create_cron(
|
|
self,
|
|
cron_name: str,
|
|
expression: str,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
additional_metadata: JSONSerializableMapping | None = None,
|
|
priority: int | None = None,
|
|
) -> CronWorkflows:
|
|
"""
|
|
Create a cron job for the workflow.
|
|
|
|
:param cron_name: The name of the cron job.
|
|
:param expression: The cron expression that defines the schedule for the cron job.
|
|
:param input: The input data for the workflow.
|
|
:param additional_metadata: Additional metadata for the cron job.
|
|
:param priority: The priority of the cron job. Must be between 1 and 3, inclusive.
|
|
|
|
:returns: A `CronWorkflows` object representing the created cron job.
|
|
"""
|
|
return self.client.cron.create(
|
|
workflow_name=self.config.name,
|
|
cron_name=cron_name,
|
|
expression=expression,
|
|
input=self._serialize_input(input),
|
|
additional_metadata=additional_metadata or {},
|
|
priority=priority,
|
|
)
|
|
|
|
async def aio_create_cron(
|
|
self,
|
|
cron_name: str,
|
|
expression: str,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
additional_metadata: JSONSerializableMapping | None = None,
|
|
priority: int | None = None,
|
|
) -> CronWorkflows:
|
|
"""
|
|
Create a cron job for the workflow.
|
|
|
|
:param cron_name: The name of the cron job.
|
|
:param expression: The cron expression that defines the schedule for the cron job.
|
|
:param input: The input data for the workflow.
|
|
:param additional_metadata: Additional metadata for the cron job.
|
|
:param priority: The priority of the cron job. Must be between 1 and 3, inclusive.
|
|
|
|
:returns: A `CronWorkflows` object representing the created cron job.
|
|
"""
|
|
return await self.client.cron.aio_create(
|
|
workflow_name=self.config.name,
|
|
cron_name=cron_name,
|
|
expression=expression,
|
|
input=self._serialize_input(input),
|
|
additional_metadata=additional_metadata or {},
|
|
priority=priority,
|
|
)
|
|
|
|
def delete(self) -> None:
|
|
"""
|
|
Permanently delete the workflow.
|
|
|
|
**DANGEROUS: This will delete a workflow and all of its data**
|
|
"""
|
|
self.client.workflows.delete(self.id)
|
|
|
|
async def aio_delete(self) -> None:
|
|
"""
|
|
Permanently delete the workflow.
|
|
|
|
**DANGEROUS: This will delete a workflow and all of its data**
|
|
"""
|
|
await self.client.workflows.aio_delete(self.id)
|
|
|
|
|
|
class Workflow(BaseWorkflow[TWorkflowInput]):
|
|
"""
|
|
A Hatchet workflow, which allows you to define tasks to be run and perform actions on the workflow.
|
|
|
|
Workflows in Hatchet represent coordinated units of work that can be triggered, scheduled, or run on a cron schedule.
|
|
Each workflow can contain multiple tasks that can be arranged in dependencies (DAGs), have customized retry behavior,
|
|
timeouts, concurrency controls, and more.
|
|
|
|
Example:
|
|
```python
|
|
from pydantic import BaseModel
|
|
from hatchet_sdk import Hatchet
|
|
|
|
class MyInput(BaseModel):
|
|
name: str
|
|
|
|
hatchet = Hatchet()
|
|
workflow = hatchet.workflow("my-workflow", input_type=MyInput)
|
|
|
|
@workflow.task()
|
|
def greet(input, ctx):
|
|
return f"Hello, {input.name}!"
|
|
|
|
# Run the workflow
|
|
result = workflow.run(MyInput(name="World"))
|
|
```
|
|
|
|
Workflows support various execution patterns including:
|
|
- One-time execution with `run()` or `aio_run()`
|
|
- Scheduled execution with `schedule()`
|
|
- Cron-based recurring execution with `create_cron()`
|
|
- Bulk operations with `run_many()`
|
|
|
|
Tasks within workflows can be defined with `@workflow.task()` or `@workflow.durable_task()` decorators
|
|
and can be arranged into complex dependency patterns.
|
|
"""
|
|
|
|
def run_no_wait(
|
|
self,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
) -> WorkflowRunRef:
|
|
"""
|
|
Synchronously trigger a workflow run without waiting for it to complete.
|
|
This method is useful for starting a workflow run and immediately returning a reference to the run without blocking while the workflow runs.
|
|
|
|
:param input: The input data for the workflow.
|
|
:param options: Additional options for workflow execution.
|
|
|
|
:returns: A `WorkflowRunRef` object representing the reference to the workflow run.
|
|
"""
|
|
return self.client._client.admin.run_workflow(
|
|
workflow_name=self.config.name,
|
|
input=self._serialize_input(input),
|
|
options=options,
|
|
)
|
|
|
|
def run(
|
|
self,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Run the workflow synchronously and wait for it to complete.
|
|
|
|
This method triggers a workflow run, blocks until completion, and returns the final result.
|
|
|
|
:param input: The input data for the workflow, must match the workflow's input type.
|
|
:param options: Additional options for workflow execution like metadata and parent workflow ID.
|
|
|
|
:returns: The result of the workflow execution as a dictionary.
|
|
"""
|
|
|
|
ref = self.client._client.admin.run_workflow(
|
|
workflow_name=self.config.name,
|
|
input=self._serialize_input(input),
|
|
options=options,
|
|
)
|
|
|
|
return ref.result()
|
|
|
|
async def aio_run_no_wait(
|
|
self,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
) -> WorkflowRunRef:
|
|
"""
|
|
Asynchronously trigger a workflow run without waiting for it to complete.
|
|
This method is useful for starting a workflow run and immediately returning a reference to the run without blocking while the workflow runs.
|
|
|
|
:param input: The input data for the workflow.
|
|
:param options: Additional options for workflow execution.
|
|
|
|
:returns: A `WorkflowRunRef` object representing the reference to the workflow run.
|
|
"""
|
|
|
|
return await self.client._client.admin.aio_run_workflow(
|
|
workflow_name=self.config.name,
|
|
input=self._serialize_input(input),
|
|
options=options,
|
|
)
|
|
|
|
async def aio_run(
|
|
self,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Run the workflow asynchronously and wait for it to complete.
|
|
|
|
This method triggers a workflow run, awaits until completion, and returns the final result.
|
|
|
|
:param input: The input data for the workflow, must match the workflow's input type.
|
|
:param options: Additional options for workflow execution like metadata and parent workflow ID.
|
|
|
|
:returns: The result of the workflow execution as a dictionary.
|
|
"""
|
|
ref = await self.client._client.admin.aio_run_workflow(
|
|
workflow_name=self.config.name,
|
|
input=self._serialize_input(input),
|
|
options=options,
|
|
)
|
|
|
|
return await ref.aio_result()
|
|
|
|
def _get_result(
|
|
self, ref: WorkflowRunRef, return_exceptions: bool
|
|
) -> dict[str, Any] | BaseException:
|
|
try:
|
|
return ref.result()
|
|
except Exception as e:
|
|
if return_exceptions:
|
|
return e
|
|
raise e
|
|
|
|
@overload
|
|
def run_many(
|
|
self,
|
|
workflows: list[WorkflowRunTriggerConfig],
|
|
return_exceptions: Literal[True],
|
|
) -> list[dict[str, Any] | BaseException]: ...
|
|
|
|
@overload
|
|
def run_many(
|
|
self,
|
|
workflows: list[WorkflowRunTriggerConfig],
|
|
return_exceptions: Literal[False] = False,
|
|
) -> list[dict[str, Any]]: ...
|
|
|
|
def run_many(
|
|
self,
|
|
workflows: list[WorkflowRunTriggerConfig],
|
|
return_exceptions: bool = False,
|
|
) -> list[dict[str, Any]] | list[dict[str, Any] | BaseException]:
|
|
"""
|
|
Run a workflow in bulk and wait for all runs to complete.
|
|
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
|
|
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
|
:returns: A list of results for each workflow run.
|
|
"""
|
|
refs = self.client._client.admin.run_workflows(
|
|
workflows=workflows,
|
|
)
|
|
|
|
return [self._get_result(ref, return_exceptions) for ref in refs]
|
|
|
|
@overload
|
|
async def aio_run_many(
|
|
self,
|
|
workflows: list[WorkflowRunTriggerConfig],
|
|
return_exceptions: Literal[True],
|
|
) -> list[dict[str, Any] | BaseException]: ...
|
|
|
|
@overload
|
|
async def aio_run_many(
|
|
self,
|
|
workflows: list[WorkflowRunTriggerConfig],
|
|
return_exceptions: Literal[False] = False,
|
|
) -> list[dict[str, Any]]: ...
|
|
|
|
async def aio_run_many(
|
|
self,
|
|
workflows: list[WorkflowRunTriggerConfig],
|
|
return_exceptions: bool = False,
|
|
) -> list[dict[str, Any]] | list[dict[str, Any] | BaseException]:
|
|
"""
|
|
Run a workflow in bulk and wait for all runs to complete.
|
|
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
|
|
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
|
:returns: A list of results for each workflow run.
|
|
"""
|
|
refs = await self.client._client.admin.aio_run_workflows(
|
|
workflows=workflows,
|
|
)
|
|
|
|
return await asyncio.gather(
|
|
*[ref.aio_result() for ref in refs], return_exceptions=return_exceptions
|
|
)
|
|
|
|
def run_many_no_wait(
|
|
self,
|
|
workflows: list[WorkflowRunTriggerConfig],
|
|
) -> list[WorkflowRunRef]:
|
|
"""
|
|
Run a workflow in bulk without waiting for all runs to complete.
|
|
|
|
This method triggers multiple workflow runs and immediately returns a list of references to the runs without blocking while the workflows run.
|
|
|
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
:returns: A list of `WorkflowRunRef` objects, each representing a reference to a workflow run.
|
|
"""
|
|
return self.client._client.admin.run_workflows(
|
|
workflows=workflows,
|
|
)
|
|
|
|
async def aio_run_many_no_wait(
|
|
self,
|
|
workflows: list[WorkflowRunTriggerConfig],
|
|
) -> list[WorkflowRunRef]:
|
|
"""
|
|
Run a workflow in bulk without waiting for all runs to complete.
|
|
|
|
This method triggers multiple workflow runs and immediately returns a list of references to the runs without blocking while the workflows run.
|
|
|
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
|
|
:returns: A list of `WorkflowRunRef` objects, each representing a reference to a workflow run.
|
|
"""
|
|
return await self.client._client.admin.aio_run_workflows(
|
|
workflows=workflows,
|
|
)
|
|
|
|
def _parse_task_name(
|
|
self,
|
|
name: str | None,
|
|
func: Callable[..., Any],
|
|
) -> str:
|
|
non_null_name = name or func.__name__
|
|
return non_null_name.lower()
|
|
|
|
def task(
|
|
self,
|
|
name: str | None = None,
|
|
schedule_timeout: Duration = timedelta(minutes=5),
|
|
execution_timeout: Duration = timedelta(seconds=60),
|
|
parents: list[Task[TWorkflowInput, Any]] | None = None,
|
|
retries: int = 0,
|
|
rate_limits: list[RateLimit] | None = None,
|
|
desired_worker_labels: dict[str, DesiredWorkerLabel] | None = None,
|
|
backoff_factor: float | None = None,
|
|
backoff_max_seconds: int | None = None,
|
|
concurrency: int | list[ConcurrencyExpression] | None = None,
|
|
wait_for: list[Condition | OrGroup] | None = None,
|
|
skip_if: list[Condition | OrGroup] | None = None,
|
|
cancel_if: list[Condition | OrGroup] | None = None,
|
|
) -> Callable[
|
|
[Callable[Concatenate[TWorkflowInput, Context, P], R | CoroutineLike[R]]],
|
|
Task[TWorkflowInput, R],
|
|
]:
|
|
"""
|
|
A decorator to transform a function into a Hatchet task that runs as part of a workflow.
|
|
|
|
:param name: The name of the task. If not specified, defaults to the name of the function being wrapped by the `task` decorator.
|
|
|
|
:param schedule_timeout: The maximum time to wait for the task to be scheduled. The run will be canceled if the task does not begin within this time.
|
|
|
|
:param execution_timeout: The maximum time to wait for the task to complete. The run will be canceled if the task does not complete within this time.
|
|
|
|
:param parents: A list of tasks that are parents of the task. Note: Parents must be defined before their children.
|
|
|
|
:param retries: The number of times to retry the task before failing.
|
|
|
|
:param rate_limits: A list of rate limit configurations for the task.
|
|
|
|
:param desired_worker_labels: A dictionary of desired worker labels that determine to which worker the task should be assigned. See documentation and examples on affinity and worker labels for more details.
|
|
|
|
:param backoff_factor: The backoff factor for controlling exponential backoff in retries.
|
|
|
|
:param backoff_max_seconds: The maximum number of seconds to allow retries with exponential backoff to continue.
|
|
|
|
:param concurrency: A list of concurrency expressions for the task. If an integer is provided, it is treated as a constant concurrency limit with a `GROUP_ROUND_ROBIN` strategy, which means that only `N` runs of the task may execute at any given time.
|
|
|
|
:param wait_for: A list of conditions that must be met before the task can run.
|
|
|
|
:param skip_if: A list of conditions that, if met, will cause the task to be skipped.
|
|
|
|
:param cancel_if: A list of conditions that, if met, will cause the task to be canceled.
|
|
|
|
:returns: A decorator which creates a `Task` object.
|
|
"""
|
|
|
|
computed_params = ComputedTaskParameters(
|
|
schedule_timeout=schedule_timeout,
|
|
execution_timeout=execution_timeout,
|
|
retries=retries,
|
|
backoff_factor=backoff_factor,
|
|
backoff_max_seconds=backoff_max_seconds,
|
|
task_defaults=self.config.task_defaults,
|
|
)
|
|
|
|
def inner(
|
|
func: Callable[
|
|
Concatenate[TWorkflowInput, Context, P], R | CoroutineLike[R]
|
|
],
|
|
) -> Task[TWorkflowInput, R]:
|
|
task = Task(
|
|
_fn=func,
|
|
is_durable=False,
|
|
workflow=self,
|
|
type=StepType.DEFAULT,
|
|
name=self._parse_task_name(name, func),
|
|
execution_timeout=computed_params.execution_timeout,
|
|
schedule_timeout=computed_params.schedule_timeout,
|
|
parents=parents,
|
|
retries=computed_params.retries,
|
|
rate_limits=[r.to_proto() for r in rate_limits or []],
|
|
desired_worker_labels={
|
|
key: transform_desired_worker_label(d)
|
|
for key, d in (desired_worker_labels or {}).items()
|
|
},
|
|
backoff_factor=computed_params.backoff_factor,
|
|
backoff_max_seconds=computed_params.backoff_max_seconds,
|
|
concurrency=concurrency,
|
|
wait_for=wait_for,
|
|
skip_if=skip_if,
|
|
cancel_if=cancel_if,
|
|
)
|
|
|
|
self._default_tasks.append(task)
|
|
|
|
return task
|
|
|
|
return inner
|
|
|
|
def durable_task(
|
|
self,
|
|
name: str | None = None,
|
|
schedule_timeout: Duration = timedelta(minutes=5),
|
|
execution_timeout: Duration = timedelta(seconds=60),
|
|
parents: list[Task[TWorkflowInput, Any]] | None = None,
|
|
retries: int = 0,
|
|
rate_limits: list[RateLimit] | None = None,
|
|
desired_worker_labels: dict[str, DesiredWorkerLabel] | None = None,
|
|
backoff_factor: float | None = None,
|
|
backoff_max_seconds: int | None = None,
|
|
concurrency: int | list[ConcurrencyExpression] | None = None,
|
|
wait_for: list[Condition | OrGroup] | None = None,
|
|
skip_if: list[Condition | OrGroup] | None = None,
|
|
cancel_if: list[Condition | OrGroup] | None = None,
|
|
) -> Callable[
|
|
[
|
|
Callable[
|
|
Concatenate[TWorkflowInput, DurableContext, P], R | CoroutineLike[R]
|
|
]
|
|
],
|
|
Task[TWorkflowInput, R],
|
|
]:
|
|
"""
|
|
A decorator to transform a function into a durable Hatchet task that runs as part of a workflow.
|
|
|
|
**IMPORTANT:** This decorator creates a _durable_ task, which works using Hatchet's durable execution capabilities. This is an advanced feature of Hatchet.
|
|
|
|
See the Hatchet docs for more information on durable execution to decide if this is right for you.
|
|
|
|
:param name: The name of the task. If not specified, defaults to the name of the function being wrapped by the `task` decorator.
|
|
|
|
:param schedule_timeout: The maximum time to wait for the task to be scheduled. The run will be canceled if the task does not begin within this time.
|
|
|
|
:param execution_timeout: The maximum time to wait for the task to complete. The run will be canceled if the task does not complete within this time.
|
|
|
|
:param parents: A list of tasks that are parents of the task. Note: Parents must be defined before their children.
|
|
|
|
:param retries: The number of times to retry the task before failing.
|
|
|
|
:param rate_limits: A list of rate limit configurations for the task.
|
|
|
|
:param desired_worker_labels: A dictionary of desired worker labels that determine to which worker the task should be assigned. See documentation and examples on affinity and worker labels for more details.
|
|
|
|
:param backoff_factor: The backoff factor for controlling exponential backoff in retries.
|
|
|
|
:param backoff_max_seconds: The maximum number of seconds to allow retries with exponential backoff to continue.
|
|
|
|
:param concurrency: A list of concurrency expressions for the task. If an integer is provided, it is treated as a constant concurrency limit with a `GROUP_ROUND_ROBIN` strategy, which means that only `N` runs of the task may execute at any given time.
|
|
|
|
:param wait_for: A list of conditions that must be met before the task can run.
|
|
|
|
:param skip_if: A list of conditions that, if met, will cause the task to be skipped.
|
|
|
|
:param cancel_if: A list of conditions that, if met, will cause the task to be canceled.
|
|
|
|
:returns: A decorator which creates a `Task` object.
|
|
"""
|
|
|
|
computed_params = ComputedTaskParameters(
|
|
schedule_timeout=schedule_timeout,
|
|
execution_timeout=execution_timeout,
|
|
retries=retries,
|
|
backoff_factor=backoff_factor,
|
|
backoff_max_seconds=backoff_max_seconds,
|
|
task_defaults=self.config.task_defaults,
|
|
)
|
|
|
|
def inner(
|
|
func: Callable[
|
|
Concatenate[TWorkflowInput, DurableContext, P], R | CoroutineLike[R]
|
|
],
|
|
) -> Task[TWorkflowInput, R]:
|
|
task = Task(
|
|
_fn=func,
|
|
is_durable=True,
|
|
workflow=self,
|
|
type=StepType.DEFAULT,
|
|
name=self._parse_task_name(name, func),
|
|
execution_timeout=computed_params.execution_timeout,
|
|
schedule_timeout=computed_params.schedule_timeout,
|
|
parents=parents,
|
|
retries=computed_params.retries,
|
|
rate_limits=[r.to_proto() for r in rate_limits or []],
|
|
desired_worker_labels={
|
|
key: transform_desired_worker_label(d)
|
|
for key, d in (desired_worker_labels or {}).items()
|
|
},
|
|
backoff_factor=computed_params.backoff_factor,
|
|
backoff_max_seconds=computed_params.backoff_max_seconds,
|
|
concurrency=concurrency,
|
|
wait_for=wait_for,
|
|
skip_if=skip_if,
|
|
cancel_if=cancel_if,
|
|
)
|
|
|
|
self._durable_tasks.append(task)
|
|
|
|
return task
|
|
|
|
return inner
|
|
|
|
def on_failure_task(
|
|
self,
|
|
name: str | None = None,
|
|
schedule_timeout: Duration = timedelta(minutes=5),
|
|
execution_timeout: Duration = timedelta(seconds=60),
|
|
retries: int = 0,
|
|
rate_limits: list[RateLimit] | None = None,
|
|
backoff_factor: float | None = None,
|
|
backoff_max_seconds: int | None = None,
|
|
concurrency: int | list[ConcurrencyExpression] | None = None,
|
|
) -> Callable[
|
|
[Callable[Concatenate[TWorkflowInput, Context, P], R | CoroutineLike[R]]],
|
|
Task[TWorkflowInput, R],
|
|
]:
|
|
"""
|
|
A decorator to transform a function into a Hatchet on-failure task that runs as the last step in a workflow that had at least one task fail.
|
|
|
|
:param name: The name of the on-failure task. If not specified, defaults to the name of the function being wrapped by the `on_failure_task` decorator.
|
|
|
|
:param schedule_timeout: The maximum time to wait for the task to be scheduled. The run will be canceled if the task does not begin within this time.
|
|
|
|
:param execution_timeout: The maximum time to wait for the task to complete. The run will be canceled if the task does not complete within this time.
|
|
|
|
:param retries: The number of times to retry the on-failure task before failing.
|
|
|
|
:param rate_limits: A list of rate limit configurations for the on-failure task.
|
|
|
|
:param backoff_factor: The backoff factor for controlling exponential backoff in retries.
|
|
|
|
:param backoff_max_seconds: The maximum number of seconds to allow retries with exponential backoff to continue.
|
|
|
|
:param concurrency: A list of concurrency expressions for the on-failure task. If an integer is provided, it is treated as a constant concurrency limit with a `GROUP_ROUND_ROBIN` strategy, which means that only `N` runs of the task may execute at any given time.
|
|
|
|
:returns: A decorator which creates a `Task` object.
|
|
"""
|
|
|
|
def inner(
|
|
func: Callable[
|
|
Concatenate[TWorkflowInput, Context, P], R | CoroutineLike[R]
|
|
],
|
|
) -> Task[TWorkflowInput, R]:
|
|
task = Task(
|
|
is_durable=False,
|
|
_fn=func,
|
|
workflow=self,
|
|
type=StepType.ON_FAILURE,
|
|
name=self._parse_task_name(name, func) + "-on-failure",
|
|
execution_timeout=execution_timeout,
|
|
schedule_timeout=schedule_timeout,
|
|
retries=retries,
|
|
rate_limits=[r.to_proto() for r in rate_limits or []],
|
|
backoff_factor=backoff_factor,
|
|
backoff_max_seconds=backoff_max_seconds,
|
|
concurrency=concurrency,
|
|
desired_worker_labels=None,
|
|
parents=None,
|
|
wait_for=None,
|
|
skip_if=None,
|
|
cancel_if=None,
|
|
)
|
|
|
|
if self._on_failure_task:
|
|
raise ValueError("Only one on-failure task is allowed")
|
|
|
|
self._on_failure_task = task
|
|
|
|
return task
|
|
|
|
return inner
|
|
|
|
def on_success_task(
|
|
self,
|
|
name: str | None = None,
|
|
schedule_timeout: Duration = timedelta(minutes=5),
|
|
execution_timeout: Duration = timedelta(seconds=60),
|
|
retries: int = 0,
|
|
rate_limits: list[RateLimit] | None = None,
|
|
backoff_factor: float | None = None,
|
|
backoff_max_seconds: int | None = None,
|
|
concurrency: int | list[ConcurrencyExpression] | None = None,
|
|
) -> Callable[
|
|
[Callable[Concatenate[TWorkflowInput, Context, P], R | CoroutineLike[R]]],
|
|
Task[TWorkflowInput, R],
|
|
]:
|
|
"""
|
|
A decorator to transform a function into a Hatchet on-success task that runs as the last step in a workflow that had all upstream tasks succeed.
|
|
|
|
:param name: The name of the on-success task. If not specified, defaults to the name of the function being wrapped by the `on_success_task` decorator.
|
|
|
|
:param schedule_timeout: The maximum time to wait for the task to be scheduled. The run will be canceled if the task does not begin within this time.
|
|
|
|
:param execution_timeout: The maximum time to wait for the task to complete. The run will be canceled if the task does not complete within this time.
|
|
|
|
:param retries: The number of times to retry the on-success task before failing
|
|
|
|
:param rate_limits: A list of rate limit configurations for the on-success task.
|
|
|
|
:param backoff_factor: The backoff factor for controlling exponential backoff in retries.
|
|
|
|
:param backoff_max_seconds: The maximum number of seconds to allow retries with exponential backoff to continue.
|
|
|
|
:param concurrency: A list of concurrency expressions for the on-success task. If an integer is provided, it is treated as a constant concurrency limit with a `GROUP_ROUND_ROBIN` strategy, which means that only `N` runs of the task may execute at any given time.
|
|
|
|
:returns: A decorator which creates a Task object.
|
|
"""
|
|
|
|
def inner(
|
|
func: Callable[
|
|
Concatenate[TWorkflowInput, Context, P], R | CoroutineLike[R]
|
|
],
|
|
) -> Task[TWorkflowInput, R]:
|
|
task = Task(
|
|
is_durable=False,
|
|
_fn=func,
|
|
workflow=self,
|
|
type=StepType.ON_SUCCESS,
|
|
name=self._parse_task_name(name, func) + "-on-success",
|
|
execution_timeout=execution_timeout,
|
|
schedule_timeout=schedule_timeout,
|
|
retries=retries,
|
|
rate_limits=[r.to_proto() for r in rate_limits or []],
|
|
backoff_factor=backoff_factor,
|
|
backoff_max_seconds=backoff_max_seconds,
|
|
concurrency=concurrency,
|
|
parents=None,
|
|
desired_worker_labels=None,
|
|
wait_for=None,
|
|
skip_if=None,
|
|
cancel_if=None,
|
|
)
|
|
|
|
if self._on_success_task:
|
|
raise ValueError("Only one on-success task is allowed")
|
|
|
|
self._on_success_task = task
|
|
|
|
return task
|
|
|
|
return inner
|
|
|
|
def add_task(self, task: "Standalone[TWorkflowInput, Any]") -> None:
|
|
"""
|
|
Add a task to a workflow. Intended to be used with a previously existing task (a Standalone),
|
|
such as one created with `@hatchet.task()`, which has been converted to a `Task` object using `to_task`.
|
|
|
|
For example:
|
|
|
|
```python
|
|
@hatchet.task()
|
|
def my_task(input, ctx) -> None:
|
|
pass
|
|
|
|
wf = hatchet.workflow()
|
|
|
|
wf.add_task(my_task.to_task())
|
|
```
|
|
"""
|
|
_task = task._task
|
|
|
|
match _task.type:
|
|
case StepType.DEFAULT:
|
|
self._default_tasks.append(_task)
|
|
case StepType.ON_FAILURE:
|
|
if self._on_failure_task:
|
|
raise ValueError("Only one on-failure task is allowed")
|
|
|
|
self._on_failure_task = _task
|
|
case StepType.ON_SUCCESS:
|
|
if self._on_success_task:
|
|
raise ValueError("Only one on-success task is allowed")
|
|
|
|
self._on_success_task = _task
|
|
case _:
|
|
raise ValueError("Invalid task type")
|
|
|
|
|
|
class TaskRunRef(Generic[TWorkflowInput, R]):
|
|
def __init__(
|
|
self,
|
|
standalone: "Standalone[TWorkflowInput, R]",
|
|
workflow_run_ref: WorkflowRunRef,
|
|
):
|
|
self._s = standalone
|
|
self._wrr = workflow_run_ref
|
|
|
|
self.workflow_run_id = workflow_run_ref.workflow_run_id
|
|
|
|
def __str__(self) -> str:
|
|
return self.workflow_run_id
|
|
|
|
async def aio_result(self) -> R:
|
|
result = await self._wrr.workflow_run_listener.aio_result(
|
|
self._wrr.workflow_run_id
|
|
)
|
|
return self._s._extract_result(result)
|
|
|
|
def result(self) -> R:
|
|
result = self._wrr.result()
|
|
|
|
return self._s._extract_result(result)
|
|
|
|
def stream(self) -> RunEventListener:
|
|
return self._wrr.stream()
|
|
|
|
|
|
class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
|
def __init__(
|
|
self, workflow: Workflow[TWorkflowInput], task: Task[TWorkflowInput, R]
|
|
) -> None:
|
|
super().__init__(config=workflow.config, client=workflow.client)
|
|
|
|
## NOTE: This is a hack to assign the task back to the base workflow,
|
|
## since the decorator to mutate the tasks is not being called.
|
|
self._default_tasks = [task]
|
|
|
|
self._workflow = workflow
|
|
self._task = task
|
|
|
|
return_type = get_type_hints(self._task.fn).get("return")
|
|
|
|
self._output_validator: TypeAdapter[TaskPayloadForInternalUse] = TypeAdapter(
|
|
normalize_validator(return_type)
|
|
)
|
|
|
|
self.config = self._workflow.config
|
|
|
|
@overload
|
|
def _extract_result(self, result: dict[str, Any]) -> R: ...
|
|
|
|
@overload
|
|
def _extract_result(self, result: BaseException) -> BaseException: ...
|
|
|
|
def _extract_result(
|
|
self, result: dict[str, Any] | BaseException
|
|
) -> R | BaseException:
|
|
if isinstance(result, BaseException):
|
|
return result
|
|
|
|
## if a task is cancelled, we can get `None` back here
|
|
## this is a bit of an edge case since both `None` and an empty dict
|
|
## would cause Pydantic validation errors, but if you were expecting a `dict`
|
|
## return, then the empty dict would not error and would work correctly
|
|
output = result.get(self._task.name) or {}
|
|
|
|
return cast(R, self._output_validator.validate_python(output))
|
|
|
|
def run(
|
|
self,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
) -> R:
|
|
"""
|
|
Run the workflow synchronously and wait for it to complete.
|
|
|
|
This method triggers a workflow run, blocks until completion, and returns the extracted result.
|
|
|
|
:param input: The input data for the workflow.
|
|
:param options: Additional options for workflow execution.
|
|
|
|
:returns: The extracted result of the workflow execution.
|
|
"""
|
|
return self._extract_result(self._workflow.run(input, options))
|
|
|
|
async def aio_run(
|
|
self,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
) -> R:
|
|
"""
|
|
Run the workflow asynchronously and wait for it to complete.
|
|
|
|
This method triggers a workflow run, awaits until completion, and returns the extracted result.
|
|
|
|
:param input: The input data for the workflow, must match the workflow's input type.
|
|
:param options: Additional options for workflow execution like metadata and parent workflow ID.
|
|
|
|
:returns: The extracted result of the workflow execution.
|
|
"""
|
|
result = await self._workflow.aio_run(input, options)
|
|
return self._extract_result(result)
|
|
|
|
def run_no_wait(
|
|
self,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
) -> TaskRunRef[TWorkflowInput, R]:
|
|
"""
|
|
Trigger a workflow run without waiting for it to complete.
|
|
|
|
This method triggers a workflow run and immediately returns a reference to the run without blocking while the workflow runs.
|
|
|
|
:param input: The input data for the workflow, must match the workflow's input type.
|
|
:param options: Additional options for workflow execution like metadata and parent workflow ID.
|
|
|
|
:returns: A `TaskRunRef` object representing the reference to the workflow run.
|
|
"""
|
|
ref = self._workflow.run_no_wait(input, options)
|
|
|
|
return TaskRunRef[TWorkflowInput, R](self, ref)
|
|
|
|
async def aio_run_no_wait(
|
|
self,
|
|
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
|
|
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
) -> TaskRunRef[TWorkflowInput, R]:
|
|
"""
|
|
Asynchronously trigger a workflow run without waiting for it to complete.
|
|
This method is useful for starting a workflow run and immediately returning a reference to the run without blocking while the workflow runs.
|
|
|
|
:param input: The input data for the workflow.
|
|
:param options: Additional options for workflow execution.
|
|
|
|
:returns: A `TaskRunRef` object representing the reference to the workflow run.
|
|
"""
|
|
ref = await self._workflow.aio_run_no_wait(input, options)
|
|
|
|
return TaskRunRef[TWorkflowInput, R](self, ref)
|
|
|
|
@overload
|
|
def run_many(
|
|
self,
|
|
workflows: list[WorkflowRunTriggerConfig],
|
|
return_exceptions: Literal[True],
|
|
) -> list[R | BaseException]: ...
|
|
|
|
@overload
|
|
def run_many(
|
|
self,
|
|
workflows: list[WorkflowRunTriggerConfig],
|
|
return_exceptions: Literal[False] = False,
|
|
) -> list[R]: ...
|
|
|
|
def run_many(
|
|
self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
|
|
) -> list[R] | list[R | BaseException]:
|
|
"""
|
|
Run a workflow in bulk and wait for all runs to complete.
|
|
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
|
|
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
|
:returns: A list of results for each workflow run.
|
|
"""
|
|
return [
|
|
self._extract_result(result)
|
|
for result in self._workflow.run_many(
|
|
workflows,
|
|
## hack: typing needs literal
|
|
True if return_exceptions else False, # noqa: SIM210
|
|
)
|
|
]
|
|
|
|
@overload
|
|
async def aio_run_many(
|
|
self,
|
|
workflows: list[WorkflowRunTriggerConfig],
|
|
return_exceptions: Literal[True],
|
|
) -> list[R | BaseException]: ...
|
|
|
|
@overload
|
|
async def aio_run_many(
|
|
self,
|
|
workflows: list[WorkflowRunTriggerConfig],
|
|
return_exceptions: Literal[False] = False,
|
|
) -> list[R]: ...
|
|
|
|
async def aio_run_many(
|
|
self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
|
|
) -> list[R] | list[R | BaseException]:
|
|
"""
|
|
Run a workflow in bulk and wait for all runs to complete.
|
|
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
|
|
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
|
:returns: A list of results for each workflow run.
|
|
"""
|
|
return [
|
|
self._extract_result(result)
|
|
for result in await self._workflow.aio_run_many(
|
|
workflows,
|
|
## hack: typing needs literal
|
|
True if return_exceptions else False, # noqa: SIM210
|
|
)
|
|
]
|
|
|
|
def run_many_no_wait(
|
|
self, workflows: list[WorkflowRunTriggerConfig]
|
|
) -> list[TaskRunRef[TWorkflowInput, R]]:
|
|
"""
|
|
Run a workflow in bulk without waiting for all runs to complete.
|
|
|
|
This method triggers multiple workflow runs and immediately returns a list of references to the runs without blocking while the workflows run.
|
|
|
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
:returns: A list of `WorkflowRunRef` objects, each representing a reference to a workflow run.
|
|
"""
|
|
refs = self._workflow.run_many_no_wait(workflows)
|
|
|
|
return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
|
|
|
|
async def aio_run_many_no_wait(
|
|
self, workflows: list[WorkflowRunTriggerConfig]
|
|
) -> list[TaskRunRef[TWorkflowInput, R]]:
|
|
"""
|
|
Run a workflow in bulk without waiting for all runs to complete.
|
|
|
|
This method triggers multiple workflow runs and immediately returns a list of references to the runs without blocking while the workflows run.
|
|
|
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
|
|
:returns: A list of `WorkflowRunRef` objects, each representing a reference to a workflow run.
|
|
"""
|
|
refs = await self._workflow.aio_run_many_no_wait(workflows)
|
|
|
|
return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
|
|
|
|
def mock_run(
|
|
self,
|
|
input: TWorkflowInput | None = None,
|
|
additional_metadata: JSONSerializableMapping | None = None,
|
|
parent_outputs: dict[str, JSONSerializableMapping] | None = None,
|
|
retry_count: int = 0,
|
|
lifespan: Any = None,
|
|
dependencies: dict[str, Any] | None = None,
|
|
) -> R:
|
|
"""
|
|
Mimic the execution of a task. This method is intended to be used to unit test
|
|
tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
|
|
tasks and `aio_mock_run` for async tasks.
|
|
|
|
:param input: The input to the task.
|
|
:param additional_metadata: Additional metadata to attach to the task.
|
|
:param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
|
|
:param retry_count: The number of times the task has been retried.
|
|
:param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
|
|
:param dependencies: Dependencies to be injected into the task. This is useful for tasks that have dependencies defined using `Depends`. **IMPORTANT**: You must pass the dependencies _directly_, **not** the `Depends` objects themselves. For example, if you have a task that has a dependency `config: Annotated[str, Depends(get_config)]`, you should pass `dependencies={"config": "config_value"}` to `aio_mock_run`.
|
|
|
|
:return: The output of the task.
|
|
"""
|
|
|
|
return self._task.mock_run(
|
|
input=input,
|
|
additional_metadata=additional_metadata,
|
|
parent_outputs=parent_outputs,
|
|
retry_count=retry_count,
|
|
lifespan=lifespan,
|
|
dependencies=dependencies,
|
|
)
|
|
|
|
async def aio_mock_run(
|
|
self,
|
|
input: TWorkflowInput | None = None,
|
|
additional_metadata: JSONSerializableMapping | None = None,
|
|
parent_outputs: dict[str, JSONSerializableMapping] | None = None,
|
|
retry_count: int = 0,
|
|
lifespan: Any = None,
|
|
dependencies: dict[str, Any] | None = None,
|
|
) -> R:
|
|
"""
|
|
Mimic the execution of a task. This method is intended to be used to unit test
|
|
tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
|
|
tasks and `aio_mock_run` for async tasks.
|
|
|
|
:param input: The input to the task.
|
|
:param additional_metadata: Additional metadata to attach to the task.
|
|
:param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
|
|
:param retry_count: The number of times the task has been retried.
|
|
:param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
|
|
:param dependencies: Dependencies to be injected into the task. This is useful for tasks that have dependencies defined using `Depends`. **IMPORTANT**: You must pass the dependencies _directly_, **not** the `Depends` objects themselves. For example, if you have a task that has a dependency `config: Annotated[str, Depends(get_config)]`, you should pass `dependencies={"config": "config_value"}` to `aio_mock_run`.
|
|
|
|
:return: The output of the task.
|
|
"""
|
|
|
|
return await self._task.aio_mock_run(
|
|
input=input,
|
|
additional_metadata=additional_metadata,
|
|
parent_outputs=parent_outputs,
|
|
retry_count=retry_count,
|
|
lifespan=lifespan,
|
|
dependencies=dependencies,
|
|
)
|
|
|
|
@property
|
|
def is_async_function(self) -> bool:
|
|
"""
|
|
Check if the task is an async function.
|
|
|
|
:returns: True if the task is an async function, False otherwise.
|
|
"""
|
|
return self._task.is_async_function
|
|
|
|
def get_run_ref(self, run_id: str) -> TaskRunRef[TWorkflowInput, R]:
|
|
"""
|
|
Get a reference to a task run by its run ID.
|
|
|
|
:param run_id: The ID of the run to get the reference for.
|
|
:returns: A `TaskRunRef` object representing the reference to the task run.
|
|
"""
|
|
wrr = self._workflow.client._client.runs.get_run_ref(run_id)
|
|
return TaskRunRef[TWorkflowInput, R](self, wrr)
|
|
|
|
async def aio_get_result(self, run_id: str) -> R:
|
|
"""
|
|
Get the result of a task run by its run ID.
|
|
|
|
:param run_id: The ID of the run to get the result for.
|
|
:returns: The result of the task run.
|
|
"""
|
|
run_ref = self.get_run_ref(run_id)
|
|
|
|
return await run_ref.aio_result()
|
|
|
|
def get_result(self, run_id: str) -> R:
|
|
"""
|
|
Get the result of a task run by its run ID.
|
|
|
|
:param run_id: The ID of the run to get the result for.
|
|
:returns: The result of the task run.
|
|
"""
|
|
run_ref = self.get_run_ref(run_id)
|
|
|
|
return run_ref.result()
|