Fix: Spawn index incorrect on retry (#1752)

* fix: child index on retry not being reset correctly

* chore: ver

* chore: changelog

* feat: test
This commit is contained in:
Matt Kaye
2025-05-19 15:38:52 -04:00
committed by GitHub
parent 80a4757a5d
commit e4b1d36de6
15 changed files with 239 additions and 134 deletions
+6
View File
@@ -5,6 +5,12 @@ All notable changes to Hatchet's Python SDK will be documented in this changelog
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [1.10.2] - 2025-05-19
### Changed
- Fixing an issue with the spawn index being set at the `workflow_run_id` level and not the `(workflow_run_id, retry_count)` level, causing children to be spawned multiple times on retry.
## [1.10.1] - 2025-05-16
### Added
+5 -3
View File
@@ -20,6 +20,7 @@ from hatchet_sdk.features.runs import RunsClient
from hatchet_sdk.metadata import get_metadata
from hatchet_sdk.rate_limit import RateLimitDuration
from hatchet_sdk.runnables.contextvars import (
ctx_action_key,
ctx_step_run_id,
ctx_worker_id,
ctx_workflow_run_id,
@@ -281,11 +282,12 @@ class AdminClient:
workflow_run_id = ctx_workflow_run_id.get()
step_run_id = ctx_step_run_id.get()
worker_id = ctx_worker_id.get()
spawn_index = workflow_spawn_indices[workflow_run_id] if workflow_run_id else 0
action_key = ctx_action_key.get()
spawn_index = workflow_spawn_indices[action_key] if action_key else 0
## Increment the spawn_index for the parent workflow
if workflow_run_id:
workflow_spawn_indices[workflow_run_id] += 1
if action_key:
workflow_spawn_indices[action_key] += 1
desired_worker_id = (
(options.desired_worker_id or worker_id) if options.sticky else None
@@ -1,13 +1,11 @@
import asyncio
import json
import time
from dataclasses import field
from enum import Enum
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
from typing import TYPE_CHECKING, AsyncGenerator, cast
import grpc
import grpc.aio
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from hatchet_sdk.clients.event_ts import (
ThreadSafeEvent,
@@ -30,8 +28,8 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub
from hatchet_sdk.logger import logger
from hatchet_sdk.metadata import get_metadata
from hatchet_sdk.runnables.action import Action, ActionPayload, ActionType
from hatchet_sdk.utils.backoff import exp_backoff_sleep
from hatchet_sdk.utils.opentelemetry import OTelAttribute
from hatchet_sdk.utils.proto_enums import convert_proto_enum_to_python
from hatchet_sdk.utils.typing import JSONSerializableMapping
@@ -67,120 +65,6 @@ class GetActionListenerRequest(BaseModel):
return self
class ActionPayload(BaseModel):
model_config = ConfigDict(extra="allow")
input: JSONSerializableMapping = Field(default_factory=dict)
parents: dict[str, JSONSerializableMapping] = Field(default_factory=dict)
overrides: JSONSerializableMapping = Field(default_factory=dict)
user_data: JSONSerializableMapping = Field(default_factory=dict)
step_run_errors: dict[str, str] = Field(default_factory=dict)
triggered_by: str | None = None
triggers: JSONSerializableMapping = Field(default_factory=dict)
filter_payload: JSONSerializableMapping = Field(default_factory=dict)
@field_validator(
"input",
"parents",
"overrides",
"user_data",
"step_run_errors",
"filter_payload",
mode="before",
)
@classmethod
def validate_fields(cls, v: Any) -> Any:
return v or {}
@model_validator(mode="after")
def validate_filter_payload(self) -> "ActionPayload":
self.filter_payload = self.triggers.get("filter_payload", {})
return self
class ActionType(str, Enum):
START_STEP_RUN = "START_STEP_RUN"
CANCEL_STEP_RUN = "CANCEL_STEP_RUN"
START_GET_GROUP_KEY = "START_GET_GROUP_KEY"
ActionKey = str
class Action(BaseModel):
worker_id: str
tenant_id: str
workflow_run_id: str
workflow_id: str | None = None
workflow_version_id: str | None = None
get_group_key_run_id: str
job_id: str
job_name: str
job_run_id: str
step_id: str
step_run_id: str
action_id: str
action_type: ActionType
retry_count: int
action_payload: ActionPayload
additional_metadata: JSONSerializableMapping = field(default_factory=dict)
child_workflow_index: int | None = None
child_workflow_key: str | None = None
parent_workflow_run_id: str | None = None
priority: int | None = None
def _dump_payload_to_str(self) -> str:
try:
return json.dumps(self.action_payload.model_dump(), default=str)
except Exception:
return str(self.action_payload)
def get_otel_attributes(self, config: "ClientConfig") -> dict[str, str | int]:
try:
payload_str = json.dumps(self.action_payload.model_dump(), default=str)
except Exception:
payload_str = str(self.action_payload)
attrs: dict[OTelAttribute, str | int | None] = {
OTelAttribute.TENANT_ID: self.tenant_id,
OTelAttribute.WORKER_ID: self.worker_id,
OTelAttribute.WORKFLOW_RUN_ID: self.workflow_run_id,
OTelAttribute.STEP_ID: self.step_id,
OTelAttribute.STEP_RUN_ID: self.step_run_id,
OTelAttribute.RETRY_COUNT: self.retry_count,
OTelAttribute.PARENT_WORKFLOW_RUN_ID: self.parent_workflow_run_id,
OTelAttribute.CHILD_WORKFLOW_INDEX: self.child_workflow_index,
OTelAttribute.CHILD_WORKFLOW_KEY: self.child_workflow_key,
OTelAttribute.ACTION_PAYLOAD: payload_str,
OTelAttribute.WORKFLOW_NAME: self.job_name,
OTelAttribute.ACTION_NAME: self.action_id,
OTelAttribute.GET_GROUP_KEY_RUN_ID: self.get_group_key_run_id,
OTelAttribute.WORKFLOW_ID: self.workflow_id,
OTelAttribute.WORKFLOW_VERSION_ID: self.workflow_version_id,
}
return {
f"hatchet.{k.value}": v
for k, v in attrs.items()
if v and k not in config.otel.excluded_attributes
}
@property
def key(self) -> ActionKey:
"""
This key is used to uniquely identify a single step run by its id + retry count.
It's used when storing references to a task, a context, etc. in a dictionary so that
we can look up those items in the dictionary by a unique key.
"""
if self.action_type == ActionType.START_GET_GROUP_KEY:
return f"{self.get_group_key_run_id}/{self.retry_count}"
else:
return f"{self.step_run_id}/{self.retry_count}"
def parse_additional_metadata(additional_metadata: str) -> JSONSerializableMapping:
try:
return cast(
@@ -4,7 +4,6 @@ import grpc.aio
from google.protobuf.timestamp_pb2 import Timestamp
from hatchet_sdk.clients.dispatcher.action_listener import (
Action,
ActionListener,
GetActionListenerRequest,
)
@@ -29,6 +28,7 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
)
from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub
from hatchet_sdk.metadata import get_metadata
from hatchet_sdk.runnables.action import Action
DEFAULT_REGISTER_TIMEOUT = 30
@@ -33,13 +33,13 @@ from hatchet_sdk.clients.admin import (
TriggerWorkflowOptions,
WorkflowRunTriggerConfig,
)
from hatchet_sdk.clients.dispatcher.action_listener import Action
from hatchet_sdk.clients.events import (
BulkPushEventWithMetadata,
EventClient,
PushEventOptions,
)
from hatchet_sdk.contracts.events_pb2 import Event
from hatchet_sdk.runnables.action import Action
from hatchet_sdk.worker.runner.runner import Runner
from hatchet_sdk.workflow_run import WorkflowRunRef
+125
View File
@@ -0,0 +1,125 @@
import json
from dataclasses import field
from enum import Enum
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from hatchet_sdk.utils.opentelemetry import OTelAttribute
from hatchet_sdk.utils.typing import JSONSerializableMapping
if TYPE_CHECKING:
from hatchet_sdk.config import ClientConfig
ActionKey = str
class ActionPayload(BaseModel):
model_config = ConfigDict(extra="allow")
input: JSONSerializableMapping = Field(default_factory=dict)
parents: dict[str, JSONSerializableMapping] = Field(default_factory=dict)
overrides: JSONSerializableMapping = Field(default_factory=dict)
user_data: JSONSerializableMapping = Field(default_factory=dict)
step_run_errors: dict[str, str] = Field(default_factory=dict)
triggered_by: str | None = None
triggers: JSONSerializableMapping = Field(default_factory=dict)
filter_payload: JSONSerializableMapping = Field(default_factory=dict)
@field_validator(
"input",
"parents",
"overrides",
"user_data",
"step_run_errors",
"filter_payload",
mode="before",
)
@classmethod
def validate_fields(cls, v: Any) -> Any:
return v or {}
@model_validator(mode="after")
def validate_filter_payload(self) -> "ActionPayload":
self.filter_payload = self.triggers.get("filter_payload", {})
return self
class ActionType(str, Enum):
START_STEP_RUN = "START_STEP_RUN"
CANCEL_STEP_RUN = "CANCEL_STEP_RUN"
START_GET_GROUP_KEY = "START_GET_GROUP_KEY"
class Action(BaseModel):
worker_id: str
tenant_id: str
workflow_run_id: str
workflow_id: str | None = None
workflow_version_id: str | None = None
get_group_key_run_id: str
job_id: str
job_name: str
job_run_id: str
step_id: str
step_run_id: str
action_id: str
action_type: ActionType
retry_count: int
action_payload: ActionPayload
additional_metadata: JSONSerializableMapping = field(default_factory=dict)
child_workflow_index: int | None = None
child_workflow_key: str | None = None
parent_workflow_run_id: str | None = None
priority: int | None = None
def _dump_payload_to_str(self) -> str:
try:
return json.dumps(self.action_payload.model_dump(), default=str)
except Exception:
return str(self.action_payload)
def get_otel_attributes(self, config: "ClientConfig") -> dict[str, str | int]:
try:
payload_str = json.dumps(self.action_payload.model_dump(), default=str)
except Exception:
payload_str = str(self.action_payload)
attrs: dict[OTelAttribute, str | int | None] = {
OTelAttribute.TENANT_ID: self.tenant_id,
OTelAttribute.WORKER_ID: self.worker_id,
OTelAttribute.WORKFLOW_RUN_ID: self.workflow_run_id,
OTelAttribute.STEP_ID: self.step_id,
OTelAttribute.STEP_RUN_ID: self.step_run_id,
OTelAttribute.RETRY_COUNT: self.retry_count,
OTelAttribute.PARENT_WORKFLOW_RUN_ID: self.parent_workflow_run_id,
OTelAttribute.CHILD_WORKFLOW_INDEX: self.child_workflow_index,
OTelAttribute.CHILD_WORKFLOW_KEY: self.child_workflow_key,
OTelAttribute.ACTION_PAYLOAD: payload_str,
OTelAttribute.WORKFLOW_NAME: self.job_name,
OTelAttribute.ACTION_NAME: self.action_id,
OTelAttribute.GET_GROUP_KEY_RUN_ID: self.get_group_key_run_id,
OTelAttribute.WORKFLOW_ID: self.workflow_id,
OTelAttribute.WORKFLOW_VERSION_ID: self.workflow_version_id,
}
return {
f"hatchet.{k.value}": v
for k, v in attrs.items()
if v and k not in config.otel.excluded_attributes
}
@property
def key(self) -> ActionKey:
"""
This key is used to uniquely identify a single step run by its id + retry count.
It's used when storing references to a task, a context, etc. in a dictionary so that
we can look up those items in the dictionary by a unique key.
"""
if self.action_type == ActionType.START_GET_GROUP_KEY:
return f"{self.get_group_key_run_id}/{self.retry_count}"
else:
return f"{self.step_run_id}/{self.retry_count}"
@@ -2,11 +2,16 @@ import asyncio
from collections import Counter
from contextvars import ContextVar
from hatchet_sdk.runnables.action import ActionKey
ctx_workflow_run_id: ContextVar[str | None] = ContextVar(
"ctx_workflow_run_id", default=None
)
ctx_action_key: ContextVar[ActionKey | None] = ContextVar(
"ctx_action_key", default=None
)
ctx_step_run_id: ContextVar[str | None] = ContextVar("ctx_step_run_id", default=None)
ctx_worker_id: ContextVar[str | None] = ContextVar("ctx_worker_id", default=None)
workflow_spawn_indices = Counter[str]()
workflow_spawn_indices = Counter[ActionKey]()
spawn_index_lock = asyncio.Lock()
@@ -10,9 +10,7 @@ import grpc
from hatchet_sdk.client import Client
from hatchet_sdk.clients.dispatcher.action_listener import (
Action,
ActionListener,
ActionType,
GetActionListenerRequest,
)
from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
@@ -23,7 +21,9 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
STEP_EVENT_TYPE_STARTED,
)
from hatchet_sdk.logger import logger
from hatchet_sdk.runnables.action import Action, ActionType
from hatchet_sdk.runnables.contextvars import (
ctx_action_key,
ctx_step_run_id,
ctx_worker_id,
ctx_workflow_run_id,
@@ -230,6 +230,7 @@ class WorkerActionListenerProcess:
ctx_step_run_id.set(action.step_run_id)
ctx_workflow_run_id.set(action.workflow_run_id)
ctx_worker_id.set(action.worker_id)
ctx_action_key.set(action.key)
# Process the action here
match action.action_type:
@@ -4,9 +4,9 @@ from multiprocessing import Queue
from typing import Any, Literal, TypeVar
from hatchet_sdk.client import Client
from hatchet_sdk.clients.dispatcher.action_listener import Action
from hatchet_sdk.config import ClientConfig
from hatchet_sdk.logger import logger
from hatchet_sdk.runnables.action import Action
from hatchet_sdk.runnables.task import Task
from hatchet_sdk.worker.action_listener_process import ActionEvent
from hatchet_sdk.worker.runner.runner import Runner
@@ -14,7 +14,6 @@ from pydantic import BaseModel
from hatchet_sdk.client import Client
from hatchet_sdk.clients.admin import AdminClient
from hatchet_sdk.clients.dispatcher.action_listener import Action, ActionKey, ActionType
from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
from hatchet_sdk.clients.events import EventClient
from hatchet_sdk.clients.listeners.durable_event_listener import DurableEventListener
@@ -34,7 +33,9 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
from hatchet_sdk.exceptions import NonRetryableException
from hatchet_sdk.features.runs import RunsClient
from hatchet_sdk.logger import logger
from hatchet_sdk.runnables.action import Action, ActionKey, ActionType
from hatchet_sdk.runnables.contextvars import (
ctx_action_key,
ctx_step_run_id,
ctx_worker_id,
ctx_workflow_run_id,
@@ -244,6 +245,7 @@ class Runner:
ctx_step_run_id.set(action.step_run_id)
ctx_workflow_run_id.set(action.workflow_run_id)
ctx_worker_id.set(action.worker_id)
ctx_action_key.set(action.key)
try:
if task.is_async_function:
@@ -388,9 +390,9 @@ class Runner:
## Once the step run completes, we need to remove the workflow spawn index
## so we don't leak memory
if action.workflow_run_id in workflow_spawn_indices:
if action.key in workflow_spawn_indices:
async with spawn_index_lock:
workflow_spawn_indices.pop(action.workflow_run_id)
workflow_spawn_indices.pop(action.key)
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
async def handle_start_group_key_run(self, action: Action) -> Exception | None:
+1 -1
View File
@@ -21,10 +21,10 @@ from prometheus_client import Gauge, generate_latest
from pydantic import BaseModel
from hatchet_sdk.client import Client
from hatchet_sdk.clients.dispatcher.action_listener import Action
from hatchet_sdk.config import ClientConfig
from hatchet_sdk.contracts.v1.workflows_pb2 import CreateWorkflowVersionRequest
from hatchet_sdk.logger import logger
from hatchet_sdk.runnables.action import Action
from hatchet_sdk.runnables.task import Task
from hatchet_sdk.runnables.workflow import BaseWorkflow
from hatchet_sdk.worker.action_listener_process import (
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "hatchet-sdk"
version = "1.10.1"
version = "1.10.2"
description = ""
authors = ["Alexander Belanger <alexander@hatchet.run>"]
readme = "README.md"
@@ -0,0 +1,42 @@
from subprocess import Popen
from typing import Any
from uuid import uuid4
import pytest
from hatchet_sdk import Hatchet, TriggerWorkflowOptions
from hatchet_sdk.clients.rest.models.v1_task_status import V1TaskStatus
from tests.child_spawn_cache_on_retry.worker import (
spawn_cache_on_retry_child,
spawn_cache_on_retry_parent,
)
@pytest.mark.parametrize(
"on_demand_worker",
[(["poetry", "run", "python", "tests/worker.py", "--slots", "5"], 8005)],
indirect=True,
)
@pytest.mark.asyncio(loop_scope="session")
async def test_spawn_caching_on_retry(
hatchet: Hatchet, on_demand_worker: Popen[Any]
) -> None:
test_run_id = str(uuid4())
try:
await spawn_cache_on_retry_parent.aio_run(
options=TriggerWorkflowOptions(
additional_metadata={"test_run_id": test_run_id}
)
)
except Exception as e:
assert "Task exceeded timeout of" in str(e)
runs = await spawn_cache_on_retry_child.aio_list_runs(
additional_metadata={"test_run_id": test_run_id}
)
assert len(runs) == 1
run = runs[0]
assert run.status == V1TaskStatus.COMPLETED
@@ -0,0 +1,30 @@
import asyncio
from datetime import timedelta
from hatchet_sdk import Context, EmptyModel, Hatchet, TriggerWorkflowOptions
PARENT_EXECUTION_TIMEOUT_SECONDS = 5
PARENT_RETRIES = 2
hatchet = Hatchet(debug=True)
@hatchet.task(
execution_timeout=timedelta(seconds=PARENT_EXECUTION_TIMEOUT_SECONDS),
retries=PARENT_RETRIES,
)
async def spawn_cache_on_retry_parent(input: EmptyModel, ctx: Context) -> None:
await spawn_cache_on_retry_child.aio_run_no_wait(
options=TriggerWorkflowOptions(
additional_metadata=ctx.additional_metadata or {}
)
)
for _ in range(60):
await asyncio.sleep(1)
@hatchet.task()
async def spawn_cache_on_retry_child(input: EmptyModel, ctx: Context) -> None:
await asyncio.sleep(10)
+9 -1
View File
@@ -2,6 +2,10 @@ import argparse
from typing import cast
from hatchet_sdk import Hatchet
from tests.child_spawn_cache_on_retry.worker import (
spawn_cache_on_retry_child,
spawn_cache_on_retry_parent,
)
from tests.correct_failure_on_timeout_with_multi_concurrency.workflow import (
multiple_concurrent_cancellations_test_workflow,
)
@@ -13,7 +17,11 @@ def main(slots: int) -> None:
worker = hatchet.worker(
"e2e-test-worker-2",
slots=slots,
workflows=[multiple_concurrent_cancellations_test_workflow],
workflows=[
multiple_concurrent_cancellations_test_workflow,
spawn_cache_on_retry_parent,
spawn_cache_on_retry_child,
],
)
worker.start()