From 1db5cb5ee68ba98d4c75f22f8d7394d14722bf08 Mon Sep 17 00:00:00 2001 From: Matt Kaye Date: Tue, 13 May 2025 14:11:36 -0400 Subject: [PATCH] [Python]: Fix: Adding Observability + Failure Handling (#1701) * feat: simple thread pool monitoring job * feat: sdk-side workflow pause * Revert "feat: sdk-side workflow pause" This reverts commit 83b4a63c208942d8ff2a65d219bccf8e38b55785. * feat: set stop signal to stop accepting new work * proposal: no-op healthcheck workflow * Revert "proposal: no-op healthcheck workflow" This reverts commit f651a521374aab4086b515c436d3b044bea4bebe. * fix: rm removed concurrency strats * chore: ver * fix: rm heartbeat wf * fix: rm zombie keys * fix: hint on invalid token * fix: flaky test * fix: rm buggy / dead code --- .../python/examples/priority/test_priority.py | 2 +- sdks/python/hatchet_sdk/config.py | 7 +++ sdks/python/hatchet_sdk/runnables/types.py | 29 +-------- sdks/python/hatchet_sdk/runnables/workflow.py | 28 --------- .../worker/action_listener_process.py | 3 + .../hatchet_sdk/worker/runner/runner.py | 59 +++++++++++++++---- sdks/python/pyproject.toml | 2 +- 7 files changed, 62 insertions(+), 68 deletions(-) diff --git a/sdks/python/examples/priority/test_priority.py b/sdks/python/examples/priority/test_priority.py index 6445aaa2b..4790dfb30 100644 --- a/sdks/python/examples/priority/test_priority.py +++ b/sdks/python/examples/priority/test_priority.py @@ -278,7 +278,7 @@ async def crons( def time_until_next_minute() -> float: now = datetime.now() - next_minute = now.replace(second=0, microsecond=0, minute=now.minute + 1) + next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0) return (next_minute - now).total_seconds() diff --git a/sdks/python/hatchet_sdk/config.py b/sdks/python/hatchet_sdk/config.py index 67062f97c..4e502db7f 100644 --- a/sdks/python/hatchet_sdk/config.py +++ b/sdks/python/hatchet_sdk/config.py @@ -64,13 +64,20 @@ class ClientConfig(BaseSettings): ) worker_preset_labels: dict[str, str] = Field(default_factory=dict) + enable_force_kill_sync_threads: bool = False + enable_thread_pool_monitoring: bool = False @model_validator(mode="after") def validate_token_and_tenant(self) -> "ClientConfig": if not self.token: raise ValueError("Token must be set") + if not self.token.startswith("ey"): + raise ValueError( + f"Token must be a valid JWT. Hint: These are the first few characters of the token provided: {self.token[:5]}" + ) + if not self.tenant_id: self.tenant_id = get_tenant_id_from_jwt(self.token) diff --git a/sdks/python/hatchet_sdk/runnables/types.py b/sdks/python/hatchet_sdk/runnables/types.py index 07da14433..9c19c5a38 100644 --- a/sdks/python/hatchet_sdk/runnables/types.py +++ b/sdks/python/hatchet_sdk/runnables/types.py @@ -2,7 +2,7 @@ import asyncio from enum import Enum from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeGuard, TypeVar, Union -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field from hatchet_sdk.context.context import Context, DurableContext from hatchet_sdk.contracts.v1.workflows_pb2 import Concurrency @@ -26,8 +26,6 @@ class StickyStrategy(str, Enum): class ConcurrencyLimitStrategy(str, Enum): CANCEL_IN_PROGRESS = "CANCEL_IN_PROGRESS" - DROP_NEWEST = "DROP_NEWEST" - QUEUE_NEWEST = "QUEUE_NEWEST" GROUP_ROUND_ROBIN = "GROUP_ROUND_ROBIN" CANCEL_NEWEST = "CANCEL_NEWEST" @@ -82,31 +80,6 @@ class WorkflowConfig(BaseModel): task_defaults: TaskDefaults = TaskDefaults() - def _raise_for_invalid_expression(self, expr: str) -> None: - if not expr.startswith("input."): - return None - - _, field = expr.split(".", maxsplit=2) - - if field not in self.input_validator.model_fields.keys(): - raise ValueError( - f"The concurrency expression provided relies on the `{field}` field, which was not present in `{self.input_validator.__name__}`." - ) - - @model_validator(mode="after") - def validate_concurrency_expression(self) -> "WorkflowConfig": - if not self.concurrency: - return self - - if isinstance(self.concurrency, list): - for item in self.concurrency: - self._raise_for_invalid_expression(item.expression) - - if isinstance(self.concurrency, ConcurrencyExpression): - self._raise_for_invalid_expression(self.concurrency.expression) - - return self - class StepType(str, Enum): DEFAULT = "default" diff --git a/sdks/python/hatchet_sdk/runnables/workflow.py b/sdks/python/hatchet_sdk/runnables/workflow.py index f8a729bb8..e4b7ebcad 100644 --- a/sdks/python/hatchet_sdk/runnables/workflow.py +++ b/sdks/python/hatchet_sdk/runnables/workflow.py @@ -132,34 +132,6 @@ class BaseWorkflow(Generic[TWorkflowInput]): def _create_action_name(self, step: Task[TWorkflowInput, Any]) -> str: return self.service_name + ":" + step.name - def _raise_for_invalid_concurrency( - self, concurrency: ConcurrencyExpression - ) -> bool: - expr = concurrency.expression - - if not expr.startswith("input."): - return True - - _, field = expr.split(".", maxsplit=2) - - if field not in self.config.input_validator.model_fields.keys(): - raise ValueError( - f"The concurrency expression provided relies on the `{field}` field, which was not present in `{self.config.input_validator.__name__}`." - ) - - return True - - def _validate_priority(self, default_priority: int | None) -> int | None: - validated_priority = ( - max(1, min(3, default_priority)) if default_priority else None - ) - if validated_priority != default_priority: - logger.warning( - "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." - ) - - return validated_priority - def _is_leaf_task(self, task: Task[TWorkflowInput, Any]) -> bool: return not any(task in t.parents for t in self.tasks if task != t) diff --git a/sdks/python/hatchet_sdk/worker/action_listener_process.py b/sdks/python/hatchet_sdk/worker/action_listener_process.py index 6ef404854..c725686fb 100644 --- a/sdks/python/hatchet_sdk/worker/action_listener_process.py +++ b/sdks/python/hatchet_sdk/worker/action_listener_process.py @@ -291,6 +291,9 @@ class WorkerActionListenerProcess: self.event_queue.put(STOP_LOOP) async def exit_gracefully(self) -> None: + if self.listener: + self.listener.stop_signal = True + await self.pause_task_assignment() if self.killing: diff --git a/sdks/python/hatchet_sdk/worker/runner/runner.py b/sdks/python/hatchet_sdk/worker/runner/runner.py index d47d9176c..de05de1a4 100644 --- a/sdks/python/hatchet_sdk/worker/runner/runner.py +++ b/sdks/python/hatchet_sdk/worker/runner/runner.py @@ -105,6 +105,9 @@ class Runner: self.lifespan_context = lifespan_context + if self.config.enable_thread_pool_monitoring: + self.start_background_monitoring() + def create_workflow_run_url(self, action: Action) -> str: return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}" @@ -270,6 +273,47 @@ class Runner: finally: self.cleanup_run_id(action.key) + async def log_thread_pool_status(self) -> None: + thread_pool_details = { + "max_workers": self.slots, + "total_threads": len(self.thread_pool._threads), + "idle_threads": self.thread_pool._idle_semaphore._value, + "active_threads": len(self.threads), + "pending_tasks": len(self.tasks), + "queue_size": self.thread_pool._work_queue.qsize(), + "threads_alive": sum(1 for t in self.thread_pool._threads if t.is_alive()), + "threads_daemon": sum(1 for t in self.thread_pool._threads if t.daemon), + } + + logger.warning("Thread pool detailed status %s", thread_pool_details) + + async def _start_monitoring(self) -> None: + logger.debug("Thread pool monitoring started") + try: + while True: + await self.log_thread_pool_status() + + for key in self.threads.keys(): + if key not in self.tasks: + logger.debug(f"Potential zombie thread found for key {key}") + + for key, task in self.tasks.items(): + if task.done() and key in self.threads: + logger.debug( + f"Task is done but thread still exists for key {key}" + ) + + await asyncio.sleep(60) + except asyncio.CancelledError: + logger.warning("Thread pool monitoring task cancelled") + except Exception as e: + logger.exception(f"Error in thread pool monitoring: {e}") + + def start_background_monitoring(self) -> None: + loop = asyncio.get_event_loop() + self.monitoring_task = loop.create_task(self._start_monitoring()) + logger.debug("Started thread pool monitoring background task") + def cleanup_run_id(self, key: ActionKey) -> None: if key in self.tasks: del self.tasks[key] @@ -419,23 +463,18 @@ class Runner: try: # call cancel to signal the context to stop if key in self.contexts: - context = self.contexts.get(key) - - if context: - context._set_cancellation_flag() + self.contexts[key]._set_cancellation_flag() await asyncio.sleep(1) if key in self.tasks: - future = self.tasks.get(key) - - if future: - future.cancel() + self.tasks[key].cancel() # check if thread is still running, if so, print a warning if key in self.threads: - thread = self.threads.get(key) - if thread and self.config.enable_force_kill_sync_threads: + thread = self.threads[key] + + if self.config.enable_force_kill_sync_threads: self.force_kill_thread(thread) await asyncio.sleep(1) diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index af3dc5601..19ab59d18 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "hatchet-sdk" -version = "1.8.1" +version = "1.8.2" description = "" authors = ["Alexander Belanger "] readme = "README.md"