mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: refactor services folder/module structure
Refactor services folder/module structure. **Motivation** While working on our services I've repeatedly encountered circular imports and a general lack of clarity regarding where to put things. The structure introduced goes a long way towards resolving those issues, setting us up for a clean structure going forward. **Services** Services are now in their own folder with a few files: - `services/{service_name}/__init__.py`: init as needed, mostly empty now - `services/{service_name}/{service_name}_base.py`: the base class for the service - `services/{service_name}/{service_name}_{impl_type}.py`: the default concrete implementation of the service - typically one of `sqlite`, `default`, or `memory` - `services/{service_name}/{service_name}_common.py`: any common items - models, exceptions, utilities, etc Though it's a bit verbose to have the service name both as the folder name and the prefix for files, I found it is _extremely_ confusing to have all of the base classes just be named `base.py`. So, at the cost of some verbosity when importing things, I've included the service name in the filename. There are some minor logic changes. For example, in `InvocationProcessor`, instead of assigning the model manager service to a variable to be used later in the file, the service is used directly via the `Invoker`. **Shared** Things that are used across disparate services are in `services/shared/`: - `default_graphs.py`: previously in `services/` - `graphs.py`: previously in `services/` - `paginatation`: generic pagination models used in a few services - `sqlite`: the `SqliteDatabase` class, other sqlite-specific things
This commit is contained in:
committed by
Kent Keirsey
parent
88bee96ca3
commit
402cf9b0ee
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
@ -10,20 +9,27 @@ from .test_nodes import ( # isort: split
|
||||
TestEventService,
|
||||
TextToImageTestInvocation,
|
||||
)
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from invokeai.app.invocations.collections import RangeInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||
from invokeai.app.services.graph import CollectInvocation, Graph, GraphExecutionState, IterateInvocation, LibraryGraph
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from invokeai.app.services.shared.graph import (
|
||||
CollectInvocation,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
IterateInvocation,
|
||||
LibraryGraph,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .test_invoker import create_edge
|
||||
|
||||
@ -42,29 +48,33 @@ def simple_graph():
|
||||
# the test invocations.
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
lock = threading.Lock()
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
db = SqliteDatabase(configuration, InvokeAILogger.get_logger())
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
conn=db_conn, table_name="graph_executions", lock=lock
|
||||
)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
return InvocationServices(
|
||||
model_manager=None, # type: ignore
|
||||
events=TestEventService(),
|
||||
logger=logging, # type: ignore
|
||||
images=None, # type: ignore
|
||||
latents=None, # type: ignore
|
||||
boards=None, # type: ignore
|
||||
board_image_records=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs", lock=lock),
|
||||
board_records=None, # type: ignore
|
||||
boards=None, # type: ignore
|
||||
configuration=configuration,
|
||||
events=TestEventService(),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](db=db, table_name="graphs"),
|
||||
image_files=None, # type: ignore
|
||||
image_records=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||
latents=None, # type: ignore
|
||||
logger=logging, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
queue=MemoryInvocationQueue(),
|
||||
session_processor=None, # type: ignore
|
||||
invocation_cache=MemoryInvocationCache(), # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,10 +1,9 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# This import must happen before other invoke imports or test in other files(!!) break
|
||||
from .test_nodes import ( # isort: split
|
||||
@ -16,15 +15,16 @@ from .test_nodes import ( # isort: split
|
||||
wait_until,
|
||||
)
|
||||
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -52,29 +52,34 @@ def graph_with_subgraph():
|
||||
# the test invocations.
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
lock = threading.Lock()
|
||||
db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
conn=db_conn, table_name="graph_executions", lock=lock
|
||||
)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
return InvocationServices(
|
||||
model_manager=None, # type: ignore
|
||||
events=TestEventService(),
|
||||
logger=logging, # type: ignore
|
||||
images=None, # type: ignore
|
||||
latents=None, # type: ignore
|
||||
boards=None, # type: ignore
|
||||
board_image_records=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs", lock=lock),
|
||||
board_records=None, # type: ignore
|
||||
boards=None, # type: ignore
|
||||
configuration=configuration,
|
||||
events=TestEventService(),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
session_processor=None, # type: ignore
|
||||
graph_library=SqliteItemStorage[LibraryGraph](db=db, table_name="graphs"),
|
||||
image_files=None, # type: ignore
|
||||
image_records=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||
latents=None, # type: ignore
|
||||
logger=logging, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
queue=MemoryInvocationQueue(),
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
@ -11,8 +11,8 @@ from invokeai.app.invocations.image import ShowImageInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
|
||||
from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||
from invokeai.app.services.default_graphs import create_text_to_image
|
||||
from invokeai.app.services.graph import (
|
||||
from invokeai.app.services.shared.default_graphs import create_text_to_image
|
||||
from invokeai.app.services.shared.graph import (
|
||||
CollectInvocation,
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
|
@ -82,8 +82,8 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# Importing these must happen after test invocations are defined or they won't register
|
||||
from invokeai.app.services.events import EventServiceBase # noqa: E402
|
||||
from invokeai.app.services.graph import Edge, EdgeConnection # noqa: E402
|
||||
from invokeai.app.services.events.events_base import EventServiceBase # noqa: E402
|
||||
from invokeai.app.services.shared.graph import Edge, EdgeConnection # noqa: E402
|
||||
|
||||
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||
|
@ -1,7 +1,6 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError, parse_raw_as
|
||||
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchDataCollection,
|
||||
@ -12,6 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
populate_graph,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
||||
from tests.nodes.test_nodes import PromptTestInvocation
|
||||
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
class TestModel(BaseModel):
|
||||
@ -14,8 +14,8 @@ class TestModel(BaseModel):
|
||||
|
||||
@pytest.fixture
|
||||
def db() -> SqliteItemStorage[TestModel]:
|
||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
return SqliteItemStorage[TestModel](db_conn, table_name="test", id_field="id", lock=threading.Lock())
|
||||
sqlite_db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
|
||||
return SqliteItemStorage[TestModel](db=sqlite_db, table_name="test", id_field="id")
|
||||
|
||||
|
||||
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
||||
|
@ -2,7 +2,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend import BaseModelType, ModelManager, ModelType, SubModelType
|
||||
|
||||
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||
|
Reference in New Issue
Block a user