diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index 0bb15b17df..d8cbc38860 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -195,3 +195,30 @@ def test_graph_executes_depth_first(): assert get_completed_count(g, "prompt_iterated") == 2 assert get_completed_count(g, "prompt_successor") == 2 + + +# Because this tests deterministic ordering, we run it multiple times +@pytest.mark.parametrize("execution_number", range(5)) +def test_graph_iterate_execution_order(execution_number: int): + """Tests that iterate nodes execution is ordered by the order of the collection""" + + graph = Graph() + + test_prompts = ["Banana sushi", "Cat sushi", "Strawberry Sushi", "Dinosaur Sushi"] + graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) + graph.add_node(IterateInvocation(id="iterate")) + graph.add_node(PromptTestInvocation(id="prompt_iterated")) + graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) + graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) + + g = GraphExecutionState(graph=graph) + _ = invoke_next(g) + _ = invoke_next(g) + assert _[1].item == "Banana sushi" + _ = invoke_next(g) + assert _[1].item == "Cat sushi" + _ = invoke_next(g) + assert _[1].item == "Strawberry Sushi" + _ = invoke_next(g) + assert _[1].item == "Dinosaur Sushi" + _ = invoke_next(g)