From fb37395913801b75892f6d09f300245352dbe8d3 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Sun, 30 Mar 2025 12:35:46 -0400 Subject: [PATCH] feat: factor out run_async_from_sync --- sdks/python/hatchet_sdk/client.py | 12 +----- .../hatchet_sdk/clients/v1/api_client.py | 39 +---------------- sdks/python/hatchet_sdk/features/cron.py | 9 ++-- sdks/python/hatchet_sdk/features/logs.py | 3 +- sdks/python/hatchet_sdk/features/metrics.py | 7 ++-- sdks/python/hatchet_sdk/features/runs.py | 15 +++---- sdks/python/hatchet_sdk/features/scheduled.py | 9 ++-- sdks/python/hatchet_sdk/features/workers.py | 7 ++-- sdks/python/hatchet_sdk/features/workflows.py | 7 ++-- sdks/python/hatchet_sdk/utils/aio_utils.py | 42 +++++++++++++++++++ 10 files changed, 76 insertions(+), 74 deletions(-) diff --git a/sdks/python/hatchet_sdk/client.py b/sdks/python/hatchet_sdk/client.py index ca1f29845..398c125b6 100644 --- a/sdks/python/hatchet_sdk/client.py +++ b/sdks/python/hatchet_sdk/client.py @@ -1,7 +1,3 @@ -import asyncio - -import grpc - from hatchet_sdk.clients.admin import AdminClient from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient from hatchet_sdk.clients.events import EventClient, new_event @@ -29,13 +25,7 @@ class Client: workflow_listener: PooledWorkflowRunListener | None | None = None, debug: bool = False, ): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - conn: grpc.Channel = new_conn(config, False) + conn = new_conn(config, False) self.config = config self.admin = admin_client or AdminClient(config) diff --git a/sdks/python/hatchet_sdk/clients/v1/api_client.py b/sdks/python/hatchet_sdk/clients/v1/api_client.py index 4487b6311..ac962a539 100644 --- a/sdks/python/hatchet_sdk/clients/v1/api_client.py +++ b/sdks/python/hatchet_sdk/clients/v1/api_client.py @@ -1,6 +1,4 @@ -import asyncio -from concurrent.futures import ThreadPoolExecutor -from typing import AsyncContextManager, Callable, Coroutine, ParamSpec, TypeVar +from typing import AsyncContextManager, ParamSpec, TypeVar from hatchet_sdk.clients.rest.api_client import ApiClient from hatchet_sdk.clients.rest.configuration import Configuration @@ -44,38 +42,3 @@ class BaseRestClient: def client(self) -> AsyncContextManager[ApiClient]: return ApiClient(self.api_config) - - def _run_async_function_do_not_use_directly( - self, - async_func: Callable[P, Coroutine[Y, S, R]], - *args: P.args, - **kwargs: P.kwargs, - ) -> R: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete(async_func(*args, **kwargs)) - finally: - loop.close() - - def _run_async_from_sync( - self, - async_func: Callable[P, Coroutine[Y, S, R]], - *args: P.args, - **kwargs: P.kwargs, - ) -> R: - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = None - - if loop and loop.is_running(): - return loop.run_until_complete(async_func(*args, **kwargs)) - else: - with ThreadPoolExecutor() as executor: - future = executor.submit( - lambda: self._run_async_function_do_not_use_directly( - async_func, *args, **kwargs - ) - ) - return future.result() diff --git a/sdks/python/hatchet_sdk/features/cron.py b/sdks/python/hatchet_sdk/features/cron.py index 8183f4595..05b6516d0 100644 --- a/sdks/python/hatchet_sdk/features/cron.py +++ b/sdks/python/hatchet_sdk/features/cron.py @@ -18,6 +18,7 @@ from hatchet_sdk.clients.v1.api_client import ( BaseRestClient, maybe_additional_metadata_to_kv, ) +from hatchet_sdk.utils.aio_utils import run_async_from_sync from hatchet_sdk.utils.typing import JSONSerializableMapping @@ -121,7 +122,7 @@ class CronClient(BaseRestClient): input: JSONSerializableMapping, additional_metadata: JSONSerializableMapping, ) -> CronWorkflows: - return self._run_async_from_sync( + return run_async_from_sync( self.aio_create, workflow_name, cron_name, @@ -143,7 +144,7 @@ class CronClient(BaseRestClient): ) def delete(self, cron_id: str) -> None: - return self._run_async_from_sync(self.aio_delete, cron_id) + return run_async_from_sync(self.aio_delete, cron_id) async def aio_list( self, @@ -204,7 +205,7 @@ class CronClient(BaseRestClient): Returns: CronWorkflowsList: A list of cron workflows. """ - return self._run_async_from_sync( + return run_async_from_sync( self.aio_list, offset=offset, limit=limit, @@ -239,4 +240,4 @@ class CronClient(BaseRestClient): Returns: CronWorkflows: The requested cron workflow instance. """ - return self._run_async_from_sync(self.aio_get, cron_id) + return run_async_from_sync(self.aio_get, cron_id) diff --git a/sdks/python/hatchet_sdk/features/logs.py b/sdks/python/hatchet_sdk/features/logs.py index 873458058..62821cfdf 100644 --- a/sdks/python/hatchet_sdk/features/logs.py +++ b/sdks/python/hatchet_sdk/features/logs.py @@ -2,6 +2,7 @@ from hatchet_sdk.clients.rest.api.log_api import LogApi from hatchet_sdk.clients.rest.api_client import ApiClient from hatchet_sdk.clients.rest.models.v1_log_line_list import V1LogLineList from hatchet_sdk.clients.v1.api_client import BaseRestClient +from hatchet_sdk.utils.aio_utils import run_async_from_sync class LogsClient(BaseRestClient): @@ -13,4 +14,4 @@ class LogsClient(BaseRestClient): return await self._la(client).v1_log_line_list(task=task_run_id) def list(self, task_run_id: str) -> V1LogLineList: - return self._run_async_from_sync(self.aio_list, task_run_id) + return run_async_from_sync(self.aio_list, task_run_id) diff --git a/sdks/python/hatchet_sdk/features/metrics.py b/sdks/python/hatchet_sdk/features/metrics.py index d6a03ab3f..0189b3f47 100644 --- a/sdks/python/hatchet_sdk/features/metrics.py +++ b/sdks/python/hatchet_sdk/features/metrics.py @@ -11,6 +11,7 @@ from hatchet_sdk.clients.v1.api_client import ( BaseRestClient, maybe_additional_metadata_to_kv, ) +from hatchet_sdk.utils.aio_utils import run_async_from_sync from hatchet_sdk.utils.typing import JSONSerializableMapping @@ -38,7 +39,7 @@ class MetricsClient(BaseRestClient): status: WorkflowRunStatus | None = None, group_key: str | None = None, ) -> WorkflowMetrics: - return self._run_async_from_sync( + return run_async_from_sync( self.aio_get_workflow_metrics, workflow_id, status, group_key ) @@ -61,7 +62,7 @@ class MetricsClient(BaseRestClient): workflow_ids: list[str] | None = None, additional_metadata: JSONSerializableMapping | None = None, ) -> TenantQueueMetrics: - return self._run_async_from_sync( + return run_async_from_sync( self.aio_get_queue_metrics, workflow_ids, additional_metadata ) @@ -72,4 +73,4 @@ class MetricsClient(BaseRestClient): ) def get_task_metrics(self) -> TenantStepRunQueueMetrics: - return self._run_async_from_sync(self.aio_get_task_metrics) + return run_async_from_sync(self.aio_get_task_metrics) diff --git a/sdks/python/hatchet_sdk/features/runs.py b/sdks/python/hatchet_sdk/features/runs.py index d316f8884..240c94255 100644 --- a/sdks/python/hatchet_sdk/features/runs.py +++ b/sdks/python/hatchet_sdk/features/runs.py @@ -19,6 +19,7 @@ from hatchet_sdk.clients.v1.api_client import ( BaseRestClient, maybe_additional_metadata_to_kv, ) +from hatchet_sdk.utils.aio_utils import run_async_from_sync from hatchet_sdk.utils.typing import JSONSerializableMapping @@ -93,7 +94,7 @@ class RunsClient(BaseRestClient): return await self._wra(client).v1_workflow_run_get(str(workflow_run_id)) def get(self, workflow_run_id: str) -> V1WorkflowRunDetails: - return self._run_async_from_sync(self.aio_get, workflow_run_id) + return run_async_from_sync(self.aio_get, workflow_run_id) async def aio_list( self, @@ -138,7 +139,7 @@ class RunsClient(BaseRestClient): worker_id: str | None = None, parent_task_external_id: str | None = None, ) -> V1TaskSummaryList: - return self._run_async_from_sync( + return run_async_from_sync( self.aio_list, since=since, only_tasks=only_tasks, @@ -174,7 +175,7 @@ class RunsClient(BaseRestClient): input: JSONSerializableMapping, additional_metadata: JSONSerializableMapping = {}, ) -> V1WorkflowRunDetails: - return self._run_async_from_sync( + return run_async_from_sync( self.aio_create, workflow_name, input, additional_metadata ) @@ -182,7 +183,7 @@ class RunsClient(BaseRestClient): await self.aio_bulk_replay(opts=BulkCancelReplayOpts(ids=[run_id])) def replay(self, run_id: str) -> None: - return self._run_async_from_sync(self.aio_replay, run_id) + return run_async_from_sync(self.aio_replay, run_id) async def aio_bulk_replay(self, opts: BulkCancelReplayOpts) -> None: async with self.client() as client: @@ -192,13 +193,13 @@ class RunsClient(BaseRestClient): ) def bulk_replay(self, opts: BulkCancelReplayOpts) -> None: - return self._run_async_from_sync(self.aio_bulk_replay, opts) + return run_async_from_sync(self.aio_bulk_replay, opts) async def aio_cancel(self, run_id: str) -> None: await self.aio_bulk_cancel(opts=BulkCancelReplayOpts(ids=[run_id])) def cancel(self, run_id: str) -> None: - return self._run_async_from_sync(self.aio_cancel, run_id) + return run_async_from_sync(self.aio_cancel, run_id) async def aio_bulk_cancel(self, opts: BulkCancelReplayOpts) -> None: async with self.client() as client: @@ -208,7 +209,7 @@ class RunsClient(BaseRestClient): ) def bulk_cancel(self, opts: BulkCancelReplayOpts) -> None: - return self._run_async_from_sync(self.aio_bulk_cancel, opts) + return run_async_from_sync(self.aio_bulk_cancel, opts) async def aio_get_result(self, run_id: str) -> JSONSerializableMapping: details = await self.aio_get(run_id) diff --git a/sdks/python/hatchet_sdk/features/scheduled.py b/sdks/python/hatchet_sdk/features/scheduled.py index c51ca04b7..ee6b2dba8 100644 --- a/sdks/python/hatchet_sdk/features/scheduled.py +++ b/sdks/python/hatchet_sdk/features/scheduled.py @@ -22,6 +22,7 @@ from hatchet_sdk.clients.v1.api_client import ( BaseRestClient, maybe_additional_metadata_to_kv, ) +from hatchet_sdk.utils.aio_utils import run_async_from_sync from hatchet_sdk.utils.typing import JSONSerializableMapping @@ -82,7 +83,7 @@ class ScheduledClient(BaseRestClient): ScheduledWorkflows: The created scheduled workflow instance. """ - return self._run_async_from_sync( + return run_async_from_sync( self.aio_create, workflow_name, trigger_at, @@ -104,7 +105,7 @@ class ScheduledClient(BaseRestClient): ) def delete(self, scheduled_id: str) -> None: - self._run_async_from_sync(self.aio_delete, scheduled_id) + run_async_from_sync(self.aio_delete, scheduled_id) async def aio_list( self, @@ -175,7 +176,7 @@ class ScheduledClient(BaseRestClient): Returns: List[ScheduledWorkflows]: A list of scheduled workflows matching the criteria. """ - return self._run_async_from_sync( + return run_async_from_sync( self.aio_list, offset=offset, limit=limit, @@ -214,4 +215,4 @@ class ScheduledClient(BaseRestClient): Returns: ScheduledWorkflows: The requested scheduled workflow instance. """ - return self._run_async_from_sync(self.aio_get, scheduled_id) + return run_async_from_sync(self.aio_get, scheduled_id) diff --git a/sdks/python/hatchet_sdk/features/workers.py b/sdks/python/hatchet_sdk/features/workers.py index 4738272be..ac5ad8cae 100644 --- a/sdks/python/hatchet_sdk/features/workers.py +++ b/sdks/python/hatchet_sdk/features/workers.py @@ -4,6 +4,7 @@ from hatchet_sdk.clients.rest.models.update_worker_request import UpdateWorkerRe from hatchet_sdk.clients.rest.models.worker import Worker from hatchet_sdk.clients.rest.models.worker_list import WorkerList from hatchet_sdk.clients.v1.api_client import BaseRestClient +from hatchet_sdk.utils.aio_utils import run_async_from_sync class WorkersClient(BaseRestClient): @@ -15,7 +16,7 @@ class WorkersClient(BaseRestClient): return await self._wa(client).worker_get(worker_id) def get(self, worker_id: str) -> Worker: - return self._run_async_from_sync(self.aio_get, worker_id) + return run_async_from_sync(self.aio_get, worker_id) async def aio_list( self, @@ -28,7 +29,7 @@ class WorkersClient(BaseRestClient): def list( self, ) -> WorkerList: - return self._run_async_from_sync(self.aio_list) + return run_async_from_sync(self.aio_list) async def aio_update(self, worker_id: str, opts: UpdateWorkerRequest) -> Worker: async with self.client() as client: @@ -38,4 +39,4 @@ class WorkersClient(BaseRestClient): ) def update(self, worker_id: str, opts: UpdateWorkerRequest) -> Worker: - return self._run_async_from_sync(self.aio_update, worker_id, opts) + return run_async_from_sync(self.aio_update, worker_id, opts) diff --git a/sdks/python/hatchet_sdk/features/workflows.py b/sdks/python/hatchet_sdk/features/workflows.py index 87cb87696..234f32b30 100644 --- a/sdks/python/hatchet_sdk/features/workflows.py +++ b/sdks/python/hatchet_sdk/features/workflows.py @@ -5,6 +5,7 @@ from hatchet_sdk.clients.rest.models.workflow import Workflow from hatchet_sdk.clients.rest.models.workflow_list import WorkflowList from hatchet_sdk.clients.rest.models.workflow_version import WorkflowVersion from hatchet_sdk.clients.v1.api_client import BaseRestClient +from hatchet_sdk.utils.aio_utils import run_async_from_sync class WorkflowsClient(BaseRestClient): @@ -19,7 +20,7 @@ class WorkflowsClient(BaseRestClient): return await self._wa(client).workflow_get(workflow_id) def get(self, workflow_id: str) -> Workflow: - return self._run_async_from_sync(self.aio_get, workflow_id) + return run_async_from_sync(self.aio_get, workflow_id) async def aio_list( self, @@ -41,7 +42,7 @@ class WorkflowsClient(BaseRestClient): limit: int | None = None, offset: int | None = None, ) -> WorkflowList: - return self._run_async_from_sync(self.aio_list, workflow_name, limit, offset) + return run_async_from_sync(self.aio_list, workflow_name, limit, offset) async def aio_get_version( self, workflow_id: str, version: str | None = None @@ -52,4 +53,4 @@ class WorkflowsClient(BaseRestClient): def get_version( self, workflow_id: str, version: str | None = None ) -> WorkflowVersion: - return self._run_async_from_sync(self.aio_get_version, workflow_id, version) + return run_async_from_sync(self.aio_get_version, workflow_id, version) diff --git a/sdks/python/hatchet_sdk/utils/aio_utils.py b/sdks/python/hatchet_sdk/utils/aio_utils.py index a7d346b9b..c0d279298 100644 --- a/sdks/python/hatchet_sdk/utils/aio_utils.py +++ b/sdks/python/hatchet_sdk/utils/aio_utils.py @@ -1,4 +1,11 @@ import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Coroutine, ParamSpec, TypeVar + +P = ParamSpec("P") +R = TypeVar("R") +Y = TypeVar("Y") +S = TypeVar("S") def get_active_event_loop() -> asyncio.AbstractEventLoop | None: @@ -16,3 +23,38 @@ def get_active_event_loop() -> asyncio.AbstractEventLoop | None: return None else: raise e + + +def _run_async_function_do_not_use_directly( + async_func: Callable[P, Coroutine[Y, S, R]], + *args: P.args, + **kwargs: P.kwargs, +) -> R: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(async_func(*args, **kwargs)) + finally: + loop.close() + + +def run_async_from_sync( + async_func: Callable[P, Coroutine[Y, S, R]], + *args: P.args, + **kwargs: P.kwargs, +) -> R: + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + return loop.run_until_complete(async_func(*args, **kwargs)) + else: + with ThreadPoolExecutor() as executor: + future = executor.submit( + lambda: _run_async_function_do_not_use_directly( + async_func, *args, **kwargs + ) + ) + return future.result()