diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index e6f0cfaaa3..8086532943 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -20,6 +20,7 @@ from invokeai.app.services.batch_manager import ( from invokeai.app.services.graph import ( Graph, GraphExecutionState, + GraphInvocation, LibraryGraph, ) import pytest @@ -69,6 +70,47 @@ def simple_batch(): ) +@pytest.fixture +def graph_with_subgraph(): + g = Graph() + g.add_node(GraphInvocation(id="1", graph=simple_graph())) + return g + + +@pytest.fixture +def batch_with_subgraph(): + return Batch( + data=[ + [ + BatchData( + node_path="1.1", + field_name="prompt", + items=[ + "Tomato sushi", + "Strawberry sushi", + "Broccoli sushi", + "Asparagus sushi", + "Tea sushi", + ], + ) + ], + [ + BatchData( + node_path="1.2", + field_name="prompt", + items=[ + "Ume sushi", + "Ichigo sushi", + "Momo sushi", + "Mikan sushi", + "Cha sushi", + ], + ) + ], + ] + ) + + # This must be defined here to avoid issues with the dynamic creation of the union of all invocation types # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # the test invocations. @@ -173,6 +215,26 @@ def test_handles_errors(mock_invoker: Invoker): assert all((i in g.errors for i in g.source_prepared_mapping["1"])) +def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph): + batch_process_res = mock_invoker.services.batch_manager.create_batch_process( + batch=batch_with_subgraph, + graph=graph_with_subgraph, + ) + assert batch_process_res.batch_id + assert len(batch_process_res.session_ids) == 25 + # TODO: without the mock events service emitting the `graph_execution_state` events, + # the batch sessions do not know when they have finished, so this logic will fail + + # mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id) + + # def has_executed_all_batches(batch_id: str): + # batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id) + # print(batch_sessions) + # return all((s.state == "completed" for s in batch_sessions)) + + # wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1) + + def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch): batch_process_res = mock_invoker.services.batch_manager.create_batch_process( batch=simple_batch,