Files
hatchet/sdks/python/hatchet_sdk/worker/runner/runner.py
matt cebf3a6fa7 [Python] Fix: Raise on un-serializable task output (#2562)
* fix: raise on illegal output

* chore: version, changelog
2025-12-08 09:49:20 -05:00

532 lines
19 KiB
Python

import asyncio
import ctypes
import functools
import json
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, is_dataclass
from enum import Enum
from multiprocessing import Queue
from textwrap import dedent
from threading import Thread, current_thread
from typing import Any, Literal, cast, overload
from pydantic import BaseModel
from hatchet_sdk.client import Client
from hatchet_sdk.clients.admin import AdminClient
from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
from hatchet_sdk.clients.events import EventClient
from hatchet_sdk.clients.listeners.durable_event_listener import DurableEventListener
from hatchet_sdk.clients.listeners.run_event_listener import RunEventListenerClient
from hatchet_sdk.clients.listeners.workflow_listener import PooledWorkflowRunListener
from hatchet_sdk.config import ClientConfig
from hatchet_sdk.context.context import Context, DurableContext
from hatchet_sdk.context.worker_context import WorkerContext
from hatchet_sdk.contracts.dispatcher_pb2 import (
STEP_EVENT_TYPE_COMPLETED,
STEP_EVENT_TYPE_FAILED,
STEP_EVENT_TYPE_STARTED,
)
from hatchet_sdk.exceptions import (
IllegalTaskOutputError,
NonRetryableException,
TaskRunError,
)
from hatchet_sdk.features.runs import RunsClient
from hatchet_sdk.logger import logger
from hatchet_sdk.runnables.action import Action, ActionKey, ActionType
from hatchet_sdk.runnables.contextvars import (
ctx_action_key,
ctx_additional_metadata,
ctx_step_run_id,
ctx_worker_id,
ctx_workflow_run_id,
spawn_index_lock,
task_count,
workflow_spawn_indices,
)
from hatchet_sdk.runnables.task import Task
from hatchet_sdk.runnables.types import R, TWorkflowInput
from hatchet_sdk.utils.serde import remove_null_unicode_character
from hatchet_sdk.utils.typing import DataclassInstance
from hatchet_sdk.worker.action_listener_process import ActionEvent
from hatchet_sdk.worker.runner.utils.capture_logs import (
AsyncLogSender,
ContextVarToCopy,
ContextVarToCopyDict,
ContextVarToCopyStr,
copy_context_vars,
)
class WorkerStatus(Enum):
INITIALIZED = 1
STARTING = 2
HEALTHY = 3
UNHEALTHY = 4
class Runner:
def __init__(
self,
event_queue: "Queue[ActionEvent]",
config: ClientConfig,
slots: int,
handle_kill: bool,
action_registry: dict[str, Task[TWorkflowInput, R]],
labels: dict[str, str | int] | None,
lifespan_context: Any | None,
log_sender: AsyncLogSender,
):
# We store the config so we can dynamically create clients for the dispatcher client.
self.config = config
self.slots = slots
self.tasks: dict[ActionKey, asyncio.Task[Any]] = {} # Store run ids and futures
self.contexts: dict[ActionKey, Context] = {} # Store run ids and contexts
self.action_registry = action_registry or {}
self.event_queue = event_queue
# The thread pool is used for synchronous functions which need to run concurrently
self.thread_pool = ThreadPoolExecutor(max_workers=slots)
self.threads: dict[ActionKey, Thread] = {} # Store run ids and threads
self.running_tasks = set[asyncio.Task[Exception | None]]()
self.killing = False
self.handle_kill = handle_kill
self.dispatcher_client = DispatcherClient(self.config)
self.workflow_run_event_listener = RunEventListenerClient(self.config)
self.workflow_listener = PooledWorkflowRunListener(self.config)
self.runs_client = RunsClient(
config=self.config,
workflow_run_event_listener=self.workflow_run_event_listener,
workflow_run_listener=self.workflow_listener,
)
self.admin_client = AdminClient(
self.config,
self.workflow_listener,
self.workflow_run_event_listener,
self.runs_client,
)
self.event_client = EventClient(self.config)
self.durable_event_listener = DurableEventListener(self.config)
self.worker_context = WorkerContext(
labels=labels or {}, client=Client(config=config).dispatcher
)
self.lifespan_context = lifespan_context
self.log_sender = log_sender
if self.config.enable_thread_pool_monitoring:
self.start_background_monitoring()
def create_workflow_run_url(self, action: Action) -> str:
return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}"
def run(self, action: Action) -> None:
if self.worker_context.id() is None:
self.worker_context._worker_id = action.worker_id
t: asyncio.Task[Exception | None] | None = None
match action.action_type:
case ActionType.START_STEP_RUN:
log = f"run: start step: {action.action_id}/{action.step_run_id}"
logger.info(log)
t = asyncio.create_task(self.handle_start_step_run(action))
case ActionType.CANCEL_STEP_RUN:
log = f"cancel: step run: {action.action_id}/{action.step_run_id}/{action.retry_count}"
logger.info(log)
t = asyncio.create_task(self.handle_cancel_action(action))
case _:
log = f"unknown action type: {action.action_type}"
logger.error(log)
if t is not None:
self.running_tasks.add(t)
t.add_done_callback(lambda task: self.running_tasks.discard(task))
def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]:
def inner_callback(task: asyncio.Task[Any]) -> None:
self.cleanup_run_id(action.key)
if task.cancelled():
return
try:
output = task.result()
except Exception as e:
should_not_retry = isinstance(e, NonRetryableException)
exc = TaskRunError.from_exception(e, action.step_run_id)
# This except is coming from the application itself, so we want to send that to the Hatchet instance
self.event_queue.put(
ActionEvent(
action=action,
type=STEP_EVENT_TYPE_FAILED,
payload=exc.serialize(include_metadata=True),
should_not_retry=should_not_retry,
)
)
log_with_level = logger.info if should_not_retry else logger.exception
log_with_level(
f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize(include_metadata=False)}"
)
return
try:
output = self.serialize_output(output)
self.event_queue.put(
ActionEvent(
action=action,
type=STEP_EVENT_TYPE_COMPLETED,
payload=output,
should_not_retry=False,
)
)
except IllegalTaskOutputError as e:
exc = TaskRunError.from_exception(e, action.step_run_id)
self.event_queue.put(
ActionEvent(
action=action,
type=STEP_EVENT_TYPE_FAILED,
payload=exc.serialize(include_metadata=True),
should_not_retry=False,
)
)
logger.exception(
f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize(include_metadata=False)}"
)
return
logger.info(f"finished step run: {action.action_id}/{action.step_run_id}")
return inner_callback
def thread_action_func(
self,
ctx: Context,
task: Task[TWorkflowInput, R],
action: Action,
dependencies: dict[str, Any],
) -> R:
if action.step_run_id:
self.threads[action.key] = current_thread()
return task.call(ctx, dependencies)
# We wrap all actions in an async func
async def async_wrapped_action_func(
self,
ctx: Context,
task: Task[TWorkflowInput, R],
action: Action,
) -> R:
ctx_step_run_id.set(action.step_run_id)
ctx_workflow_run_id.set(action.workflow_run_id)
ctx_worker_id.set(action.worker_id)
ctx_action_key.set(action.key)
ctx_additional_metadata.set(action.additional_metadata)
dependencies = await task._unpack_dependencies(ctx)
try:
if task.is_async_function:
return await task.aio_call(ctx, dependencies)
pfunc = functools.partial(
# we must copy the context vars to the new thread, as only asyncio natively supports
# contextvars
copy_context_vars,
[
ContextVarToCopy(
var=ContextVarToCopyStr(
name="ctx_step_run_id",
value=action.step_run_id,
)
),
ContextVarToCopy(
var=ContextVarToCopyStr(
name="ctx_workflow_run_id",
value=action.workflow_run_id,
)
),
ContextVarToCopy(
var=ContextVarToCopyStr(
name="ctx_worker_id",
value=action.worker_id,
)
),
ContextVarToCopy(
var=ContextVarToCopyStr(
name="ctx_action_key",
value=action.key,
)
),
ContextVarToCopy(
var=ContextVarToCopyDict(
name="ctx_additional_metadata",
value=action.additional_metadata,
)
),
],
self.thread_action_func,
ctx,
task,
action,
dependencies,
)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.thread_pool, pfunc)
finally:
self.cleanup_run_id(action.key)
async def log_thread_pool_status(self) -> None:
thread_pool_details = {
"max_workers": self.slots,
"total_threads": len(self.thread_pool._threads),
"idle_threads": self.thread_pool._idle_semaphore._value,
"active_threads": len(self.threads),
"pending_tasks": len(self.tasks),
"queue_size": self.thread_pool._work_queue.qsize(),
"threads_alive": sum(1 for t in self.thread_pool._threads if t.is_alive()),
"threads_daemon": sum(1 for t in self.thread_pool._threads if t.daemon),
}
logger.warning("thread pool detailed status %s", thread_pool_details)
async def _start_monitoring(self) -> None:
logger.debug("thread pool monitoring started")
try:
while True:
await self.log_thread_pool_status()
for key in self.threads:
if key not in self.tasks:
logger.debug(f"potential zombie thread found for key {key}")
for key, task in self.tasks.items():
if task.done() and key in self.threads:
logger.debug(
f"task is done but thread still exists for key {key}"
)
await asyncio.sleep(60)
except asyncio.CancelledError:
logger.warning("thread pool monitoring task cancelled")
except Exception as e:
logger.exception(f"error in thread pool monitoring: {e}")
def start_background_monitoring(self) -> None:
loop = asyncio.get_event_loop()
self.monitoring_task = loop.create_task(self._start_monitoring())
logger.debug("started thread pool monitoring background task")
def cleanup_run_id(self, key: ActionKey) -> None:
if key in self.tasks:
del self.tasks[key]
if key in self.threads:
del self.threads[key]
if key in self.contexts:
del self.contexts[key]
@overload
def create_context(
self, action: Action, is_durable: Literal[True] = True
) -> DurableContext: ...
@overload
def create_context(
self, action: Action, is_durable: Literal[False] = False
) -> Context: ...
def create_context(
self, action: Action, is_durable: bool = True
) -> Context | DurableContext:
constructor = DurableContext if is_durable else Context
return constructor(
action=action,
dispatcher_client=self.dispatcher_client,
admin_client=self.admin_client,
event_client=self.event_client,
durable_event_listener=self.durable_event_listener,
worker=self.worker_context,
runs_client=self.runs_client,
lifespan_context=self.lifespan_context,
log_sender=self.log_sender,
)
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
async def handle_start_step_run(self, action: Action) -> Exception | None:
action_name = action.action_id
# Find the corresponding action function from the registry
action_func = self.action_registry.get(action_name)
if action_func:
context = self.create_context(
action,
True if action_func.is_durable else False, # noqa: SIM210
)
self.contexts[action.key] = context
self.event_queue.put(
ActionEvent(
action=action,
type=STEP_EVENT_TYPE_STARTED,
payload="",
should_not_retry=False,
)
)
loop = asyncio.get_event_loop()
task = loop.create_task(
self.async_wrapped_action_func(context, action_func, action)
)
task.add_done_callback(self.step_run_callback(action))
self.tasks[action.key] = task
task_count.increment()
## FIXME: Handle cancelled exceptions and other special exceptions
## that we don't want to suppress here
try:
await task
except Exception as e:
## Used for the OTel instrumentor to capture exceptions
return e
## Once the step run completes, we need to remove the workflow spawn index
## so we don't leak memory
if action.key in workflow_spawn_indices:
async with spawn_index_lock:
workflow_spawn_indices.pop(action.key)
return None
def force_kill_thread(self, thread: Thread) -> None:
"""Terminate a python threading.Thread."""
try:
if not thread.is_alive():
return
ident = cast(int, thread.ident)
logger.info(f"forcefully terminating thread {ident}")
exc = ctypes.py_object(SystemExit)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc)
if res == 0:
raise ValueError("Invalid thread ID")
if res != 1:
logger.error("PyThreadState_SetAsyncExc failed")
# Call with exception set to 0 is needed to cleanup properly.
ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
logger.info(f"successfully terminated thread {ident}")
# Immediately add a new thread to the thread pool, because we've actually killed a worker
# in the ThreadPoolExecutor
self.thread_pool.submit(lambda: None)
except Exception as e:
logger.exception(f"failed to terminate thread: {e}")
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
async def handle_cancel_action(self, action: Action) -> None:
key = action.key
try:
# call cancel to signal the context to stop
if key in self.contexts:
self.contexts[key]._set_cancellation_flag()
await asyncio.sleep(1)
if key in self.tasks:
self.tasks[key].cancel()
# check if thread is still running, if so, print a warning
if key in self.threads:
thread = self.threads[key]
if self.config.enable_force_kill_sync_threads:
self.force_kill_thread(thread)
await asyncio.sleep(1)
logger.warning(
f"thread {self.threads[key].ident} with key {key} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running."
)
finally:
self.cleanup_run_id(key)
def serialize_output(self, output: Any) -> str:
if not output:
return ""
if isinstance(output, BaseModel):
try:
output = output.model_dump(mode="json")
except Exception as e:
logger.exception("could not serialize pydantic model output")
raise IllegalTaskOutputError(
f"could not serialize Pydantic BaseModel output: {e}"
) from e
elif is_dataclass(output):
output = asdict(cast(DataclassInstance, output))
if not isinstance(output, dict):
raise IllegalTaskOutputError(
f"Tasks must return either a dictionary, a Pydantic BaseModel, or a dataclass which can be serialized to a JSON object. Got object of type {type(output)} instead."
)
if output is None:
return ""
try:
serialized_output = json.dumps(output, default=str)
except Exception as e:
logger.exception("could not serialize output")
raise IllegalTaskOutputError(
"Task output could not be serialized to JSON. Please ensure that all task outputs are JSON serializable."
) from e
if "\\u0000" in serialized_output:
raise IllegalTaskOutputError(
dedent(
f"""
Task outputs cannot contain the unicode null character \\u0000
Please see this Discord thread: https://discord.com/channels/1088927970518909068/1384324576166678710/1386714014565928992
Relevant Postgres documentation: https://www.postgresql.org/docs/current/datatype-json.html
Use `hatchet_sdk.{remove_null_unicode_character.__name__}` to sanitize your output if you'd like to remove the character.
"""
)
)
return serialized_output
async def wait_for_tasks(self) -> None:
running = len(self.tasks.keys())
while running > 0:
logger.info(f"waiting for {running} tasks to finish...")
await asyncio.sleep(1)
running = len(self.tasks.keys())