mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
271 lines
8.9 KiB
Python
271 lines
8.9 KiB
Python
from .test_nodes import (
|
|
TestEventService,
|
|
ErrorInvocation,
|
|
TextToImageTestInvocation,
|
|
PromptTestInvocation,
|
|
create_edge,
|
|
wait_until,
|
|
)
|
|
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
|
from invokeai.app.services.processor import DefaultInvocationProcessor
|
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
|
from invokeai.app.services.invoker import Invoker
|
|
from invokeai.app.services.invocation_services import InvocationServices
|
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
|
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
|
|
from invokeai.app.services.batch_manager import (
|
|
Batch,
|
|
BatchManager,
|
|
)
|
|
from invokeai.app.services.graph import (
|
|
Graph,
|
|
GraphExecutionState,
|
|
LibraryGraph,
|
|
)
|
|
import pytest
|
|
import sqlite3
|
|
|
|
|
|
@pytest.fixture
|
|
def simple_graph():
|
|
g = Graph()
|
|
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
|
g.add_node(TextToImageTestInvocation(id="2"))
|
|
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
|
return g
|
|
|
|
|
|
@pytest.fixture
|
|
def simple_batch():
|
|
return Batch(
|
|
data=[
|
|
[
|
|
BatchData(
|
|
node_id="1",
|
|
field_name="prompt",
|
|
items=[
|
|
"Tomato sushi",
|
|
"Strawberry sushi",
|
|
"Broccoli sushi",
|
|
"Asparagus sushi",
|
|
"Tea sushi",
|
|
],
|
|
)
|
|
],
|
|
[
|
|
BatchData(
|
|
node_id="2",
|
|
field_name="prompt",
|
|
items=[
|
|
"Ume sushi",
|
|
"Ichigo sushi",
|
|
"Momo sushi",
|
|
"Mikan sushi",
|
|
"Cha sushi",
|
|
],
|
|
)
|
|
],
|
|
]
|
|
)
|
|
|
|
|
|
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
|
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
|
# the test invocations.
|
|
@pytest.fixture
|
|
def mock_services() -> InvocationServices:
|
|
# NOTE: none of these are actually called by the test invocations
|
|
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
|
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
|
return InvocationServices(
|
|
model_manager=None, # type: ignore
|
|
events=TestEventService(),
|
|
logger=None, # type: ignore
|
|
images=None, # type: ignore
|
|
latents=None, # type: ignore
|
|
batch_manager=BatchManager(batch_manager_storage),
|
|
boards=None, # type: ignore
|
|
board_images=None, # type: ignore
|
|
queue=MemoryInvocationQueue(),
|
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
|
graph_execution_manager=graph_execution_manager,
|
|
processor=DefaultInvocationProcessor(),
|
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
|
configuration=None, # type: ignore
|
|
)
|
|
|
|
|
|
@pytest.fixture()
|
|
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
|
return Invoker(services=mock_services)
|
|
|
|
|
|
def test_can_create_graph_state(mock_invoker: Invoker):
|
|
g = mock_invoker.create_execution_state()
|
|
mock_invoker.stop()
|
|
|
|
assert g is not None
|
|
assert isinstance(g, GraphExecutionState)
|
|
|
|
|
|
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
|
|
g = mock_invoker.create_execution_state(graph=simple_graph)
|
|
mock_invoker.stop()
|
|
|
|
assert g is not None
|
|
assert isinstance(g, GraphExecutionState)
|
|
assert g.graph == simple_graph
|
|
|
|
|
|
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
|
def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
|
g = mock_invoker.create_execution_state(graph=simple_graph)
|
|
invocation_id = mock_invoker.invoke(g)
|
|
assert invocation_id is not None
|
|
|
|
def has_executed_any(g: GraphExecutionState):
|
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
|
return len(g.executed) > 0
|
|
|
|
wait_until(lambda: has_executed_any(g), timeout=5, interval=1)
|
|
mock_invoker.stop()
|
|
|
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
|
assert len(g.executed) > 0
|
|
|
|
|
|
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
|
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
|
g = mock_invoker.create_execution_state(graph=simple_graph)
|
|
invocation_id = mock_invoker.invoke(g, invoke_all=True)
|
|
assert invocation_id is not None
|
|
|
|
def has_executed_all(g: GraphExecutionState):
|
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
|
return g.is_complete()
|
|
|
|
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
|
|
mock_invoker.stop()
|
|
|
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
|
assert g.is_complete()
|
|
|
|
|
|
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
|
def test_handles_errors(mock_invoker: Invoker):
|
|
g = mock_invoker.create_execution_state()
|
|
g.graph.add_node(ErrorInvocation(id="1"))
|
|
|
|
mock_invoker.invoke(g, invoke_all=True)
|
|
|
|
def has_executed_all(g: GraphExecutionState):
|
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
|
return g.is_complete()
|
|
|
|
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
|
|
mock_invoker.stop()
|
|
|
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
|
assert g.has_error()
|
|
assert g.is_complete()
|
|
|
|
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
|
|
|
|
|
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
|
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
|
batch=simple_batch,
|
|
graph=simple_graph,
|
|
)
|
|
assert batch_process_res.batch_id
|
|
assert len(batch_process_res.session_ids) == 25
|
|
# TODO: without the mock events service emitting the `graph_execution_state` events,
|
|
# the batch sessions do not know when they have finished, so this logic will fail
|
|
|
|
# mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
|
|
|
# def has_executed_all_batches(batch_id: str):
|
|
# batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id)
|
|
# print(batch_sessions)
|
|
# return all((s.state == "completed" for s in batch_sessions))
|
|
|
|
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
|
|
|
|
|
def test_can_create_bad_batches():
|
|
batch = None
|
|
try:
|
|
batch = Batch( # This batch has a duplicate node_id|fieldname combo
|
|
data=[
|
|
[
|
|
BatchData(
|
|
node_id="1",
|
|
field_name="prompt",
|
|
items=[
|
|
"Tomato sushi",
|
|
],
|
|
)
|
|
],
|
|
[
|
|
BatchData(
|
|
node_id="1",
|
|
field_name="prompt",
|
|
items=[
|
|
"Ume sushi",
|
|
],
|
|
)
|
|
],
|
|
]
|
|
)
|
|
except Exception as e:
|
|
assert e
|
|
try:
|
|
batch = Batch( # This batch has different item list lengths in the same group
|
|
data=[
|
|
[
|
|
BatchData(
|
|
node_id="1",
|
|
field_name="prompt",
|
|
items=[
|
|
"Tomato sushi",
|
|
],
|
|
),
|
|
BatchData(
|
|
node_id="1",
|
|
field_name="prompt",
|
|
items=[
|
|
"Tomato sushi",
|
|
"Courgette sushi",
|
|
],
|
|
),
|
|
],
|
|
[
|
|
BatchData(
|
|
node_id="1",
|
|
field_name="prompt",
|
|
items=[
|
|
"Ume sushi",
|
|
],
|
|
)
|
|
],
|
|
]
|
|
)
|
|
except Exception as e:
|
|
assert e
|
|
try:
|
|
batch = Batch( # This batch has a type mismatch in single items list
|
|
data=[
|
|
[
|
|
BatchData(
|
|
node_id="1",
|
|
field_name="prompt",
|
|
items=["Tomato sushi", 5],
|
|
),
|
|
],
|
|
]
|
|
)
|
|
except Exception as e:
|
|
assert e
|
|
assert not batch
|