feat: test for dict inputs

This commit is contained in:
mrkaye97
2025-12-15 17:10:12 -05:00
parent 5b844565f6
commit 44f4c97a93
4 changed files with 74 additions and 1 deletions

View File

@@ -0,0 +1,55 @@
import pytest
from examples.dict_input.worker import Output, say_hello_unsafely
@pytest.mark.asyncio(loop_scope="session")
async def test_dict_input() -> None:
input = {"name": "Hatchet"}
x1 = say_hello_unsafely.run(input)
x2 = await say_hello_unsafely.aio_run(input)
x3 = say_hello_unsafely.run_many([say_hello_unsafely.create_bulk_run_item(input)])[
0
]
x4 = (
await say_hello_unsafely.aio_run_many(
[say_hello_unsafely.create_bulk_run_item(input)]
)
)[0]
x5 = say_hello_unsafely.run_no_wait(input).result()
x6 = (await say_hello_unsafely.aio_run_no_wait(input)).result()
x7 = [
x.result()
for x in say_hello_unsafely.run_many_no_wait(
[say_hello_unsafely.create_bulk_run_item(input)]
)
][0]
x8 = [
x.result()
for x in await say_hello_unsafely.aio_run_many_no_wait(
[say_hello_unsafely.create_bulk_run_item(input)]
)
][0]
x9 = await say_hello_unsafely.run_no_wait(input).aio_result()
x10 = await (await say_hello_unsafely.aio_run_no_wait(input)).aio_result()
x11 = [
await x.aio_result()
for x in say_hello_unsafely.run_many_no_wait(
[say_hello_unsafely.create_bulk_run_item(input)]
)
][0]
x12 = [
await x.aio_result()
for x in await say_hello_unsafely.aio_run_many_no_wait(
[say_hello_unsafely.create_bulk_run_item(input)]
)
][0]
assert all(
x == Output(message="Hello, Hatchet!")
for x in [x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12]
)

View File

@@ -0,0 +1,16 @@
from pydantic import BaseModel
from hatchet_sdk import Context, Hatchet
class Output(BaseModel):
message: str
hatchet = Hatchet(debug=True)
@hatchet.task(input_validator=dict)
def say_hello_unsafely(input: dict[str, str], _c: Context) -> Output:
name = input["name"] # untyped
return Output(message=f"Hello, {name}!")

View File

@@ -30,6 +30,7 @@ from examples.dependency_injection.worker import (
durable_sync_task_with_dependencies,
sync_task_with_dependencies,
)
from examples.dict_input.worker import say_hello_unsafely
from examples.durable.worker import durable_workflow, wait_for_sleep_twice
from examples.events.worker import event_workflow
from examples.fanout.worker import child_wf, parent_wf
@@ -94,6 +95,7 @@ def main() -> None:
durable_async_task_with_dependencies,
durable_sync_task_with_dependencies,
say_hello,
say_hello_unsafely,
],
lifespan=lifespan,
)

View File

@@ -61,7 +61,7 @@ class ConcurrencyExpression(BaseModel):
TWorkflowInput = TypeVar(
"TWorkflowInput", bound=BaseModel | DataclassInstance | dict[str, Any] | None
"TWorkflowInput", bound=BaseModel | DataclassInstance | dict[str, Any]
)