mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
94c00cee4c
Refactor services folder/module structure. **Motivation** While working on our services I've repeatedly encountered circular imports and a general lack of clarity regarding where to put things. The structure introduced goes a long way towards resolving those issues, setting us up for a clean structure going forward. **Services** Services are now in their own folder with a few files: - `services/{service_name}/__init__.py`: init as needed, mostly empty now - `services/{service_name}/{service_name}_base.py`: the base class for the service - `services/{service_name}/{service_name}_{impl_type}.py`: the default concrete implementation of the service - typically one of `sqlite`, `default`, or `memory` - `services/{service_name}/{service_name}_common.py`: any common items - models, exceptions, utilities, etc Though it's a bit verbose to have the service name both as the folder name and the prefix for files, I found it is _extremely_ confusing to have all of the base classes just be named `base.py`. So, at the cost of some verbosity when importing things, I've included the service name in the filename. There are some minor logic changes. For example, in `InvocationProcessor`, instead of assigning the model manager service to a variable to be used later in the file, the service is used directly via the `Invoker`. **Shared** Things that are used across disparate services are in `services/shared/`: - `default_graphs.py`: previously in `services/` - `graphs.py`: previously in `services/` - `paginatation`: generic pagination models used in a few services - `sqlite`: the `SqliteDatabase` class, other sqlite-specific things
252 lines
9.8 KiB
Python
252 lines
9.8 KiB
Python
import logging
|
|
|
|
import pytest
|
|
|
|
# This import must happen before other invoke imports or test in other files(!!) break
|
|
from .test_nodes import ( # isort: split
|
|
PromptCollectionTestInvocation,
|
|
PromptTestInvocation,
|
|
TestEventService,
|
|
TextToImageTestInvocation,
|
|
)
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
|
from invokeai.app.invocations.collections import RangeInvocation
|
|
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
|
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
|
from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
|
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
|
from invokeai.app.services.invocation_services import InvocationServices
|
|
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
|
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
|
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
|
from invokeai.app.services.shared.graph import (
|
|
CollectInvocation,
|
|
Graph,
|
|
GraphExecutionState,
|
|
IterateInvocation,
|
|
LibraryGraph,
|
|
)
|
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
|
from invokeai.backend.util.logging import InvokeAILogger
|
|
|
|
from .test_invoker import create_edge
|
|
|
|
|
|
@pytest.fixture
|
|
def simple_graph():
|
|
g = Graph()
|
|
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
|
g.add_node(TextToImageTestInvocation(id="2"))
|
|
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
|
return g
|
|
|
|
|
|
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
|
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
|
# the test invocations.
|
|
@pytest.fixture
|
|
def mock_services() -> InvocationServices:
|
|
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
|
db = SqliteDatabase(configuration, InvokeAILogger.get_logger())
|
|
# NOTE: none of these are actually called by the test invocations
|
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
|
return InvocationServices(
|
|
board_image_records=None, # type: ignore
|
|
board_images=None, # type: ignore
|
|
board_records=None, # type: ignore
|
|
boards=None, # type: ignore
|
|
configuration=configuration,
|
|
events=TestEventService(),
|
|
graph_execution_manager=graph_execution_manager,
|
|
graph_library=SqliteItemStorage[LibraryGraph](db=db, table_name="graphs"),
|
|
image_files=None, # type: ignore
|
|
image_records=None, # type: ignore
|
|
images=None, # type: ignore
|
|
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
|
latents=None, # type: ignore
|
|
logger=logging, # type: ignore
|
|
model_manager=None, # type: ignore
|
|
names=None, # type: ignore
|
|
performance_statistics=InvocationStatsService(),
|
|
processor=DefaultInvocationProcessor(),
|
|
queue=MemoryInvocationQueue(),
|
|
session_processor=None, # type: ignore
|
|
session_queue=None, # type: ignore
|
|
urls=None, # type: ignore
|
|
)
|
|
|
|
|
|
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
|
n = g.next()
|
|
if n is None:
|
|
return (None, None)
|
|
|
|
print(f"invoking {n.id}: {type(n)}")
|
|
o = n.invoke(
|
|
InvocationContext(
|
|
queue_batch_id="1",
|
|
queue_item_id=1,
|
|
queue_id=DEFAULT_QUEUE_ID,
|
|
services=services,
|
|
graph_execution_state_id="1",
|
|
)
|
|
)
|
|
g.complete(n.id, o)
|
|
|
|
return (n, o)
|
|
|
|
|
|
def test_graph_state_executes_in_order(simple_graph, mock_services):
|
|
g = GraphExecutionState(graph=simple_graph)
|
|
|
|
n1 = invoke_next(g, mock_services)
|
|
n2 = invoke_next(g, mock_services)
|
|
n3 = g.next()
|
|
|
|
assert g.prepared_source_mapping[n1[0].id] == "1"
|
|
assert g.prepared_source_mapping[n2[0].id] == "2"
|
|
assert n3 is None
|
|
assert g.results[n1[0].id].prompt == n1[0].prompt
|
|
assert n2[0].prompt == n1[0].prompt
|
|
|
|
|
|
def test_graph_is_complete(simple_graph, mock_services):
|
|
g = GraphExecutionState(graph=simple_graph)
|
|
_ = invoke_next(g, mock_services)
|
|
_ = invoke_next(g, mock_services)
|
|
_ = g.next()
|
|
|
|
assert g.is_complete()
|
|
|
|
|
|
def test_graph_is_not_complete(simple_graph, mock_services):
|
|
g = GraphExecutionState(graph=simple_graph)
|
|
_ = invoke_next(g, mock_services)
|
|
_ = g.next()
|
|
|
|
assert not g.is_complete()
|
|
|
|
|
|
# TODO: test completion with iterators/subgraphs
|
|
|
|
|
|
def test_graph_state_expands_iterator(mock_services):
|
|
graph = Graph()
|
|
graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1))
|
|
graph.add_node(IterateInvocation(id="1"))
|
|
graph.add_node(MultiplyInvocation(id="2", b=10))
|
|
graph.add_node(AddInvocation(id="3", b=1))
|
|
graph.add_edge(create_edge("0", "collection", "1", "collection"))
|
|
graph.add_edge(create_edge("1", "item", "2", "a"))
|
|
graph.add_edge(create_edge("2", "value", "3", "a"))
|
|
|
|
g = GraphExecutionState(graph=graph)
|
|
while not g.is_complete():
|
|
invoke_next(g, mock_services)
|
|
|
|
prepared_add_nodes = g.source_prepared_mapping["3"]
|
|
results = set([g.results[n].value for n in prepared_add_nodes])
|
|
expected = set([1, 11, 21])
|
|
assert results == expected
|
|
|
|
|
|
def test_graph_state_collects(mock_services):
|
|
graph = Graph()
|
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
|
graph.add_node(PromptCollectionTestInvocation(id="1", collection=list(test_prompts)))
|
|
graph.add_node(IterateInvocation(id="2"))
|
|
graph.add_node(PromptTestInvocation(id="3"))
|
|
graph.add_node(CollectInvocation(id="4"))
|
|
graph.add_edge(create_edge("1", "collection", "2", "collection"))
|
|
graph.add_edge(create_edge("2", "item", "3", "prompt"))
|
|
graph.add_edge(create_edge("3", "prompt", "4", "item"))
|
|
|
|
g = GraphExecutionState(graph=graph)
|
|
_ = invoke_next(g, mock_services)
|
|
_ = invoke_next(g, mock_services)
|
|
_ = invoke_next(g, mock_services)
|
|
_ = invoke_next(g, mock_services)
|
|
_ = invoke_next(g, mock_services)
|
|
n6 = invoke_next(g, mock_services)
|
|
|
|
assert isinstance(n6[0], CollectInvocation)
|
|
|
|
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
|
|
|
|
|
|
def test_graph_state_prepares_eagerly(mock_services):
|
|
"""Tests that all prepareable nodes are prepared"""
|
|
graph = Graph()
|
|
|
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
|
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
|
graph.add_node(IterateInvocation(id="iterate"))
|
|
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
|
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
|
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
|
|
|
# separated, fully-preparable chain of nodes
|
|
graph.add_node(PromptTestInvocation(id="prompt_chain_1", prompt="Dinosaur sushi"))
|
|
graph.add_node(PromptTestInvocation(id="prompt_chain_2"))
|
|
graph.add_node(PromptTestInvocation(id="prompt_chain_3"))
|
|
graph.add_edge(create_edge("prompt_chain_1", "prompt", "prompt_chain_2", "prompt"))
|
|
graph.add_edge(create_edge("prompt_chain_2", "prompt", "prompt_chain_3", "prompt"))
|
|
|
|
g = GraphExecutionState(graph=graph)
|
|
g.next()
|
|
|
|
assert "prompt_collection" in g.source_prepared_mapping
|
|
assert "prompt_chain_1" in g.source_prepared_mapping
|
|
assert "prompt_chain_2" in g.source_prepared_mapping
|
|
assert "prompt_chain_3" in g.source_prepared_mapping
|
|
assert "iterate" not in g.source_prepared_mapping
|
|
assert "prompt_iterated" not in g.source_prepared_mapping
|
|
|
|
|
|
def test_graph_executes_depth_first(mock_services):
|
|
"""Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch"""
|
|
graph = Graph()
|
|
|
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
|
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
|
graph.add_node(IterateInvocation(id="iterate"))
|
|
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
|
graph.add_node(PromptTestInvocation(id="prompt_successor"))
|
|
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
|
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
|
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
|
|
|
|
g = GraphExecutionState(graph=graph)
|
|
_ = invoke_next(g, mock_services)
|
|
_ = invoke_next(g, mock_services)
|
|
_ = invoke_next(g, mock_services)
|
|
_ = invoke_next(g, mock_services)
|
|
|
|
# Because ordering is not guaranteed, we cannot compare results directly.
|
|
# Instead, we must count the number of results.
|
|
def get_completed_count(g, id):
|
|
ids = [i for i in g.source_prepared_mapping[id]]
|
|
completed_ids = [i for i in g.executed if i in ids]
|
|
return len(completed_ids)
|
|
|
|
# Check at each step that the number of executed nodes matches the expectation for depth-first execution
|
|
assert get_completed_count(g, "prompt_iterated") == 1
|
|
assert get_completed_count(g, "prompt_successor") == 0
|
|
|
|
_ = invoke_next(g, mock_services)
|
|
|
|
assert get_completed_count(g, "prompt_iterated") == 1
|
|
assert get_completed_count(g, "prompt_successor") == 1
|
|
|
|
_ = invoke_next(g, mock_services)
|
|
|
|
assert get_completed_count(g, "prompt_iterated") == 2
|
|
assert get_completed_count(g, "prompt_successor") == 1
|
|
|
|
_ = invoke_next(g, mock_services)
|
|
|
|
assert get_completed_count(g, "prompt_iterated") == 2
|
|
assert get_completed_count(g, "prompt_successor") == 2
|