InvokeAI/tests/nodes/test_invoker.py

185 lines
6.2 KiB
Python
Raw Normal View History

2023-06-29 06:01:17 +00:00
from .test_nodes import (
TestEventService,
ErrorInvocation,
TextToImageTestInvocation,
PromptTestInvocation,
create_edge,
wait_until,
)
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invoker import Invoker
2023-06-29 06:01:17 +00:00
from invokeai.app.services.invocation_services import InvocationServices
2023-08-02 22:31:10 +00:00
from invokeai.app.services.invocation_stats import InvocationStatsService
2023-08-16 19:33:15 +00:00
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
2023-08-14 14:57:18 +00:00
from invokeai.app.services.batch_manager import (
Batch,
BatchManager,
)
2023-06-29 06:01:17 +00:00
from invokeai.app.services.graph import (
Graph,
GraphExecutionState,
LibraryGraph,
)
import pytest
2023-08-16 18:35:49 +00:00
import sqlite3
@pytest.fixture
def simple_graph():
g = Graph()
2023-06-29 06:01:17 +00:00
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
g.add_node(TextToImageTestInvocation(id="2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g
2023-06-29 06:01:17 +00:00
2023-08-14 14:57:18 +00:00
@pytest.fixture
2023-08-16 19:33:15 +00:00
def simple_batch():
return Batch(
data=[
[
BatchData(
node_id="1",
field_name="prompt",
items=[
"Tomato sushi",
"Strawberry sushi",
"Broccoli sushi",
"Asparagus sushi",
"Tea sushi",
],
)
2023-08-14 15:01:31 +00:00
],
2023-08-16 19:33:15 +00:00
[
BatchData(
node_id="2",
field_name="prompt",
items=[
"Ume sushi",
"Ichigo sushi",
"Momo sushi",
"Mikan sushi",
"Cha sushi",
],
)
2023-08-14 15:01:31 +00:00
],
2023-08-16 19:33:15 +00:00
]
)
2023-08-14 14:57:18 +00:00
2023-06-29 06:01:17 +00:00
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
2023-08-16 18:35:49 +00:00
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
2023-08-16 19:33:15 +00:00
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
2023-08-16 18:35:49 +00:00
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
2023-06-29 06:01:17 +00:00
return InvocationServices(
2023-07-27 14:54:01 +00:00
model_manager=None, # type: ignore
events=TestEventService(),
logger=None, # type: ignore
images=None, # type: ignore
latents=None, # type: ignore
2023-08-14 14:57:18 +00:00
batch_manager=BatchManager(batch_manager_storage),
2023-07-27 14:54:01 +00:00
boards=None, # type: ignore
board_images=None, # type: ignore
queue=MemoryInvocationQueue(),
2023-08-16 18:35:49 +00:00
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
2023-08-02 22:31:10 +00:00
graph_execution_manager=graph_execution_manager,
2023-07-27 14:54:01 +00:00
processor=DefaultInvocationProcessor(),
2023-08-02 22:31:10 +00:00
performance_statistics=InvocationStatsService(graph_execution_manager),
2023-07-27 14:54:01 +00:00
configuration=None, # type: ignore
2023-06-29 06:01:17 +00:00
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
2023-06-29 06:01:17 +00:00
return Invoker(services=mock_services)
def test_can_create_graph_state(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
mock_invoker.stop()
assert g is not None
assert isinstance(g, GraphExecutionState)
2023-06-29 06:01:17 +00:00
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
2023-06-29 06:01:17 +00:00
g = mock_invoker.create_execution_state(graph=simple_graph)
mock_invoker.stop()
assert g is not None
assert isinstance(g, GraphExecutionState)
assert g.graph == simple_graph
2023-06-29 06:01:17 +00:00
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke(mock_invoker: Invoker, simple_graph):
2023-06-29 06:01:17 +00:00
g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(g)
assert invocation_id is not None
def has_executed_any(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id)
return len(g.executed) > 0
2023-06-29 06:01:17 +00:00
wait_until(lambda: has_executed_any(g), timeout=5, interval=1)
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0
2023-06-29 06:01:17 +00:00
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
2023-06-29 06:01:17 +00:00
g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(g, invoke_all=True)
assert invocation_id is not None
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id)
return g.is_complete()
2023-06-29 06:01:17 +00:00
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.is_complete()
2023-06-29 06:01:17 +00:00
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_handles_errors(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
2023-06-29 06:01:17 +00:00
g.graph.add_node(ErrorInvocation(id="1"))
mock_invoker.invoke(g, invoke_all=True)
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id)
return g.is_complete()
2023-06-29 06:01:17 +00:00
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.has_error()
assert g.is_complete()
2023-06-29 06:01:17 +00:00
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
2023-08-14 14:57:18 +00:00
2023-08-14 15:01:31 +00:00
2023-08-16 19:33:15 +00:00
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
2023-08-14 14:57:18 +00:00
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
2023-08-16 19:33:15 +00:00
batch=simple_batch,
2023-08-14 14:57:18 +00:00
graph=simple_graph,
)
assert batch_process_res.batch_id
assert len(batch_process_res.session_ids) == 25
for session in batch_process_res.session_ids:
assert mock_invoker.services.graph_execution_manager.get(session)