mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(tests): add test for batch with subgraph [WIP]
The tests still don't work due to the test events service not emitting events the batch mgr can listen for.
This commit is contained in:
parent
e8a4a654ac
commit
2185c85287
@ -20,6 +20,7 @@ from invokeai.app.services.batch_manager import (
|
||||
from invokeai.app.services.graph import (
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
GraphInvocation,
|
||||
LibraryGraph,
|
||||
)
|
||||
import pytest
|
||||
@ -69,6 +70,47 @@ def simple_batch():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_with_subgraph():
|
||||
g = Graph()
|
||||
g.add_node(GraphInvocation(id="1", graph=simple_graph()))
|
||||
return g
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_with_subgraph():
|
||||
return Batch(
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_path="1.1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
"Strawberry sushi",
|
||||
"Broccoli sushi",
|
||||
"Asparagus sushi",
|
||||
"Tea sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
[
|
||||
BatchData(
|
||||
node_path="1.2",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
"Ichigo sushi",
|
||||
"Momo sushi",
|
||||
"Mikan sushi",
|
||||
"Cha sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||
# the test invocations.
|
||||
@ -173,6 +215,26 @@ def test_handles_errors(mock_invoker: Invoker):
|
||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
||||
|
||||
|
||||
def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
batch=batch_with_subgraph,
|
||||
graph=graph_with_subgraph,
|
||||
)
|
||||
assert batch_process_res.batch_id
|
||||
assert len(batch_process_res.session_ids) == 25
|
||||
# TODO: without the mock events service emitting the `graph_execution_state` events,
|
||||
# the batch sessions do not know when they have finished, so this logic will fail
|
||||
|
||||
# mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||
|
||||
# def has_executed_all_batches(batch_id: str):
|
||||
# batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id)
|
||||
# print(batch_sessions)
|
||||
# return all((s.state == "completed" for s in batch_sessions))
|
||||
|
||||
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
||||
|
||||
|
||||
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
batch=simple_batch,
|
||||
|
Loading…
Reference in New Issue
Block a user