From d54168b8fb84ea45a8a108228754f7051e20d68a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 9 Jun 2023 13:36:49 +1000 Subject: [PATCH] feat(nodes): add tests for depth-first execution --- tests/nodes/test_graph_execution_state.py | 75 +++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 9f433aa330..5363cc480b 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -121,3 +121,78 @@ def test_graph_state_collects(mock_services): assert isinstance(n6[0], CollectInvocation) assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts) + + +def test_graph_state_prepares_eagerly(mock_services): + """Tests that all prepareable nodes are prepared""" + graph = Graph() + + test_prompts = ["Banana sushi", "Cat sushi"] + graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) + graph.add_node(IterateInvocation(id="iterate")) + graph.add_node(PromptTestInvocation(id="prompt_iterated")) + graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) + graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) + + # separated, fully-preparable chain of nodes + graph.add_node(PromptTestInvocation(id="prompt_chain_1", prompt="Dinosaur sushi")) + graph.add_node(PromptTestInvocation(id="prompt_chain_2")) + graph.add_node(PromptTestInvocation(id="prompt_chain_3")) + graph.add_edge(create_edge("prompt_chain_1", "prompt", "prompt_chain_2", "prompt")) + graph.add_edge(create_edge("prompt_chain_2", "prompt", "prompt_chain_3", "prompt")) + + g = GraphExecutionState(graph=graph) + g.next() + + assert "prompt_collection" in g.source_prepared_mapping + assert "prompt_chain_1" in g.source_prepared_mapping + assert "prompt_chain_2" in g.source_prepared_mapping + assert "prompt_chain_3" in g.source_prepared_mapping + assert "iterate" not in g.source_prepared_mapping + assert "prompt_iterated" not in g.source_prepared_mapping + + +def test_graph_executes_depth_first(mock_services): + """Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch""" + graph = Graph() + + test_prompts = ["Banana sushi", "Cat sushi"] + graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) + graph.add_node(IterateInvocation(id="iterate")) + graph.add_node(PromptTestInvocation(id="prompt_iterated")) + graph.add_node(PromptTestInvocation(id="prompt_successor")) + graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) + graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) + graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")) + + g = GraphExecutionState(graph=graph) + n1 = invoke_next(g, mock_services) + n2 = invoke_next(g, mock_services) + n3 = invoke_next(g, mock_services) + n4 = invoke_next(g, mock_services) + + # Because ordering is not guaranteed, we cannot compare results directly. + # Instead, we must count the number of results. + def get_completed_count(g, id): + ids = [i for i in g.source_prepared_mapping[id]] + completed_ids = [i for i in g.executed if i in ids] + return len(completed_ids) + + # Check at each step that the number of executed nodes matches the expectation for depth-first execution + assert get_completed_count(g, "prompt_iterated") == 1 + assert get_completed_count(g, "prompt_successor") == 0 + + n5 = invoke_next(g, mock_services) + + assert get_completed_count(g, "prompt_iterated") == 1 + assert get_completed_count(g, "prompt_successor") == 1 + + n6 = invoke_next(g, mock_services) + + assert get_completed_count(g, "prompt_iterated") == 2 + assert get_completed_count(g, "prompt_successor") == 1 + + n7 = invoke_next(g, mock_services) + + assert get_completed_count(g, "prompt_iterated") == 2 + assert get_completed_count(g, "prompt_successor") == 2