mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-04-23 02:34:48 -05:00
[Python] Fix: Remove global event loop setters everywhere (#1452)
* feat: factor out async to sync * feat: remove global event loop setter in workflow listener * fix: rm more global loop sets * fix: more loop sets * fix: more * fix: more * fix: rm one more * fix: stream from thread * fix: dispatcher * fix: make tests have independent loop scopes (woohoo!) * feat: use default loop scope * fix: try adding back tests * Revert "fix: try adding back tests" This reverts commit bed34a9bae539650e4fe32e0518aa9d1c5c0af5c. * fix: rm dead code * fix: remove redundant `_utils` * fix: add typing to client stubs + regenerate * fix: create more clients lazily * fix: print cruft * chore: version * fix: lint
This commit is contained in:
@@ -5,7 +5,7 @@ from hatchet_sdk import Hatchet
|
||||
|
||||
|
||||
# requires scope module or higher for shared event loop
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_run(hatchet: Hatchet) -> None:
|
||||
result = await bulk_parent_wf.aio_run(input=ParentInput(n=12))
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from hatchet_sdk import Hatchet
|
||||
|
||||
|
||||
# requires scope module or higher for shared event loop
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_run(hatchet: Hatchet) -> None:
|
||||
with pytest.raises(Exception, match="(Task exceeded timeout|TIMED_OUT)"):
|
||||
await wf.aio_run()
|
||||
|
||||
@@ -6,7 +6,7 @@ from hatchet_sdk.workflow_run import WorkflowRunRef
|
||||
|
||||
|
||||
# requires scope module or higher for shared event loop
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.skip(reason="The timing for this test is not reliable")
|
||||
async def test_run(hatchet: Hatchet) -> None:
|
||||
num_runs = 6
|
||||
|
||||
@@ -9,7 +9,7 @@ from hatchet_sdk.workflow_run import WorkflowRunRef
|
||||
|
||||
# requires scope module or higher for shared event loop
|
||||
@pytest.mark.skip(reason="The timing for this test is not reliable")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_run(hatchet: Hatchet) -> None:
|
||||
num_groups = 2
|
||||
runs: list[WorkflowRunRef] = []
|
||||
|
||||
@@ -5,7 +5,7 @@ from hatchet_sdk import Hatchet
|
||||
|
||||
|
||||
# requires scope module or higher for shared event loop
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_run(hatchet: Hatchet) -> None:
|
||||
result = await dag_workflow.aio_run()
|
||||
|
||||
|
||||
@@ -7,6 +7,6 @@
|
||||
# worker = fixture_bg_worker(["poetry", "run", "manual_trigger"])
|
||||
|
||||
# # requires scope module or higher for shared event loop
|
||||
# @pytest.mark.asyncio(loop_scope="session")
|
||||
# @pytest.mark.asyncio()
|
||||
# async def test_run(hatchet: Hatchet):
|
||||
# # TODO
|
||||
|
||||
@@ -11,7 +11,7 @@ from hatchet_sdk import Hatchet
|
||||
os.getenv("CI", "false").lower() == "true",
|
||||
reason="Skipped in CI because of unreliability",
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_durable(hatchet: Hatchet) -> None:
|
||||
ref = durable_workflow.run_no_wait()
|
||||
|
||||
|
||||
@@ -5,21 +5,21 @@ from hatchet_sdk.hatchet import Hatchet
|
||||
|
||||
|
||||
# requires scope module or higher for shared event loop
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_event_push(hatchet: Hatchet) -> None:
|
||||
e = hatchet.event.push("user:create", {"test": "test"})
|
||||
|
||||
assert e.eventId is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_async_event_push(aiohatchet: Hatchet) -> None:
|
||||
e = await aiohatchet.event.aio_push("user:create", {"test": "test"})
|
||||
|
||||
assert e.eventId is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_async_event_bulk_push(aiohatchet: Hatchet) -> None:
|
||||
|
||||
events = [
|
||||
|
||||
@@ -5,7 +5,7 @@ from hatchet_sdk import Hatchet
|
||||
|
||||
|
||||
# requires scope module or higher for shared event loop
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_run(hatchet: Hatchet) -> None:
|
||||
result = await parent_wf.aio_run(ParentInput(n=2))
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from hatchet_sdk import Hatchet
|
||||
|
||||
|
||||
# requires scope module or higher for shared event loop
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_run(hatchet: Hatchet) -> None:
|
||||
result = await logging_workflow.aio_run()
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from hatchet_sdk.clients.rest.models.v1_task_status import V1TaskStatus
|
||||
|
||||
|
||||
# requires scope module or higher for shared event loop
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_run_timeout(aiohatchet: Hatchet, worker: Worker) -> None:
|
||||
run = on_failure_wf.run_no_wait()
|
||||
try:
|
||||
|
||||
@@ -9,7 +9,7 @@ from hatchet_sdk import Hatchet
|
||||
|
||||
# requires scope module or higher for shared event loop
|
||||
@pytest.mark.skip(reason="The timing for this test is not reliable")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_run(hatchet: Hatchet) -> None:
|
||||
|
||||
run1 = rate_limit_workflow.run_no_wait()
|
||||
|
||||
@@ -5,7 +5,7 @@ from hatchet_sdk import Hatchet
|
||||
|
||||
|
||||
# requires scope module or higher for shared event loop
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_execution_timeout(hatchet: Hatchet) -> None:
|
||||
run = timeout_wf.run_no_wait()
|
||||
|
||||
@@ -13,7 +13,7 @@ async def test_execution_timeout(hatchet: Hatchet) -> None:
|
||||
await run.aio_result()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_run_refresh_timeout(hatchet: Hatchet) -> None:
|
||||
result = await refresh_timeout_wf.aio_run()
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from hatchet_sdk import Hatchet
|
||||
os.getenv("CI", "false").lower() == "true",
|
||||
reason="Skipped in CI because of unreliability",
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_waits(hatchet: Hatchet) -> None:
|
||||
|
||||
ref = task_condition_workflow.run_no_wait()
|
||||
|
||||
@@ -105,6 +105,9 @@ find ./hatchet_sdk/contracts -type f -name '*_grpc.py' -print0 | xargs -0 sed -i
|
||||
find ./hatchet_sdk/contracts -type f -name '*_grpc.py' -print0 | xargs -0 sed -i '' 's/import events_pb2 as events__pb2/from hatchet_sdk.contracts import events_pb2 as events__pb2/g'
|
||||
find ./hatchet_sdk/contracts -type f -name '*_grpc.py' -print0 | xargs -0 sed -i '' 's/import workflows_pb2 as workflows__pb2/from hatchet_sdk.contracts import workflows_pb2 as workflows__pb2/g'
|
||||
|
||||
find ./hatchet_sdk/contracts -type f -name '*_grpc.py' -print0 | xargs -0 sed -i '' 's/def __init__(self, channel):/def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:/g'
|
||||
|
||||
|
||||
# ensure that pre-commit is applied without errors
|
||||
./lint.sh
|
||||
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
import asyncio
|
||||
|
||||
import grpc
|
||||
|
||||
from hatchet_sdk.clients.admin import AdminClient
|
||||
from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
|
||||
from hatchet_sdk.clients.events import EventClient, new_event
|
||||
from hatchet_sdk.clients.events import EventClient
|
||||
from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
|
||||
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
|
||||
from hatchet_sdk.config import ClientConfig
|
||||
from hatchet_sdk.connection import new_conn
|
||||
from hatchet_sdk.features.cron import CronClient
|
||||
from hatchet_sdk.features.logs import LogsClient
|
||||
from hatchet_sdk.features.metrics import MetricsClient
|
||||
@@ -29,21 +24,13 @@ class Client:
|
||||
workflow_listener: PooledWorkflowRunListener | None | None = None,
|
||||
debug: bool = False,
|
||||
):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
conn: grpc.Channel = new_conn(config, False)
|
||||
|
||||
self.config = config
|
||||
self.admin = admin_client or AdminClient(config)
|
||||
self.dispatcher = dispatcher_client or DispatcherClient(config)
|
||||
self.event = event_client or new_event(conn, config)
|
||||
self.event = event_client or EventClient(config)
|
||||
self.listener = RunEventListenerClient(config)
|
||||
self.workflow_listener = workflow_listener
|
||||
self.logInterceptor = config.logger
|
||||
self.log_interceptor = config.logger
|
||||
self.debug = debug
|
||||
|
||||
self.cron = CronClient(self.config)
|
||||
|
||||
@@ -8,8 +8,6 @@ from google.protobuf import timestamp_pb2
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
|
||||
from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
|
||||
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
|
||||
from hatchet_sdk.config import ClientConfig
|
||||
from hatchet_sdk.connection import new_conn
|
||||
from hatchet_sdk.contracts import workflows_pb2 as v0_workflow_protos
|
||||
@@ -64,14 +62,11 @@ class AdminClient:
|
||||
def __init__(self, config: ClientConfig):
|
||||
conn = new_conn(config, False)
|
||||
self.config = config
|
||||
self.client = AdminServiceStub(conn) # type: ignore[no-untyped-call]
|
||||
self.v0_client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call]
|
||||
self.client = AdminServiceStub(conn)
|
||||
self.v0_client = WorkflowServiceStub(conn)
|
||||
self.token = config.token
|
||||
self.listener_client = RunEventListenerClient(config=config)
|
||||
self.namespace = config.namespace
|
||||
|
||||
self.pooled_workflow_listener: PooledWorkflowRunListener | None = None
|
||||
|
||||
class TriggerWorkflowRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
@@ -307,9 +302,6 @@ class AdminClient:
|
||||
) -> WorkflowRunRef:
|
||||
request = self._create_workflow_run_request(workflow_name, input, options)
|
||||
|
||||
if not self.pooled_workflow_listener:
|
||||
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
|
||||
|
||||
try:
|
||||
resp = cast(
|
||||
v0_workflow_protos.TriggerWorkflowResponse,
|
||||
@@ -325,8 +317,7 @@ class AdminClient:
|
||||
|
||||
return WorkflowRunRef(
|
||||
workflow_run_id=resp.workflow_run_id,
|
||||
workflow_listener=self.pooled_workflow_listener,
|
||||
workflow_run_event_listener=self.listener_client,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
|
||||
@@ -343,9 +334,6 @@ class AdminClient:
|
||||
async with spawn_index_lock:
|
||||
request = self._create_workflow_run_request(workflow_name, input, options)
|
||||
|
||||
if not self.pooled_workflow_listener:
|
||||
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
|
||||
|
||||
try:
|
||||
resp = cast(
|
||||
v0_workflow_protos.TriggerWorkflowResponse,
|
||||
@@ -362,8 +350,7 @@ class AdminClient:
|
||||
|
||||
return WorkflowRunRef(
|
||||
workflow_run_id=resp.workflow_run_id,
|
||||
workflow_listener=self.pooled_workflow_listener,
|
||||
workflow_run_event_listener=self.listener_client,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
|
||||
@@ -372,9 +359,6 @@ class AdminClient:
|
||||
self,
|
||||
workflows: list[WorkflowRunTriggerConfig],
|
||||
) -> list[WorkflowRunRef]:
|
||||
if not self.pooled_workflow_listener:
|
||||
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
|
||||
|
||||
bulk_request = v0_workflow_protos.BulkTriggerWorkflowRequest(
|
||||
workflows=[
|
||||
self._create_workflow_run_request(
|
||||
@@ -395,8 +379,7 @@ class AdminClient:
|
||||
return [
|
||||
WorkflowRunRef(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_listener=self.pooled_workflow_listener,
|
||||
workflow_run_event_listener=self.listener_client,
|
||||
config=self.config,
|
||||
)
|
||||
for workflow_run_id in resp.workflow_run_ids
|
||||
]
|
||||
@@ -409,9 +392,6 @@ class AdminClient:
|
||||
## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
|
||||
## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
|
||||
## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
|
||||
if not self.pooled_workflow_listener:
|
||||
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
|
||||
|
||||
async with spawn_index_lock:
|
||||
bulk_request = v0_workflow_protos.BulkTriggerWorkflowRequest(
|
||||
workflows=[
|
||||
@@ -433,18 +413,13 @@ class AdminClient:
|
||||
return [
|
||||
WorkflowRunRef(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_listener=self.pooled_workflow_listener,
|
||||
workflow_run_event_listener=self.listener_client,
|
||||
config=self.config,
|
||||
)
|
||||
for workflow_run_id in resp.workflow_run_ids
|
||||
]
|
||||
|
||||
def get_workflow_run(self, workflow_run_id: str) -> WorkflowRunRef:
|
||||
if not self.pooled_workflow_listener:
|
||||
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
|
||||
|
||||
return WorkflowRunRef(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_listener=self.pooled_workflow_listener,
|
||||
workflow_run_event_listener=self.listener_client,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
@@ -152,7 +152,7 @@ class ActionListener:
|
||||
self.config = config
|
||||
self.worker_id = worker_id
|
||||
|
||||
self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call]
|
||||
self.aio_client = DispatcherStub(new_conn(self.config, True))
|
||||
self.token = self.config.token
|
||||
|
||||
self.retries = 0
|
||||
@@ -232,14 +232,8 @@ class ActionListener:
|
||||
if self.heartbeat_task is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError as e:
|
||||
if str(e).startswith("There is no current event loop in thread"):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
else:
|
||||
raise e
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self.heartbeat_task = loop.create_task(self.heartbeat())
|
||||
|
||||
def __aiter__(self) -> AsyncGenerator[Action | None, None]:
|
||||
@@ -386,7 +380,7 @@ class ActionListener:
|
||||
f"action listener connection interrupted, retrying... ({self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT})"
|
||||
)
|
||||
|
||||
self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call]
|
||||
self.aio_client = DispatcherStub(new_conn(self.config, True))
|
||||
|
||||
if self.listen_strategy == "v2":
|
||||
# we should await for the listener to be established before
|
||||
|
||||
@@ -34,20 +34,23 @@ DEFAULT_REGISTER_TIMEOUT = 30
|
||||
|
||||
|
||||
class DispatcherClient:
|
||||
config: ClientConfig
|
||||
|
||||
def __init__(self, config: ClientConfig):
|
||||
conn = new_conn(config, False)
|
||||
self.client = DispatcherStub(conn) # type: ignore[no-untyped-call]
|
||||
self.client = DispatcherStub(conn)
|
||||
|
||||
aio_conn = new_conn(config, True)
|
||||
self.aio_client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call]
|
||||
self.token = config.token
|
||||
self.config = config
|
||||
|
||||
## IMPORTANT: This needs to be created lazily so we don't require
|
||||
## an event loop to instantiate the client.
|
||||
self.aio_client: DispatcherStub | None = None
|
||||
|
||||
async def get_action_listener(
|
||||
self, req: GetActionListenerRequest
|
||||
) -> ActionListener:
|
||||
if not self.aio_client:
|
||||
aio_conn = new_conn(self.config, True)
|
||||
self.aio_client = DispatcherStub(aio_conn)
|
||||
|
||||
# Override labels with the preset labels
|
||||
preset_labels = self.config.worker_preset_labels
|
||||
@@ -95,6 +98,10 @@ class DispatcherClient:
|
||||
async def _try_send_step_action_event(
|
||||
self, action: Action, event_type: StepActionEventType, payload: str
|
||||
) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse]:
|
||||
if not self.aio_client:
|
||||
aio_conn = new_conn(self.config, True)
|
||||
self.aio_client = DispatcherStub(aio_conn)
|
||||
|
||||
event_timestamp = Timestamp()
|
||||
event_timestamp.GetCurrentTime()
|
||||
|
||||
@@ -122,6 +129,10 @@ class DispatcherClient:
|
||||
async def send_group_key_action_event(
|
||||
self, action: Action, event_type: GroupKeyActionEventType, payload: str
|
||||
) -> grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse]:
|
||||
if not self.aio_client:
|
||||
aio_conn = new_conn(self.config, True)
|
||||
self.aio_client = DispatcherStub(aio_conn)
|
||||
|
||||
event_timestamp = Timestamp()
|
||||
event_timestamp.GetCurrentTime()
|
||||
|
||||
@@ -191,6 +202,10 @@ class DispatcherClient:
|
||||
worker_id: str | None,
|
||||
labels: dict[str, str | int],
|
||||
) -> None:
|
||||
if not self.aio_client:
|
||||
aio_conn = new_conn(self.config, True)
|
||||
self.aio_client = DispatcherStub(aio_conn)
|
||||
|
||||
worker_labels = {}
|
||||
|
||||
for key, value in labels.items():
|
||||
|
||||
@@ -84,14 +84,6 @@ class RegisterDurableEventRequest(BaseModel):
|
||||
|
||||
class DurableEventListener:
|
||||
def __init__(self, config: ClientConfig):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
conn = new_conn(config, True)
|
||||
self.client = V1DispatcherStub(conn) # type: ignore[no-untyped-call]
|
||||
self.token = config.token
|
||||
self.config = config
|
||||
|
||||
@@ -129,11 +121,14 @@ class DurableEventListener:
|
||||
self.interrupt.set()
|
||||
|
||||
async def _init_producer(self) -> None:
|
||||
conn = new_conn(self.config, True)
|
||||
client = V1DispatcherStub(conn)
|
||||
|
||||
try:
|
||||
if not self.listener:
|
||||
while True:
|
||||
try:
|
||||
self.listener = await self._retry_subscribe()
|
||||
self.listener = await self._retry_subscribe(client)
|
||||
|
||||
logger.debug("Workflow run listener connected.")
|
||||
|
||||
@@ -282,6 +277,7 @@ class DurableEventListener:
|
||||
|
||||
async def _retry_subscribe(
|
||||
self,
|
||||
client: V1DispatcherStub,
|
||||
) -> grpc.aio.UnaryStreamCall[ListenForDurableEventRequest, DurableEvent]:
|
||||
retries = 0
|
||||
|
||||
@@ -298,8 +294,8 @@ class DurableEventListener:
|
||||
grpc.aio.UnaryStreamCall[
|
||||
ListenForDurableEventRequest, DurableEvent
|
||||
],
|
||||
self.client.ListenForDurableEvent(
|
||||
self._request(),
|
||||
client.ListenForDurableEvent(
|
||||
self._request(), # type: ignore[arg-type]
|
||||
metadata=get_metadata(self.token),
|
||||
),
|
||||
)
|
||||
@@ -315,7 +311,10 @@ class DurableEventListener:
|
||||
def register_durable_event(
|
||||
self, request: RegisterDurableEventRequest
|
||||
) -> Literal[True]:
|
||||
self.client.RegisterDurableEvent(
|
||||
conn = new_conn(self.config, True)
|
||||
client = V1DispatcherStub(conn)
|
||||
|
||||
client.RegisterDurableEvent(
|
||||
request.to_proto(),
|
||||
timeout=5,
|
||||
metadata=get_metadata(self.token),
|
||||
|
||||
@@ -3,15 +3,16 @@ import datetime
|
||||
import json
|
||||
from typing import List, cast
|
||||
|
||||
import grpc
|
||||
from google.protobuf import timestamp_pb2
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
|
||||
from hatchet_sdk.config import ClientConfig
|
||||
from hatchet_sdk.connection import new_conn
|
||||
from hatchet_sdk.contracts.events_pb2 import (
|
||||
BulkPushEventRequest,
|
||||
Event,
|
||||
Events,
|
||||
PushEventRequest,
|
||||
PutLogRequest,
|
||||
PutStreamEventRequest,
|
||||
@@ -21,13 +22,6 @@ from hatchet_sdk.metadata import get_metadata
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
|
||||
|
||||
def new_event(conn: grpc.Channel, config: ClientConfig) -> "EventClient":
|
||||
return EventClient(
|
||||
client=EventsServiceStub(conn), # type: ignore[no-untyped-call]
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def proto_timestamp_now() -> timestamp_pb2.Timestamp:
|
||||
t = datetime.datetime.now().timestamp()
|
||||
seconds = int(t)
|
||||
@@ -52,8 +46,10 @@ class BulkPushEventWithMetadata(BaseModel):
|
||||
|
||||
|
||||
class EventClient:
|
||||
def __init__(self, client: EventsServiceStub, config: ClientConfig):
|
||||
self.client = client
|
||||
def __init__(self, config: ClientConfig):
|
||||
conn = new_conn(config, False)
|
||||
self.client = EventsServiceStub(conn)
|
||||
|
||||
self.token = config.token
|
||||
self.namespace = config.namespace
|
||||
|
||||
@@ -146,11 +142,11 @@ class EventClient:
|
||||
]
|
||||
)
|
||||
|
||||
response = self.client.BulkPush(bulk_request, metadata=get_metadata(self.token))
|
||||
|
||||
return cast(
|
||||
list[Event],
|
||||
response.events,
|
||||
return list(
|
||||
cast(
|
||||
Events,
|
||||
self.client.BulkPush(bulk_request, metadata=get_metadata(self.token)),
|
||||
).events
|
||||
)
|
||||
|
||||
def log(self, message: str, step_run_id: str) -> None:
|
||||
|
||||
@@ -29,8 +29,10 @@ class TenantResource(str, Enum):
|
||||
allowed enum values
|
||||
"""
|
||||
WORKER = "WORKER"
|
||||
WORKER_SLOT = "WORKER_SLOT"
|
||||
EVENT = "EVENT"
|
||||
WORKFLOW_RUN = "WORKFLOW_RUN"
|
||||
TASK_RUN = "TASK_RUN"
|
||||
CRON = "CRON"
|
||||
SCHEDULE = "SCHEDULE"
|
||||
|
||||
|
||||
@@ -22,17 +22,13 @@ from typing import Any, ClassVar, Dict, List, Optional, Set
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import Self
|
||||
|
||||
from hatchet_sdk.clients.rest.models.workflow_runs_metrics_counts import (
|
||||
WorkflowRunsMetricsCounts,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRunsMetrics(BaseModel):
|
||||
"""
|
||||
WorkflowRunsMetrics
|
||||
""" # noqa: E501
|
||||
|
||||
counts: Optional[WorkflowRunsMetricsCounts] = None
|
||||
counts: Optional[Dict[str, Any]] = None
|
||||
__properties: ClassVar[List[str]] = ["counts"]
|
||||
|
||||
model_config = ConfigDict(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncGenerator, Callable, Generator, cast
|
||||
from queue import Empty, Queue
|
||||
from threading import Thread
|
||||
from typing import Any, AsyncGenerator, Callable, Generator, Literal, TypeVar, cast
|
||||
|
||||
import grpc
|
||||
from pydantic import BaseModel
|
||||
@@ -55,6 +57,8 @@ workflow_run_event_type_mapping = {
|
||||
ResourceEventType.RESOURCE_EVENT_TYPE_TIMED_OUT: WorkflowRunEventType.WORKFLOW_RUN_EVENT_TYPE_TIMED_OUT,
|
||||
}
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class StepRunEvent(BaseModel):
|
||||
type: StepRunEventType
|
||||
@@ -64,18 +68,20 @@ class StepRunEvent(BaseModel):
|
||||
class RunEventListener:
|
||||
def __init__(
|
||||
self,
|
||||
client: DispatcherStub,
|
||||
token: str,
|
||||
config: ClientConfig,
|
||||
workflow_run_id: str | None = None,
|
||||
additional_meta_kv: tuple[str, str] | None = None,
|
||||
):
|
||||
self.client = client
|
||||
self.config = config
|
||||
self.stop_signal = False
|
||||
self.token = token
|
||||
|
||||
self.workflow_run_id = workflow_run_id
|
||||
self.additional_meta_kv = additional_meta_kv
|
||||
|
||||
## IMPORTANT: This needs to be created lazily so we don't require
|
||||
## an event loop to instantiate the client.
|
||||
self.client: DispatcherStub | None = None
|
||||
|
||||
def abort(self) -> None:
|
||||
self.stop_signal = True
|
||||
|
||||
@@ -85,27 +91,46 @@ class RunEventListener:
|
||||
async def __anext__(self) -> StepRunEvent:
|
||||
return await self._generator().__anext__()
|
||||
|
||||
def __iter__(self) -> Generator[StepRunEvent, None, None]:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError as e:
|
||||
if str(e).startswith("There is no current event loop in thread"):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
else:
|
||||
raise e
|
||||
def async_to_sync_thread(
|
||||
self, async_iter: AsyncGenerator[T, None]
|
||||
) -> Generator[T, None, None]:
|
||||
q = Queue[T | Literal["DONE"]]()
|
||||
done_sentinel: Literal["DONE"] = "DONE"
|
||||
|
||||
async_iter = self.__aiter__()
|
||||
def runner() -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
async def consume() -> None:
|
||||
try:
|
||||
async for item in async_iter:
|
||||
q.put(item)
|
||||
finally:
|
||||
q.put(done_sentinel)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(consume())
|
||||
finally:
|
||||
loop.stop()
|
||||
loop.close()
|
||||
|
||||
thread = Thread(target=runner)
|
||||
thread.start()
|
||||
|
||||
while True:
|
||||
try:
|
||||
future = asyncio.ensure_future(async_iter.__anext__())
|
||||
yield loop.run_until_complete(future)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error in synchronous iterator: {e}")
|
||||
break
|
||||
item = q.get(timeout=1)
|
||||
if item == "DONE":
|
||||
break
|
||||
yield item
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
thread.join()
|
||||
|
||||
def __iter__(self) -> Generator[StepRunEvent, None, None]:
|
||||
for item in self.async_to_sync_thread(self.__aiter__()):
|
||||
yield item
|
||||
|
||||
async def _generator(self) -> AsyncGenerator[StepRunEvent, None]:
|
||||
while True:
|
||||
@@ -172,6 +197,10 @@ class RunEventListener:
|
||||
async def retry_subscribe(self) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
retries = 0
|
||||
|
||||
if self.client is None:
|
||||
aio_conn = new_conn(self.config, True)
|
||||
self.client = DispatcherStub(aio_conn)
|
||||
|
||||
while retries < DEFAULT_ACTION_LISTENER_RETRY_COUNT:
|
||||
try:
|
||||
if retries > 0:
|
||||
@@ -184,7 +213,7 @@ class RunEventListener:
|
||||
SubscribeToWorkflowEventsRequest(
|
||||
workflowRunId=self.workflow_run_id,
|
||||
),
|
||||
metadata=get_metadata(self.token),
|
||||
metadata=get_metadata(self.config.token),
|
||||
),
|
||||
)
|
||||
elif self.additional_meta_kv is not None:
|
||||
@@ -195,7 +224,7 @@ class RunEventListener:
|
||||
additionalMetaKey=self.additional_meta_kv[0],
|
||||
additionalMetaValue=self.additional_meta_kv[1],
|
||||
),
|
||||
metadata=get_metadata(self.token),
|
||||
metadata=get_metadata(self.config.token),
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -212,30 +241,16 @@ class RunEventListener:
|
||||
|
||||
class RunEventListenerClient:
|
||||
def __init__(self, config: ClientConfig):
|
||||
self.token = config.token
|
||||
self.config = config
|
||||
self.client: DispatcherStub | None = None
|
||||
|
||||
def stream_by_run_id(self, workflow_run_id: str) -> RunEventListener:
|
||||
return self.stream(workflow_run_id)
|
||||
|
||||
def stream(self, workflow_run_id: str) -> RunEventListener:
|
||||
if not self.client:
|
||||
aio_conn = new_conn(self.config, True)
|
||||
self.client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call]
|
||||
|
||||
return RunEventListener(
|
||||
client=self.client, token=self.token, workflow_run_id=workflow_run_id
|
||||
)
|
||||
return RunEventListener(config=self.config, workflow_run_id=workflow_run_id)
|
||||
|
||||
def stream_by_additional_metadata(self, key: str, value: str) -> RunEventListener:
|
||||
if not self.client:
|
||||
aio_conn = new_conn(self.config, True)
|
||||
self.client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call]
|
||||
|
||||
return RunEventListener(
|
||||
client=self.client, token=self.token, additional_meta_kv=(key, value)
|
||||
)
|
||||
return RunEventListener(config=self.config, additional_meta_kv=(key, value))
|
||||
|
||||
async def on(
|
||||
self, workflow_run_id: str, handler: Callable[[StepRunEvent], Any] | None = None
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import AsyncContextManager, Callable, Coroutine, ParamSpec, TypeVar
|
||||
from typing import AsyncContextManager, ParamSpec, TypeVar
|
||||
|
||||
from hatchet_sdk.clients.rest.api_client import ApiClient
|
||||
from hatchet_sdk.clients.rest.configuration import Configuration
|
||||
@@ -44,38 +42,3 @@ class BaseRestClient:
|
||||
|
||||
def client(self) -> AsyncContextManager[ApiClient]:
|
||||
return ApiClient(self.api_config)
|
||||
|
||||
def _run_async_function_do_not_use_directly(
|
||||
self,
|
||||
async_func: Callable[P, Coroutine[Y, S, R]],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def _run_async_from_sync(
|
||||
self,
|
||||
async_func: Callable[P, Coroutine[Y, S, R]],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
else:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
lambda: self._run_async_function_do_not_use_directly(
|
||||
async_func, *args, **kwargs
|
||||
)
|
||||
)
|
||||
return future.result()
|
||||
|
||||
@@ -54,14 +54,6 @@ class _Subscription:
|
||||
|
||||
class PooledWorkflowRunListener:
|
||||
def __init__(self, config: ClientConfig):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
conn = new_conn(config, True)
|
||||
self.client = DispatcherStub(conn) # type: ignore[no-untyped-call]
|
||||
self.token = config.token
|
||||
self.config = config
|
||||
|
||||
@@ -91,6 +83,10 @@ class PooledWorkflowRunListener:
|
||||
|
||||
self.interrupter: asyncio.Task[None] | None = None
|
||||
|
||||
## IMPORTANT: This needs to be created lazily so we don't require
|
||||
## an event loop to instantiate the client.
|
||||
self.client: DispatcherStub | None = None
|
||||
|
||||
async def _interrupter(self) -> None:
|
||||
"""
|
||||
_interrupter runs in a separate thread and interrupts the listener according to a configurable duration.
|
||||
@@ -239,7 +235,7 @@ class PooledWorkflowRunListener:
|
||||
if subscription_id:
|
||||
self.cleanup_subscription(subscription_id)
|
||||
|
||||
async def result(self, workflow_run_id: str) -> dict[str, Any]:
|
||||
async def aio_result(self, workflow_run_id: str) -> dict[str, Any]:
|
||||
from hatchet_sdk.clients.admin import DedupeViolationErr
|
||||
|
||||
event = await self.subscribe(workflow_run_id)
|
||||
@@ -261,6 +257,9 @@ class PooledWorkflowRunListener:
|
||||
self,
|
||||
) -> grpc.aio.UnaryStreamCall[SubscribeToWorkflowRunsRequest, WorkflowRunEvent]:
|
||||
retries = 0
|
||||
if self.client is None:
|
||||
conn = new_conn(self.config, True)
|
||||
self.client = DispatcherStub(conn)
|
||||
|
||||
while retries < DEFAULT_WORKFLOW_LISTENER_RETRY_COUNT:
|
||||
try:
|
||||
@@ -276,7 +275,7 @@ class PooledWorkflowRunListener:
|
||||
SubscribeToWorkflowRunsRequest, WorkflowRunEvent
|
||||
],
|
||||
self.client.SubscribeToWorkflowRuns(
|
||||
self._request(),
|
||||
self._request(), # type: ignore[arg-type]
|
||||
metadata=get_metadata(self.token),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ if _version_not_supported:
|
||||
class DispatcherStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -33,7 +33,7 @@ if _version_not_supported:
|
||||
class EventsServiceStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -33,7 +33,7 @@ if _version_not_supported:
|
||||
class V1DispatcherStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -34,7 +34,7 @@ class AdminServiceStub(object):
|
||||
"""AdminService represents a set of RPCs for admin management of tasks, workflows, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -34,7 +34,7 @@ class WorkflowServiceStub(object):
|
||||
"""WorkflowService represents a set of RPCs for managing workflows.
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -18,6 +18,7 @@ from hatchet_sdk.clients.v1.api_client import (
|
||||
BaseRestClient,
|
||||
maybe_additional_metadata_to_kv,
|
||||
)
|
||||
from hatchet_sdk.utils.aio import run_async_from_sync
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
|
||||
|
||||
@@ -121,7 +122,7 @@ class CronClient(BaseRestClient):
|
||||
input: JSONSerializableMapping,
|
||||
additional_metadata: JSONSerializableMapping,
|
||||
) -> CronWorkflows:
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_create,
|
||||
workflow_name,
|
||||
cron_name,
|
||||
@@ -143,7 +144,7 @@ class CronClient(BaseRestClient):
|
||||
)
|
||||
|
||||
def delete(self, cron_id: str) -> None:
|
||||
return self._run_async_from_sync(self.aio_delete, cron_id)
|
||||
return run_async_from_sync(self.aio_delete, cron_id)
|
||||
|
||||
async def aio_list(
|
||||
self,
|
||||
@@ -204,7 +205,7 @@ class CronClient(BaseRestClient):
|
||||
Returns:
|
||||
CronWorkflowsList: A list of cron workflows.
|
||||
"""
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_list,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
@@ -239,4 +240,4 @@ class CronClient(BaseRestClient):
|
||||
Returns:
|
||||
CronWorkflows: The requested cron workflow instance.
|
||||
"""
|
||||
return self._run_async_from_sync(self.aio_get, cron_id)
|
||||
return run_async_from_sync(self.aio_get, cron_id)
|
||||
|
||||
@@ -2,6 +2,7 @@ from hatchet_sdk.clients.rest.api.log_api import LogApi
|
||||
from hatchet_sdk.clients.rest.api_client import ApiClient
|
||||
from hatchet_sdk.clients.rest.models.v1_log_line_list import V1LogLineList
|
||||
from hatchet_sdk.clients.v1.api_client import BaseRestClient
|
||||
from hatchet_sdk.utils.aio import run_async_from_sync
|
||||
|
||||
|
||||
class LogsClient(BaseRestClient):
|
||||
@@ -13,4 +14,4 @@ class LogsClient(BaseRestClient):
|
||||
return await self._la(client).v1_log_line_list(task=task_run_id)
|
||||
|
||||
def list(self, task_run_id: str) -> V1LogLineList:
|
||||
return self._run_async_from_sync(self.aio_list, task_run_id)
|
||||
return run_async_from_sync(self.aio_list, task_run_id)
|
||||
|
||||
@@ -11,6 +11,7 @@ from hatchet_sdk.clients.v1.api_client import (
|
||||
BaseRestClient,
|
||||
maybe_additional_metadata_to_kv,
|
||||
)
|
||||
from hatchet_sdk.utils.aio import run_async_from_sync
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
|
||||
|
||||
@@ -38,7 +39,7 @@ class MetricsClient(BaseRestClient):
|
||||
status: WorkflowRunStatus | None = None,
|
||||
group_key: str | None = None,
|
||||
) -> WorkflowMetrics:
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_get_workflow_metrics, workflow_id, status, group_key
|
||||
)
|
||||
|
||||
@@ -61,7 +62,7 @@ class MetricsClient(BaseRestClient):
|
||||
workflow_ids: list[str] | None = None,
|
||||
additional_metadata: JSONSerializableMapping | None = None,
|
||||
) -> TenantQueueMetrics:
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_get_queue_metrics, workflow_ids, additional_metadata
|
||||
)
|
||||
|
||||
@@ -72,4 +73,4 @@ class MetricsClient(BaseRestClient):
|
||||
)
|
||||
|
||||
def get_task_metrics(self) -> TenantStepRunQueueMetrics:
|
||||
return self._run_async_from_sync(self.aio_get_task_metrics)
|
||||
return run_async_from_sync(self.aio_get_task_metrics)
|
||||
|
||||
@@ -24,7 +24,7 @@ class RateLimitsClient(BaseRestClient):
|
||||
)
|
||||
|
||||
conn = new_conn(self.client_config, False)
|
||||
client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call]
|
||||
client = WorkflowServiceStub(conn)
|
||||
|
||||
client.PutRateLimit(
|
||||
v0_workflow_protos.PutRateLimitRequest(
|
||||
|
||||
@@ -19,6 +19,7 @@ from hatchet_sdk.clients.v1.api_client import (
|
||||
BaseRestClient,
|
||||
maybe_additional_metadata_to_kv,
|
||||
)
|
||||
from hatchet_sdk.utils.aio import run_async_from_sync
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
|
||||
|
||||
@@ -93,7 +94,7 @@ class RunsClient(BaseRestClient):
|
||||
return await self._wra(client).v1_workflow_run_get(str(workflow_run_id))
|
||||
|
||||
def get(self, workflow_run_id: str) -> V1WorkflowRunDetails:
|
||||
return self._run_async_from_sync(self.aio_get, workflow_run_id)
|
||||
return run_async_from_sync(self.aio_get, workflow_run_id)
|
||||
|
||||
async def aio_list(
|
||||
self,
|
||||
@@ -138,7 +139,7 @@ class RunsClient(BaseRestClient):
|
||||
worker_id: str | None = None,
|
||||
parent_task_external_id: str | None = None,
|
||||
) -> V1TaskSummaryList:
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_list,
|
||||
since=since,
|
||||
only_tasks=only_tasks,
|
||||
@@ -174,7 +175,7 @@ class RunsClient(BaseRestClient):
|
||||
input: JSONSerializableMapping,
|
||||
additional_metadata: JSONSerializableMapping = {},
|
||||
) -> V1WorkflowRunDetails:
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_create, workflow_name, input, additional_metadata
|
||||
)
|
||||
|
||||
@@ -182,7 +183,7 @@ class RunsClient(BaseRestClient):
|
||||
await self.aio_bulk_replay(opts=BulkCancelReplayOpts(ids=[run_id]))
|
||||
|
||||
def replay(self, run_id: str) -> None:
|
||||
return self._run_async_from_sync(self.aio_replay, run_id)
|
||||
return run_async_from_sync(self.aio_replay, run_id)
|
||||
|
||||
async def aio_bulk_replay(self, opts: BulkCancelReplayOpts) -> None:
|
||||
async with self.client() as client:
|
||||
@@ -192,13 +193,13 @@ class RunsClient(BaseRestClient):
|
||||
)
|
||||
|
||||
def bulk_replay(self, opts: BulkCancelReplayOpts) -> None:
|
||||
return self._run_async_from_sync(self.aio_bulk_replay, opts)
|
||||
return run_async_from_sync(self.aio_bulk_replay, opts)
|
||||
|
||||
async def aio_cancel(self, run_id: str) -> None:
|
||||
await self.aio_bulk_cancel(opts=BulkCancelReplayOpts(ids=[run_id]))
|
||||
|
||||
def cancel(self, run_id: str) -> None:
|
||||
return self._run_async_from_sync(self.aio_cancel, run_id)
|
||||
return run_async_from_sync(self.aio_cancel, run_id)
|
||||
|
||||
async def aio_bulk_cancel(self, opts: BulkCancelReplayOpts) -> None:
|
||||
async with self.client() as client:
|
||||
@@ -208,7 +209,7 @@ class RunsClient(BaseRestClient):
|
||||
)
|
||||
|
||||
def bulk_cancel(self, opts: BulkCancelReplayOpts) -> None:
|
||||
return self._run_async_from_sync(self.aio_bulk_cancel, opts)
|
||||
return run_async_from_sync(self.aio_bulk_cancel, opts)
|
||||
|
||||
async def aio_get_result(self, run_id: str) -> JSONSerializableMapping:
|
||||
details = await self.aio_get(run_id)
|
||||
|
||||
@@ -22,6 +22,7 @@ from hatchet_sdk.clients.v1.api_client import (
|
||||
BaseRestClient,
|
||||
maybe_additional_metadata_to_kv,
|
||||
)
|
||||
from hatchet_sdk.utils.aio import run_async_from_sync
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
|
||||
|
||||
@@ -82,7 +83,7 @@ class ScheduledClient(BaseRestClient):
|
||||
ScheduledWorkflows: The created scheduled workflow instance.
|
||||
"""
|
||||
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_create,
|
||||
workflow_name,
|
||||
trigger_at,
|
||||
@@ -104,7 +105,7 @@ class ScheduledClient(BaseRestClient):
|
||||
)
|
||||
|
||||
def delete(self, scheduled_id: str) -> None:
|
||||
self._run_async_from_sync(self.aio_delete, scheduled_id)
|
||||
run_async_from_sync(self.aio_delete, scheduled_id)
|
||||
|
||||
async def aio_list(
|
||||
self,
|
||||
@@ -175,7 +176,7 @@ class ScheduledClient(BaseRestClient):
|
||||
Returns:
|
||||
List[ScheduledWorkflows]: A list of scheduled workflows matching the criteria.
|
||||
"""
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_list,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
@@ -214,4 +215,4 @@ class ScheduledClient(BaseRestClient):
|
||||
Returns:
|
||||
ScheduledWorkflows: The requested scheduled workflow instance.
|
||||
"""
|
||||
return self._run_async_from_sync(self.aio_get, scheduled_id)
|
||||
return run_async_from_sync(self.aio_get, scheduled_id)
|
||||
|
||||
@@ -4,6 +4,7 @@ from hatchet_sdk.clients.rest.models.update_worker_request import UpdateWorkerRe
|
||||
from hatchet_sdk.clients.rest.models.worker import Worker
|
||||
from hatchet_sdk.clients.rest.models.worker_list import WorkerList
|
||||
from hatchet_sdk.clients.v1.api_client import BaseRestClient
|
||||
from hatchet_sdk.utils.aio import run_async_from_sync
|
||||
|
||||
|
||||
class WorkersClient(BaseRestClient):
|
||||
@@ -15,7 +16,7 @@ class WorkersClient(BaseRestClient):
|
||||
return await self._wa(client).worker_get(worker_id)
|
||||
|
||||
def get(self, worker_id: str) -> Worker:
|
||||
return self._run_async_from_sync(self.aio_get, worker_id)
|
||||
return run_async_from_sync(self.aio_get, worker_id)
|
||||
|
||||
async def aio_list(
|
||||
self,
|
||||
@@ -28,7 +29,7 @@ class WorkersClient(BaseRestClient):
|
||||
def list(
|
||||
self,
|
||||
) -> WorkerList:
|
||||
return self._run_async_from_sync(self.aio_list)
|
||||
return run_async_from_sync(self.aio_list)
|
||||
|
||||
async def aio_update(self, worker_id: str, opts: UpdateWorkerRequest) -> Worker:
|
||||
async with self.client() as client:
|
||||
@@ -38,4 +39,4 @@ class WorkersClient(BaseRestClient):
|
||||
)
|
||||
|
||||
def update(self, worker_id: str, opts: UpdateWorkerRequest) -> Worker:
|
||||
return self._run_async_from_sync(self.aio_update, worker_id, opts)
|
||||
return run_async_from_sync(self.aio_update, worker_id, opts)
|
||||
|
||||
@@ -5,6 +5,7 @@ from hatchet_sdk.clients.rest.models.workflow import Workflow
|
||||
from hatchet_sdk.clients.rest.models.workflow_list import WorkflowList
|
||||
from hatchet_sdk.clients.rest.models.workflow_version import WorkflowVersion
|
||||
from hatchet_sdk.clients.v1.api_client import BaseRestClient
|
||||
from hatchet_sdk.utils.aio import run_async_from_sync
|
||||
|
||||
|
||||
class WorkflowsClient(BaseRestClient):
|
||||
@@ -19,7 +20,7 @@ class WorkflowsClient(BaseRestClient):
|
||||
return await self._wa(client).workflow_get(workflow_id)
|
||||
|
||||
def get(self, workflow_id: str) -> Workflow:
|
||||
return self._run_async_from_sync(self.aio_get, workflow_id)
|
||||
return run_async_from_sync(self.aio_get, workflow_id)
|
||||
|
||||
async def aio_list(
|
||||
self,
|
||||
@@ -41,7 +42,7 @@ class WorkflowsClient(BaseRestClient):
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> WorkflowList:
|
||||
return self._run_async_from_sync(self.aio_list, workflow_name, limit, offset)
|
||||
return run_async_from_sync(self.aio_list, workflow_name, limit, offset)
|
||||
|
||||
async def aio_get_version(
|
||||
self, workflow_id: str, version: str | None = None
|
||||
@@ -52,4 +53,4 @@ class WorkflowsClient(BaseRestClient):
|
||||
def get_version(
|
||||
self, workflow_id: str, version: str | None = None
|
||||
) -> WorkflowVersion:
|
||||
return self._run_async_from_sync(self.aio_get_version, workflow_id, version)
|
||||
return run_async_from_sync(self.aio_get_version, workflow_id, version)
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
def get_metadata(token: str) -> list[tuple[str, str]]:
|
||||
return [("authorization", "bearer " + token)]
|
||||
def get_metadata(token: str) -> tuple[tuple[str, str]]:
|
||||
return (("authorization", "bearer " + token),)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any, Generic, cast, get_type_hints
|
||||
|
||||
@@ -12,7 +11,7 @@ from hatchet_sdk.contracts.workflows_pb2 import WorkflowVersion
|
||||
from hatchet_sdk.runnables.task import Task
|
||||
from hatchet_sdk.runnables.types import EmptyModel, R, TWorkflowInput
|
||||
from hatchet_sdk.runnables.workflow import BaseWorkflow, Workflow
|
||||
from hatchet_sdk.utils.aio_utils import get_active_event_loop
|
||||
from hatchet_sdk.utils.aio import run_async_from_sync
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping, is_basemodel_subclass
|
||||
from hatchet_sdk.workflow_run import WorkflowRunRef
|
||||
|
||||
@@ -27,25 +26,11 @@ class TaskRunRef(Generic[TWorkflowInput, R]):
|
||||
self._wrr = workflow_run_ref
|
||||
|
||||
async def aio_result(self) -> R:
|
||||
result = await self._wrr.workflow_listener.result(self._wrr.workflow_run_id)
|
||||
result = await self._wrr.workflow_listener.aio_result(self._wrr.workflow_run_id)
|
||||
return self._s._extract_result(result)
|
||||
|
||||
def result(self) -> R:
|
||||
coro = self._wrr.workflow_listener.result(self._wrr.workflow_run_id)
|
||||
|
||||
loop = get_active_event_loop()
|
||||
|
||||
if loop is None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(coro)
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
else:
|
||||
result = loop.run_until_complete(coro)
|
||||
|
||||
return self._s._extract_result(result)
|
||||
return run_async_from_sync(self.aio_result)
|
||||
|
||||
|
||||
class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, Coroutine, ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
Y = TypeVar("Y")
|
||||
S = TypeVar("S")
|
||||
|
||||
|
||||
def _run_async_function_do_not_use_directly(
|
||||
async_func: Callable[P, Coroutine[Y, S, R]],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def run_async_from_sync(
|
||||
async_func: Callable[P, Coroutine[Y, S, R]],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
else:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
lambda: _run_async_function_do_not_use_directly(
|
||||
async_func, *args, **kwargs
|
||||
)
|
||||
)
|
||||
return future.result()
|
||||
@@ -1,18 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
|
||||
def get_active_event_loop() -> asyncio.AbstractEventLoop | None:
|
||||
"""
|
||||
Get the active event loop.
|
||||
|
||||
Returns:
|
||||
asyncio.AbstractEventLoop: The active event loop, or None if there is no active
|
||||
event loop in the current thread.
|
||||
"""
|
||||
try:
|
||||
return asyncio.get_event_loop()
|
||||
except RuntimeError as e:
|
||||
if str(e).startswith("There is no current event loop in thread"):
|
||||
return None
|
||||
else:
|
||||
raise e
|
||||
@@ -60,7 +60,7 @@ class WorkerActionRunLoopManager:
|
||||
|
||||
async def aio_start(self, retry_count: int = 1) -> None:
|
||||
await capture_logs(
|
||||
self.client.logInterceptor,
|
||||
self.client.log_interceptor,
|
||||
self.client.event,
|
||||
self._async_start,
|
||||
)(retry_count=retry_count)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from hatchet_sdk.clients.run_event_listener import (
|
||||
@@ -6,19 +5,19 @@ from hatchet_sdk.clients.run_event_listener import (
|
||||
RunEventListenerClient,
|
||||
)
|
||||
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
|
||||
from hatchet_sdk.utils.aio_utils import get_active_event_loop
|
||||
from hatchet_sdk.config import ClientConfig
|
||||
from hatchet_sdk.utils.aio import run_async_from_sync
|
||||
|
||||
|
||||
class WorkflowRunRef:
|
||||
def __init__(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
workflow_listener: PooledWorkflowRunListener,
|
||||
workflow_run_event_listener: RunEventListenerClient,
|
||||
config: ClientConfig,
|
||||
):
|
||||
self.workflow_run_id = workflow_run_id
|
||||
self.workflow_listener = workflow_listener
|
||||
self.workflow_run_event_listener = workflow_run_event_listener
|
||||
self.workflow_listener = PooledWorkflowRunListener(config)
|
||||
self.workflow_run_event_listener = RunEventListenerClient(config=config)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.workflow_run_id
|
||||
@@ -27,19 +26,7 @@ class WorkflowRunRef:
|
||||
return self.workflow_run_event_listener.stream(self.workflow_run_id)
|
||||
|
||||
async def aio_result(self) -> dict[str, Any]:
|
||||
return await self.workflow_listener.result(self.workflow_run_id)
|
||||
return await self.workflow_listener.aio_result(self.workflow_run_id)
|
||||
|
||||
def result(self) -> dict[str, Any]:
|
||||
coro = self.workflow_listener.result(self.workflow_run_id)
|
||||
|
||||
loop = get_active_event_loop()
|
||||
|
||||
if loop is None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
else:
|
||||
return loop.run_until_complete(coro)
|
||||
return run_async_from_sync(self.aio_result)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "hatchet-sdk"
|
||||
version = "1.0.3"
|
||||
version = "1.1.0"
|
||||
description = ""
|
||||
authors = ["Alexander Belanger <alexander@hatchet.run>"]
|
||||
readme = "README.md"
|
||||
|
||||
@@ -6,7 +6,7 @@ from examples.dag.worker import dag_workflow
|
||||
from hatchet_sdk import Hatchet
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_list_runs(hatchet: Hatchet) -> None:
|
||||
dag_result = await dag_workflow.aio_run()
|
||||
|
||||
@@ -19,7 +19,7 @@ async def test_list_runs(hatchet: Hatchet) -> None:
|
||||
assert v in [r.output for r in runs.rows]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_get_run(hatchet: Hatchet) -> None:
|
||||
dag_ref = await dag_workflow.aio_run_no_wait()
|
||||
|
||||
@@ -33,7 +33,7 @@ async def test_get_run(hatchet: Hatchet) -> None:
|
||||
assert {t.name for t in dag_workflow.tasks} == {t.task_name for t in run.shape}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio()
|
||||
async def test_list_workflows(hatchet: Hatchet) -> None:
|
||||
workflows = await hatchet.workflows.aio_list(
|
||||
workflow_name=dag_workflow.config.name, limit=1, offset=0
|
||||
|
||||
Reference in New Issue
Block a user