import asyncio from collections.abc import AsyncGenerator from datetime import datetime, timedelta, timezone from random import choice from subprocess import Popen from typing import Any, Literal from uuid import uuid4 import pytest import pytest_asyncio from pydantic import BaseModel from examples.priority.worker import DEFAULT_PRIORITY, SLEEP_TIME, priority_workflow from hatchet_sdk import Hatchet, ScheduleTriggerWorkflowOptions, TriggerWorkflowOptions from hatchet_sdk.clients.rest.models.v1_task_status import V1TaskStatus Priority = Literal["low", "medium", "high", "default"] class RunPriorityStartedAt(BaseModel): priority: Priority started_at: datetime finished_at: datetime def priority_to_int(priority: Priority) -> int: match priority: case "high": return 3 case "medium": return 2 case "low": return 1 case "default": return DEFAULT_PRIORITY case _: raise ValueError(f"Invalid priority: {priority}") @pytest_asyncio.fixture(loop_scope="session", scope="function") async def dummy_runs() -> None: priority: Priority = "high" await priority_workflow.aio_run_many_no_wait( [ priority_workflow.create_bulk_run_item( options=TriggerWorkflowOptions( priority=(priority_to_int(priority)), additional_metadata={ "priority": priority, "key": ix, "type": "dummy", }, ) ) for ix in range(40) ] ) await asyncio.sleep(3) return @pytest.mark.parametrize( "on_demand_worker", [ ( ["poetry", "run", "python", "examples/priority/worker.py", "--slots", "1"], 8003, ) ], indirect=True, ) @pytest.mark.asyncio(loop_scope="session") async def test_priority( hatchet: Hatchet, dummy_runs: None, on_demand_worker: Popen[Any] ) -> None: test_run_id = str(uuid4()) choices: list[Priority] = ["low", "medium", "high", "default"] N = 30 run_refs = await priority_workflow.aio_run_many_no_wait( [ priority_workflow.create_bulk_run_item( options=TriggerWorkflowOptions( priority=(priority_to_int(priority := choice(choices))), additional_metadata={ "priority": priority, "key": ix, "test_run_id": test_run_id, }, ) ) for ix in range(N) ] ) await asyncio.gather(*[r.aio_result() for r in run_refs]) workflows = ( await hatchet.workflows.aio_list(workflow_name=priority_workflow.name) ).rows assert workflows workflow = next((w for w in workflows if w.name == priority_workflow.name), None) assert workflow assert workflow.name == priority_workflow.name runs = await hatchet.runs.aio_list( workflow_ids=[workflow.metadata.id], additional_metadata={ "test_run_id": test_run_id, }, limit=1_000, ) runs_ids_started_ats: list[RunPriorityStartedAt] = sorted( [ RunPriorityStartedAt( priority=(r.additional_metadata or {}).get("priority") or "low", started_at=r.started_at or datetime.min, finished_at=r.finished_at or datetime.min, ) for r in runs.rows ], key=lambda x: x.started_at, ) assert len(runs_ids_started_ats) == len(run_refs) assert len(runs_ids_started_ats) == N for i in range(len(runs_ids_started_ats) - 1): curr = runs_ids_started_ats[i] nxt = runs_ids_started_ats[i + 1] """Run start times should be in order of priority""" assert priority_to_int(curr.priority) >= priority_to_int(nxt.priority) """Runs should proceed one at a time""" assert curr.finished_at <= nxt.finished_at assert nxt.finished_at >= nxt.started_at """Runs should finish after starting (this is mostly a test for engine datetime handling bugs)""" assert curr.finished_at >= curr.started_at @pytest.mark.parametrize( "on_demand_worker", [ ( ["poetry", "run", "python", "examples/priority/worker.py", "--slots", "1"], 8003, ) ], indirect=True, ) @pytest.mark.asyncio(loop_scope="session") async def test_priority_via_scheduling( hatchet: Hatchet, dummy_runs: None, on_demand_worker: Popen[Any] ) -> None: test_run_id = str(uuid4()) sleep_time = 3 n = 30 choices: list[Priority] = ["low", "medium", "high", "default"] run_at = datetime.now(tz=timezone.utc) + timedelta(seconds=sleep_time) versions = await asyncio.gather( *[ priority_workflow.aio_schedule( run_at=run_at, options=ScheduleTriggerWorkflowOptions( priority=(priority_to_int(priority := choice(choices))), additional_metadata={ "priority": priority, "key": ix, "test_run_id": test_run_id, }, ), ) for ix in range(n) ] ) await asyncio.sleep(sleep_time * 2) workflow_id = versions[0].workflow_id attempts = 0 while True: if attempts >= SLEEP_TIME * n * 2: raise TimeoutError("Timed out waiting for runs to finish") attempts += 1 await asyncio.sleep(1) runs = await hatchet.runs.aio_list( workflow_ids=[workflow_id], additional_metadata={ "test_run_id": test_run_id, }, limit=1_000, ) if not runs.rows: continue if any( r.status in [V1TaskStatus.FAILED, V1TaskStatus.CANCELLED] for r in runs.rows ): raise ValueError("One or more runs failed or were cancelled") if all(r.status == V1TaskStatus.COMPLETED for r in runs.rows): break runs_ids_started_ats: list[RunPriorityStartedAt] = sorted( [ RunPriorityStartedAt( priority=(r.additional_metadata or {}).get("priority") or "low", started_at=r.started_at or datetime.min, finished_at=r.finished_at or datetime.min, ) for r in runs.rows ], key=lambda x: x.started_at, ) assert len(runs_ids_started_ats) == len(versions) for i in range(len(runs_ids_started_ats) - 1): curr = runs_ids_started_ats[i] nxt = runs_ids_started_ats[i + 1] """Run start times should be in order of priority""" assert priority_to_int(curr.priority) >= priority_to_int(nxt.priority) """Runs should proceed one at a time""" assert curr.finished_at <= nxt.finished_at assert nxt.finished_at >= nxt.started_at """Runs should finish after starting (this is mostly a test for engine datetime handling bugs)""" assert curr.finished_at >= curr.started_at @pytest_asyncio.fixture(loop_scope="session", scope="function") async def crons( hatchet: Hatchet, dummy_runs: None ) -> AsyncGenerator[tuple[str, str, int], None]: test_run_id = str(uuid4()) choices: list[Priority] = ["low", "medium", "high"] n = 30 crons = await asyncio.gather( *[ hatchet.cron.aio_create( workflow_name=priority_workflow.name, cron_name=f"{test_run_id}-cron-{i}", expression="* * * * *", input={}, additional_metadata={ "trigger": "cron", "test_run_id": test_run_id, "priority": (priority := choice(choices)), "key": str(i), }, priority=(priority_to_int(priority)), ) for i in range(n) ] ) yield crons[0].workflow_id, test_run_id, n await asyncio.gather(*[hatchet.cron.aio_delete(cron.metadata.id) for cron in crons]) def time_until_next_minute() -> float: now = datetime.now(tz=timezone.utc) next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0) return (next_minute - now).total_seconds() @pytest.mark.skip( reason="Test is flaky because the first jobs that are picked up don't necessarily go in priority order" ) @pytest.mark.parametrize( "on_demand_worker", [ ( ["poetry", "run", "python", "examples/priority/worker.py", "--slots", "1"], 8003, ) ], indirect=True, ) @pytest.mark.asyncio(loop_scope="session") async def test_priority_via_cron( hatchet: Hatchet, crons: tuple[str, str, int], on_demand_worker: Popen[Any] ) -> None: workflow_id, test_run_id, n = crons await asyncio.sleep(time_until_next_minute() + 10) attempts = 0 while True: if attempts >= SLEEP_TIME * n * 2: raise TimeoutError("Timed out waiting for runs to finish") attempts += 1 await asyncio.sleep(1) runs = await hatchet.runs.aio_list( workflow_ids=[workflow_id], additional_metadata={ "test_run_id": test_run_id, }, limit=1_000, ) if not runs.rows: continue if any( r.status in [V1TaskStatus.FAILED, V1TaskStatus.CANCELLED] for r in runs.rows ): raise ValueError("One or more runs failed or were cancelled") if all(r.status == V1TaskStatus.COMPLETED for r in runs.rows): break runs_ids_started_ats: list[RunPriorityStartedAt] = sorted( [ RunPriorityStartedAt( priority=(r.additional_metadata or {}).get("priority") or "low", started_at=r.started_at or datetime.min, finished_at=r.finished_at or datetime.min, ) for r in runs.rows ], key=lambda x: x.started_at, ) assert len(runs_ids_started_ats) == n for i in range(len(runs_ids_started_ats) - 1): curr = runs_ids_started_ats[i] nxt = runs_ids_started_ats[i + 1] """Run start times should be in order of priority""" assert priority_to_int(curr.priority) >= priority_to_int(nxt.priority) """Runs should proceed one at a time""" assert curr.finished_at <= nxt.finished_at assert nxt.finished_at >= nxt.started_at """Runs should finish after starting (this is mostly a test for engine datetime handling bugs)""" assert curr.finished_at >= curr.started_at