Feat: Python task unit tests (#1990)

* feat: add mock run methods for tasks

* feat: docs

* feat: first pass at unit tests

* cleanup: split out tests

* feat: pass lifespan through

* fix: rm comment

* drive by: retry on 404 to help with races

* chore: changelog

* chore: ver

* feat: improve logging everywhere

* chore: changelog

* fix: rm print cruft

* feat: print statement linter

* feat: helper for getting result of a standalone

* feat: docs for mock run

* feat: add task run getter

* feat: propagate additional metadata properly

* chore: gen

* fix: date

* chore: gen

* feat: return exceptions

* chore: gen

* chore: changelog

* feat: tests + gen again

* fix: rm print cruft
This commit is contained in:
Matt Kaye
2025-07-17 13:54:40 -04:00
committed by GitHub
parent b4544f170e
commit f1f276f6dc
88 changed files with 1809 additions and 413 deletions
@@ -77,7 +77,7 @@ async def test_bulk_replay(hatchet: Hatchet) -> None:
)
)
await asyncio.sleep(5)
await asyncio.sleep(10)
runs = await hatchet.runs.aio_list(
workflow_ids=workflow_ids,
@@ -1,41 +0,0 @@
import pytest
from examples.concurrency_limit.worker import WorkflowInput, concurrency_limit_workflow
from hatchet_sdk.workflow_run import WorkflowRunRef
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.skip(reason="The timing for this test is not reliable")
async def test_run() -> None:
num_runs = 6
runs: list[WorkflowRunRef] = []
# Start all runs
for i in range(1, num_runs + 1):
run = concurrency_limit_workflow.run_no_wait(
WorkflowInput(run=i, group_key=str(i))
)
runs.append(run)
# Wait for all results
successful_runs = []
cancelled_runs = []
# Process each run individually
for i, run in enumerate(runs, start=1):
try:
result = await run.aio_result()
successful_runs.append((i, result))
except Exception as e:
if "CANCELLED_BY_CONCURRENCY_LIMIT" in str(e):
cancelled_runs.append((i, str(e)))
else:
raise # Re-raise if it's an unexpected error
# Check that we have the correct number of successful and cancelled runs
assert (
len(successful_runs) == 5
), f"Expected 5 successful runs, got {len(successful_runs)}"
assert (
len(cancelled_runs) == 1
), f"Expected 1 cancelled run, got {len(cancelled_runs)}"
+42 -2
View File
@@ -1,10 +1,50 @@
import asyncio
from uuid import uuid4
import pytest
from examples.fanout.worker import ParentInput, parent_wf
from hatchet_sdk import Hatchet, TriggerWorkflowOptions
@pytest.mark.asyncio(loop_scope="session")
async def test_run() -> None:
result = await parent_wf.aio_run(ParentInput(n=2))
async def test_run(hatchet: Hatchet) -> None:
ref = await parent_wf.aio_run_no_wait(
ParentInput(n=2),
)
result = await ref.aio_result()
assert len(result["spawn"]["results"]) == 2
@pytest.mark.asyncio(loop_scope="session")
async def test_additional_metadata_propagation(hatchet: Hatchet) -> None:
test_run_id = uuid4().hex
ref = await parent_wf.aio_run_no_wait(
ParentInput(n=2),
options=TriggerWorkflowOptions(
additional_metadata={"test_run_id": test_run_id}
),
)
await ref.aio_result()
await asyncio.sleep(1)
runs = await hatchet.runs.aio_list(
parent_task_external_id=ref.workflow_run_id,
additional_metadata={"test_run_id": test_run_id},
)
assert runs.rows
"""Assert that the additional metadata is propagated to the child runs."""
for run in runs.rows:
assert run.additional_metadata
assert run.additional_metadata["test_run_id"] == test_run_id
assert run.children
for child in run.children:
assert child.additional_metadata
assert child.additional_metadata["test_run_id"] == test_run_id
+3 -3
View File
@@ -34,7 +34,7 @@ async def spawn(input: ParentInput, ctx: Context) -> dict[str, Any]:
),
)
for i in range(input.n)
]
],
)
print(f"results {result}")
@@ -47,13 +47,13 @@ async def spawn(input: ParentInput, ctx: Context) -> dict[str, Any]:
# > FanoutChild
@child_wf.task()
def process(input: ChildInput, ctx: Context) -> dict[str, str]:
async def process(input: ChildInput, ctx: Context) -> dict[str, str]:
print(f"child process {input.a}")
return {"status": input.a}
@child_wf.task(parents=[process])
def process2(input: ChildInput, ctx: Context) -> dict[str, str]:
async def process2(input: ChildInput, ctx: Context) -> dict[str, str]:
process_output = ctx.task_output(process)
a = process_output["status"]
@@ -1,4 +1,10 @@
import asyncio
from uuid import uuid4
import pytest
from examples.fanout_sync.worker import ParentInput, sync_fanout_parent
from hatchet_sdk import Hatchet, TriggerWorkflowOptions
def test_run() -> None:
@@ -7,3 +13,35 @@ def test_run() -> None:
result = sync_fanout_parent.run(ParentInput(n=N))
assert len(result["spawn"]["results"]) == N
@pytest.mark.asyncio(loop_scope="session")
async def test_additional_metadata_propagation_sync(hatchet: Hatchet) -> None:
test_run_id = uuid4().hex
ref = await sync_fanout_parent.aio_run_no_wait(
ParentInput(n=2),
options=TriggerWorkflowOptions(
additional_metadata={"test_run_id": test_run_id}
),
)
await ref.aio_result()
await asyncio.sleep(1)
runs = await hatchet.runs.aio_list(
parent_task_external_id=ref.workflow_run_id,
additional_metadata={"test_run_id": test_run_id},
)
assert runs.rows
"""Assert that the additional metadata is propagated to the child runs."""
for run in runs.rows:
assert run.additional_metadata
assert run.additional_metadata["test_run_id"] == test_run_id
assert run.children
for child in run.children:
assert child.additional_metadata
assert child.additional_metadata["test_run_id"] == test_run_id
@@ -47,6 +47,14 @@ def process(input: ChildInput, ctx: Context) -> dict[str, str]:
return {"status": "success " + input.a}
@sync_fanout_child.task(parents=[process])
def process2(input: ChildInput, ctx: Context) -> dict[str, str]:
process_output = ctx.task_output(process)
a = process_output["status"]
return {"status2": a + "2"}
def main() -> None:
worker = hatchet.worker(
"sync-fanout-worker",
@@ -1,27 +0,0 @@
import asyncio
import time
import pytest
from examples.rate_limit.worker import rate_limit_workflow
@pytest.mark.skip(reason="The timing for this test is not reliable")
@pytest.mark.asyncio(loop_scope="session")
async def test_run() -> None:
run1 = rate_limit_workflow.run_no_wait()
run2 = rate_limit_workflow.run_no_wait()
run3 = rate_limit_workflow.run_no_wait()
start_time = time.time()
await asyncio.gather(run1.aio_result(), run2.aio_result(), run3.aio_result())
end_time = time.time()
total_time = end_time - start_time
assert (
1 <= total_time <= 5
), f"Expected runtime to be a bit more than 1 seconds, but it took {total_time:.2f} seconds"
@@ -0,0 +1,40 @@
import asyncio
import pytest
from examples.return_exceptions.worker import Input, return_exceptions_task
@pytest.mark.asyncio(loop_scope="session")
async def test_return_exceptions_async() -> None:
results = await return_exceptions_task.aio_run_many(
[
return_exceptions_task.create_bulk_run_item(input=Input(index=i))
for i in range(10)
],
return_exceptions=True,
)
for i, result in enumerate(results):
if i % 2 == 0:
assert isinstance(result, Exception)
assert f"error in task with index {i}" in str(result)
else:
assert result == {"message": "this is a successful task."}
def test_return_exceptions_sync() -> None:
results = return_exceptions_task.run_many(
[
return_exceptions_task.create_bulk_run_item(input=Input(index=i))
for i in range(10)
],
return_exceptions=True,
)
for i, result in enumerate(results):
if i % 2 == 0:
assert isinstance(result, Exception)
assert f"error in task with index {i}" in str(result)
else:
assert result == {"message": "this is a successful task."}
@@ -0,0 +1,17 @@
from pydantic import BaseModel
from hatchet_sdk import Context, EmptyModel, Hatchet
hatchet = Hatchet()
class Input(EmptyModel):
index: int
@hatchet.task(input_validator=Input)
async def return_exceptions_task(input: Input, ctx: Context) -> dict[str, str]:
if input.index % 2 == 0:
raise ValueError(f"error in task with index {input.index}")
return {"message": "this is a successful task."}
+2 -3
View File
@@ -16,10 +16,10 @@ timeout_wf = hatchet.workflow(
# > ExecutionTimeout
# 👀 Specify an execution timeout on a task
@timeout_wf.task(
execution_timeout=timedelta(seconds=4), schedule_timeout=timedelta(minutes=10)
execution_timeout=timedelta(seconds=5), schedule_timeout=timedelta(minutes=10)
)
def timeout_task(input: EmptyModel, ctx: Context) -> dict[str, str]:
time.sleep(5)
time.sleep(30)
return {"status": "success"}
@@ -31,7 +31,6 @@ refresh_timeout_wf = hatchet.workflow(name="RefreshTimeoutWorkflow")
# > RefreshTimeout
@refresh_timeout_wf.task(execution_timeout=timedelta(seconds=4))
def refresh_task(input: EmptyModel, ctx: Context) -> dict[str, str]:
ctx.refresh_timeout(timedelta(seconds=10))
time.sleep(5)
@@ -0,0 +1,96 @@
import pytest
from examples.unit_testing.workflows import (
Lifespan,
UnitTestInput,
UnitTestOutput,
async_complex_workflow,
async_simple_workflow,
async_standalone,
durable_async_complex_workflow,
durable_async_simple_workflow,
durable_async_standalone,
durable_sync_complex_workflow,
durable_sync_simple_workflow,
durable_sync_standalone,
start,
sync_complex_workflow,
sync_simple_workflow,
sync_standalone,
)
from hatchet_sdk import Task
@pytest.mark.parametrize(
"func",
[
sync_standalone,
durable_sync_standalone,
sync_simple_workflow,
durable_sync_simple_workflow,
sync_complex_workflow,
durable_sync_complex_workflow,
],
)
def test_simple_unit_sync(func: Task[UnitTestInput, UnitTestOutput]) -> None:
input = UnitTestInput(key="test_key", number=42)
additional_metadata = {"meta_key": "meta_value"}
lifespan = Lifespan(mock_db_url="sqlite:///:memory:")
retry_count = 1
expected_output = UnitTestOutput(
key=input.key,
number=input.number,
additional_metadata=additional_metadata,
retry_count=retry_count,
mock_db_url=lifespan.mock_db_url,
)
assert (
func.mock_run(
input=input,
additional_metadata=additional_metadata,
lifespan=lifespan,
retry_count=retry_count,
parent_outputs={start.name: expected_output.model_dump()},
)
== expected_output
)
@pytest.mark.parametrize(
"func",
[
async_standalone,
durable_async_standalone,
async_simple_workflow,
durable_async_simple_workflow,
async_complex_workflow,
durable_async_complex_workflow,
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_simple_unit_async(func: Task[UnitTestInput, UnitTestOutput]) -> None:
input = UnitTestInput(key="test_key", number=42)
additional_metadata = {"meta_key": "meta_value"}
lifespan = Lifespan(mock_db_url="sqlite:///:memory:")
retry_count = 1
expected_output = UnitTestOutput(
key=input.key,
number=input.number,
additional_metadata=additional_metadata,
retry_count=retry_count,
mock_db_url=lifespan.mock_db_url,
)
assert (
await func.aio_mock_run(
input=input,
additional_metadata=additional_metadata,
lifespan=lifespan,
retry_count=retry_count,
parent_outputs={start.name: expected_output.model_dump()},
)
== expected_output
)
@@ -0,0 +1,171 @@
from typing import cast
from pydantic import BaseModel
from hatchet_sdk import Context, DurableContext, EmptyModel, Hatchet
class UnitTestInput(BaseModel):
key: str
number: int
class Lifespan(BaseModel):
mock_db_url: str
class UnitTestOutput(UnitTestInput, Lifespan):
additional_metadata: dict[str, str]
retry_count: int
hatchet = Hatchet()
@hatchet.task(input_validator=UnitTestInput)
def sync_standalone(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
return UnitTestOutput(
key=input.key,
number=input.number,
additional_metadata=ctx.additional_metadata,
retry_count=ctx.retry_count,
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
)
@hatchet.task(input_validator=UnitTestInput)
async def async_standalone(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
return UnitTestOutput(
key=input.key,
number=input.number,
additional_metadata=ctx.additional_metadata,
retry_count=ctx.retry_count,
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
)
@hatchet.durable_task(input_validator=UnitTestInput)
def durable_sync_standalone(
input: UnitTestInput, ctx: DurableContext
) -> UnitTestOutput:
return UnitTestOutput(
key=input.key,
number=input.number,
additional_metadata=ctx.additional_metadata,
retry_count=ctx.retry_count,
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
)
@hatchet.durable_task(input_validator=UnitTestInput)
async def durable_async_standalone(
input: UnitTestInput, ctx: DurableContext
) -> UnitTestOutput:
return UnitTestOutput(
key=input.key,
number=input.number,
additional_metadata=ctx.additional_metadata,
retry_count=ctx.retry_count,
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
)
simple_workflow = hatchet.workflow(
name="simple-unit-test-workflow", input_validator=UnitTestInput
)
@simple_workflow.task()
def sync_simple_workflow(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
return UnitTestOutput(
key=input.key,
number=input.number,
additional_metadata=ctx.additional_metadata,
retry_count=ctx.retry_count,
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
)
@simple_workflow.task()
async def async_simple_workflow(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
return UnitTestOutput(
key=input.key,
number=input.number,
additional_metadata=ctx.additional_metadata,
retry_count=ctx.retry_count,
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
)
@simple_workflow.durable_task()
def durable_sync_simple_workflow(
input: UnitTestInput, ctx: DurableContext
) -> UnitTestOutput:
return UnitTestOutput(
key=input.key,
number=input.number,
additional_metadata=ctx.additional_metadata,
retry_count=ctx.retry_count,
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
)
@simple_workflow.durable_task()
async def durable_async_simple_workflow(
input: UnitTestInput, ctx: DurableContext
) -> UnitTestOutput:
return UnitTestOutput(
key=input.key,
number=input.number,
additional_metadata=ctx.additional_metadata,
retry_count=ctx.retry_count,
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
)
complex_workflow = hatchet.workflow(
name="complex-unit-test-workflow", input_validator=UnitTestInput
)
@complex_workflow.task()
async def start(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
return UnitTestOutput(
key=input.key,
number=input.number,
additional_metadata=ctx.additional_metadata,
retry_count=ctx.retry_count,
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
)
@complex_workflow.task(
parents=[start],
)
def sync_complex_workflow(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
return ctx.task_output(start)
@complex_workflow.task(
parents=[start],
)
async def async_complex_workflow(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
return ctx.task_output(start)
@complex_workflow.durable_task(
parents=[start],
)
def durable_sync_complex_workflow(
input: UnitTestInput, ctx: DurableContext
) -> UnitTestOutput:
return ctx.task_output(start)
@complex_workflow.durable_task(
parents=[start],
)
async def durable_async_complex_workflow(
input: UnitTestInput, ctx: DurableContext
) -> UnitTestOutput:
return ctx.task_output(start)
+2
View File
@@ -23,6 +23,7 @@ from examples.lifespans.simple import lifespan, lifespan_task
from examples.logger.workflow import logging_workflow
from examples.non_retryable.worker import non_retryable_workflow
from examples.on_failure.worker import on_failure_wf, on_failure_wf_with_details
from examples.return_exceptions.worker import return_exceptions_task
from examples.simple.worker import simple, simple_durable
from examples.timeout.worker import refresh_timeout_wf, timeout_wf
from hatchet_sdk import Hatchet
@@ -65,6 +66,7 @@ def main() -> None:
bulk_replay_test_1,
bulk_replay_test_2,
bulk_replay_test_3,
return_exceptions_task,
],
lifespan=lifespan,
)