feat: factor out run_async_from_sync

This commit is contained in:
mrkaye97
2025-03-30 12:35:46 -04:00
parent 46fa6b13a0
commit fb37395913
10 changed files with 76 additions and 74 deletions

View File

@@ -1,7 +1,3 @@
import asyncio
import grpc
from hatchet_sdk.clients.admin import AdminClient
from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
from hatchet_sdk.clients.events import EventClient, new_event
@@ -29,13 +25,7 @@ class Client:
workflow_listener: PooledWorkflowRunListener | None | None = None,
debug: bool = False,
):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
conn: grpc.Channel = new_conn(config, False)
conn = new_conn(config, False)
self.config = config
self.admin = admin_client or AdminClient(config)

View File

@@ -1,6 +1,4 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncContextManager, Callable, Coroutine, ParamSpec, TypeVar
from typing import AsyncContextManager, ParamSpec, TypeVar
from hatchet_sdk.clients.rest.api_client import ApiClient
from hatchet_sdk.clients.rest.configuration import Configuration
@@ -44,38 +42,3 @@ class BaseRestClient:
def client(self) -> AsyncContextManager[ApiClient]:
return ApiClient(self.api_config)
def _run_async_function_do_not_use_directly(
self,
async_func: Callable[P, Coroutine[Y, S, R]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(async_func(*args, **kwargs))
finally:
loop.close()
def _run_async_from_sync(
self,
async_func: Callable[P, Coroutine[Y, S, R]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
return loop.run_until_complete(async_func(*args, **kwargs))
else:
with ThreadPoolExecutor() as executor:
future = executor.submit(
lambda: self._run_async_function_do_not_use_directly(
async_func, *args, **kwargs
)
)
return future.result()

View File

@@ -18,6 +18,7 @@ from hatchet_sdk.clients.v1.api_client import (
BaseRestClient,
maybe_additional_metadata_to_kv,
)
from hatchet_sdk.utils.aio_utils import run_async_from_sync
from hatchet_sdk.utils.typing import JSONSerializableMapping
@@ -121,7 +122,7 @@ class CronClient(BaseRestClient):
input: JSONSerializableMapping,
additional_metadata: JSONSerializableMapping,
) -> CronWorkflows:
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_create,
workflow_name,
cron_name,
@@ -143,7 +144,7 @@ class CronClient(BaseRestClient):
)
def delete(self, cron_id: str) -> None:
return self._run_async_from_sync(self.aio_delete, cron_id)
return run_async_from_sync(self.aio_delete, cron_id)
async def aio_list(
self,
@@ -204,7 +205,7 @@ class CronClient(BaseRestClient):
Returns:
CronWorkflowsList: A list of cron workflows.
"""
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_list,
offset=offset,
limit=limit,
@@ -239,4 +240,4 @@ class CronClient(BaseRestClient):
Returns:
CronWorkflows: The requested cron workflow instance.
"""
return self._run_async_from_sync(self.aio_get, cron_id)
return run_async_from_sync(self.aio_get, cron_id)

View File

@@ -2,6 +2,7 @@ from hatchet_sdk.clients.rest.api.log_api import LogApi
from hatchet_sdk.clients.rest.api_client import ApiClient
from hatchet_sdk.clients.rest.models.v1_log_line_list import V1LogLineList
from hatchet_sdk.clients.v1.api_client import BaseRestClient
from hatchet_sdk.utils.aio_utils import run_async_from_sync
class LogsClient(BaseRestClient):
@@ -13,4 +14,4 @@ class LogsClient(BaseRestClient):
return await self._la(client).v1_log_line_list(task=task_run_id)
def list(self, task_run_id: str) -> V1LogLineList:
return self._run_async_from_sync(self.aio_list, task_run_id)
return run_async_from_sync(self.aio_list, task_run_id)

View File

@@ -11,6 +11,7 @@ from hatchet_sdk.clients.v1.api_client import (
BaseRestClient,
maybe_additional_metadata_to_kv,
)
from hatchet_sdk.utils.aio_utils import run_async_from_sync
from hatchet_sdk.utils.typing import JSONSerializableMapping
@@ -38,7 +39,7 @@ class MetricsClient(BaseRestClient):
status: WorkflowRunStatus | None = None,
group_key: str | None = None,
) -> WorkflowMetrics:
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_get_workflow_metrics, workflow_id, status, group_key
)
@@ -61,7 +62,7 @@ class MetricsClient(BaseRestClient):
workflow_ids: list[str] | None = None,
additional_metadata: JSONSerializableMapping | None = None,
) -> TenantQueueMetrics:
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_get_queue_metrics, workflow_ids, additional_metadata
)
@@ -72,4 +73,4 @@ class MetricsClient(BaseRestClient):
)
def get_task_metrics(self) -> TenantStepRunQueueMetrics:
return self._run_async_from_sync(self.aio_get_task_metrics)
return run_async_from_sync(self.aio_get_task_metrics)

View File

@@ -19,6 +19,7 @@ from hatchet_sdk.clients.v1.api_client import (
BaseRestClient,
maybe_additional_metadata_to_kv,
)
from hatchet_sdk.utils.aio_utils import run_async_from_sync
from hatchet_sdk.utils.typing import JSONSerializableMapping
@@ -93,7 +94,7 @@ class RunsClient(BaseRestClient):
return await self._wra(client).v1_workflow_run_get(str(workflow_run_id))
def get(self, workflow_run_id: str) -> V1WorkflowRunDetails:
return self._run_async_from_sync(self.aio_get, workflow_run_id)
return run_async_from_sync(self.aio_get, workflow_run_id)
async def aio_list(
self,
@@ -138,7 +139,7 @@ class RunsClient(BaseRestClient):
worker_id: str | None = None,
parent_task_external_id: str | None = None,
) -> V1TaskSummaryList:
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_list,
since=since,
only_tasks=only_tasks,
@@ -174,7 +175,7 @@ class RunsClient(BaseRestClient):
input: JSONSerializableMapping,
additional_metadata: JSONSerializableMapping = {},
) -> V1WorkflowRunDetails:
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_create, workflow_name, input, additional_metadata
)
@@ -182,7 +183,7 @@ class RunsClient(BaseRestClient):
await self.aio_bulk_replay(opts=BulkCancelReplayOpts(ids=[run_id]))
def replay(self, run_id: str) -> None:
return self._run_async_from_sync(self.aio_replay, run_id)
return run_async_from_sync(self.aio_replay, run_id)
async def aio_bulk_replay(self, opts: BulkCancelReplayOpts) -> None:
async with self.client() as client:
@@ -192,13 +193,13 @@ class RunsClient(BaseRestClient):
)
def bulk_replay(self, opts: BulkCancelReplayOpts) -> None:
return self._run_async_from_sync(self.aio_bulk_replay, opts)
return run_async_from_sync(self.aio_bulk_replay, opts)
async def aio_cancel(self, run_id: str) -> None:
await self.aio_bulk_cancel(opts=BulkCancelReplayOpts(ids=[run_id]))
def cancel(self, run_id: str) -> None:
return self._run_async_from_sync(self.aio_cancel, run_id)
return run_async_from_sync(self.aio_cancel, run_id)
async def aio_bulk_cancel(self, opts: BulkCancelReplayOpts) -> None:
async with self.client() as client:
@@ -208,7 +209,7 @@ class RunsClient(BaseRestClient):
)
def bulk_cancel(self, opts: BulkCancelReplayOpts) -> None:
return self._run_async_from_sync(self.aio_bulk_cancel, opts)
return run_async_from_sync(self.aio_bulk_cancel, opts)
async def aio_get_result(self, run_id: str) -> JSONSerializableMapping:
details = await self.aio_get(run_id)

View File

@@ -22,6 +22,7 @@ from hatchet_sdk.clients.v1.api_client import (
BaseRestClient,
maybe_additional_metadata_to_kv,
)
from hatchet_sdk.utils.aio_utils import run_async_from_sync
from hatchet_sdk.utils.typing import JSONSerializableMapping
@@ -82,7 +83,7 @@ class ScheduledClient(BaseRestClient):
ScheduledWorkflows: The created scheduled workflow instance.
"""
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_create,
workflow_name,
trigger_at,
@@ -104,7 +105,7 @@ class ScheduledClient(BaseRestClient):
)
def delete(self, scheduled_id: str) -> None:
self._run_async_from_sync(self.aio_delete, scheduled_id)
run_async_from_sync(self.aio_delete, scheduled_id)
async def aio_list(
self,
@@ -175,7 +176,7 @@ class ScheduledClient(BaseRestClient):
Returns:
List[ScheduledWorkflows]: A list of scheduled workflows matching the criteria.
"""
return self._run_async_from_sync(
return run_async_from_sync(
self.aio_list,
offset=offset,
limit=limit,
@@ -214,4 +215,4 @@ class ScheduledClient(BaseRestClient):
Returns:
ScheduledWorkflows: The requested scheduled workflow instance.
"""
return self._run_async_from_sync(self.aio_get, scheduled_id)
return run_async_from_sync(self.aio_get, scheduled_id)

View File

@@ -4,6 +4,7 @@ from hatchet_sdk.clients.rest.models.update_worker_request import UpdateWorkerRe
from hatchet_sdk.clients.rest.models.worker import Worker
from hatchet_sdk.clients.rest.models.worker_list import WorkerList
from hatchet_sdk.clients.v1.api_client import BaseRestClient
from hatchet_sdk.utils.aio_utils import run_async_from_sync
class WorkersClient(BaseRestClient):
@@ -15,7 +16,7 @@ class WorkersClient(BaseRestClient):
return await self._wa(client).worker_get(worker_id)
def get(self, worker_id: str) -> Worker:
return self._run_async_from_sync(self.aio_get, worker_id)
return run_async_from_sync(self.aio_get, worker_id)
async def aio_list(
self,
@@ -28,7 +29,7 @@ class WorkersClient(BaseRestClient):
def list(
self,
) -> WorkerList:
return self._run_async_from_sync(self.aio_list)
return run_async_from_sync(self.aio_list)
async def aio_update(self, worker_id: str, opts: UpdateWorkerRequest) -> Worker:
async with self.client() as client:
@@ -38,4 +39,4 @@ class WorkersClient(BaseRestClient):
)
def update(self, worker_id: str, opts: UpdateWorkerRequest) -> Worker:
return self._run_async_from_sync(self.aio_update, worker_id, opts)
return run_async_from_sync(self.aio_update, worker_id, opts)

View File

@@ -5,6 +5,7 @@ from hatchet_sdk.clients.rest.models.workflow import Workflow
from hatchet_sdk.clients.rest.models.workflow_list import WorkflowList
from hatchet_sdk.clients.rest.models.workflow_version import WorkflowVersion
from hatchet_sdk.clients.v1.api_client import BaseRestClient
from hatchet_sdk.utils.aio_utils import run_async_from_sync
class WorkflowsClient(BaseRestClient):
@@ -19,7 +20,7 @@ class WorkflowsClient(BaseRestClient):
return await self._wa(client).workflow_get(workflow_id)
def get(self, workflow_id: str) -> Workflow:
return self._run_async_from_sync(self.aio_get, workflow_id)
return run_async_from_sync(self.aio_get, workflow_id)
async def aio_list(
self,
@@ -41,7 +42,7 @@ class WorkflowsClient(BaseRestClient):
limit: int | None = None,
offset: int | None = None,
) -> WorkflowList:
return self._run_async_from_sync(self.aio_list, workflow_name, limit, offset)
return run_async_from_sync(self.aio_list, workflow_name, limit, offset)
async def aio_get_version(
self, workflow_id: str, version: str | None = None
@@ -52,4 +53,4 @@ class WorkflowsClient(BaseRestClient):
def get_version(
self, workflow_id: str, version: str | None = None
) -> WorkflowVersion:
return self._run_async_from_sync(self.aio_get_version, workflow_id, version)
return run_async_from_sync(self.aio_get_version, workflow_id, version)

View File

@@ -1,4 +1,11 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Coroutine, ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
Y = TypeVar("Y")
S = TypeVar("S")
def get_active_event_loop() -> asyncio.AbstractEventLoop | None:
@@ -16,3 +23,38 @@ def get_active_event_loop() -> asyncio.AbstractEventLoop | None:
return None
else:
raise e
def _run_async_function_do_not_use_directly(
async_func: Callable[P, Coroutine[Y, S, R]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(async_func(*args, **kwargs))
finally:
loop.close()
def run_async_from_sync(
async_func: Callable[P, Coroutine[Y, S, R]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
return loop.run_until_complete(async_func(*args, **kwargs))
else:
with ThreadPoolExecutor() as executor:
future = executor.submit(
lambda: _run_async_function_do_not_use_directly(
async_func, *args, **kwargs
)
)
return future.result()