mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(tests): add test for batch with subgraph [WIP]
The tests still don't work due to the test events service not emitting events the batch mgr can listen for.
This commit is contained in:
parent
e8a4a654ac
commit
2185c85287
@ -20,6 +20,7 @@ from invokeai.app.services.batch_manager import (
|
|||||||
from invokeai.app.services.graph import (
|
from invokeai.app.services.graph import (
|
||||||
Graph,
|
Graph,
|
||||||
GraphExecutionState,
|
GraphExecutionState,
|
||||||
|
GraphInvocation,
|
||||||
LibraryGraph,
|
LibraryGraph,
|
||||||
)
|
)
|
||||||
import pytest
|
import pytest
|
||||||
@ -69,6 +70,47 @@ def simple_batch():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def graph_with_subgraph():
|
||||||
|
g = Graph()
|
||||||
|
g.add_node(GraphInvocation(id="1", graph=simple_graph()))
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def batch_with_subgraph():
|
||||||
|
return Batch(
|
||||||
|
data=[
|
||||||
|
[
|
||||||
|
BatchData(
|
||||||
|
node_path="1.1",
|
||||||
|
field_name="prompt",
|
||||||
|
items=[
|
||||||
|
"Tomato sushi",
|
||||||
|
"Strawberry sushi",
|
||||||
|
"Broccoli sushi",
|
||||||
|
"Asparagus sushi",
|
||||||
|
"Tea sushi",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
BatchData(
|
||||||
|
node_path="1.2",
|
||||||
|
field_name="prompt",
|
||||||
|
items=[
|
||||||
|
"Ume sushi",
|
||||||
|
"Ichigo sushi",
|
||||||
|
"Momo sushi",
|
||||||
|
"Mikan sushi",
|
||||||
|
"Cha sushi",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||||
# the test invocations.
|
# the test invocations.
|
||||||
@ -173,6 +215,26 @@ def test_handles_errors(mock_invoker: Invoker):
|
|||||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
|
||||||
|
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||||
|
batch=batch_with_subgraph,
|
||||||
|
graph=graph_with_subgraph,
|
||||||
|
)
|
||||||
|
assert batch_process_res.batch_id
|
||||||
|
assert len(batch_process_res.session_ids) == 25
|
||||||
|
# TODO: without the mock events service emitting the `graph_execution_state` events,
|
||||||
|
# the batch sessions do not know when they have finished, so this logic will fail
|
||||||
|
|
||||||
|
# mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||||
|
|
||||||
|
# def has_executed_all_batches(batch_id: str):
|
||||||
|
# batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id)
|
||||||
|
# print(batch_sessions)
|
||||||
|
# return all((s.state == "completed" for s in batch_sessions))
|
||||||
|
|
||||||
|
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
||||||
|
|
||||||
|
|
||||||
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
||||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||||
batch=simple_batch,
|
batch=simple_batch,
|
||||||
|
Loading…
Reference in New Issue
Block a user