mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-03-19 03:00:46 -05:00
* feat: adds support for multiple slot types, primarily motivated by durable slots --------- Co-authored-by: mrkaye97 <mrkaye97@gmail.com>
128 lines
3.5 KiB
Python
128 lines
3.5 KiB
Python
import os
|
|
from typing import Any, cast
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
|
|
from hatchet_sdk.config import ClientConfig
|
|
from hatchet_sdk.runnables.workflow import BaseWorkflow
|
|
from hatchet_sdk.utils.slots import resolve_worker_slot_config
|
|
from hatchet_sdk.worker.slot_types import SlotType
|
|
|
|
|
|
def test_client_initialization_from_defaults() -> None:
|
|
assert isinstance(ClientConfig(), ClientConfig)
|
|
|
|
|
|
def test_client_host_port_overrides() -> None:
|
|
host_port = "foo:8080"
|
|
with_host_port = ClientConfig(host_port=host_port)
|
|
|
|
assert with_host_port.host_port == host_port
|
|
assert with_host_port.server_url == ClientConfig().server_url
|
|
|
|
assert ClientConfig().host_port != host_port
|
|
assert ClientConfig().server_url != host_port
|
|
|
|
|
|
def test_client_host_port_override_when_env_var() -> None:
|
|
with mock.patch.dict(os.environ, {"HATCHET_CLIENT_HOST_PORT": "foo:8080"}):
|
|
config = ClientConfig()
|
|
|
|
assert config.host_port == "foo:8080"
|
|
assert config.server_url == ClientConfig().server_url
|
|
|
|
|
|
def test_client_server_url_override_when_env_var() -> None:
|
|
with mock.patch.dict(os.environ, {"HATCHET_CLIENT_SERVER_URL": "foobaz:8080"}):
|
|
config = ClientConfig()
|
|
|
|
assert config.server_url == "foobaz:8080"
|
|
assert config.host_port == ClientConfig().host_port
|
|
|
|
|
|
def test_resolve_slot_config_no_durable() -> None:
|
|
resolved = resolve_worker_slot_config(
|
|
slot_config=None,
|
|
slots=None,
|
|
durable_slots=None,
|
|
workflows=None,
|
|
)
|
|
|
|
assert resolved == {SlotType.DEFAULT: 100}
|
|
|
|
|
|
def test_resolve_slot_config_only_durable() -> None:
|
|
class DummyTask:
|
|
is_durable = True
|
|
slot_requests: dict[str, int] = {"durable": 1}
|
|
|
|
class DummyWorkflow:
|
|
tasks = [DummyTask()]
|
|
|
|
resolved = resolve_worker_slot_config(
|
|
slot_config=None,
|
|
slots=None,
|
|
durable_slots=None,
|
|
workflows=cast(list[BaseWorkflow[Any]], [DummyWorkflow()]),
|
|
)
|
|
|
|
assert resolved == {SlotType.DURABLE: 1000}
|
|
|
|
|
|
def test_resolve_slot_config_mixed() -> None:
|
|
class DefaultTask:
|
|
is_durable = False
|
|
slot_requests: dict[str, int] = {"default": 1}
|
|
|
|
class DurableTask:
|
|
is_durable = True
|
|
slot_requests: dict[str, int] = {"durable": 1}
|
|
|
|
class DummyWorkflow:
|
|
tasks = [DefaultTask(), DurableTask()]
|
|
|
|
resolved = resolve_worker_slot_config(
|
|
slot_config=None,
|
|
slots=None,
|
|
durable_slots=None,
|
|
workflows=cast(list[BaseWorkflow[Any]], [DummyWorkflow()]),
|
|
)
|
|
|
|
assert resolved == {SlotType.DEFAULT: 100, SlotType.DURABLE: 1000}
|
|
|
|
|
|
def test_resolve_slot_config_custom_type_raises_when_missing() -> None:
|
|
class GpuTask:
|
|
is_durable = False
|
|
slot_requests: dict[str, int] = {"gpu": 1}
|
|
|
|
class DummyWorkflow:
|
|
tasks = [GpuTask()]
|
|
|
|
with pytest.raises(ValueError, match="gpu"):
|
|
resolve_worker_slot_config(
|
|
slot_config=None,
|
|
slots=None,
|
|
durable_slots=None,
|
|
workflows=cast(list[BaseWorkflow[Any]], [DummyWorkflow()]),
|
|
)
|
|
|
|
|
|
def test_resolve_slot_config_custom_type_passes_when_configured() -> None:
|
|
class GpuTask:
|
|
is_durable = False
|
|
slot_requests: dict[str, int] = {"gpu": 1}
|
|
|
|
class DummyWorkflow:
|
|
tasks = [GpuTask()]
|
|
|
|
resolved = resolve_worker_slot_config(
|
|
slot_config={"gpu": 4},
|
|
slots=None,
|
|
durable_slots=None,
|
|
workflows=cast(list[BaseWorkflow[Any]], [DummyWorkflow()]),
|
|
)
|
|
|
|
assert resolved == {"gpu": 4}
|