[Python] Fix: Remove global event loop setters everywhere (#1452)

* feat: factor out async to sync

* feat: remove global event loop setter in workflow listener

* fix: rm more global loop sets

* fix: more loop sets

* fix: more

* fix: more

* fix: rm one more

* fix: stream from thread

* fix: dispatcher

* fix: make tests have independent loop scopes (woohoo!)

* feat: use default loop scope

* fix: try adding back tests

* Revert "fix: try adding back tests"

This reverts commit bed34a9bae539650e4fe32e0518aa9d1c5c0af5c.

* fix: rm dead code

* fix: remove redundant `_utils`

* fix: add typing to client stubs + regenerate

* fix: create more clients lazily

* fix: print cruft

* chore: version

* fix: lint
This commit is contained in:
Matt Kaye
2025-03-31 13:58:50 -04:00
committed by GitHub
parent 8172d59f84
commit 46edb1f0b0
47 changed files with 242 additions and 294 deletions
@@ -5,7 +5,7 @@ from hatchet_sdk import Hatchet
# requires scope module or higher for shared event loop
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_run(hatchet: Hatchet) -> None:
result = await bulk_parent_wf.aio_run(input=ParentInput(n=12))
@@ -5,7 +5,7 @@ from hatchet_sdk import Hatchet
# requires scope module or higher for shared event loop
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_run(hatchet: Hatchet) -> None:
with pytest.raises(Exception, match="(Task exceeded timeout|TIMED_OUT)"):
await wf.aio_run()
@@ -6,7 +6,7 @@ from hatchet_sdk.workflow_run import WorkflowRunRef
# requires scope module or higher for shared event loop
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
@pytest.mark.skip(reason="The timing for this test is not reliable")
async def test_run(hatchet: Hatchet) -> None:
num_runs = 6
@@ -9,7 +9,7 @@ from hatchet_sdk.workflow_run import WorkflowRunRef
# requires scope module or higher for shared event loop
@pytest.mark.skip(reason="The timing for this test is not reliable")
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_run(hatchet: Hatchet) -> None:
num_groups = 2
runs: list[WorkflowRunRef] = []
+1 -1
View File
@@ -5,7 +5,7 @@ from hatchet_sdk import Hatchet
# requires scope module or higher for shared event loop
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_run(hatchet: Hatchet) -> None:
result = await dag_workflow.aio_run()
+1 -1
View File
@@ -7,6 +7,6 @@
# worker = fixture_bg_worker(["poetry", "run", "manual_trigger"])
# # requires scope module or higher for shared event loop
# @pytest.mark.asyncio(loop_scope="session")
# @pytest.mark.asyncio()
# async def test_run(hatchet: Hatchet):
# # TODO
+1 -1
View File
@@ -11,7 +11,7 @@ from hatchet_sdk import Hatchet
os.getenv("CI", "false").lower() == "true",
reason="Skipped in CI because of unreliability",
)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_durable(hatchet: Hatchet) -> None:
ref = durable_workflow.run_no_wait()
+3 -3
View File
@@ -5,21 +5,21 @@ from hatchet_sdk.hatchet import Hatchet
# requires scope module or higher for shared event loop
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_event_push(hatchet: Hatchet) -> None:
e = hatchet.event.push("user:create", {"test": "test"})
assert e.eventId is not None
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_async_event_push(aiohatchet: Hatchet) -> None:
e = await aiohatchet.event.aio_push("user:create", {"test": "test"})
assert e.eventId is not None
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_async_event_bulk_push(aiohatchet: Hatchet) -> None:
events = [
+1 -1
View File
@@ -5,7 +5,7 @@ from hatchet_sdk import Hatchet
# requires scope module or higher for shared event loop
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_run(hatchet: Hatchet) -> None:
result = await parent_wf.aio_run(ParentInput(n=2))
+1 -1
View File
@@ -5,7 +5,7 @@ from hatchet_sdk import Hatchet
# requires scope module or higher for shared event loop
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_run(hatchet: Hatchet) -> None:
result = await logging_workflow.aio_run()
@@ -8,7 +8,7 @@ from hatchet_sdk.clients.rest.models.v1_task_status import V1TaskStatus
# requires scope module or higher for shared event loop
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_run_timeout(aiohatchet: Hatchet, worker: Worker) -> None:
run = on_failure_wf.run_no_wait()
try:
@@ -9,7 +9,7 @@ from hatchet_sdk import Hatchet
# requires scope module or higher for shared event loop
@pytest.mark.skip(reason="The timing for this test is not reliable")
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_run(hatchet: Hatchet) -> None:
run1 = rate_limit_workflow.run_no_wait()
+2 -2
View File
@@ -5,7 +5,7 @@ from hatchet_sdk import Hatchet
# requires scope module or higher for shared event loop
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_execution_timeout(hatchet: Hatchet) -> None:
run = timeout_wf.run_no_wait()
@@ -13,7 +13,7 @@ async def test_execution_timeout(hatchet: Hatchet) -> None:
await run.aio_result()
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_run_refresh_timeout(hatchet: Hatchet) -> None:
result = await refresh_timeout_wf.aio_run()
+1 -1
View File
@@ -11,7 +11,7 @@ from hatchet_sdk import Hatchet
os.getenv("CI", "false").lower() == "true",
reason="Skipped in CI because of unreliability",
)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_waits(hatchet: Hatchet) -> None:
ref = task_condition_workflow.run_no_wait()
+3
View File
@@ -105,6 +105,9 @@ find ./hatchet_sdk/contracts -type f -name '*_grpc.py' -print0 | xargs -0 sed -i
find ./hatchet_sdk/contracts -type f -name '*_grpc.py' -print0 | xargs -0 sed -i '' 's/import events_pb2 as events__pb2/from hatchet_sdk.contracts import events_pb2 as events__pb2/g'
find ./hatchet_sdk/contracts -type f -name '*_grpc.py' -print0 | xargs -0 sed -i '' 's/import workflows_pb2 as workflows__pb2/from hatchet_sdk.contracts import workflows_pb2 as workflows__pb2/g'
find ./hatchet_sdk/contracts -type f -name '*_grpc.py' -print0 | xargs -0 sed -i '' 's/def __init__(self, channel):/def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:/g'
# ensure that pre-commit is applied without errors
./lint.sh
+3 -16
View File
@@ -1,14 +1,9 @@
import asyncio
import grpc
from hatchet_sdk.clients.admin import AdminClient
from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
from hatchet_sdk.clients.events import EventClient, new_event
from hatchet_sdk.clients.events import EventClient
from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
from hatchet_sdk.config import ClientConfig
from hatchet_sdk.connection import new_conn
from hatchet_sdk.features.cron import CronClient
from hatchet_sdk.features.logs import LogsClient
from hatchet_sdk.features.metrics import MetricsClient
@@ -29,21 +24,13 @@ class Client:
workflow_listener: PooledWorkflowRunListener | None | None = None,
debug: bool = False,
):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
conn: grpc.Channel = new_conn(config, False)
self.config = config
self.admin = admin_client or AdminClient(config)
self.dispatcher = dispatcher_client or DispatcherClient(config)
self.event = event_client or new_event(conn, config)
self.event = event_client or EventClient(config)
self.listener = RunEventListenerClient(config)
self.workflow_listener = workflow_listener
self.logInterceptor = config.logger
self.log_interceptor = config.logger
self.debug = debug
self.cron = CronClient(self.config)
+7 -32
View File
@@ -8,8 +8,6 @@ from google.protobuf import timestamp_pb2
from pydantic import BaseModel, ConfigDict, Field, field_validator
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
from hatchet_sdk.config import ClientConfig
from hatchet_sdk.connection import new_conn
from hatchet_sdk.contracts import workflows_pb2 as v0_workflow_protos
@@ -64,14 +62,11 @@ class AdminClient:
def __init__(self, config: ClientConfig):
conn = new_conn(config, False)
self.config = config
self.client = AdminServiceStub(conn) # type: ignore[no-untyped-call]
self.v0_client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call]
self.client = AdminServiceStub(conn)
self.v0_client = WorkflowServiceStub(conn)
self.token = config.token
self.listener_client = RunEventListenerClient(config=config)
self.namespace = config.namespace
self.pooled_workflow_listener: PooledWorkflowRunListener | None = None
class TriggerWorkflowRequest(BaseModel):
model_config = ConfigDict(extra="ignore")
@@ -307,9 +302,6 @@ class AdminClient:
) -> WorkflowRunRef:
request = self._create_workflow_run_request(workflow_name, input, options)
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
try:
resp = cast(
v0_workflow_protos.TriggerWorkflowResponse,
@@ -325,8 +317,7 @@ class AdminClient:
return WorkflowRunRef(
workflow_run_id=resp.workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
config=self.config,
)
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@@ -343,9 +334,6 @@ class AdminClient:
async with spawn_index_lock:
request = self._create_workflow_run_request(workflow_name, input, options)
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
try:
resp = cast(
v0_workflow_protos.TriggerWorkflowResponse,
@@ -362,8 +350,7 @@ class AdminClient:
return WorkflowRunRef(
workflow_run_id=resp.workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
config=self.config,
)
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@@ -372,9 +359,6 @@ class AdminClient:
self,
workflows: list[WorkflowRunTriggerConfig],
) -> list[WorkflowRunRef]:
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
bulk_request = v0_workflow_protos.BulkTriggerWorkflowRequest(
workflows=[
self._create_workflow_run_request(
@@ -395,8 +379,7 @@ class AdminClient:
return [
WorkflowRunRef(
workflow_run_id=workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
config=self.config,
)
for workflow_run_id in resp.workflow_run_ids
]
@@ -409,9 +392,6 @@ class AdminClient:
## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
async with spawn_index_lock:
bulk_request = v0_workflow_protos.BulkTriggerWorkflowRequest(
workflows=[
@@ -433,18 +413,13 @@ class AdminClient:
return [
WorkflowRunRef(
workflow_run_id=workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
config=self.config,
)
for workflow_run_id in resp.workflow_run_ids
]
def get_workflow_run(self, workflow_run_id: str) -> WorkflowRunRef:
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
return WorkflowRunRef(
workflow_run_id=workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
config=self.config,
)
@@ -152,7 +152,7 @@ class ActionListener:
self.config = config
self.worker_id = worker_id
self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call]
self.aio_client = DispatcherStub(new_conn(self.config, True))
self.token = self.config.token
self.retries = 0
@@ -232,14 +232,8 @@ class ActionListener:
if self.heartbeat_task is not None:
return
try:
loop = asyncio.get_event_loop()
except RuntimeError as e:
if str(e).startswith("There is no current event loop in thread"):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
else:
raise e
loop = asyncio.get_event_loop()
self.heartbeat_task = loop.create_task(self.heartbeat())
def __aiter__(self) -> AsyncGenerator[Action | None, None]:
@@ -386,7 +380,7 @@ class ActionListener:
f"action listener connection interrupted, retrying... ({self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT})"
)
self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call]
self.aio_client = DispatcherStub(new_conn(self.config, True))
if self.listen_strategy == "v2":
# we should await for the listener to be established before
@@ -34,20 +34,23 @@ DEFAULT_REGISTER_TIMEOUT = 30
class DispatcherClient:
config: ClientConfig
def __init__(self, config: ClientConfig):
conn = new_conn(config, False)
self.client = DispatcherStub(conn) # type: ignore[no-untyped-call]
self.client = DispatcherStub(conn)
aio_conn = new_conn(config, True)
self.aio_client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call]
self.token = config.token
self.config = config
## IMPORTANT: This needs to be created lazily so we don't require
## an event loop to instantiate the client.
self.aio_client: DispatcherStub | None = None
async def get_action_listener(
self, req: GetActionListenerRequest
) -> ActionListener:
if not self.aio_client:
aio_conn = new_conn(self.config, True)
self.aio_client = DispatcherStub(aio_conn)
# Override labels with the preset labels
preset_labels = self.config.worker_preset_labels
@@ -95,6 +98,10 @@ class DispatcherClient:
async def _try_send_step_action_event(
self, action: Action, event_type: StepActionEventType, payload: str
) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse]:
if not self.aio_client:
aio_conn = new_conn(self.config, True)
self.aio_client = DispatcherStub(aio_conn)
event_timestamp = Timestamp()
event_timestamp.GetCurrentTime()
@@ -122,6 +129,10 @@ class DispatcherClient:
async def send_group_key_action_event(
self, action: Action, event_type: GroupKeyActionEventType, payload: str
) -> grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse]:
if not self.aio_client:
aio_conn = new_conn(self.config, True)
self.aio_client = DispatcherStub(aio_conn)
event_timestamp = Timestamp()
event_timestamp.GetCurrentTime()
@@ -191,6 +202,10 @@ class DispatcherClient:
worker_id: str | None,
labels: dict[str, str | int],
) -> None:
if not self.aio_client:
aio_conn = new_conn(self.config, True)
self.aio_client = DispatcherStub(aio_conn)
worker_labels = {}
for key, value in labels.items():
@@ -84,14 +84,6 @@ class RegisterDurableEventRequest(BaseModel):
class DurableEventListener:
def __init__(self, config: ClientConfig):
try:
asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
conn = new_conn(config, True)
self.client = V1DispatcherStub(conn) # type: ignore[no-untyped-call]
self.token = config.token
self.config = config
@@ -129,11 +121,14 @@ class DurableEventListener:
self.interrupt.set()
async def _init_producer(self) -> None:
conn = new_conn(self.config, True)
client = V1DispatcherStub(conn)
try:
if not self.listener:
while True:
try:
self.listener = await self._retry_subscribe()
self.listener = await self._retry_subscribe(client)
logger.debug("Workflow run listener connected.")
@@ -282,6 +277,7 @@ class DurableEventListener:
async def _retry_subscribe(
self,
client: V1DispatcherStub,
) -> grpc.aio.UnaryStreamCall[ListenForDurableEventRequest, DurableEvent]:
retries = 0
@@ -298,8 +294,8 @@ class DurableEventListener:
grpc.aio.UnaryStreamCall[
ListenForDurableEventRequest, DurableEvent
],
self.client.ListenForDurableEvent(
self._request(),
client.ListenForDurableEvent(
self._request(), # type: ignore[arg-type]
metadata=get_metadata(self.token),
),
)
@@ -315,7 +311,10 @@ class DurableEventListener:
def register_durable_event(
self, request: RegisterDurableEventRequest
) -> Literal[True]:
self.client.RegisterDurableEvent(
conn = new_conn(self.config, True)
client = V1DispatcherStub(conn)
client.RegisterDurableEvent(
request.to_proto(),
timeout=5,
metadata=get_metadata(self.token),
+11 -15
View File
@@ -3,15 +3,16 @@ import datetime
import json
from typing import List, cast
import grpc
from google.protobuf import timestamp_pb2
from pydantic import BaseModel, Field
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
from hatchet_sdk.config import ClientConfig
from hatchet_sdk.connection import new_conn
from hatchet_sdk.contracts.events_pb2 import (
BulkPushEventRequest,
Event,
Events,
PushEventRequest,
PutLogRequest,
PutStreamEventRequest,
@@ -21,13 +22,6 @@ from hatchet_sdk.metadata import get_metadata
from hatchet_sdk.utils.typing import JSONSerializableMapping
def new_event(conn: grpc.Channel, config: ClientConfig) -> "EventClient":
return EventClient(
client=EventsServiceStub(conn), # type: ignore[no-untyped-call]
config=config,
)
def proto_timestamp_now() -> timestamp_pb2.Timestamp:
t = datetime.datetime.now().timestamp()
seconds = int(t)
@@ -52,8 +46,10 @@ class BulkPushEventWithMetadata(BaseModel):
class EventClient:
def __init__(self, client: EventsServiceStub, config: ClientConfig):
self.client = client
def __init__(self, config: ClientConfig):
conn = new_conn(config, False)
self.client = EventsServiceStub(conn)
self.token = config.token
self.namespace = config.namespace
@@ -146,11 +142,11 @@ class EventClient:
]
)
response = self.client.BulkPush(bulk_request, metadata=get_metadata(self.token))
return cast(
list[Event],
response.events,
return list(
cast(
Events,
self.client.BulkPush(bulk_request, metadata=get_metadata(self.token)),
).events
)
def log(self, message: str, step_run_id: str) -> None:
@@ -29,8 +29,10 @@ class TenantResource(str, Enum):
allowed enum values
"""
WORKER = "WORKER"
WORKER_SLOT = "WORKER_SLOT"
EVENT = "EVENT"
WORKFLOW_RUN = "WORKFLOW_RUN"
TASK_RUN = "TASK_RUN"
CRON = "CRON"
SCHEDULE = "SCHEDULE"
@@ -22,17 +22,13 @@ from typing import Any, ClassVar, Dict, List, Optional, Set
from pydantic import BaseModel, ConfigDict
from typing_extensions import Self
from hatchet_sdk.clients.rest.models.workflow_runs_metrics_counts import (
WorkflowRunsMetricsCounts,
)
class WorkflowRunsMetrics(BaseModel):
"""
WorkflowRunsMetrics
""" # noqa: E501
counts: Optional[WorkflowRunsMetricsCounts] = None
counts: Optional[Dict[str, Any]] = None
__properties: ClassVar[List[str]] = ["counts"]
model_config = ConfigDict(
@@ -1,6 +1,8 @@
import asyncio
from enum import Enum
from typing import Any, AsyncGenerator, Callable, Generator, cast
from queue import Empty, Queue
from threading import Thread
from typing import Any, AsyncGenerator, Callable, Generator, Literal, TypeVar, cast
import grpc
from pydantic import BaseModel
@@ -55,6 +57,8 @@ workflow_run_event_type_mapping = {
ResourceEventType.RESOURCE_EVENT_TYPE_TIMED_OUT: WorkflowRunEventType.WORKFLOW_RUN_EVENT_TYPE_TIMED_OUT,
}
T = TypeVar("T")
class StepRunEvent(BaseModel):
type: StepRunEventType
@@ -64,18 +68,20 @@ class StepRunEvent(BaseModel):
class RunEventListener:
def __init__(
self,
client: DispatcherStub,
token: str,
config: ClientConfig,
workflow_run_id: str | None = None,
additional_meta_kv: tuple[str, str] | None = None,
):
self.client = client
self.config = config
self.stop_signal = False
self.token = token
self.workflow_run_id = workflow_run_id
self.additional_meta_kv = additional_meta_kv
## IMPORTANT: This needs to be created lazily so we don't require
## an event loop to instantiate the client.
self.client: DispatcherStub | None = None
def abort(self) -> None:
self.stop_signal = True
@@ -85,27 +91,46 @@ class RunEventListener:
async def __anext__(self) -> StepRunEvent:
return await self._generator().__anext__()
def __iter__(self) -> Generator[StepRunEvent, None, None]:
try:
loop = asyncio.get_event_loop()
except RuntimeError as e:
if str(e).startswith("There is no current event loop in thread"):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
else:
raise e
def async_to_sync_thread(
self, async_iter: AsyncGenerator[T, None]
) -> Generator[T, None, None]:
q = Queue[T | Literal["DONE"]]()
done_sentinel: Literal["DONE"] = "DONE"
async_iter = self.__aiter__()
def runner() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def consume() -> None:
try:
async for item in async_iter:
q.put(item)
finally:
q.put(done_sentinel)
try:
loop.run_until_complete(consume())
finally:
loop.stop()
loop.close()
thread = Thread(target=runner)
thread.start()
while True:
try:
future = asyncio.ensure_future(async_iter.__anext__())
yield loop.run_until_complete(future)
except StopAsyncIteration:
break
except Exception as e:
print(f"Error in synchronous iterator: {e}")
break
item = q.get(timeout=1)
if item == "DONE":
break
yield item
except Empty:
continue
thread.join()
def __iter__(self) -> Generator[StepRunEvent, None, None]:
for item in self.async_to_sync_thread(self.__aiter__()):
yield item
async def _generator(self) -> AsyncGenerator[StepRunEvent, None]:
while True:
@@ -172,6 +197,10 @@ class RunEventListener:
async def retry_subscribe(self) -> AsyncGenerator[WorkflowEvent, None]:
retries = 0
if self.client is None:
aio_conn = new_conn(self.config, True)
self.client = DispatcherStub(aio_conn)
while retries < DEFAULT_ACTION_LISTENER_RETRY_COUNT:
try:
if retries > 0:
@@ -184,7 +213,7 @@ class RunEventListener:
SubscribeToWorkflowEventsRequest(
workflowRunId=self.workflow_run_id,
),
metadata=get_metadata(self.token),
metadata=get_metadata(self.config.token),
),
)
elif self.additional_meta_kv is not None:
@@ -195,7 +224,7 @@ class RunEventListener:
additionalMetaKey=self.additional_meta_kv[0],
additionalMetaValue=self.additional_meta_kv[1],
),
metadata=get_metadata(self.token),
metadata=get_metadata(self.config.token),
),
)
else:
@@ -212,30 +241,16 @@ class RunEventListener:
class RunEventListenerClient:
def __init__(self, config: ClientConfig):
self.token = config.token
self.config = config
self.client: DispatcherStub | None = None
def stream_by_run_id(self, workflow_run_id: str) -> RunEventListener:
return self.stream(workflow_run_id)
def stream(self, workflow_run_id: str) -> RunEventListener:
if not self.client:
aio_conn = new_conn(self.config, True)
self.client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call]
return RunEventListener(
client=self.client, token=self.token, workflow_run_id=workflow_run_id
)
return RunEventListener(config=self.config, workflow_run_id=workflow_run_id)
def stream_by_additional_metadata(self, key: str, value: str) -> RunEventListener:
if not self.client:
aio_conn = new_conn(self.config, True)
self.client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call]
return RunEventListener(
client=self.client, token=self.token, additional_meta_kv=(key, value)
)
return RunEventListener(config=self.config, additional_meta_kv=(key, value))
async def on(
self, workflow_run_id: str, handler: Callable[[StepRunEvent], Any] | None = None
@@ -1,6 +1,4 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncContextManager, Callable, Coroutine, ParamSpec, TypeVar
from typing import AsyncContextManager, ParamSpec, TypeVar
from hatchet_sdk.clients.rest.api_client import ApiClient
from hatchet_sdk.clients.rest.configuration import Configuration
@@ -44,38 +42,3 @@ class BaseRestClient:
def client(self) -> AsyncContextManager[ApiClient]:
return ApiClient(self.api_config)
def _run_async_function_do_not_use_directly(
self,
async_func: Callable[P, Coroutine[Y, S, R]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(async_func(*args, **kwargs))
finally:
loop.close()
def _run_async_from_sync(
self,
async_func: Callable[P, Coroutine[Y, S, R]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
return loop.run_until_complete(async_func(*args, **kwargs))
else:
with ThreadPoolExecutor() as executor:
future = executor.submit(
lambda: self._run_async_function_do_not_use_directly(
async_func, *args, **kwargs
)
)
return future.result()
@@ -54,14 +54,6 @@ class _Subscription:
class PooledWorkflowRunListener:
def __init__(self, config: ClientConfig):
try:
asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
conn = new_conn(config, True)
self.client = DispatcherStub(conn) # type: ignore[no-untyped-call]
self.token = config.token
self.config = config
@@ -91,6 +83,10 @@ class PooledWorkflowRunListener:
self.interrupter: asyncio.Task[None] | None = None
## IMPORTANT: This needs to be created lazily so we don't require
## an event loop to instantiate the client.
self.client: DispatcherStub | None = None
async def _interrupter(self) -> None:
"""
_interrupter runs in a separate thread and interrupts the listener according to a configurable duration.
@@ -239,7 +235,7 @@ class PooledWorkflowRunListener:
if subscription_id:
self.cleanup_subscription(subscription_id)
async def result(self, workflow_run_id: str) -> dict[str, Any]:
async def aio_result(self, workflow_run_id: str) -> dict[str, Any]:
from hatchet_sdk.clients.admin import DedupeViolationErr
event = await self.subscribe(workflow_run_id)
@@ -261,6 +257,9 @@ class PooledWorkflowRunListener:
self,
) -> grpc.aio.UnaryStreamCall[SubscribeToWorkflowRunsRequest, WorkflowRunEvent]:
retries = 0
if self.client is None:
conn = new_conn(self.config, True)
self.client = DispatcherStub(conn)
while retries < DEFAULT_WORKFLOW_LISTENER_RETRY_COUNT:
try:
@@ -276,7 +275,7 @@ class PooledWorkflowRunListener:
SubscribeToWorkflowRunsRequest, WorkflowRunEvent
],
self.client.SubscribeToWorkflowRuns(
self._request(),
self._request(), # type: ignore[arg-type]
metadata=get_metadata(self.token),
),
)
@@ -33,7 +33,7 @@ if _version_not_supported:
class DispatcherStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
"""Constructor.
Args:
@@ -33,7 +33,7 @@ if _version_not_supported:
class EventsServiceStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
"""Constructor.
Args:
@@ -33,7 +33,7 @@ if _version_not_supported:
class V1DispatcherStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
"""Constructor.
Args:
@@ -34,7 +34,7 @@ class AdminServiceStub(object):
"""AdminService represents a set of RPCs for admin management of tasks, workflows, etc.
"""
def __init__(self, channel):
def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
"""Constructor.
Args:
@@ -34,7 +34,7 @@ class WorkflowServiceStub(object):
"""WorkflowService represents a set of RPCs for managing workflows.
"""
def __init__(self, channel):
def __init__(self, channel: grpc.Channel | grpc.aio.Channel) -> None:
"""Constructor.
Args:
+5 -4
View File
@@ -18,6 +18,7 @@ from hatchet_sdk.clients.v1.api_client import (
BaseRestClient,
maybe_additional_metadata_to_kv,
)
from hatchet_sdk.utils.aio import run_async_from_sync
from hatchet_sdk.utils.typing import JSONSerializableMapping
@@ -121,7 +122,7 @@ class CronClient(BaseRestClient):
input: JSONSerializableMapping,
additional_metadata: JSONSerializableMapping,
) -> CronWorkflows:
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_create,
workflow_name,
cron_name,
@@ -143,7 +144,7 @@ class CronClient(BaseRestClient):
)
def delete(self, cron_id: str) -> None:
return self._run_async_from_sync(self.aio_delete, cron_id)
return run_async_from_sync(self.aio_delete, cron_id)
async def aio_list(
self,
@@ -204,7 +205,7 @@ class CronClient(BaseRestClient):
Returns:
CronWorkflowsList: A list of cron workflows.
"""
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_list,
offset=offset,
limit=limit,
@@ -239,4 +240,4 @@ class CronClient(BaseRestClient):
Returns:
CronWorkflows: The requested cron workflow instance.
"""
return self._run_async_from_sync(self.aio_get, cron_id)
return run_async_from_sync(self.aio_get, cron_id)
+2 -1
View File
@@ -2,6 +2,7 @@ from hatchet_sdk.clients.rest.api.log_api import LogApi
from hatchet_sdk.clients.rest.api_client import ApiClient
from hatchet_sdk.clients.rest.models.v1_log_line_list import V1LogLineList
from hatchet_sdk.clients.v1.api_client import BaseRestClient
from hatchet_sdk.utils.aio import run_async_from_sync
class LogsClient(BaseRestClient):
@@ -13,4 +14,4 @@ class LogsClient(BaseRestClient):
return await self._la(client).v1_log_line_list(task=task_run_id)
def list(self, task_run_id: str) -> V1LogLineList:
return self._run_async_from_sync(self.aio_list, task_run_id)
return run_async_from_sync(self.aio_list, task_run_id)
+4 -3
View File
@@ -11,6 +11,7 @@ from hatchet_sdk.clients.v1.api_client import (
BaseRestClient,
maybe_additional_metadata_to_kv,
)
from hatchet_sdk.utils.aio import run_async_from_sync
from hatchet_sdk.utils.typing import JSONSerializableMapping
@@ -38,7 +39,7 @@ class MetricsClient(BaseRestClient):
status: WorkflowRunStatus | None = None,
group_key: str | None = None,
) -> WorkflowMetrics:
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_get_workflow_metrics, workflow_id, status, group_key
)
@@ -61,7 +62,7 @@ class MetricsClient(BaseRestClient):
workflow_ids: list[str] | None = None,
additional_metadata: JSONSerializableMapping | None = None,
) -> TenantQueueMetrics:
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_get_queue_metrics, workflow_ids, additional_metadata
)
@@ -72,4 +73,4 @@ class MetricsClient(BaseRestClient):
)
def get_task_metrics(self) -> TenantStepRunQueueMetrics:
return self._run_async_from_sync(self.aio_get_task_metrics)
return run_async_from_sync(self.aio_get_task_metrics)
@@ -24,7 +24,7 @@ class RateLimitsClient(BaseRestClient):
)
conn = new_conn(self.client_config, False)
client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call]
client = WorkflowServiceStub(conn)
client.PutRateLimit(
v0_workflow_protos.PutRateLimitRequest(
+8 -7
View File
@@ -19,6 +19,7 @@ from hatchet_sdk.clients.v1.api_client import (
BaseRestClient,
maybe_additional_metadata_to_kv,
)
from hatchet_sdk.utils.aio import run_async_from_sync
from hatchet_sdk.utils.typing import JSONSerializableMapping
@@ -93,7 +94,7 @@ class RunsClient(BaseRestClient):
return await self._wra(client).v1_workflow_run_get(str(workflow_run_id))
def get(self, workflow_run_id: str) -> V1WorkflowRunDetails:
return self._run_async_from_sync(self.aio_get, workflow_run_id)
return run_async_from_sync(self.aio_get, workflow_run_id)
async def aio_list(
self,
@@ -138,7 +139,7 @@ class RunsClient(BaseRestClient):
worker_id: str | None = None,
parent_task_external_id: str | None = None,
) -> V1TaskSummaryList:
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_list,
since=since,
only_tasks=only_tasks,
@@ -174,7 +175,7 @@ class RunsClient(BaseRestClient):
input: JSONSerializableMapping,
additional_metadata: JSONSerializableMapping = {},
) -> V1WorkflowRunDetails:
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_create, workflow_name, input, additional_metadata
)
@@ -182,7 +183,7 @@ class RunsClient(BaseRestClient):
await self.aio_bulk_replay(opts=BulkCancelReplayOpts(ids=[run_id]))
def replay(self, run_id: str) -> None:
return self._run_async_from_sync(self.aio_replay, run_id)
return run_async_from_sync(self.aio_replay, run_id)
async def aio_bulk_replay(self, opts: BulkCancelReplayOpts) -> None:
async with self.client() as client:
@@ -192,13 +193,13 @@ class RunsClient(BaseRestClient):
)
def bulk_replay(self, opts: BulkCancelReplayOpts) -> None:
return self._run_async_from_sync(self.aio_bulk_replay, opts)
return run_async_from_sync(self.aio_bulk_replay, opts)
async def aio_cancel(self, run_id: str) -> None:
await self.aio_bulk_cancel(opts=BulkCancelReplayOpts(ids=[run_id]))
def cancel(self, run_id: str) -> None:
return self._run_async_from_sync(self.aio_cancel, run_id)
return run_async_from_sync(self.aio_cancel, run_id)
async def aio_bulk_cancel(self, opts: BulkCancelReplayOpts) -> None:
async with self.client() as client:
@@ -208,7 +209,7 @@ class RunsClient(BaseRestClient):
)
def bulk_cancel(self, opts: BulkCancelReplayOpts) -> None:
return self._run_async_from_sync(self.aio_bulk_cancel, opts)
return run_async_from_sync(self.aio_bulk_cancel, opts)
async def aio_get_result(self, run_id: str) -> JSONSerializableMapping:
details = await self.aio_get(run_id)
@@ -22,6 +22,7 @@ from hatchet_sdk.clients.v1.api_client import (
BaseRestClient,
maybe_additional_metadata_to_kv,
)
from hatchet_sdk.utils.aio import run_async_from_sync
from hatchet_sdk.utils.typing import JSONSerializableMapping
@@ -82,7 +83,7 @@ class ScheduledClient(BaseRestClient):
ScheduledWorkflows: The created scheduled workflow instance.
"""
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_create,
workflow_name,
trigger_at,
@@ -104,7 +105,7 @@ class ScheduledClient(BaseRestClient):
)
def delete(self, scheduled_id: str) -> None:
self._run_async_from_sync(self.aio_delete, scheduled_id)
run_async_from_sync(self.aio_delete, scheduled_id)
async def aio_list(
self,
@@ -175,7 +176,7 @@ class ScheduledClient(BaseRestClient):
Returns:
List[ScheduledWorkflows]: A list of scheduled workflows matching the criteria.
"""
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_list,
offset=offset,
limit=limit,
@@ -214,4 +215,4 @@ class ScheduledClient(BaseRestClient):
Returns:
ScheduledWorkflows: The requested scheduled workflow instance.
"""
return self._run_async_from_sync(self.aio_get, scheduled_id)
return run_async_from_sync(self.aio_get, scheduled_id)
+4 -3
View File
@@ -4,6 +4,7 @@ from hatchet_sdk.clients.rest.models.update_worker_request import UpdateWorkerRe
from hatchet_sdk.clients.rest.models.worker import Worker
from hatchet_sdk.clients.rest.models.worker_list import WorkerList
from hatchet_sdk.clients.v1.api_client import BaseRestClient
from hatchet_sdk.utils.aio import run_async_from_sync
class WorkersClient(BaseRestClient):
@@ -15,7 +16,7 @@ class WorkersClient(BaseRestClient):
return await self._wa(client).worker_get(worker_id)
def get(self, worker_id: str) -> Worker:
return self._run_async_from_sync(self.aio_get, worker_id)
return run_async_from_sync(self.aio_get, worker_id)
async def aio_list(
self,
@@ -28,7 +29,7 @@ class WorkersClient(BaseRestClient):
def list(
self,
) -> WorkerList:
return self._run_async_from_sync(self.aio_list)
return run_async_from_sync(self.aio_list)
async def aio_update(self, worker_id: str, opts: UpdateWorkerRequest) -> Worker:
async with self.client() as client:
@@ -38,4 +39,4 @@ class WorkersClient(BaseRestClient):
)
def update(self, worker_id: str, opts: UpdateWorkerRequest) -> Worker:
return self._run_async_from_sync(self.aio_update, worker_id, opts)
return run_async_from_sync(self.aio_update, worker_id, opts)
@@ -5,6 +5,7 @@ from hatchet_sdk.clients.rest.models.workflow import Workflow
from hatchet_sdk.clients.rest.models.workflow_list import WorkflowList
from hatchet_sdk.clients.rest.models.workflow_version import WorkflowVersion
from hatchet_sdk.clients.v1.api_client import BaseRestClient
from hatchet_sdk.utils.aio import run_async_from_sync
class WorkflowsClient(BaseRestClient):
@@ -19,7 +20,7 @@ class WorkflowsClient(BaseRestClient):
return await self._wa(client).workflow_get(workflow_id)
def get(self, workflow_id: str) -> Workflow:
return self._run_async_from_sync(self.aio_get, workflow_id)
return run_async_from_sync(self.aio_get, workflow_id)
async def aio_list(
self,
@@ -41,7 +42,7 @@ class WorkflowsClient(BaseRestClient):
limit: int | None = None,
offset: int | None = None,
) -> WorkflowList:
return self._run_async_from_sync(self.aio_list, workflow_name, limit, offset)
return run_async_from_sync(self.aio_list, workflow_name, limit, offset)
async def aio_get_version(
self, workflow_id: str, version: str | None = None
@@ -52,4 +53,4 @@ class WorkflowsClient(BaseRestClient):
def get_version(
self, workflow_id: str, version: str | None = None
) -> WorkflowVersion:
return self._run_async_from_sync(self.aio_get_version, workflow_id, version)
return run_async_from_sync(self.aio_get_version, workflow_id, version)
+2 -2
View File
@@ -1,2 +1,2 @@
def get_metadata(token: str) -> list[tuple[str, str]]:
return [("authorization", "bearer " + token)]
def get_metadata(token: str) -> tuple[tuple[str, str]]:
return (("authorization", "bearer " + token),)
@@ -1,4 +1,3 @@
import asyncio
from datetime import datetime
from typing import Any, Generic, cast, get_type_hints
@@ -12,7 +11,7 @@ from hatchet_sdk.contracts.workflows_pb2 import WorkflowVersion
from hatchet_sdk.runnables.task import Task
from hatchet_sdk.runnables.types import EmptyModel, R, TWorkflowInput
from hatchet_sdk.runnables.workflow import BaseWorkflow, Workflow
from hatchet_sdk.utils.aio_utils import get_active_event_loop
from hatchet_sdk.utils.aio import run_async_from_sync
from hatchet_sdk.utils.typing import JSONSerializableMapping, is_basemodel_subclass
from hatchet_sdk.workflow_run import WorkflowRunRef
@@ -27,25 +26,11 @@ class TaskRunRef(Generic[TWorkflowInput, R]):
self._wrr = workflow_run_ref
async def aio_result(self) -> R:
result = await self._wrr.workflow_listener.result(self._wrr.workflow_run_id)
result = await self._wrr.workflow_listener.aio_result(self._wrr.workflow_run_id)
return self._s._extract_result(result)
def result(self) -> R:
coro = self._wrr.workflow_listener.result(self._wrr.workflow_run_id)
loop = get_active_event_loop()
if loop is None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(coro)
finally:
asyncio.set_event_loop(None)
else:
result = loop.run_until_complete(coro)
return self._s._extract_result(result)
return run_async_from_sync(self.aio_result)
class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
+43
View File
@@ -0,0 +1,43 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Coroutine, ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
Y = TypeVar("Y")
S = TypeVar("S")
def _run_async_function_do_not_use_directly(
async_func: Callable[P, Coroutine[Y, S, R]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(async_func(*args, **kwargs))
finally:
loop.close()
def run_async_from_sync(
async_func: Callable[P, Coroutine[Y, S, R]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
return loop.run_until_complete(async_func(*args, **kwargs))
else:
with ThreadPoolExecutor() as executor:
future = executor.submit(
lambda: _run_async_function_do_not_use_directly(
async_func, *args, **kwargs
)
)
return future.result()
@@ -1,18 +0,0 @@
import asyncio
def get_active_event_loop() -> asyncio.AbstractEventLoop | None:
"""
Get the active event loop.
Returns:
asyncio.AbstractEventLoop: The active event loop, or None if there is no active
event loop in the current thread.
"""
try:
return asyncio.get_event_loop()
except RuntimeError as e:
if str(e).startswith("There is no current event loop in thread"):
return None
else:
raise e
@@ -60,7 +60,7 @@ class WorkerActionRunLoopManager:
async def aio_start(self, retry_count: int = 1) -> None:
await capture_logs(
self.client.logInterceptor,
self.client.log_interceptor,
self.client.event,
self._async_start,
)(retry_count=retry_count)
+7 -20
View File
@@ -1,4 +1,3 @@
import asyncio
from typing import Any
from hatchet_sdk.clients.run_event_listener import (
@@ -6,19 +5,19 @@ from hatchet_sdk.clients.run_event_listener import (
RunEventListenerClient,
)
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
from hatchet_sdk.utils.aio_utils import get_active_event_loop
from hatchet_sdk.config import ClientConfig
from hatchet_sdk.utils.aio import run_async_from_sync
class WorkflowRunRef:
def __init__(
self,
workflow_run_id: str,
workflow_listener: PooledWorkflowRunListener,
workflow_run_event_listener: RunEventListenerClient,
config: ClientConfig,
):
self.workflow_run_id = workflow_run_id
self.workflow_listener = workflow_listener
self.workflow_run_event_listener = workflow_run_event_listener
self.workflow_listener = PooledWorkflowRunListener(config)
self.workflow_run_event_listener = RunEventListenerClient(config=config)
def __str__(self) -> str:
return self.workflow_run_id
@@ -27,19 +26,7 @@ class WorkflowRunRef:
return self.workflow_run_event_listener.stream(self.workflow_run_id)
async def aio_result(self) -> dict[str, Any]:
return await self.workflow_listener.result(self.workflow_run_id)
return await self.workflow_listener.aio_result(self.workflow_run_id)
def result(self) -> dict[str, Any]:
coro = self.workflow_listener.result(self.workflow_run_id)
loop = get_active_event_loop()
if loop is None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coro)
finally:
asyncio.set_event_loop(None)
else:
return loop.run_until_complete(coro)
return run_async_from_sync(self.aio_result)
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "hatchet-sdk"
version = "1.0.3"
version = "1.1.0"
description = ""
authors = ["Alexander Belanger <alexander@hatchet.run>"]
readme = "README.md"
+3 -3
View File
@@ -6,7 +6,7 @@ from examples.dag.worker import dag_workflow
from hatchet_sdk import Hatchet
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_list_runs(hatchet: Hatchet) -> None:
dag_result = await dag_workflow.aio_run()
@@ -19,7 +19,7 @@ async def test_list_runs(hatchet: Hatchet) -> None:
assert v in [r.output for r in runs.rows]
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_get_run(hatchet: Hatchet) -> None:
dag_ref = await dag_workflow.aio_run_no_wait()
@@ -33,7 +33,7 @@ async def test_get_run(hatchet: Hatchet) -> None:
assert {t.name for t in dag_workflow.tasks} == {t.task_name for t in run.shape}
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio()
async def test_list_workflows(hatchet: Hatchet) -> None:
workflows = await hatchet.workflows.aio_list(
workflow_name=dag_workflow.config.name, limit=1, offset=0