mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-01-06 08:49:53 -06:00
feat: factor out run_async_from_sync
This commit is contained in:
@@ -1,7 +1,3 @@
|
||||
import asyncio
|
||||
|
||||
import grpc
|
||||
|
||||
from hatchet_sdk.clients.admin import AdminClient
|
||||
from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
|
||||
from hatchet_sdk.clients.events import EventClient, new_event
|
||||
@@ -29,13 +25,7 @@ class Client:
|
||||
workflow_listener: PooledWorkflowRunListener | None | None = None,
|
||||
debug: bool = False,
|
||||
):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
conn: grpc.Channel = new_conn(config, False)
|
||||
conn = new_conn(config, False)
|
||||
|
||||
self.config = config
|
||||
self.admin = admin_client or AdminClient(config)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import AsyncContextManager, Callable, Coroutine, ParamSpec, TypeVar
|
||||
from typing import AsyncContextManager, ParamSpec, TypeVar
|
||||
|
||||
from hatchet_sdk.clients.rest.api_client import ApiClient
|
||||
from hatchet_sdk.clients.rest.configuration import Configuration
|
||||
@@ -44,38 +42,3 @@ class BaseRestClient:
|
||||
|
||||
def client(self) -> AsyncContextManager[ApiClient]:
|
||||
return ApiClient(self.api_config)
|
||||
|
||||
def _run_async_function_do_not_use_directly(
|
||||
self,
|
||||
async_func: Callable[P, Coroutine[Y, S, R]],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def _run_async_from_sync(
|
||||
self,
|
||||
async_func: Callable[P, Coroutine[Y, S, R]],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
else:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
lambda: self._run_async_function_do_not_use_directly(
|
||||
async_func, *args, **kwargs
|
||||
)
|
||||
)
|
||||
return future.result()
|
||||
|
||||
@@ -18,6 +18,7 @@ from hatchet_sdk.clients.v1.api_client import (
|
||||
BaseRestClient,
|
||||
maybe_additional_metadata_to_kv,
|
||||
)
|
||||
from hatchet_sdk.utils.aio_utils import run_async_from_sync
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
|
||||
|
||||
@@ -121,7 +122,7 @@ class CronClient(BaseRestClient):
|
||||
input: JSONSerializableMapping,
|
||||
additional_metadata: JSONSerializableMapping,
|
||||
) -> CronWorkflows:
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_create,
|
||||
workflow_name,
|
||||
cron_name,
|
||||
@@ -143,7 +144,7 @@ class CronClient(BaseRestClient):
|
||||
)
|
||||
|
||||
def delete(self, cron_id: str) -> None:
|
||||
return self._run_async_from_sync(self.aio_delete, cron_id)
|
||||
return run_async_from_sync(self.aio_delete, cron_id)
|
||||
|
||||
async def aio_list(
|
||||
self,
|
||||
@@ -204,7 +205,7 @@ class CronClient(BaseRestClient):
|
||||
Returns:
|
||||
CronWorkflowsList: A list of cron workflows.
|
||||
"""
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_list,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
@@ -239,4 +240,4 @@ class CronClient(BaseRestClient):
|
||||
Returns:
|
||||
CronWorkflows: The requested cron workflow instance.
|
||||
"""
|
||||
return self._run_async_from_sync(self.aio_get, cron_id)
|
||||
return run_async_from_sync(self.aio_get, cron_id)
|
||||
|
||||
@@ -2,6 +2,7 @@ from hatchet_sdk.clients.rest.api.log_api import LogApi
|
||||
from hatchet_sdk.clients.rest.api_client import ApiClient
|
||||
from hatchet_sdk.clients.rest.models.v1_log_line_list import V1LogLineList
|
||||
from hatchet_sdk.clients.v1.api_client import BaseRestClient
|
||||
from hatchet_sdk.utils.aio_utils import run_async_from_sync
|
||||
|
||||
|
||||
class LogsClient(BaseRestClient):
|
||||
@@ -13,4 +14,4 @@ class LogsClient(BaseRestClient):
|
||||
return await self._la(client).v1_log_line_list(task=task_run_id)
|
||||
|
||||
def list(self, task_run_id: str) -> V1LogLineList:
|
||||
return self._run_async_from_sync(self.aio_list, task_run_id)
|
||||
return run_async_from_sync(self.aio_list, task_run_id)
|
||||
|
||||
@@ -11,6 +11,7 @@ from hatchet_sdk.clients.v1.api_client import (
|
||||
BaseRestClient,
|
||||
maybe_additional_metadata_to_kv,
|
||||
)
|
||||
from hatchet_sdk.utils.aio_utils import run_async_from_sync
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
|
||||
|
||||
@@ -38,7 +39,7 @@ class MetricsClient(BaseRestClient):
|
||||
status: WorkflowRunStatus | None = None,
|
||||
group_key: str | None = None,
|
||||
) -> WorkflowMetrics:
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_get_workflow_metrics, workflow_id, status, group_key
|
||||
)
|
||||
|
||||
@@ -61,7 +62,7 @@ class MetricsClient(BaseRestClient):
|
||||
workflow_ids: list[str] | None = None,
|
||||
additional_metadata: JSONSerializableMapping | None = None,
|
||||
) -> TenantQueueMetrics:
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_get_queue_metrics, workflow_ids, additional_metadata
|
||||
)
|
||||
|
||||
@@ -72,4 +73,4 @@ class MetricsClient(BaseRestClient):
|
||||
)
|
||||
|
||||
def get_task_metrics(self) -> TenantStepRunQueueMetrics:
|
||||
return self._run_async_from_sync(self.aio_get_task_metrics)
|
||||
return run_async_from_sync(self.aio_get_task_metrics)
|
||||
|
||||
@@ -19,6 +19,7 @@ from hatchet_sdk.clients.v1.api_client import (
|
||||
BaseRestClient,
|
||||
maybe_additional_metadata_to_kv,
|
||||
)
|
||||
from hatchet_sdk.utils.aio_utils import run_async_from_sync
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
|
||||
|
||||
@@ -93,7 +94,7 @@ class RunsClient(BaseRestClient):
|
||||
return await self._wra(client).v1_workflow_run_get(str(workflow_run_id))
|
||||
|
||||
def get(self, workflow_run_id: str) -> V1WorkflowRunDetails:
|
||||
return self._run_async_from_sync(self.aio_get, workflow_run_id)
|
||||
return run_async_from_sync(self.aio_get, workflow_run_id)
|
||||
|
||||
async def aio_list(
|
||||
self,
|
||||
@@ -138,7 +139,7 @@ class RunsClient(BaseRestClient):
|
||||
worker_id: str | None = None,
|
||||
parent_task_external_id: str | None = None,
|
||||
) -> V1TaskSummaryList:
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_list,
|
||||
since=since,
|
||||
only_tasks=only_tasks,
|
||||
@@ -174,7 +175,7 @@ class RunsClient(BaseRestClient):
|
||||
input: JSONSerializableMapping,
|
||||
additional_metadata: JSONSerializableMapping = {},
|
||||
) -> V1WorkflowRunDetails:
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_create, workflow_name, input, additional_metadata
|
||||
)
|
||||
|
||||
@@ -182,7 +183,7 @@ class RunsClient(BaseRestClient):
|
||||
await self.aio_bulk_replay(opts=BulkCancelReplayOpts(ids=[run_id]))
|
||||
|
||||
def replay(self, run_id: str) -> None:
|
||||
return self._run_async_from_sync(self.aio_replay, run_id)
|
||||
return run_async_from_sync(self.aio_replay, run_id)
|
||||
|
||||
async def aio_bulk_replay(self, opts: BulkCancelReplayOpts) -> None:
|
||||
async with self.client() as client:
|
||||
@@ -192,13 +193,13 @@ class RunsClient(BaseRestClient):
|
||||
)
|
||||
|
||||
def bulk_replay(self, opts: BulkCancelReplayOpts) -> None:
|
||||
return self._run_async_from_sync(self.aio_bulk_replay, opts)
|
||||
return run_async_from_sync(self.aio_bulk_replay, opts)
|
||||
|
||||
async def aio_cancel(self, run_id: str) -> None:
|
||||
await self.aio_bulk_cancel(opts=BulkCancelReplayOpts(ids=[run_id]))
|
||||
|
||||
def cancel(self, run_id: str) -> None:
|
||||
return self._run_async_from_sync(self.aio_cancel, run_id)
|
||||
return run_async_from_sync(self.aio_cancel, run_id)
|
||||
|
||||
async def aio_bulk_cancel(self, opts: BulkCancelReplayOpts) -> None:
|
||||
async with self.client() as client:
|
||||
@@ -208,7 +209,7 @@ class RunsClient(BaseRestClient):
|
||||
)
|
||||
|
||||
def bulk_cancel(self, opts: BulkCancelReplayOpts) -> None:
|
||||
return self._run_async_from_sync(self.aio_bulk_cancel, opts)
|
||||
return run_async_from_sync(self.aio_bulk_cancel, opts)
|
||||
|
||||
async def aio_get_result(self, run_id: str) -> JSONSerializableMapping:
|
||||
details = await self.aio_get(run_id)
|
||||
|
||||
@@ -22,6 +22,7 @@ from hatchet_sdk.clients.v1.api_client import (
|
||||
BaseRestClient,
|
||||
maybe_additional_metadata_to_kv,
|
||||
)
|
||||
from hatchet_sdk.utils.aio_utils import run_async_from_sync
|
||||
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
||||
|
||||
|
||||
@@ -82,7 +83,7 @@ class ScheduledClient(BaseRestClient):
|
||||
ScheduledWorkflows: The created scheduled workflow instance.
|
||||
"""
|
||||
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_create,
|
||||
workflow_name,
|
||||
trigger_at,
|
||||
@@ -104,7 +105,7 @@ class ScheduledClient(BaseRestClient):
|
||||
)
|
||||
|
||||
def delete(self, scheduled_id: str) -> None:
|
||||
self._run_async_from_sync(self.aio_delete, scheduled_id)
|
||||
run_async_from_sync(self.aio_delete, scheduled_id)
|
||||
|
||||
async def aio_list(
|
||||
self,
|
||||
@@ -175,7 +176,7 @@ class ScheduledClient(BaseRestClient):
|
||||
Returns:
|
||||
List[ScheduledWorkflows]: A list of scheduled workflows matching the criteria.
|
||||
"""
|
||||
return self._run_async_from_sync(
|
||||
return run_async_from_sync(
|
||||
self.aio_list,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
@@ -214,4 +215,4 @@ class ScheduledClient(BaseRestClient):
|
||||
Returns:
|
||||
ScheduledWorkflows: The requested scheduled workflow instance.
|
||||
"""
|
||||
return self._run_async_from_sync(self.aio_get, scheduled_id)
|
||||
return run_async_from_sync(self.aio_get, scheduled_id)
|
||||
|
||||
@@ -4,6 +4,7 @@ from hatchet_sdk.clients.rest.models.update_worker_request import UpdateWorkerRe
|
||||
from hatchet_sdk.clients.rest.models.worker import Worker
|
||||
from hatchet_sdk.clients.rest.models.worker_list import WorkerList
|
||||
from hatchet_sdk.clients.v1.api_client import BaseRestClient
|
||||
from hatchet_sdk.utils.aio_utils import run_async_from_sync
|
||||
|
||||
|
||||
class WorkersClient(BaseRestClient):
|
||||
@@ -15,7 +16,7 @@ class WorkersClient(BaseRestClient):
|
||||
return await self._wa(client).worker_get(worker_id)
|
||||
|
||||
def get(self, worker_id: str) -> Worker:
|
||||
return self._run_async_from_sync(self.aio_get, worker_id)
|
||||
return run_async_from_sync(self.aio_get, worker_id)
|
||||
|
||||
async def aio_list(
|
||||
self,
|
||||
@@ -28,7 +29,7 @@ class WorkersClient(BaseRestClient):
|
||||
def list(
|
||||
self,
|
||||
) -> WorkerList:
|
||||
return self._run_async_from_sync(self.aio_list)
|
||||
return run_async_from_sync(self.aio_list)
|
||||
|
||||
async def aio_update(self, worker_id: str, opts: UpdateWorkerRequest) -> Worker:
|
||||
async with self.client() as client:
|
||||
@@ -38,4 +39,4 @@ class WorkersClient(BaseRestClient):
|
||||
)
|
||||
|
||||
def update(self, worker_id: str, opts: UpdateWorkerRequest) -> Worker:
|
||||
return self._run_async_from_sync(self.aio_update, worker_id, opts)
|
||||
return run_async_from_sync(self.aio_update, worker_id, opts)
|
||||
|
||||
@@ -5,6 +5,7 @@ from hatchet_sdk.clients.rest.models.workflow import Workflow
|
||||
from hatchet_sdk.clients.rest.models.workflow_list import WorkflowList
|
||||
from hatchet_sdk.clients.rest.models.workflow_version import WorkflowVersion
|
||||
from hatchet_sdk.clients.v1.api_client import BaseRestClient
|
||||
from hatchet_sdk.utils.aio_utils import run_async_from_sync
|
||||
|
||||
|
||||
class WorkflowsClient(BaseRestClient):
|
||||
@@ -19,7 +20,7 @@ class WorkflowsClient(BaseRestClient):
|
||||
return await self._wa(client).workflow_get(workflow_id)
|
||||
|
||||
def get(self, workflow_id: str) -> Workflow:
|
||||
return self._run_async_from_sync(self.aio_get, workflow_id)
|
||||
return run_async_from_sync(self.aio_get, workflow_id)
|
||||
|
||||
async def aio_list(
|
||||
self,
|
||||
@@ -41,7 +42,7 @@ class WorkflowsClient(BaseRestClient):
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> WorkflowList:
|
||||
return self._run_async_from_sync(self.aio_list, workflow_name, limit, offset)
|
||||
return run_async_from_sync(self.aio_list, workflow_name, limit, offset)
|
||||
|
||||
async def aio_get_version(
|
||||
self, workflow_id: str, version: str | None = None
|
||||
@@ -52,4 +53,4 @@ class WorkflowsClient(BaseRestClient):
|
||||
def get_version(
|
||||
self, workflow_id: str, version: str | None = None
|
||||
) -> WorkflowVersion:
|
||||
return self._run_async_from_sync(self.aio_get_version, workflow_id, version)
|
||||
return run_async_from_sync(self.aio_get_version, workflow_id, version)
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, Coroutine, ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
Y = TypeVar("Y")
|
||||
S = TypeVar("S")
|
||||
|
||||
|
||||
def get_active_event_loop() -> asyncio.AbstractEventLoop | None:
|
||||
@@ -16,3 +23,38 @@ def get_active_event_loop() -> asyncio.AbstractEventLoop | None:
|
||||
return None
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def _run_async_function_do_not_use_directly(
|
||||
async_func: Callable[P, Coroutine[Y, S, R]],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def run_async_from_sync(
|
||||
async_func: Callable[P, Coroutine[Y, S, R]],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
else:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
lambda: _run_async_function_do_not_use_directly(
|
||||
async_func, *args, **kwargs
|
||||
)
|
||||
)
|
||||
return future.result()
|
||||
|
||||
Reference in New Issue
Block a user