mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2025-12-21 08:40:10 -06:00
532 lines
19 KiB
Python
532 lines
19 KiB
Python
import asyncio
|
|
import ctypes
|
|
import functools
|
|
import json
|
|
from collections.abc import Callable
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import asdict, is_dataclass
|
|
from enum import Enum
|
|
from multiprocessing import Queue
|
|
from textwrap import dedent
|
|
from threading import Thread, current_thread
|
|
from typing import Any, Literal, cast, overload
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from hatchet_sdk.client import Client
|
|
from hatchet_sdk.clients.admin import AdminClient
|
|
from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
|
|
from hatchet_sdk.clients.events import EventClient
|
|
from hatchet_sdk.clients.listeners.durable_event_listener import DurableEventListener
|
|
from hatchet_sdk.clients.listeners.run_event_listener import RunEventListenerClient
|
|
from hatchet_sdk.clients.listeners.workflow_listener import PooledWorkflowRunListener
|
|
from hatchet_sdk.config import ClientConfig
|
|
from hatchet_sdk.context.context import Context, DurableContext
|
|
from hatchet_sdk.context.worker_context import WorkerContext
|
|
from hatchet_sdk.contracts.dispatcher_pb2 import (
|
|
STEP_EVENT_TYPE_COMPLETED,
|
|
STEP_EVENT_TYPE_FAILED,
|
|
STEP_EVENT_TYPE_STARTED,
|
|
)
|
|
from hatchet_sdk.exceptions import (
|
|
IllegalTaskOutputError,
|
|
NonRetryableException,
|
|
TaskRunError,
|
|
)
|
|
from hatchet_sdk.features.runs import RunsClient
|
|
from hatchet_sdk.logger import logger
|
|
from hatchet_sdk.runnables.action import Action, ActionKey, ActionType
|
|
from hatchet_sdk.runnables.contextvars import (
|
|
ctx_action_key,
|
|
ctx_additional_metadata,
|
|
ctx_step_run_id,
|
|
ctx_worker_id,
|
|
ctx_workflow_run_id,
|
|
spawn_index_lock,
|
|
task_count,
|
|
workflow_spawn_indices,
|
|
)
|
|
from hatchet_sdk.runnables.task import Task
|
|
from hatchet_sdk.runnables.types import R, TWorkflowInput
|
|
from hatchet_sdk.utils.serde import remove_null_unicode_character
|
|
from hatchet_sdk.utils.typing import DataclassInstance
|
|
from hatchet_sdk.worker.action_listener_process import ActionEvent
|
|
from hatchet_sdk.worker.runner.utils.capture_logs import (
|
|
AsyncLogSender,
|
|
ContextVarToCopy,
|
|
ContextVarToCopyDict,
|
|
ContextVarToCopyStr,
|
|
copy_context_vars,
|
|
)
|
|
|
|
|
|
class WorkerStatus(Enum):
|
|
INITIALIZED = 1
|
|
STARTING = 2
|
|
HEALTHY = 3
|
|
UNHEALTHY = 4
|
|
|
|
|
|
class Runner:
|
|
def __init__(
|
|
self,
|
|
event_queue: "Queue[ActionEvent]",
|
|
config: ClientConfig,
|
|
slots: int,
|
|
handle_kill: bool,
|
|
action_registry: dict[str, Task[TWorkflowInput, R]],
|
|
labels: dict[str, str | int] | None,
|
|
lifespan_context: Any | None,
|
|
log_sender: AsyncLogSender,
|
|
):
|
|
# We store the config so we can dynamically create clients for the dispatcher client.
|
|
self.config = config
|
|
|
|
self.slots = slots
|
|
self.tasks: dict[ActionKey, asyncio.Task[Any]] = {} # Store run ids and futures
|
|
self.contexts: dict[ActionKey, Context] = {} # Store run ids and contexts
|
|
self.action_registry = action_registry or {}
|
|
|
|
self.event_queue = event_queue
|
|
|
|
# The thread pool is used for synchronous functions which need to run concurrently
|
|
self.thread_pool = ThreadPoolExecutor(max_workers=slots)
|
|
self.threads: dict[ActionKey, Thread] = {} # Store run ids and threads
|
|
self.running_tasks = set[asyncio.Task[Exception | None]]()
|
|
|
|
self.killing = False
|
|
self.handle_kill = handle_kill
|
|
|
|
self.dispatcher_client = DispatcherClient(self.config)
|
|
self.workflow_run_event_listener = RunEventListenerClient(self.config)
|
|
self.workflow_listener = PooledWorkflowRunListener(self.config)
|
|
self.runs_client = RunsClient(
|
|
config=self.config,
|
|
workflow_run_event_listener=self.workflow_run_event_listener,
|
|
workflow_run_listener=self.workflow_listener,
|
|
)
|
|
self.admin_client = AdminClient(
|
|
self.config,
|
|
self.workflow_listener,
|
|
self.workflow_run_event_listener,
|
|
self.runs_client,
|
|
)
|
|
self.event_client = EventClient(self.config)
|
|
self.durable_event_listener = DurableEventListener(self.config)
|
|
|
|
self.worker_context = WorkerContext(
|
|
labels=labels or {}, client=Client(config=config).dispatcher
|
|
)
|
|
|
|
self.lifespan_context = lifespan_context
|
|
self.log_sender = log_sender
|
|
|
|
if self.config.enable_thread_pool_monitoring:
|
|
self.start_background_monitoring()
|
|
|
|
def create_workflow_run_url(self, action: Action) -> str:
|
|
return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}"
|
|
|
|
def run(self, action: Action) -> None:
|
|
if self.worker_context.id() is None:
|
|
self.worker_context._worker_id = action.worker_id
|
|
|
|
t: asyncio.Task[Exception | None] | None = None
|
|
match action.action_type:
|
|
case ActionType.START_STEP_RUN:
|
|
log = f"run: start step: {action.action_id}/{action.step_run_id}"
|
|
logger.info(log)
|
|
t = asyncio.create_task(self.handle_start_step_run(action))
|
|
case ActionType.CANCEL_STEP_RUN:
|
|
log = f"cancel: step run: {action.action_id}/{action.step_run_id}/{action.retry_count}"
|
|
logger.info(log)
|
|
t = asyncio.create_task(self.handle_cancel_action(action))
|
|
case _:
|
|
log = f"unknown action type: {action.action_type}"
|
|
logger.error(log)
|
|
|
|
if t is not None:
|
|
self.running_tasks.add(t)
|
|
t.add_done_callback(lambda task: self.running_tasks.discard(task))
|
|
|
|
def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]:
|
|
def inner_callback(task: asyncio.Task[Any]) -> None:
|
|
self.cleanup_run_id(action.key)
|
|
|
|
if task.cancelled():
|
|
return
|
|
|
|
try:
|
|
output = task.result()
|
|
except Exception as e:
|
|
should_not_retry = isinstance(e, NonRetryableException)
|
|
|
|
exc = TaskRunError.from_exception(e, action.step_run_id)
|
|
|
|
# This except is coming from the application itself, so we want to send that to the Hatchet instance
|
|
self.event_queue.put(
|
|
ActionEvent(
|
|
action=action,
|
|
type=STEP_EVENT_TYPE_FAILED,
|
|
payload=exc.serialize(include_metadata=True),
|
|
should_not_retry=should_not_retry,
|
|
)
|
|
)
|
|
|
|
log_with_level = logger.info if should_not_retry else logger.exception
|
|
|
|
log_with_level(
|
|
f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize(include_metadata=False)}"
|
|
)
|
|
|
|
return
|
|
|
|
try:
|
|
output = self.serialize_output(output)
|
|
|
|
self.event_queue.put(
|
|
ActionEvent(
|
|
action=action,
|
|
type=STEP_EVENT_TYPE_COMPLETED,
|
|
payload=output,
|
|
should_not_retry=False,
|
|
)
|
|
)
|
|
except IllegalTaskOutputError as e:
|
|
exc = TaskRunError.from_exception(e, action.step_run_id)
|
|
self.event_queue.put(
|
|
ActionEvent(
|
|
action=action,
|
|
type=STEP_EVENT_TYPE_FAILED,
|
|
payload=exc.serialize(include_metadata=True),
|
|
should_not_retry=False,
|
|
)
|
|
)
|
|
|
|
logger.exception(
|
|
f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize(include_metadata=False)}"
|
|
)
|
|
|
|
return
|
|
|
|
logger.info(f"finished step run: {action.action_id}/{action.step_run_id}")
|
|
|
|
return inner_callback
|
|
|
|
def thread_action_func(
|
|
self,
|
|
ctx: Context,
|
|
task: Task[TWorkflowInput, R],
|
|
action: Action,
|
|
dependencies: dict[str, Any],
|
|
) -> R:
|
|
if action.step_run_id:
|
|
self.threads[action.key] = current_thread()
|
|
|
|
return task.call(ctx, dependencies)
|
|
|
|
# We wrap all actions in an async func
|
|
async def async_wrapped_action_func(
|
|
self,
|
|
ctx: Context,
|
|
task: Task[TWorkflowInput, R],
|
|
action: Action,
|
|
) -> R:
|
|
ctx_step_run_id.set(action.step_run_id)
|
|
ctx_workflow_run_id.set(action.workflow_run_id)
|
|
ctx_worker_id.set(action.worker_id)
|
|
ctx_action_key.set(action.key)
|
|
ctx_additional_metadata.set(action.additional_metadata)
|
|
|
|
dependencies = await task._unpack_dependencies(ctx)
|
|
|
|
try:
|
|
if task.is_async_function:
|
|
return await task.aio_call(ctx, dependencies)
|
|
|
|
pfunc = functools.partial(
|
|
# we must copy the context vars to the new thread, as only asyncio natively supports
|
|
# contextvars
|
|
copy_context_vars,
|
|
[
|
|
ContextVarToCopy(
|
|
var=ContextVarToCopyStr(
|
|
name="ctx_step_run_id",
|
|
value=action.step_run_id,
|
|
)
|
|
),
|
|
ContextVarToCopy(
|
|
var=ContextVarToCopyStr(
|
|
name="ctx_workflow_run_id",
|
|
value=action.workflow_run_id,
|
|
)
|
|
),
|
|
ContextVarToCopy(
|
|
var=ContextVarToCopyStr(
|
|
name="ctx_worker_id",
|
|
value=action.worker_id,
|
|
)
|
|
),
|
|
ContextVarToCopy(
|
|
var=ContextVarToCopyStr(
|
|
name="ctx_action_key",
|
|
value=action.key,
|
|
)
|
|
),
|
|
ContextVarToCopy(
|
|
var=ContextVarToCopyDict(
|
|
name="ctx_additional_metadata",
|
|
value=action.additional_metadata,
|
|
)
|
|
),
|
|
],
|
|
self.thread_action_func,
|
|
ctx,
|
|
task,
|
|
action,
|
|
dependencies,
|
|
)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(self.thread_pool, pfunc)
|
|
finally:
|
|
self.cleanup_run_id(action.key)
|
|
|
|
async def log_thread_pool_status(self) -> None:
|
|
thread_pool_details = {
|
|
"max_workers": self.slots,
|
|
"total_threads": len(self.thread_pool._threads),
|
|
"idle_threads": self.thread_pool._idle_semaphore._value,
|
|
"active_threads": len(self.threads),
|
|
"pending_tasks": len(self.tasks),
|
|
"queue_size": self.thread_pool._work_queue.qsize(),
|
|
"threads_alive": sum(1 for t in self.thread_pool._threads if t.is_alive()),
|
|
"threads_daemon": sum(1 for t in self.thread_pool._threads if t.daemon),
|
|
}
|
|
|
|
logger.warning("thread pool detailed status %s", thread_pool_details)
|
|
|
|
async def _start_monitoring(self) -> None:
|
|
logger.debug("thread pool monitoring started")
|
|
try:
|
|
while True:
|
|
await self.log_thread_pool_status()
|
|
|
|
for key in self.threads:
|
|
if key not in self.tasks:
|
|
logger.debug(f"potential zombie thread found for key {key}")
|
|
|
|
for key, task in self.tasks.items():
|
|
if task.done() and key in self.threads:
|
|
logger.debug(
|
|
f"task is done but thread still exists for key {key}"
|
|
)
|
|
|
|
await asyncio.sleep(60)
|
|
except asyncio.CancelledError:
|
|
logger.warning("thread pool monitoring task cancelled")
|
|
except Exception as e:
|
|
logger.exception(f"error in thread pool monitoring: {e}")
|
|
|
|
def start_background_monitoring(self) -> None:
|
|
loop = asyncio.get_event_loop()
|
|
self.monitoring_task = loop.create_task(self._start_monitoring())
|
|
logger.debug("started thread pool monitoring background task")
|
|
|
|
def cleanup_run_id(self, key: ActionKey) -> None:
|
|
if key in self.tasks:
|
|
del self.tasks[key]
|
|
|
|
if key in self.threads:
|
|
del self.threads[key]
|
|
|
|
if key in self.contexts:
|
|
del self.contexts[key]
|
|
|
|
@overload
|
|
def create_context(
|
|
self, action: Action, is_durable: Literal[True] = True
|
|
) -> DurableContext: ...
|
|
|
|
@overload
|
|
def create_context(
|
|
self, action: Action, is_durable: Literal[False] = False
|
|
) -> Context: ...
|
|
|
|
def create_context(
|
|
self, action: Action, is_durable: bool = True
|
|
) -> Context | DurableContext:
|
|
constructor = DurableContext if is_durable else Context
|
|
|
|
return constructor(
|
|
action=action,
|
|
dispatcher_client=self.dispatcher_client,
|
|
admin_client=self.admin_client,
|
|
event_client=self.event_client,
|
|
durable_event_listener=self.durable_event_listener,
|
|
worker=self.worker_context,
|
|
runs_client=self.runs_client,
|
|
lifespan_context=self.lifespan_context,
|
|
log_sender=self.log_sender,
|
|
)
|
|
|
|
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
|
|
async def handle_start_step_run(self, action: Action) -> Exception | None:
|
|
action_name = action.action_id
|
|
|
|
# Find the corresponding action function from the registry
|
|
action_func = self.action_registry.get(action_name)
|
|
|
|
if action_func:
|
|
context = self.create_context(
|
|
action,
|
|
True if action_func.is_durable else False, # noqa: SIM210
|
|
)
|
|
|
|
self.contexts[action.key] = context
|
|
self.event_queue.put(
|
|
ActionEvent(
|
|
action=action,
|
|
type=STEP_EVENT_TYPE_STARTED,
|
|
payload="",
|
|
should_not_retry=False,
|
|
)
|
|
)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
task = loop.create_task(
|
|
self.async_wrapped_action_func(context, action_func, action)
|
|
)
|
|
|
|
task.add_done_callback(self.step_run_callback(action))
|
|
self.tasks[action.key] = task
|
|
|
|
task_count.increment()
|
|
|
|
## FIXME: Handle cancelled exceptions and other special exceptions
|
|
## that we don't want to suppress here
|
|
try:
|
|
await task
|
|
except Exception as e:
|
|
## Used for the OTel instrumentor to capture exceptions
|
|
return e
|
|
|
|
## Once the step run completes, we need to remove the workflow spawn index
|
|
## so we don't leak memory
|
|
if action.key in workflow_spawn_indices:
|
|
async with spawn_index_lock:
|
|
workflow_spawn_indices.pop(action.key)
|
|
|
|
return None
|
|
|
|
def force_kill_thread(self, thread: Thread) -> None:
|
|
"""Terminate a python threading.Thread."""
|
|
try:
|
|
if not thread.is_alive():
|
|
return
|
|
|
|
ident = cast(int, thread.ident)
|
|
|
|
logger.info(f"forcefully terminating thread {ident}")
|
|
|
|
exc = ctypes.py_object(SystemExit)
|
|
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc)
|
|
if res == 0:
|
|
raise ValueError("Invalid thread ID")
|
|
if res != 1:
|
|
logger.error("PyThreadState_SetAsyncExc failed")
|
|
|
|
# Call with exception set to 0 is needed to cleanup properly.
|
|
ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
|
|
raise SystemError("PyThreadState_SetAsyncExc failed")
|
|
|
|
logger.info(f"successfully terminated thread {ident}")
|
|
|
|
# Immediately add a new thread to the thread pool, because we've actually killed a worker
|
|
# in the ThreadPoolExecutor
|
|
self.thread_pool.submit(lambda: None)
|
|
except Exception as e:
|
|
logger.exception(f"failed to terminate thread: {e}")
|
|
|
|
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
|
|
async def handle_cancel_action(self, action: Action) -> None:
|
|
key = action.key
|
|
try:
|
|
# call cancel to signal the context to stop
|
|
if key in self.contexts:
|
|
self.contexts[key]._set_cancellation_flag()
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
if key in self.tasks:
|
|
self.tasks[key].cancel()
|
|
|
|
# check if thread is still running, if so, print a warning
|
|
if key in self.threads:
|
|
thread = self.threads[key]
|
|
|
|
if self.config.enable_force_kill_sync_threads:
|
|
self.force_kill_thread(thread)
|
|
await asyncio.sleep(1)
|
|
|
|
logger.warning(
|
|
f"thread {self.threads[key].ident} with key {key} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running."
|
|
)
|
|
finally:
|
|
self.cleanup_run_id(key)
|
|
|
|
def serialize_output(self, output: Any) -> str:
|
|
if not output:
|
|
return ""
|
|
|
|
if isinstance(output, BaseModel):
|
|
try:
|
|
output = output.model_dump(mode="json")
|
|
except Exception as e:
|
|
logger.exception("could not serialize pydantic model output")
|
|
|
|
raise IllegalTaskOutputError(
|
|
f"could not serialize Pydantic BaseModel output: {e}"
|
|
) from e
|
|
elif is_dataclass(output):
|
|
output = asdict(cast(DataclassInstance, output))
|
|
|
|
if not isinstance(output, dict):
|
|
raise IllegalTaskOutputError(
|
|
f"Tasks must return either a dictionary, a Pydantic BaseModel, or a dataclass which can be serialized to a JSON object. Got object of type {type(output)} instead."
|
|
)
|
|
|
|
if output is None:
|
|
return ""
|
|
|
|
try:
|
|
serialized_output = json.dumps(output, default=str)
|
|
except Exception as e:
|
|
logger.exception("could not serialize output")
|
|
raise IllegalTaskOutputError(
|
|
"Task output could not be serialized to JSON. Please ensure that all task outputs are JSON serializable."
|
|
) from e
|
|
|
|
if "\\u0000" in serialized_output:
|
|
raise IllegalTaskOutputError(
|
|
dedent(
|
|
f"""
|
|
Task outputs cannot contain the unicode null character \\u0000
|
|
|
|
Please see this Discord thread: https://discord.com/channels/1088927970518909068/1384324576166678710/1386714014565928992
|
|
Relevant Postgres documentation: https://www.postgresql.org/docs/current/datatype-json.html
|
|
|
|
Use `hatchet_sdk.{remove_null_unicode_character.__name__}` to sanitize your output if you'd like to remove the character.
|
|
"""
|
|
)
|
|
)
|
|
|
|
return serialized_output
|
|
|
|
async def wait_for_tasks(self) -> None:
|
|
running = len(self.tasks.keys())
|
|
while running > 0:
|
|
logger.info(f"waiting for {running} tasks to finish...")
|
|
await asyncio.sleep(1)
|
|
running = len(self.tasks.keys())
|