InvokeAI/tests/nodes/test_invoker.py
psychedelicious 402cf9b0ee feat: refactor services folder/module structure
Refactor services folder/module structure.

**Motivation**

While working on our services I've repeatedly encountered circular imports and a general lack of clarity regarding where to put things. The structure introduced goes a long way towards resolving those issues, setting us up for a clean structure going forward.

**Services**

Services are now in their own folder with a few files:

- `services/{service_name}/__init__.py`: init as needed, mostly empty now
- `services/{service_name}/{service_name}_base.py`: the base class for the service
- `services/{service_name}/{service_name}_{impl_type}.py`: the default concrete implementation of the service - typically one of `sqlite`, `default`, or `memory`
- `services/{service_name}/{service_name}_common.py`: any common items - models, exceptions, utilities, etc

Though it's a bit verbose to have the service name both as the folder name and the prefix for files, I found it is _extremely_ confusing to have all of the base classes just be named `base.py`. So, at the cost of some verbosity when importing things, I've included the service name in the filename.

There are some minor logic changes. For example, in `InvocationProcessor`, instead of assigning the model manager service to a variable to be used later in the file, the service is used directly via the `Invoker`.

**Shared**

Things that are used across disparate services are in `services/shared/`:

- `default_graphs.py`: previously in `services/`
- `graphs.py`: previously in `services/`
- `paginatation`: generic pagination models used in a few services
- `sqlite`: the `SqliteDatabase` class, other sqlite-specific things
2023-10-12 12:15:06 -04:00

178 lines
6.5 KiB
Python

import logging
import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
ErrorInvocation,
PromptTestInvocation,
TestEventService,
TextToImageTestInvocation,
create_edge,
wait_until,
)
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.shared.sqlite import SqliteDatabase
@pytest.fixture
def simple_graph():
g = Graph()
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
g.add_node(TextToImageTestInvocation(id="2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g
@pytest.fixture
def graph_with_subgraph():
sub_g = Graph()
sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
sub_g.add_node(TextToImageTestInvocation(id="2"))
sub_g.add_edge(create_edge("1", "prompt", "2", "prompt"))
g = Graph()
g.add_node(GraphInvocation(id="1", graph=sub_g))
return g
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
board_records=None, # type: ignore
boards=None, # type: ignore
configuration=configuration,
events=TestEventService(),
graph_execution_manager=graph_execution_manager,
graph_library=SqliteItemStorage[LibraryGraph](db=db, table_name="graphs"),
image_files=None, # type: ignore
image_records=None, # type: ignore
images=None, # type: ignore
invocation_cache=MemoryInvocationCache(max_cache_size=0),
latents=None, # type: ignore
logger=logging, # type: ignore
model_manager=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
processor=DefaultInvocationProcessor(),
queue=MemoryInvocationQueue(),
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services)
def test_can_create_graph_state(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
mock_invoker.stop()
assert g is not None
assert isinstance(g, GraphExecutionState)
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph)
mock_invoker.stop()
assert g is not None
assert isinstance(g, GraphExecutionState)
assert g.graph == simple_graph
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(
session_queue_batch_id="1",
session_queue_item_id=1,
session_queue_id=DEFAULT_QUEUE_ID,
graph_execution_state=g,
)
assert invocation_id is not None
def has_executed_any(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id)
return len(g.executed) > 0
wait_until(lambda: has_executed_any(g), timeout=5, interval=1)
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(
session_queue_batch_id="1",
session_queue_item_id=1,
session_queue_id=DEFAULT_QUEUE_ID,
graph_execution_state=g,
invoke_all=True,
)
assert invocation_id is not None
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id)
return g.is_complete()
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.is_complete()
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_handles_errors(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
g.graph.add_node(ErrorInvocation(id="1"))
mock_invoker.invoke(
session_queue_batch_id="1",
session_queue_item_id=1,
session_queue_id=DEFAULT_QUEUE_ID,
graph_execution_state=g,
invoke_all=True,
)
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id)
return g.is_complete()
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.has_error()
assert g.is_complete()
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))