[Python]: Refactor: Remove validator registry (#1528)

* feat: remove validator registry!

* chore: version

* refactor: rename `WorkflowValidator` -> `TaskIOValidator`

* fix: rm unnecessary variable
This commit is contained in:
Matt Kaye
2025-04-10 17:45:35 -04:00
committed by GitHub
parent 7781200123
commit e4e57e7951
7 changed files with 15 additions and 44 deletions
+2 -19
View File
@@ -19,7 +19,7 @@ from hatchet_sdk.context.worker_context import WorkerContext
from hatchet_sdk.features.runs import RunsClient
from hatchet_sdk.logger import logger
from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr
from hatchet_sdk.utils.typing import JSONSerializableMapping, WorkflowValidator
from hatchet_sdk.utils.typing import JSONSerializableMapping
from hatchet_sdk.waits import SleepCondition, UserEventCondition
if TYPE_CHECKING:
@@ -37,10 +37,8 @@ class Context:
durable_event_listener: DurableEventListener | None,
worker: WorkerContext,
runs_client: RunsClient,
validator_registry: dict[str, WorkflowValidator] = {},
):
self.worker = worker
self.validator_registry = validator_registry
self.data = action.action_payload
@@ -74,27 +72,12 @@ class Context:
if self.was_skipped(task):
raise ValueError("{task.name} was skipped")
action_prefix = self.action.action_id.split(":")[0]
workflow_validator = next(
(
v
for k, v in self.validator_registry.items()
if k == f"{action_prefix}:{task.name}"
),
None,
)
try:
parent_step_data = cast(R, self.data.parents[task.name])
except KeyError:
raise ValueError(f"Step output for '{task.name}' not found")
if (
parent_step_data
and workflow_validator
and (v := workflow_validator.step_output)
):
if parent_step_data and (v := task.validators.step_output):
return cast(R, v.model_validate(parent_step_data))
return parent_step_data
@@ -7,6 +7,7 @@ from typing import (
TypeVar,
Union,
cast,
get_type_hints,
)
from hatchet_sdk.context.context import Context, DurableContext
@@ -28,6 +29,7 @@ from hatchet_sdk.runnables.types import (
is_sync_fn,
)
from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr
from hatchet_sdk.utils.typing import TaskIOValidator, is_basemodel_subclass
from hatchet_sdk.waits import (
Action,
Condition,
@@ -106,6 +108,13 @@ class Task(Generic[TWorkflowInput, R]):
self.skip_if = self._flatten_conditions(skip_if)
self.cancel_if = self._flatten_conditions(cancel_if)
return_type = get_type_hints(_fn).get("return")
self.validators: TaskIOValidator = TaskIOValidator(
workflow_input=workflow.config.input_validator,
step_output=return_type if is_basemodel_subclass(return_type) else None,
)
def _flatten_conditions(
self, conditions: list[Condition | OrGroup]
) -> list[Condition]:
+2 -4
View File
@@ -1,9 +1,7 @@
from typing import Any, Mapping, Type, TypeGuard, TypeVar
from typing import Any, Mapping, Type, TypeGuard
from pydantic import BaseModel
T = TypeVar("T", bound=BaseModel)
def is_basemodel_subclass(model: Any) -> TypeGuard[Type[BaseModel]]:
try:
@@ -12,7 +10,7 @@ def is_basemodel_subclass(model: Any) -> TypeGuard[Type[BaseModel]]:
return False
class WorkflowValidator(BaseModel):
class TaskIOValidator(BaseModel):
workflow_input: Type[BaseModel] | None = None
step_output: Type[BaseModel] | None = None
@@ -8,7 +8,6 @@ from hatchet_sdk.clients.dispatcher.action_listener import Action
from hatchet_sdk.config import ClientConfig
from hatchet_sdk.logger import logger
from hatchet_sdk.runnables.task import Task
from hatchet_sdk.utils.typing import WorkflowValidator
from hatchet_sdk.worker.action_listener_process import ActionEvent
from hatchet_sdk.worker.runner.runner import Runner
from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs
@@ -24,7 +23,6 @@ class WorkerActionRunLoopManager:
self,
name: str,
action_registry: dict[str, Task[Any, Any]],
validator_registry: dict[str, WorkflowValidator],
slots: int | None,
config: ClientConfig,
action_queue: "Queue[Action | STOP_LOOP_TYPE]",
@@ -36,7 +34,6 @@ class WorkerActionRunLoopManager:
) -> None:
self.name = name
self.action_registry = action_registry
self.validator_registry = validator_registry
self.slots = slots
self.config = config
self.action_queue = action_queue
@@ -88,7 +85,6 @@ class WorkerActionRunLoopManager:
self.slots,
self.handle_kill,
self.action_registry,
self.validator_registry,
self.labels,
)
@@ -43,7 +43,6 @@ from hatchet_sdk.runnables.contextvars import (
)
from hatchet_sdk.runnables.task import Task
from hatchet_sdk.runnables.types import R, TWorkflowInput
from hatchet_sdk.utils.typing import WorkflowValidator
from hatchet_sdk.worker.action_listener_process import ActionEvent
from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars
@@ -63,7 +62,6 @@ class Runner:
slots: int | None = None,
handle_kill: bool = True,
action_registry: dict[str, Task[TWorkflowInput, R]] = {},
validator_registry: dict[str, WorkflowValidator] = {},
labels: dict[str, str | int] = {},
):
# We store the config so we can dynamically create clients for the dispatcher client.
@@ -73,7 +71,6 @@ class Runner:
self.tasks: dict[str, asyncio.Task[Any]] = {} # Store run ids and futures
self.contexts: dict[str, Context] = {} # Store run ids and contexts
self.action_registry = action_registry
self.validator_registry = validator_registry
self.event_queue = event_queue
@@ -305,7 +302,6 @@ class Runner:
event_client=self.event_client,
durable_event_listener=self.durable_event_listener,
worker=self.worker_context,
validator_registry=self.validator_registry,
runs_client=self.runs_client,
)
+1 -12
View File
@@ -10,7 +10,7 @@ from enum import Enum
from multiprocessing import Queue
from multiprocessing.process import BaseProcess
from types import FrameType
from typing import Any, TypeVar, get_type_hints
from typing import Any, TypeVar
from warnings import warn
from aiohttp import web
@@ -26,7 +26,6 @@ from hatchet_sdk.contracts.v1.workflows_pb2 import CreateWorkflowVersionRequest
from hatchet_sdk.logger import logger
from hatchet_sdk.runnables.task import Task
from hatchet_sdk.runnables.workflow import BaseWorkflow
from hatchet_sdk.utils.typing import WorkflowValidator, is_basemodel_subclass
from hatchet_sdk.worker.action_listener_process import (
ActionEvent,
worker_action_listener_process,
@@ -87,8 +86,6 @@ class Worker:
self.action_registry: dict[str, Task[Any, Any]] = {}
self.durable_action_registry: dict[str, Task[Any, Any]] = {}
self.validator_registry: dict[str, WorkflowValidator] = {}
self.killing: bool = False
self._status: WorkerStatus
@@ -153,13 +150,6 @@ class Worker:
self.has_any_non_durable = True
self.action_registry[action_name] = step
return_type = get_type_hints(step.fn).get("return")
self.validator_registry[action_name] = WorkflowValidator(
workflow_input=workflow.config.input_validator,
step_output=return_type if is_basemodel_subclass(return_type) else None,
)
def register_workflows(self, workflows: list[BaseWorkflow[Any]]) -> None:
for workflow in workflows:
self.register_workflow(workflow)
@@ -285,7 +275,6 @@ class Worker:
return WorkerActionRunLoopManager(
self.name + ("_durable" if is_durable else ""),
self.durable_action_registry if is_durable else self.action_registry,
self.validator_registry,
1_000 if is_durable else self.slots,
self.config,
self.durable_action_queue if is_durable else self.action_queue,
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "hatchet-sdk"
version = "1.3.1"
version = "1.3.2"
description = ""
authors = ["Alexander Belanger <alexander@hatchet.run>"]
readme = "README.md"