mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add tests for depth-first execution
This commit is contained in:
parent
c91b071c47
commit
d54168b8fb
@ -121,3 +121,78 @@ def test_graph_state_collects(mock_services):
|
|||||||
assert isinstance(n6[0], CollectInvocation)
|
assert isinstance(n6[0], CollectInvocation)
|
||||||
|
|
||||||
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
|
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_state_prepares_eagerly(mock_services):
|
||||||
|
"""Tests that all prepareable nodes are prepared"""
|
||||||
|
graph = Graph()
|
||||||
|
|
||||||
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
||||||
|
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
||||||
|
graph.add_node(IterateInvocation(id="iterate"))
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
||||||
|
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
||||||
|
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
||||||
|
|
||||||
|
# separated, fully-preparable chain of nodes
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_chain_1", prompt="Dinosaur sushi"))
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_chain_2"))
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_chain_3"))
|
||||||
|
graph.add_edge(create_edge("prompt_chain_1", "prompt", "prompt_chain_2", "prompt"))
|
||||||
|
graph.add_edge(create_edge("prompt_chain_2", "prompt", "prompt_chain_3", "prompt"))
|
||||||
|
|
||||||
|
g = GraphExecutionState(graph=graph)
|
||||||
|
g.next()
|
||||||
|
|
||||||
|
assert "prompt_collection" in g.source_prepared_mapping
|
||||||
|
assert "prompt_chain_1" in g.source_prepared_mapping
|
||||||
|
assert "prompt_chain_2" in g.source_prepared_mapping
|
||||||
|
assert "prompt_chain_3" in g.source_prepared_mapping
|
||||||
|
assert "iterate" not in g.source_prepared_mapping
|
||||||
|
assert "prompt_iterated" not in g.source_prepared_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_executes_depth_first(mock_services):
|
||||||
|
"""Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch"""
|
||||||
|
graph = Graph()
|
||||||
|
|
||||||
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
||||||
|
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
||||||
|
graph.add_node(IterateInvocation(id="iterate"))
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_successor"))
|
||||||
|
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
||||||
|
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
||||||
|
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
|
||||||
|
|
||||||
|
g = GraphExecutionState(graph=graph)
|
||||||
|
n1 = invoke_next(g, mock_services)
|
||||||
|
n2 = invoke_next(g, mock_services)
|
||||||
|
n3 = invoke_next(g, mock_services)
|
||||||
|
n4 = invoke_next(g, mock_services)
|
||||||
|
|
||||||
|
# Because ordering is not guaranteed, we cannot compare results directly.
|
||||||
|
# Instead, we must count the number of results.
|
||||||
|
def get_completed_count(g, id):
|
||||||
|
ids = [i for i in g.source_prepared_mapping[id]]
|
||||||
|
completed_ids = [i for i in g.executed if i in ids]
|
||||||
|
return len(completed_ids)
|
||||||
|
|
||||||
|
# Check at each step that the number of executed nodes matches the expectation for depth-first execution
|
||||||
|
assert get_completed_count(g, "prompt_iterated") == 1
|
||||||
|
assert get_completed_count(g, "prompt_successor") == 0
|
||||||
|
|
||||||
|
n5 = invoke_next(g, mock_services)
|
||||||
|
|
||||||
|
assert get_completed_count(g, "prompt_iterated") == 1
|
||||||
|
assert get_completed_count(g, "prompt_successor") == 1
|
||||||
|
|
||||||
|
n6 = invoke_next(g, mock_services)
|
||||||
|
|
||||||
|
assert get_completed_count(g, "prompt_iterated") == 2
|
||||||
|
assert get_completed_count(g, "prompt_successor") == 1
|
||||||
|
|
||||||
|
n7 = invoke_next(g, mock_services)
|
||||||
|
|
||||||
|
assert get_completed_count(g, "prompt_iterated") == 2
|
||||||
|
assert get_completed_count(g, "prompt_successor") == 2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user