mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-04-25 03:39:38 -05:00
[Python]: Refactor: Remove validator registry (#1528)
* feat: remove validator registry! * chore: version * refactor: rename `WorkflowValidator` -> `TaskIOValidator` * fix: rm unnecessary variable
This commit is contained in:
@@ -19,7 +19,7 @@ from hatchet_sdk.context.worker_context import WorkerContext
|
||||
from hatchet_sdk.features.runs import RunsClient
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping, WorkflowValidator
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
from hatchet_sdk.waits import SleepCondition, UserEventCondition
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -37,10 +37,8 @@ class Context:
|
||||
durable_event_listener: DurableEventListener | None,
|
||||
worker: WorkerContext,
|
||||
runs_client: RunsClient,
|
||||
validator_registry: dict[str, WorkflowValidator] = {},
|
||||
):
|
||||
self.worker = worker
|
||||
self.validator_registry = validator_registry
|
||||
|
||||
self.data = action.action_payload
|
||||
|
||||
@@ -74,27 +72,12 @@ class Context:
|
||||
if self.was_skipped(task):
|
||||
raise ValueError("{task.name} was skipped")
|
||||
|
||||
action_prefix = self.action.action_id.split(":")[0]
|
||||
|
||||
workflow_validator = next(
|
||||
(
|
||||
v
|
||||
for k, v in self.validator_registry.items()
|
||||
if k == f"{action_prefix}:{task.name}"
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
try:
|
||||
parent_step_data = cast(R, self.data.parents[task.name])
|
||||
except KeyError:
|
||||
raise ValueError(f"Step output for '{task.name}' not found")
|
||||
|
||||
if (
|
||||
parent_step_data
|
||||
and workflow_validator
|
||||
and (v := workflow_validator.step_output)
|
||||
):
|
||||
if parent_step_data and (v := task.validators.step_output):
|
||||
return cast(R, v.model_validate(parent_step_data))
|
||||
|
||||
return parent_step_data
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from hatchet_sdk.context.context import Context, DurableContext
|
||||
@@ -28,6 +29,7 @@ from hatchet_sdk.runnables.types import (
|
||||
is_sync_fn,
|
||||
)
|
||||
from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr
|
||||
from hatchet_sdk.utils.typing import TaskIOValidator, is_basemodel_subclass
|
||||
from hatchet_sdk.waits import (
|
||||
Action,
|
||||
Condition,
|
||||
@@ -106,6 +108,13 @@ class Task(Generic[TWorkflowInput, R]):
|
||||
self.skip_if = self._flatten_conditions(skip_if)
|
||||
self.cancel_if = self._flatten_conditions(cancel_if)
|
||||
|
||||
return_type = get_type_hints(_fn).get("return")
|
||||
|
||||
self.validators: TaskIOValidator = TaskIOValidator(
|
||||
workflow_input=workflow.config.input_validator,
|
||||
step_output=return_type if is_basemodel_subclass(return_type) else None,
|
||||
)
|
||||
|
||||
def _flatten_conditions(
|
||||
self, conditions: list[Condition | OrGroup]
|
||||
) -> list[Condition]:
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from typing import Any, Mapping, Type, TypeGuard, TypeVar
|
||||
from typing import Any, Mapping, Type, TypeGuard
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def is_basemodel_subclass(model: Any) -> TypeGuard[Type[BaseModel]]:
|
||||
try:
|
||||
@@ -12,7 +10,7 @@ def is_basemodel_subclass(model: Any) -> TypeGuard[Type[BaseModel]]:
|
||||
return False
|
||||
|
||||
|
||||
class WorkflowValidator(BaseModel):
|
||||
class TaskIOValidator(BaseModel):
|
||||
workflow_input: Type[BaseModel] | None = None
|
||||
step_output: Type[BaseModel] | None = None
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from hatchet_sdk.clients.dispatcher.action_listener import Action
|
||||
from hatchet_sdk.config import ClientConfig
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.runnables.task import Task
|
||||
from hatchet_sdk.utils.typing import WorkflowValidator
|
||||
from hatchet_sdk.worker.action_listener_process import ActionEvent
|
||||
from hatchet_sdk.worker.runner.runner import Runner
|
||||
from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs
|
||||
@@ -24,7 +23,6 @@ class WorkerActionRunLoopManager:
|
||||
self,
|
||||
name: str,
|
||||
action_registry: dict[str, Task[Any, Any]],
|
||||
validator_registry: dict[str, WorkflowValidator],
|
||||
slots: int | None,
|
||||
config: ClientConfig,
|
||||
action_queue: "Queue[Action | STOP_LOOP_TYPE]",
|
||||
@@ -36,7 +34,6 @@ class WorkerActionRunLoopManager:
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.action_registry = action_registry
|
||||
self.validator_registry = validator_registry
|
||||
self.slots = slots
|
||||
self.config = config
|
||||
self.action_queue = action_queue
|
||||
@@ -88,7 +85,6 @@ class WorkerActionRunLoopManager:
|
||||
self.slots,
|
||||
self.handle_kill,
|
||||
self.action_registry,
|
||||
self.validator_registry,
|
||||
self.labels,
|
||||
)
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ from hatchet_sdk.runnables.contextvars import (
|
||||
)
|
||||
from hatchet_sdk.runnables.task import Task
|
||||
from hatchet_sdk.runnables.types import R, TWorkflowInput
|
||||
from hatchet_sdk.utils.typing import WorkflowValidator
|
||||
from hatchet_sdk.worker.action_listener_process import ActionEvent
|
||||
from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars
|
||||
|
||||
@@ -63,7 +62,6 @@ class Runner:
|
||||
slots: int | None = None,
|
||||
handle_kill: bool = True,
|
||||
action_registry: dict[str, Task[TWorkflowInput, R]] = {},
|
||||
validator_registry: dict[str, WorkflowValidator] = {},
|
||||
labels: dict[str, str | int] = {},
|
||||
):
|
||||
# We store the config so we can dynamically create clients for the dispatcher client.
|
||||
@@ -73,7 +71,6 @@ class Runner:
|
||||
self.tasks: dict[str, asyncio.Task[Any]] = {} # Store run ids and futures
|
||||
self.contexts: dict[str, Context] = {} # Store run ids and contexts
|
||||
self.action_registry = action_registry
|
||||
self.validator_registry = validator_registry
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
@@ -305,7 +302,6 @@ class Runner:
|
||||
event_client=self.event_client,
|
||||
durable_event_listener=self.durable_event_listener,
|
||||
worker=self.worker_context,
|
||||
validator_registry=self.validator_registry,
|
||||
runs_client=self.runs_client,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from enum import Enum
|
||||
from multiprocessing import Queue
|
||||
from multiprocessing.process import BaseProcess
|
||||
from types import FrameType
|
||||
from typing import Any, TypeVar, get_type_hints
|
||||
from typing import Any, TypeVar
|
||||
from warnings import warn
|
||||
|
||||
from aiohttp import web
|
||||
@@ -26,7 +26,6 @@ from hatchet_sdk.contracts.v1.workflows_pb2 import CreateWorkflowVersionRequest
|
||||
from hatchet_sdk.logger import logger
|
||||
from hatchet_sdk.runnables.task import Task
|
||||
from hatchet_sdk.runnables.workflow import BaseWorkflow
|
||||
from hatchet_sdk.utils.typing import WorkflowValidator, is_basemodel_subclass
|
||||
from hatchet_sdk.worker.action_listener_process import (
|
||||
ActionEvent,
|
||||
worker_action_listener_process,
|
||||
@@ -87,8 +86,6 @@ class Worker:
|
||||
self.action_registry: dict[str, Task[Any, Any]] = {}
|
||||
self.durable_action_registry: dict[str, Task[Any, Any]] = {}
|
||||
|
||||
self.validator_registry: dict[str, WorkflowValidator] = {}
|
||||
|
||||
self.killing: bool = False
|
||||
self._status: WorkerStatus
|
||||
|
||||
@@ -153,13 +150,6 @@ class Worker:
|
||||
self.has_any_non_durable = True
|
||||
self.action_registry[action_name] = step
|
||||
|
||||
return_type = get_type_hints(step.fn).get("return")
|
||||
|
||||
self.validator_registry[action_name] = WorkflowValidator(
|
||||
workflow_input=workflow.config.input_validator,
|
||||
step_output=return_type if is_basemodel_subclass(return_type) else None,
|
||||
)
|
||||
|
||||
def register_workflows(self, workflows: list[BaseWorkflow[Any]]) -> None:
|
||||
for workflow in workflows:
|
||||
self.register_workflow(workflow)
|
||||
@@ -285,7 +275,6 @@ class Worker:
|
||||
return WorkerActionRunLoopManager(
|
||||
self.name + ("_durable" if is_durable else ""),
|
||||
self.durable_action_registry if is_durable else self.action_registry,
|
||||
self.validator_registry,
|
||||
1_000 if is_durable else self.slots,
|
||||
self.config,
|
||||
self.durable_action_queue if is_durable else self.action_queue,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "hatchet-sdk"
|
||||
version = "1.3.1"
|
||||
version = "1.3.2"
|
||||
description = ""
|
||||
authors = ["Alexander Belanger <alexander@hatchet.run>"]
|
||||
readme = "README.md"
|
||||
|
||||
Reference in New Issue
Block a user