diff --git a/sdks/python/examples/concurrency_multiple_keys/test_multiple_concurrency_keys.py b/sdks/python/examples/concurrency_multiple_keys/test_multiple_concurrency_keys.py index 620f4a3c5..32fd32830 100644 --- a/sdks/python/examples/concurrency_multiple_keys/test_multiple_concurrency_keys.py +++ b/sdks/python/examples/concurrency_multiple_keys/test_multiple_concurrency_keys.py @@ -52,7 +52,7 @@ class RunMetadata(BaseModel): @pytest.mark.asyncio() -async def test_priority(hatchet: Hatchet) -> None: +async def test_multi_concurrency_key(hatchet: Hatchet) -> None: test_run_id = str(uuid4()) run_refs = await concurrency_multiple_keys_workflow.aio_run_many_no_wait( @@ -60,7 +60,7 @@ async def test_priority(hatchet: Hatchet) -> None: concurrency_multiple_keys_workflow.create_bulk_run_item( WorkflowInput( name=(name := choice(characters)), - digit=(digit := choice([str(i) for i in range(6)])), + digit=(digit := choice([str(i) for i in range(3)])), ), options=TriggerWorkflowOptions( additional_metadata={ @@ -108,35 +108,34 @@ async def test_priority(hatchet: Hatchet) -> None: overlapping_groups: dict[int, list[RunMetadata]] = {} for run in sorted_runs: - has_been_assigned = False + has_group_membership = False + if not overlapping_groups: overlapping_groups[1] = [run] - has_been_assigned = True + continue + + if has_group_membership: continue for id, group in overlapping_groups.items(): - if has_been_assigned: + if all(are_overlapping(run, task) for task in group): + overlapping_groups[id].append(run) + has_group_membership = True break - for task in group: - if ( - run.started_at < task.finished_at - and run.finished_at > task.started_at - ) or ( - run.finished_at > task.started_at - and run.started_at < task.finished_at - ): - overlapping_groups[id].append(run) - has_been_assigned = True - break - - if not has_been_assigned: + if not has_group_membership: overlapping_groups[len(overlapping_groups) + 1] = [run] for id, group in overlapping_groups.items(): assert is_valid_group(group), f"Group {id} is not valid" +def are_overlapping(x: RunMetadata, y: RunMetadata) -> bool: + return (x.started_at < y.finished_at and x.finished_at > y.started_at) or ( + x.finished_at > y.started_at and x.started_at < y.finished_at + ) + + def is_valid_group(group: list[RunMetadata]) -> bool: digits = Counter[str]() names = Counter[str]()