mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-04-23 02:34:48 -05:00
Feat: Python task unit tests (#1990)
* feat: add mock run methods for tasks * feat: docs * feat: first pass at unit tests * cleanup: split out tests * feat: pass lifespan through * fix: rm comment * drive by: retry on 404 to help with races * chore: changelog * chore: ver * feat: improve logging everywhere * chore: changelog * fix: rm print cruft * feat: print statement linter * feat: helper for getting result of a standalone * feat: docs for mock run * feat: add task run getter * feat: propagate additional metadata properly * chore: gen * fix: date * chore: gen * feat: return exceptions * chore: gen * chore: changelog * feat: tests + gen again * fix: rm print cruft
This commit is contained in:
@@ -77,7 +77,7 @@ async def test_bulk_replay(hatchet: Hatchet) -> None:
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(5)
|
||||
await asyncio.sleep(10)
|
||||
|
||||
runs = await hatchet.runs.aio_list(
|
||||
workflow_ids=workflow_ids,
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from examples.concurrency_limit.worker import WorkflowInput, concurrency_limit_workflow
|
||||
from hatchet_sdk.workflow_run import WorkflowRunRef
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.skip(reason="The timing for this test is not reliable")
|
||||
async def test_run() -> None:
|
||||
num_runs = 6
|
||||
runs: list[WorkflowRunRef] = []
|
||||
|
||||
# Start all runs
|
||||
for i in range(1, num_runs + 1):
|
||||
run = concurrency_limit_workflow.run_no_wait(
|
||||
WorkflowInput(run=i, group_key=str(i))
|
||||
)
|
||||
runs.append(run)
|
||||
|
||||
# Wait for all results
|
||||
successful_runs = []
|
||||
cancelled_runs = []
|
||||
|
||||
# Process each run individually
|
||||
for i, run in enumerate(runs, start=1):
|
||||
try:
|
||||
result = await run.aio_result()
|
||||
successful_runs.append((i, result))
|
||||
except Exception as e:
|
||||
if "CANCELLED_BY_CONCURRENCY_LIMIT" in str(e):
|
||||
cancelled_runs.append((i, str(e)))
|
||||
else:
|
||||
raise # Re-raise if it's an unexpected error
|
||||
|
||||
# Check that we have the correct number of successful and cancelled runs
|
||||
assert (
|
||||
len(successful_runs) == 5
|
||||
), f"Expected 5 successful runs, got {len(successful_runs)}"
|
||||
assert (
|
||||
len(cancelled_runs) == 1
|
||||
), f"Expected 1 cancelled run, got {len(cancelled_runs)}"
|
||||
@@ -1,10 +1,50 @@
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from examples.fanout.worker import ParentInput, parent_wf
|
||||
from hatchet_sdk import Hatchet, TriggerWorkflowOptions
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run() -> None:
|
||||
result = await parent_wf.aio_run(ParentInput(n=2))
|
||||
async def test_run(hatchet: Hatchet) -> None:
|
||||
ref = await parent_wf.aio_run_no_wait(
|
||||
ParentInput(n=2),
|
||||
)
|
||||
|
||||
result = await ref.aio_result()
|
||||
|
||||
assert len(result["spawn"]["results"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_additional_metadata_propagation(hatchet: Hatchet) -> None:
|
||||
test_run_id = uuid4().hex
|
||||
|
||||
ref = await parent_wf.aio_run_no_wait(
|
||||
ParentInput(n=2),
|
||||
options=TriggerWorkflowOptions(
|
||||
additional_metadata={"test_run_id": test_run_id}
|
||||
),
|
||||
)
|
||||
|
||||
await ref.aio_result()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
runs = await hatchet.runs.aio_list(
|
||||
parent_task_external_id=ref.workflow_run_id,
|
||||
additional_metadata={"test_run_id": test_run_id},
|
||||
)
|
||||
|
||||
assert runs.rows
|
||||
|
||||
"""Assert that the additional metadata is propagated to the child runs."""
|
||||
for run in runs.rows:
|
||||
assert run.additional_metadata
|
||||
assert run.additional_metadata["test_run_id"] == test_run_id
|
||||
|
||||
assert run.children
|
||||
for child in run.children:
|
||||
assert child.additional_metadata
|
||||
assert child.additional_metadata["test_run_id"] == test_run_id
|
||||
|
||||
@@ -34,7 +34,7 @@ async def spawn(input: ParentInput, ctx: Context) -> dict[str, Any]:
|
||||
),
|
||||
)
|
||||
for i in range(input.n)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
print(f"results {result}")
|
||||
@@ -47,13 +47,13 @@ async def spawn(input: ParentInput, ctx: Context) -> dict[str, Any]:
|
||||
|
||||
# > FanoutChild
|
||||
@child_wf.task()
|
||||
def process(input: ChildInput, ctx: Context) -> dict[str, str]:
|
||||
async def process(input: ChildInput, ctx: Context) -> dict[str, str]:
|
||||
print(f"child process {input.a}")
|
||||
return {"status": input.a}
|
||||
|
||||
|
||||
@child_wf.task(parents=[process])
|
||||
def process2(input: ChildInput, ctx: Context) -> dict[str, str]:
|
||||
async def process2(input: ChildInput, ctx: Context) -> dict[str, str]:
|
||||
process_output = ctx.task_output(process)
|
||||
a = process_output["status"]
|
||||
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from examples.fanout_sync.worker import ParentInput, sync_fanout_parent
|
||||
from hatchet_sdk import Hatchet, TriggerWorkflowOptions
|
||||
|
||||
|
||||
def test_run() -> None:
|
||||
@@ -7,3 +13,35 @@ def test_run() -> None:
|
||||
result = sync_fanout_parent.run(ParentInput(n=N))
|
||||
|
||||
assert len(result["spawn"]["results"]) == N
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_additional_metadata_propagation_sync(hatchet: Hatchet) -> None:
|
||||
test_run_id = uuid4().hex
|
||||
|
||||
ref = await sync_fanout_parent.aio_run_no_wait(
|
||||
ParentInput(n=2),
|
||||
options=TriggerWorkflowOptions(
|
||||
additional_metadata={"test_run_id": test_run_id}
|
||||
),
|
||||
)
|
||||
|
||||
await ref.aio_result()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
runs = await hatchet.runs.aio_list(
|
||||
parent_task_external_id=ref.workflow_run_id,
|
||||
additional_metadata={"test_run_id": test_run_id},
|
||||
)
|
||||
|
||||
assert runs.rows
|
||||
|
||||
"""Assert that the additional metadata is propagated to the child runs."""
|
||||
for run in runs.rows:
|
||||
assert run.additional_metadata
|
||||
assert run.additional_metadata["test_run_id"] == test_run_id
|
||||
|
||||
assert run.children
|
||||
for child in run.children:
|
||||
assert child.additional_metadata
|
||||
assert child.additional_metadata["test_run_id"] == test_run_id
|
||||
|
||||
@@ -47,6 +47,14 @@ def process(input: ChildInput, ctx: Context) -> dict[str, str]:
|
||||
return {"status": "success " + input.a}
|
||||
|
||||
|
||||
@sync_fanout_child.task(parents=[process])
|
||||
def process2(input: ChildInput, ctx: Context) -> dict[str, str]:
|
||||
process_output = ctx.task_output(process)
|
||||
a = process_output["status"]
|
||||
|
||||
return {"status2": a + "2"}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
worker = hatchet.worker(
|
||||
"sync-fanout-worker",
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from examples.rate_limit.worker import rate_limit_workflow
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="The timing for this test is not reliable")
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run() -> None:
|
||||
|
||||
run1 = rate_limit_workflow.run_no_wait()
|
||||
run2 = rate_limit_workflow.run_no_wait()
|
||||
run3 = rate_limit_workflow.run_no_wait()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
await asyncio.gather(run1.aio_result(), run2.aio_result(), run3.aio_result())
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
total_time = end_time - start_time
|
||||
|
||||
assert (
|
||||
1 <= total_time <= 5
|
||||
), f"Expected runtime to be a bit more than 1 seconds, but it took {total_time:.2f} seconds"
|
||||
@@ -0,0 +1,40 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from examples.return_exceptions.worker import Input, return_exceptions_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_return_exceptions_async() -> None:
|
||||
results = await return_exceptions_task.aio_run_many(
|
||||
[
|
||||
return_exceptions_task.create_bulk_run_item(input=Input(index=i))
|
||||
for i in range(10)
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if i % 2 == 0:
|
||||
assert isinstance(result, Exception)
|
||||
assert f"error in task with index {i}" in str(result)
|
||||
else:
|
||||
assert result == {"message": "this is a successful task."}
|
||||
|
||||
|
||||
def test_return_exceptions_sync() -> None:
|
||||
results = return_exceptions_task.run_many(
|
||||
[
|
||||
return_exceptions_task.create_bulk_run_item(input=Input(index=i))
|
||||
for i in range(10)
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if i % 2 == 0:
|
||||
assert isinstance(result, Exception)
|
||||
assert f"error in task with index {i}" in str(result)
|
||||
else:
|
||||
assert result == {"message": "this is a successful task."}
|
||||
@@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from hatchet_sdk import Context, EmptyModel, Hatchet
|
||||
|
||||
hatchet = Hatchet()
|
||||
|
||||
|
||||
class Input(EmptyModel):
|
||||
index: int
|
||||
|
||||
|
||||
@hatchet.task(input_validator=Input)
|
||||
async def return_exceptions_task(input: Input, ctx: Context) -> dict[str, str]:
|
||||
if input.index % 2 == 0:
|
||||
raise ValueError(f"error in task with index {input.index}")
|
||||
|
||||
return {"message": "this is a successful task."}
|
||||
@@ -16,10 +16,10 @@ timeout_wf = hatchet.workflow(
|
||||
# > ExecutionTimeout
|
||||
# 👀 Specify an execution timeout on a task
|
||||
@timeout_wf.task(
|
||||
execution_timeout=timedelta(seconds=4), schedule_timeout=timedelta(minutes=10)
|
||||
execution_timeout=timedelta(seconds=5), schedule_timeout=timedelta(minutes=10)
|
||||
)
|
||||
def timeout_task(input: EmptyModel, ctx: Context) -> dict[str, str]:
|
||||
time.sleep(5)
|
||||
time.sleep(30)
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ refresh_timeout_wf = hatchet.workflow(name="RefreshTimeoutWorkflow")
|
||||
# > RefreshTimeout
|
||||
@refresh_timeout_wf.task(execution_timeout=timedelta(seconds=4))
|
||||
def refresh_task(input: EmptyModel, ctx: Context) -> dict[str, str]:
|
||||
|
||||
ctx.refresh_timeout(timedelta(seconds=10))
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
import pytest
|
||||
|
||||
from examples.unit_testing.workflows import (
|
||||
Lifespan,
|
||||
UnitTestInput,
|
||||
UnitTestOutput,
|
||||
async_complex_workflow,
|
||||
async_simple_workflow,
|
||||
async_standalone,
|
||||
durable_async_complex_workflow,
|
||||
durable_async_simple_workflow,
|
||||
durable_async_standalone,
|
||||
durable_sync_complex_workflow,
|
||||
durable_sync_simple_workflow,
|
||||
durable_sync_standalone,
|
||||
start,
|
||||
sync_complex_workflow,
|
||||
sync_simple_workflow,
|
||||
sync_standalone,
|
||||
)
|
||||
from hatchet_sdk import Task
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func",
|
||||
[
|
||||
sync_standalone,
|
||||
durable_sync_standalone,
|
||||
sync_simple_workflow,
|
||||
durable_sync_simple_workflow,
|
||||
sync_complex_workflow,
|
||||
durable_sync_complex_workflow,
|
||||
],
|
||||
)
|
||||
def test_simple_unit_sync(func: Task[UnitTestInput, UnitTestOutput]) -> None:
|
||||
input = UnitTestInput(key="test_key", number=42)
|
||||
additional_metadata = {"meta_key": "meta_value"}
|
||||
lifespan = Lifespan(mock_db_url="sqlite:///:memory:")
|
||||
retry_count = 1
|
||||
|
||||
expected_output = UnitTestOutput(
|
||||
key=input.key,
|
||||
number=input.number,
|
||||
additional_metadata=additional_metadata,
|
||||
retry_count=retry_count,
|
||||
mock_db_url=lifespan.mock_db_url,
|
||||
)
|
||||
|
||||
assert (
|
||||
func.mock_run(
|
||||
input=input,
|
||||
additional_metadata=additional_metadata,
|
||||
lifespan=lifespan,
|
||||
retry_count=retry_count,
|
||||
parent_outputs={start.name: expected_output.model_dump()},
|
||||
)
|
||||
== expected_output
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func",
|
||||
[
|
||||
async_standalone,
|
||||
durable_async_standalone,
|
||||
async_simple_workflow,
|
||||
durable_async_simple_workflow,
|
||||
async_complex_workflow,
|
||||
durable_async_complex_workflow,
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_simple_unit_async(func: Task[UnitTestInput, UnitTestOutput]) -> None:
|
||||
input = UnitTestInput(key="test_key", number=42)
|
||||
additional_metadata = {"meta_key": "meta_value"}
|
||||
lifespan = Lifespan(mock_db_url="sqlite:///:memory:")
|
||||
retry_count = 1
|
||||
|
||||
expected_output = UnitTestOutput(
|
||||
key=input.key,
|
||||
number=input.number,
|
||||
additional_metadata=additional_metadata,
|
||||
retry_count=retry_count,
|
||||
mock_db_url=lifespan.mock_db_url,
|
||||
)
|
||||
|
||||
assert (
|
||||
await func.aio_mock_run(
|
||||
input=input,
|
||||
additional_metadata=additional_metadata,
|
||||
lifespan=lifespan,
|
||||
retry_count=retry_count,
|
||||
parent_outputs={start.name: expected_output.model_dump()},
|
||||
)
|
||||
== expected_output
|
||||
)
|
||||
@@ -0,0 +1,171 @@
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from hatchet_sdk import Context, DurableContext, EmptyModel, Hatchet
|
||||
|
||||
|
||||
class UnitTestInput(BaseModel):
|
||||
key: str
|
||||
number: int
|
||||
|
||||
|
||||
class Lifespan(BaseModel):
|
||||
mock_db_url: str
|
||||
|
||||
|
||||
class UnitTestOutput(UnitTestInput, Lifespan):
|
||||
additional_metadata: dict[str, str]
|
||||
retry_count: int
|
||||
|
||||
|
||||
hatchet = Hatchet()
|
||||
|
||||
|
||||
@hatchet.task(input_validator=UnitTestInput)
|
||||
def sync_standalone(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
|
||||
return UnitTestOutput(
|
||||
key=input.key,
|
||||
number=input.number,
|
||||
additional_metadata=ctx.additional_metadata,
|
||||
retry_count=ctx.retry_count,
|
||||
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
|
||||
)
|
||||
|
||||
|
||||
@hatchet.task(input_validator=UnitTestInput)
|
||||
async def async_standalone(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
|
||||
return UnitTestOutput(
|
||||
key=input.key,
|
||||
number=input.number,
|
||||
additional_metadata=ctx.additional_metadata,
|
||||
retry_count=ctx.retry_count,
|
||||
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
|
||||
)
|
||||
|
||||
|
||||
@hatchet.durable_task(input_validator=UnitTestInput)
|
||||
def durable_sync_standalone(
|
||||
input: UnitTestInput, ctx: DurableContext
|
||||
) -> UnitTestOutput:
|
||||
return UnitTestOutput(
|
||||
key=input.key,
|
||||
number=input.number,
|
||||
additional_metadata=ctx.additional_metadata,
|
||||
retry_count=ctx.retry_count,
|
||||
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
|
||||
)
|
||||
|
||||
|
||||
@hatchet.durable_task(input_validator=UnitTestInput)
|
||||
async def durable_async_standalone(
|
||||
input: UnitTestInput, ctx: DurableContext
|
||||
) -> UnitTestOutput:
|
||||
return UnitTestOutput(
|
||||
key=input.key,
|
||||
number=input.number,
|
||||
additional_metadata=ctx.additional_metadata,
|
||||
retry_count=ctx.retry_count,
|
||||
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
|
||||
)
|
||||
|
||||
|
||||
simple_workflow = hatchet.workflow(
|
||||
name="simple-unit-test-workflow", input_validator=UnitTestInput
|
||||
)
|
||||
|
||||
|
||||
@simple_workflow.task()
|
||||
def sync_simple_workflow(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
|
||||
return UnitTestOutput(
|
||||
key=input.key,
|
||||
number=input.number,
|
||||
additional_metadata=ctx.additional_metadata,
|
||||
retry_count=ctx.retry_count,
|
||||
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
|
||||
)
|
||||
|
||||
|
||||
@simple_workflow.task()
|
||||
async def async_simple_workflow(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
|
||||
return UnitTestOutput(
|
||||
key=input.key,
|
||||
number=input.number,
|
||||
additional_metadata=ctx.additional_metadata,
|
||||
retry_count=ctx.retry_count,
|
||||
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
|
||||
)
|
||||
|
||||
|
||||
@simple_workflow.durable_task()
|
||||
def durable_sync_simple_workflow(
|
||||
input: UnitTestInput, ctx: DurableContext
|
||||
) -> UnitTestOutput:
|
||||
return UnitTestOutput(
|
||||
key=input.key,
|
||||
number=input.number,
|
||||
additional_metadata=ctx.additional_metadata,
|
||||
retry_count=ctx.retry_count,
|
||||
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
|
||||
)
|
||||
|
||||
|
||||
@simple_workflow.durable_task()
|
||||
async def durable_async_simple_workflow(
|
||||
input: UnitTestInput, ctx: DurableContext
|
||||
) -> UnitTestOutput:
|
||||
return UnitTestOutput(
|
||||
key=input.key,
|
||||
number=input.number,
|
||||
additional_metadata=ctx.additional_metadata,
|
||||
retry_count=ctx.retry_count,
|
||||
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
|
||||
)
|
||||
|
||||
|
||||
complex_workflow = hatchet.workflow(
|
||||
name="complex-unit-test-workflow", input_validator=UnitTestInput
|
||||
)
|
||||
|
||||
|
||||
@complex_workflow.task()
|
||||
async def start(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
|
||||
return UnitTestOutput(
|
||||
key=input.key,
|
||||
number=input.number,
|
||||
additional_metadata=ctx.additional_metadata,
|
||||
retry_count=ctx.retry_count,
|
||||
mock_db_url=cast(Lifespan, ctx.lifespan).mock_db_url,
|
||||
)
|
||||
|
||||
|
||||
@complex_workflow.task(
|
||||
parents=[start],
|
||||
)
|
||||
def sync_complex_workflow(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
|
||||
return ctx.task_output(start)
|
||||
|
||||
|
||||
@complex_workflow.task(
|
||||
parents=[start],
|
||||
)
|
||||
async def async_complex_workflow(input: UnitTestInput, ctx: Context) -> UnitTestOutput:
|
||||
return ctx.task_output(start)
|
||||
|
||||
|
||||
@complex_workflow.durable_task(
|
||||
parents=[start],
|
||||
)
|
||||
def durable_sync_complex_workflow(
|
||||
input: UnitTestInput, ctx: DurableContext
|
||||
) -> UnitTestOutput:
|
||||
return ctx.task_output(start)
|
||||
|
||||
|
||||
@complex_workflow.durable_task(
|
||||
parents=[start],
|
||||
)
|
||||
async def durable_async_complex_workflow(
|
||||
input: UnitTestInput, ctx: DurableContext
|
||||
) -> UnitTestOutput:
|
||||
return ctx.task_output(start)
|
||||
@@ -23,6 +23,7 @@ from examples.lifespans.simple import lifespan, lifespan_task
|
||||
from examples.logger.workflow import logging_workflow
|
||||
from examples.non_retryable.worker import non_retryable_workflow
|
||||
from examples.on_failure.worker import on_failure_wf, on_failure_wf_with_details
|
||||
from examples.return_exceptions.worker import return_exceptions_task
|
||||
from examples.simple.worker import simple, simple_durable
|
||||
from examples.timeout.worker import refresh_timeout_wf, timeout_wf
|
||||
from hatchet_sdk import Hatchet
|
||||
@@ -65,6 +66,7 @@ def main() -> None:
|
||||
bulk_replay_test_1,
|
||||
bulk_replay_test_2,
|
||||
bulk_replay_test_3,
|
||||
return_exceptions_task,
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user