mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests: add test for iterate execution order
This commit is contained in:
parent
83b3828b55
commit
b4b1dbdd34
@ -195,3 +195,30 @@ def test_graph_executes_depth_first():
|
|||||||
|
|
||||||
assert get_completed_count(g, "prompt_iterated") == 2
|
assert get_completed_count(g, "prompt_iterated") == 2
|
||||||
assert get_completed_count(g, "prompt_successor") == 2
|
assert get_completed_count(g, "prompt_successor") == 2
|
||||||
|
|
||||||
|
|
||||||
|
# Because this tests deterministic ordering, we run it multiple times
|
||||||
|
@pytest.mark.parametrize("execution_number", range(5))
|
||||||
|
def test_graph_iterate_execution_order(execution_number: int):
|
||||||
|
"""Tests that iterate nodes execution is ordered by the order of the collection"""
|
||||||
|
|
||||||
|
graph = Graph()
|
||||||
|
|
||||||
|
test_prompts = ["Banana sushi", "Cat sushi", "Strawberry Sushi", "Dinosaur Sushi"]
|
||||||
|
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
||||||
|
graph.add_node(IterateInvocation(id="iterate"))
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
||||||
|
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
||||||
|
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
||||||
|
|
||||||
|
g = GraphExecutionState(graph=graph)
|
||||||
|
_ = invoke_next(g)
|
||||||
|
_ = invoke_next(g)
|
||||||
|
assert _[1].item == "Banana sushi"
|
||||||
|
_ = invoke_next(g)
|
||||||
|
assert _[1].item == "Cat sushi"
|
||||||
|
_ = invoke_next(g)
|
||||||
|
assert _[1].item == "Strawberry Sushi"
|
||||||
|
_ = invoke_next(g)
|
||||||
|
assert _[1].item == "Dinosaur Sushi"
|
||||||
|
_ = invoke_next(g)
|
||||||
|
Loading…
Reference in New Issue
Block a user