mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-05-07 18:39:17 -05:00
Fix: Spawn index incorrect on retry (#1752)
* fix: child index on retry not being reset correctly * chore: ver * chore: changelog * feat: test
This commit is contained in:
@@ -5,6 +5,12 @@ All notable changes to Hatchet's Python SDK will be documented in this changelog
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [1.10.2] - 2025-05-19
|
||||
|
||||
### Changed
|
||||
|
||||
- Fixing an issue with the spawn index being set at the `workflow_run_id` level and not the `(workflow_run_id, retry_count)` level, causing children to be spawned multiple times on retry.
|
||||
|
||||
## [1.10.1] - 2025-05-16
|
||||
|
||||
### Added
|
||||
|
||||
@@ -20,6 +20,7 @@ from hatchet_sdk.features.runs import RunsClient
|
||||
from hatchet_sdk.metadata import get_metadata
|
||||
from hatchet_sdk.rate_limit import RateLimitDuration
|
||||
from hatchet_sdk.runnables.contextvars import (
|
||||
ctx_action_key,
|
||||
ctx_step_run_id,
|
||||
ctx_worker_id,
|
||||
ctx_workflow_run_id,
|
||||
@@ -281,11 +282,12 @@ class AdminClient:
|
||||
workflow_run_id = ctx_workflow_run_id.get()
|
||||
step_run_id = ctx_step_run_id.get()
|
||||
worker_id = ctx_worker_id.get()
|
||||
spawn_index = workflow_spawn_indices[workflow_run_id] if workflow_run_id else 0
|
||||
action_key = ctx_action_key.get()
|
||||
spawn_index = workflow_spawn_indices[action_key] if action_key else 0
|
||||
|
||||
## Increment the spawn_index for the parent workflow
|
||||
if workflow_run_id:
|
||||
workflow_spawn_indices[workflow_run_id] += 1
|
||||
if action_key:
|
||||
workflow_spawn_indices[action_key] += 1
|
||||
|
||||
desired_worker_id = (
|
||||
(options.desired_worker_id or worker_id) if options.sticky else None
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from dataclasses import field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, cast
|
||||
|
||||
import grpc
|
||||
import grpc.aio
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from hatchet_sdk.clients.event_ts import (
|
||||
ThreadSafeEvent,
|
||||
@@ -30,8 +28,8 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
|
||||
from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.metadata import get_metadata
|
||||
from hatchet_sdk.runnables.action import Action, ActionPayload, ActionType
|
||||
from hatchet_sdk.utils.backoff import exp_backoff_sleep
|
||||
from hatchet_sdk.utils.opentelemetry import OTelAttribute
|
||||
from hatchet_sdk.utils.proto_enums import convert_proto_enum_to_python
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
|
||||
@@ -67,120 +65,6 @@ class GetActionListenerRequest(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class ActionPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
input: JSONSerializableMapping = Field(default_factory=dict)
|
||||
parents: dict[str, JSONSerializableMapping] = Field(default_factory=dict)
|
||||
overrides: JSONSerializableMapping = Field(default_factory=dict)
|
||||
user_data: JSONSerializableMapping = Field(default_factory=dict)
|
||||
step_run_errors: dict[str, str] = Field(default_factory=dict)
|
||||
triggered_by: str | None = None
|
||||
triggers: JSONSerializableMapping = Field(default_factory=dict)
|
||||
filter_payload: JSONSerializableMapping = Field(default_factory=dict)
|
||||
|
||||
@field_validator(
|
||||
"input",
|
||||
"parents",
|
||||
"overrides",
|
||||
"user_data",
|
||||
"step_run_errors",
|
||||
"filter_payload",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def validate_fields(cls, v: Any) -> Any:
|
||||
return v or {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_filter_payload(self) -> "ActionPayload":
|
||||
self.filter_payload = self.triggers.get("filter_payload", {})
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
START_STEP_RUN = "START_STEP_RUN"
|
||||
CANCEL_STEP_RUN = "CANCEL_STEP_RUN"
|
||||
START_GET_GROUP_KEY = "START_GET_GROUP_KEY"
|
||||
|
||||
|
||||
ActionKey = str
|
||||
|
||||
|
||||
class Action(BaseModel):
|
||||
worker_id: str
|
||||
tenant_id: str
|
||||
workflow_run_id: str
|
||||
workflow_id: str | None = None
|
||||
workflow_version_id: str | None = None
|
||||
get_group_key_run_id: str
|
||||
job_id: str
|
||||
job_name: str
|
||||
job_run_id: str
|
||||
step_id: str
|
||||
step_run_id: str
|
||||
action_id: str
|
||||
action_type: ActionType
|
||||
retry_count: int
|
||||
action_payload: ActionPayload
|
||||
additional_metadata: JSONSerializableMapping = field(default_factory=dict)
|
||||
|
||||
child_workflow_index: int | None = None
|
||||
child_workflow_key: str | None = None
|
||||
parent_workflow_run_id: str | None = None
|
||||
|
||||
priority: int | None = None
|
||||
|
||||
def _dump_payload_to_str(self) -> str:
|
||||
try:
|
||||
return json.dumps(self.action_payload.model_dump(), default=str)
|
||||
except Exception:
|
||||
return str(self.action_payload)
|
||||
|
||||
def get_otel_attributes(self, config: "ClientConfig") -> dict[str, str | int]:
|
||||
try:
|
||||
payload_str = json.dumps(self.action_payload.model_dump(), default=str)
|
||||
except Exception:
|
||||
payload_str = str(self.action_payload)
|
||||
|
||||
attrs: dict[OTelAttribute, str | int | None] = {
|
||||
OTelAttribute.TENANT_ID: self.tenant_id,
|
||||
OTelAttribute.WORKER_ID: self.worker_id,
|
||||
OTelAttribute.WORKFLOW_RUN_ID: self.workflow_run_id,
|
||||
OTelAttribute.STEP_ID: self.step_id,
|
||||
OTelAttribute.STEP_RUN_ID: self.step_run_id,
|
||||
OTelAttribute.RETRY_COUNT: self.retry_count,
|
||||
OTelAttribute.PARENT_WORKFLOW_RUN_ID: self.parent_workflow_run_id,
|
||||
OTelAttribute.CHILD_WORKFLOW_INDEX: self.child_workflow_index,
|
||||
OTelAttribute.CHILD_WORKFLOW_KEY: self.child_workflow_key,
|
||||
OTelAttribute.ACTION_PAYLOAD: payload_str,
|
||||
OTelAttribute.WORKFLOW_NAME: self.job_name,
|
||||
OTelAttribute.ACTION_NAME: self.action_id,
|
||||
OTelAttribute.GET_GROUP_KEY_RUN_ID: self.get_group_key_run_id,
|
||||
OTelAttribute.WORKFLOW_ID: self.workflow_id,
|
||||
OTelAttribute.WORKFLOW_VERSION_ID: self.workflow_version_id,
|
||||
}
|
||||
|
||||
return {
|
||||
f"hatchet.{k.value}": v
|
||||
for k, v in attrs.items()
|
||||
if v and k not in config.otel.excluded_attributes
|
||||
}
|
||||
|
||||
@property
|
||||
def key(self) -> ActionKey:
|
||||
"""
|
||||
This key is used to uniquely identify a single step run by its id + retry count.
|
||||
It's used when storing references to a task, a context, etc. in a dictionary so that
|
||||
we can look up those items in the dictionary by a unique key.
|
||||
"""
|
||||
if self.action_type == ActionType.START_GET_GROUP_KEY:
|
||||
return f"{self.get_group_key_run_id}/{self.retry_count}"
|
||||
else:
|
||||
return f"{self.step_run_id}/{self.retry_count}"
|
||||
|
||||
|
||||
def parse_additional_metadata(additional_metadata: str) -> JSONSerializableMapping:
|
||||
try:
|
||||
return cast(
|
||||
|
||||
@@ -4,7 +4,6 @@ import grpc.aio
|
||||
from google.protobuf.timestamp_pb2 import Timestamp
|
||||
|
||||
from hatchet_sdk.clients.dispatcher.action_listener import (
|
||||
Action,
|
||||
ActionListener,
|
||||
GetActionListenerRequest,
|
||||
)
|
||||
@@ -29,6 +28,7 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
|
||||
)
|
||||
from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub
|
||||
from hatchet_sdk.metadata import get_metadata
|
||||
from hatchet_sdk.runnables.action import Action
|
||||
|
||||
DEFAULT_REGISTER_TIMEOUT = 30
|
||||
|
||||
|
||||
@@ -33,13 +33,13 @@ from hatchet_sdk.clients.admin import (
|
||||
TriggerWorkflowOptions,
|
||||
WorkflowRunTriggerConfig,
|
||||
)
|
||||
from hatchet_sdk.clients.dispatcher.action_listener import Action
|
||||
from hatchet_sdk.clients.events import (
|
||||
BulkPushEventWithMetadata,
|
||||
EventClient,
|
||||
PushEventOptions,
|
||||
)
|
||||
from hatchet_sdk.contracts.events_pb2 import Event
|
||||
from hatchet_sdk.runnables.action import Action
|
||||
from hatchet_sdk.worker.runner.runner import Runner
|
||||
from hatchet_sdk.workflow_run import WorkflowRunRef
|
||||
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
import json
|
||||
from dataclasses import field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from hatchet_sdk.utils.opentelemetry import OTelAttribute
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from hatchet_sdk.config import ClientConfig
|
||||
|
||||
ActionKey = str
|
||||
|
||||
|
||||
class ActionPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
input: JSONSerializableMapping = Field(default_factory=dict)
|
||||
parents: dict[str, JSONSerializableMapping] = Field(default_factory=dict)
|
||||
overrides: JSONSerializableMapping = Field(default_factory=dict)
|
||||
user_data: JSONSerializableMapping = Field(default_factory=dict)
|
||||
step_run_errors: dict[str, str] = Field(default_factory=dict)
|
||||
triggered_by: str | None = None
|
||||
triggers: JSONSerializableMapping = Field(default_factory=dict)
|
||||
filter_payload: JSONSerializableMapping = Field(default_factory=dict)
|
||||
|
||||
@field_validator(
|
||||
"input",
|
||||
"parents",
|
||||
"overrides",
|
||||
"user_data",
|
||||
"step_run_errors",
|
||||
"filter_payload",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def validate_fields(cls, v: Any) -> Any:
|
||||
return v or {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_filter_payload(self) -> "ActionPayload":
|
||||
self.filter_payload = self.triggers.get("filter_payload", {})
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
START_STEP_RUN = "START_STEP_RUN"
|
||||
CANCEL_STEP_RUN = "CANCEL_STEP_RUN"
|
||||
START_GET_GROUP_KEY = "START_GET_GROUP_KEY"
|
||||
|
||||
|
||||
class Action(BaseModel):
|
||||
worker_id: str
|
||||
tenant_id: str
|
||||
workflow_run_id: str
|
||||
workflow_id: str | None = None
|
||||
workflow_version_id: str | None = None
|
||||
get_group_key_run_id: str
|
||||
job_id: str
|
||||
job_name: str
|
||||
job_run_id: str
|
||||
step_id: str
|
||||
step_run_id: str
|
||||
action_id: str
|
||||
action_type: ActionType
|
||||
retry_count: int
|
||||
action_payload: ActionPayload
|
||||
additional_metadata: JSONSerializableMapping = field(default_factory=dict)
|
||||
|
||||
child_workflow_index: int | None = None
|
||||
child_workflow_key: str | None = None
|
||||
parent_workflow_run_id: str | None = None
|
||||
|
||||
priority: int | None = None
|
||||
|
||||
def _dump_payload_to_str(self) -> str:
|
||||
try:
|
||||
return json.dumps(self.action_payload.model_dump(), default=str)
|
||||
except Exception:
|
||||
return str(self.action_payload)
|
||||
|
||||
def get_otel_attributes(self, config: "ClientConfig") -> dict[str, str | int]:
|
||||
try:
|
||||
payload_str = json.dumps(self.action_payload.model_dump(), default=str)
|
||||
except Exception:
|
||||
payload_str = str(self.action_payload)
|
||||
|
||||
attrs: dict[OTelAttribute, str | int | None] = {
|
||||
OTelAttribute.TENANT_ID: self.tenant_id,
|
||||
OTelAttribute.WORKER_ID: self.worker_id,
|
||||
OTelAttribute.WORKFLOW_RUN_ID: self.workflow_run_id,
|
||||
OTelAttribute.STEP_ID: self.step_id,
|
||||
OTelAttribute.STEP_RUN_ID: self.step_run_id,
|
||||
OTelAttribute.RETRY_COUNT: self.retry_count,
|
||||
OTelAttribute.PARENT_WORKFLOW_RUN_ID: self.parent_workflow_run_id,
|
||||
OTelAttribute.CHILD_WORKFLOW_INDEX: self.child_workflow_index,
|
||||
OTelAttribute.CHILD_WORKFLOW_KEY: self.child_workflow_key,
|
||||
OTelAttribute.ACTION_PAYLOAD: payload_str,
|
||||
OTelAttribute.WORKFLOW_NAME: self.job_name,
|
||||
OTelAttribute.ACTION_NAME: self.action_id,
|
||||
OTelAttribute.GET_GROUP_KEY_RUN_ID: self.get_group_key_run_id,
|
||||
OTelAttribute.WORKFLOW_ID: self.workflow_id,
|
||||
OTelAttribute.WORKFLOW_VERSION_ID: self.workflow_version_id,
|
||||
}
|
||||
|
||||
return {
|
||||
f"hatchet.{k.value}": v
|
||||
for k, v in attrs.items()
|
||||
if v and k not in config.otel.excluded_attributes
|
||||
}
|
||||
|
||||
@property
|
||||
def key(self) -> ActionKey:
|
||||
"""
|
||||
This key is used to uniquely identify a single step run by its id + retry count.
|
||||
It's used when storing references to a task, a context, etc. in a dictionary so that
|
||||
we can look up those items in the dictionary by a unique key.
|
||||
"""
|
||||
if self.action_type == ActionType.START_GET_GROUP_KEY:
|
||||
return f"{self.get_group_key_run_id}/{self.retry_count}"
|
||||
else:
|
||||
return f"{self.step_run_id}/{self.retry_count}"
|
||||
@@ -2,11 +2,16 @@ import asyncio
|
||||
from collections import Counter
|
||||
from contextvars import ContextVar
|
||||
|
||||
from hatchet_sdk.runnables.action import ActionKey
|
||||
|
||||
ctx_workflow_run_id: ContextVar[str | None] = ContextVar(
|
||||
"ctx_workflow_run_id", default=None
|
||||
)
|
||||
ctx_action_key: ContextVar[ActionKey | None] = ContextVar(
|
||||
"ctx_action_key", default=None
|
||||
)
|
||||
ctx_step_run_id: ContextVar[str | None] = ContextVar("ctx_step_run_id", default=None)
|
||||
ctx_worker_id: ContextVar[str | None] = ContextVar("ctx_worker_id", default=None)
|
||||
|
||||
workflow_spawn_indices = Counter[str]()
|
||||
workflow_spawn_indices = Counter[ActionKey]()
|
||||
spawn_index_lock = asyncio.Lock()
|
||||
|
||||
@@ -10,9 +10,7 @@ import grpc
|
||||
|
||||
from hatchet_sdk.client import Client
|
||||
from hatchet_sdk.clients.dispatcher.action_listener import (
|
||||
Action,
|
||||
ActionListener,
|
||||
ActionType,
|
||||
GetActionListenerRequest,
|
||||
)
|
||||
from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
|
||||
@@ -23,7 +21,9 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
|
||||
STEP_EVENT_TYPE_STARTED,
|
||||
)
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.runnables.action import Action, ActionType
|
||||
from hatchet_sdk.runnables.contextvars import (
|
||||
ctx_action_key,
|
||||
ctx_step_run_id,
|
||||
ctx_worker_id,
|
||||
ctx_workflow_run_id,
|
||||
@@ -230,6 +230,7 @@ class WorkerActionListenerProcess:
|
||||
ctx_step_run_id.set(action.step_run_id)
|
||||
ctx_workflow_run_id.set(action.workflow_run_id)
|
||||
ctx_worker_id.set(action.worker_id)
|
||||
ctx_action_key.set(action.key)
|
||||
|
||||
# Process the action here
|
||||
match action.action_type:
|
||||
|
||||
@@ -4,9 +4,9 @@ from multiprocessing import Queue
|
||||
from typing import Any, Literal, TypeVar
|
||||
|
||||
from hatchet_sdk.client import Client
|
||||
from hatchet_sdk.clients.dispatcher.action_listener import Action
|
||||
from hatchet_sdk.config import ClientConfig
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.runnables.action import Action
|
||||
from hatchet_sdk.runnables.task import Task
|
||||
from hatchet_sdk.worker.action_listener_process import ActionEvent
|
||||
from hatchet_sdk.worker.runner.runner import Runner
|
||||
|
||||
@@ -14,7 +14,6 @@ from pydantic import BaseModel
|
||||
|
||||
from hatchet_sdk.client import Client
|
||||
from hatchet_sdk.clients.admin import AdminClient
|
||||
from hatchet_sdk.clients.dispatcher.action_listener import Action, ActionKey, ActionType
|
||||
from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
|
||||
from hatchet_sdk.clients.events import EventClient
|
||||
from hatchet_sdk.clients.listeners.durable_event_listener import DurableEventListener
|
||||
@@ -34,7 +33,9 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
|
||||
from hatchet_sdk.exceptions import NonRetryableException
|
||||
from hatchet_sdk.features.runs import RunsClient
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.runnables.action import Action, ActionKey, ActionType
|
||||
from hatchet_sdk.runnables.contextvars import (
|
||||
ctx_action_key,
|
||||
ctx_step_run_id,
|
||||
ctx_worker_id,
|
||||
ctx_workflow_run_id,
|
||||
@@ -244,6 +245,7 @@ class Runner:
|
||||
ctx_step_run_id.set(action.step_run_id)
|
||||
ctx_workflow_run_id.set(action.workflow_run_id)
|
||||
ctx_worker_id.set(action.worker_id)
|
||||
ctx_action_key.set(action.key)
|
||||
|
||||
try:
|
||||
if task.is_async_function:
|
||||
@@ -388,9 +390,9 @@ class Runner:
|
||||
|
||||
## Once the step run completes, we need to remove the workflow spawn index
|
||||
## so we don't leak memory
|
||||
if action.workflow_run_id in workflow_spawn_indices:
|
||||
if action.key in workflow_spawn_indices:
|
||||
async with spawn_index_lock:
|
||||
workflow_spawn_indices.pop(action.workflow_run_id)
|
||||
workflow_spawn_indices.pop(action.key)
|
||||
|
||||
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
|
||||
async def handle_start_group_key_run(self, action: Action) -> Exception | None:
|
||||
|
||||
@@ -21,10 +21,10 @@ from prometheus_client import Gauge, generate_latest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from hatchet_sdk.client import Client
|
||||
from hatchet_sdk.clients.dispatcher.action_listener import Action
|
||||
from hatchet_sdk.config import ClientConfig
|
||||
from hatchet_sdk.contracts.v1.workflows_pb2 import CreateWorkflowVersionRequest
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.runnables.action import Action
|
||||
from hatchet_sdk.runnables.task import Task
|
||||
from hatchet_sdk.runnables.workflow import BaseWorkflow
|
||||
from hatchet_sdk.worker.action_listener_process import (
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "hatchet-sdk"
|
||||
version = "1.10.1"
|
||||
version = "1.10.2"
|
||||
description = ""
|
||||
authors = ["Alexander Belanger <alexander@hatchet.run>"]
|
||||
readme = "README.md"
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
from subprocess import Popen
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from hatchet_sdk import Hatchet, TriggerWorkflowOptions
|
||||
from hatchet_sdk.clients.rest.models.v1_task_status import V1TaskStatus
|
||||
from tests.child_spawn_cache_on_retry.worker import (
|
||||
spawn_cache_on_retry_child,
|
||||
spawn_cache_on_retry_parent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"on_demand_worker",
|
||||
[(["poetry", "run", "python", "tests/worker.py", "--slots", "5"], 8005)],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_spawn_caching_on_retry(
|
||||
hatchet: Hatchet, on_demand_worker: Popen[Any]
|
||||
) -> None:
|
||||
test_run_id = str(uuid4())
|
||||
try:
|
||||
await spawn_cache_on_retry_parent.aio_run(
|
||||
options=TriggerWorkflowOptions(
|
||||
additional_metadata={"test_run_id": test_run_id}
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
assert "Task exceeded timeout of" in str(e)
|
||||
|
||||
runs = await spawn_cache_on_retry_child.aio_list_runs(
|
||||
additional_metadata={"test_run_id": test_run_id}
|
||||
)
|
||||
|
||||
assert len(runs) == 1
|
||||
|
||||
run = runs[0]
|
||||
|
||||
assert run.status == V1TaskStatus.COMPLETED
|
||||
@@ -0,0 +1,30 @@
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
|
||||
from hatchet_sdk import Context, EmptyModel, Hatchet, TriggerWorkflowOptions
|
||||
|
||||
PARENT_EXECUTION_TIMEOUT_SECONDS = 5
|
||||
PARENT_RETRIES = 2
|
||||
|
||||
|
||||
hatchet = Hatchet(debug=True)
|
||||
|
||||
|
||||
@hatchet.task(
|
||||
execution_timeout=timedelta(seconds=PARENT_EXECUTION_TIMEOUT_SECONDS),
|
||||
retries=PARENT_RETRIES,
|
||||
)
|
||||
async def spawn_cache_on_retry_parent(input: EmptyModel, ctx: Context) -> None:
|
||||
await spawn_cache_on_retry_child.aio_run_no_wait(
|
||||
options=TriggerWorkflowOptions(
|
||||
additional_metadata=ctx.additional_metadata or {}
|
||||
)
|
||||
)
|
||||
|
||||
for _ in range(60):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@hatchet.task()
|
||||
async def spawn_cache_on_retry_child(input: EmptyModel, ctx: Context) -> None:
|
||||
await asyncio.sleep(10)
|
||||
@@ -2,6 +2,10 @@ import argparse
|
||||
from typing import cast
|
||||
|
||||
from hatchet_sdk import Hatchet
|
||||
from tests.child_spawn_cache_on_retry.worker import (
|
||||
spawn_cache_on_retry_child,
|
||||
spawn_cache_on_retry_parent,
|
||||
)
|
||||
from tests.correct_failure_on_timeout_with_multi_concurrency.workflow import (
|
||||
multiple_concurrent_cancellations_test_workflow,
|
||||
)
|
||||
@@ -13,7 +17,11 @@ def main(slots: int) -> None:
|
||||
worker = hatchet.worker(
|
||||
"e2e-test-worker-2",
|
||||
slots=slots,
|
||||
workflows=[multiple_concurrent_cancellations_test_workflow],
|
||||
workflows=[
|
||||
multiple_concurrent_cancellations_test_workflow,
|
||||
spawn_cache_on_retry_parent,
|
||||
spawn_cache_on_retry_child,
|
||||
],
|
||||
)
|
||||
|
||||
worker.start()
|
||||
|
||||
Reference in New Issue
Block a user