feat(tests): add test for batch with subgraph [WIP]

The tests still don't work due to the test events service not emitting events the batch mgr can listen for.
This commit is contained in:
psychedelicious 2023-09-05 17:31:55 +10:00
parent e8a4a654ac
commit 2185c85287

View File

@ -20,6 +20,7 @@ from invokeai.app.services.batch_manager import (
from invokeai.app.services.graph import (
Graph,
GraphExecutionState,
GraphInvocation,
LibraryGraph,
)
import pytest
@ -69,6 +70,47 @@ def simple_batch():
)
@pytest.fixture
def graph_with_subgraph():
g = Graph()
g.add_node(GraphInvocation(id="1", graph=simple_graph()))
return g
@pytest.fixture
def batch_with_subgraph():
return Batch(
data=[
[
BatchData(
node_path="1.1",
field_name="prompt",
items=[
"Tomato sushi",
"Strawberry sushi",
"Broccoli sushi",
"Asparagus sushi",
"Tea sushi",
],
)
],
[
BatchData(
node_path="1.2",
field_name="prompt",
items=[
"Ume sushi",
"Ichigo sushi",
"Momo sushi",
"Mikan sushi",
"Cha sushi",
],
)
],
]
)
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@ -173,6 +215,26 @@ def test_handles_errors(mock_invoker: Invoker):
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
batch=batch_with_subgraph,
graph=graph_with_subgraph,
)
assert batch_process_res.batch_id
assert len(batch_process_res.session_ids) == 25
# TODO: without the mock events service emitting the `graph_execution_state` events,
# the batch sessions do not know when they have finished, so this logic will fail
# mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
# def has_executed_all_batches(batch_id: str):
# batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id)
# print(batch_sessions)
# return all((s.state == "completed" for s in batch_sessions))
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
batch=simple_batch,