feat(backend): organise service dependencies

**Service Dependencies**

Services that depend on other services now access those services via the `Invoker` object. This object is provided to the service as a kwarg to its `start()` method.

Until now, most services did not utilize this feature, and several services required their dependencies to be initialized and passed in on init.

Additionally, _all_ services are now registered as invocation services - including the low-level services. This obviates issues with inter-dependent services we would otherwise experience as we add workflow storage.

**Database Access**

Previously, we were passing in a separate sqlite connection and corresponding lock as args to services in their init. A good amount of posturing was done in each service that uses the db.

These objects, along with the sqlite startup and cleanup logic, is now abstracted into a simple `SqliteDatabase` class. This creates the shared connection and lock objects, enables foreign keys, and provides a `clean()` method to do startup db maintenance.

This is not a service as it's only used by sqlite services.
This commit is contained in:
psychedelicious 2023-09-24 15:12:51 +10:00 committed by Kent Keirsey
parent 10fac5c085
commit 2a35d93a4d
15 changed files with 255 additions and 322 deletions

View File

@ -1,19 +1,19 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import sqlite3
from logging import Logger
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
from invokeai.app.services.board_images import BoardImagesService
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.boards import BoardService
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.images import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.shared.db import SqliteDatabase
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -29,7 +29,6 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto
from ..services.model_manager_service import ModelManagerService
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.thread import lock
from .events import FastAPIEventService
@ -63,100 +62,64 @@ class ApiDependencies:
logger.info(f"Root directory = {str(config.root_path)}")
logger.debug(f"Internet connectivity is {config.internet_available}")
events = FastAPIEventService(event_handler_id)
output_folder = config.output_path
# TODO: build a file/path manager?
if config.use_memory_db:
db_location = ":memory:"
else:
db_path = config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path)
db = SqliteDatabase(config, logger)
logger.info(f"Using database at {db_location}")
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
configuration = config
logger = logger
if config.log_sql:
db_conn.set_trace_callback(print)
db_conn.execute("PRAGMA foreign_keys = ON;")
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
conn=db_conn, table_name="graph_executions", lock=lock
)
urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
board_image_records = SqliteBoardImageRecordStorage(db=db)
board_images = BoardImagesService()
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
image_files = DiskImageFileStorage(f"{output_folder}/images")
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
boards = BoardService(
services=BoardServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
board_images = BoardImagesService(
services=BoardImagesServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
images = ImageService(
services=ImageServiceDependencies(
board_image_record_storage=board_image_record_storage,
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
)
model_manager = ModelManagerService(config, logger)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
queue = MemoryInvocationQueue()
session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
services = InvocationServices(
model_manager=ModelManagerService(config, logger),
events=events,
latents=latents,
images=images,
boards=boards,
board_image_records=board_image_records,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"),
board_records=board_records,
boards=boards,
configuration=configuration,
events=events,
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
graph_library=graph_library,
image_files=image_files,
image_records=image_records,
images=images,
invocation_cache=invocation_cache,
latents=latents,
logger=logger,
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
session_processor=DefaultSessionProcessor(),
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
model_manager=model_manager,
names=names,
performance_statistics=performance_statistics,
processor=processor,
queue=queue,
session_processor=session_processor,
session_queue=session_queue,
urls=urls,
)
create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services)
try:
lock.acquire()
db_conn.execute("VACUUM;")
db_conn.commit()
logger.info("Cleaned database")
finally:
lock.release()
db.clean()
@staticmethod
def shutdown():

View File

@ -5,6 +5,7 @@ from typing import Optional, cast
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import ImageRecord, deserialize_image_record
from invokeai.app.services.shared.db import SqliteDatabase
class BoardImageRecordStorageBase(ABC):
@ -57,13 +58,11 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
self._lock = lock
try:
self._lock.acquire()

View File

@ -1,12 +1,9 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import Optional
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import BoardRecord, BoardRecordStorageBase
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
from invokeai.app.services.board_record_storage import BoardRecord
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.urls import UrlServiceBase
class BoardImagesServiceABC(ABC):
@ -46,60 +43,36 @@ class BoardImagesServiceABC(ABC):
pass
class BoardImagesServiceDependencies:
"""Service dependencies for the BoardImagesService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardImagesService(BoardImagesServiceABC):
_services: BoardImagesServiceDependencies
__invoker: Invoker
def __init__(self, services: BoardImagesServiceDependencies):
self._services = services
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.add_image_to_board(board_id, image_name)
self.__invoker.services.board_image_records.add_image_to_board(board_id, image_name)
def remove_image_from_board(
self,
image_name: str,
) -> None:
self._services.board_image_records.remove_image_from_board(image_name)
self.__invoker.services.board_image_records.remove_image_from_board(image_name)
def get_all_board_image_names_for_board(
self,
board_id: str,
) -> list[str]:
return self._services.board_image_records.get_all_board_image_names_for_board(board_id)
return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
def get_board_for_image(
self,
image_name: str,
) -> Optional[str]:
board_id = self._services.board_image_records.get_board_for_image(image_name)
board_id = self.__invoker.services.board_image_records.get_board_for_image(image_name)
return board_id

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel, Extra, Field
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
from invokeai.app.services.shared.db import SqliteDatabase
from invokeai.app.util.misc import uuid_string
@ -91,13 +92,11 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
self._lock = lock
try:
self._lock.acquire()

View File

@ -1,12 +1,10 @@
from abc import ABC, abstractmethod
from logging import Logger
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_images import board_record_to_dto
from invokeai.app.services.board_record_storage import BoardChanges, BoardRecordStorageBase
from invokeai.app.services.image_record_storage import ImageRecordStorageBase, OffsetPaginatedResults
from invokeai.app.services.board_record_storage import BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.urls import UrlServiceBase
class BoardServiceABC(ABC):
@ -62,51 +60,27 @@ class BoardServiceABC(ABC):
pass
class BoardServiceDependencies:
"""Service dependencies for the BoardService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardService(BoardServiceABC):
_services: BoardServiceDependencies
__invoker: Invoker
def __init__(self, services: BoardServiceDependencies):
self._services = services
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
def create(
self,
board_name: str,
) -> BoardDTO:
board_record = self._services.board_records.save(board_name)
board_record = self.__invoker.services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0)
def get_dto(self, board_id: str) -> BoardDTO:
board_record = self._services.board_records.get(board_id)
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
board_record = self.__invoker.services.board_records.get(board_id)
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
return board_record_to_dto(board_record, cover_image_name, image_count)
def update(
@ -114,45 +88,45 @@ class BoardService(BoardServiceABC):
board_id: str,
changes: BoardChanges,
) -> BoardDTO:
board_record = self._services.board_records.update(board_id, changes)
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
board_record = self.__invoker.services.board_records.update(board_id, changes)
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
return board_record_to_dto(board_record, cover_image_name, image_count)
def delete(self, board_id: str) -> None:
self._services.board_records.delete(board_id)
self.__invoker.services.board_records.delete(board_id)
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
board_records = self._services.board_records.get_many(offset, limit)
board_records = self.__invoker.services.board_records.get_many(offset, limit)
board_dtos = []
for r in board_records.items:
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(self) -> list[BoardDTO]:
board_records = self._services.board_records.get_all()
board_records = self.__invoker.services.board_records.get_all()
board_dtos = []
for r in board_records:
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return board_dtos

View File

@ -10,6 +10,7 @@ from pydantic.generics import GenericModel
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.services.models.image_record import ImageRecord, ImageRecordChanges, deserialize_image_record
from invokeai.app.services.shared.db import SqliteDatabase
T = TypeVar("T", bound=BaseModel)
@ -152,13 +153,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
self._lock = lock
try:
self._lock.acquire()
@ -204,6 +203,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
if "workflow" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images
ADD COLUMN workflow_id TEXT;
-- TODO: This requires a migration:
-- FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE SET NULL;
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql

View File

@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import TYPE_CHECKING, Callable, Optional
from typing import Callable, Optional
from PIL.Image import Image as PILImageType
@ -11,29 +10,21 @@ from invokeai.app.models.image import (
InvalidOriginException,
ResourceOrigin,
)
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.image_file_storage import (
ImageFileDeleteException,
ImageFileNotFoundException,
ImageFileSaveException,
ImageFileStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
ImageRecordSaveException,
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.models.image_record import ImageDTO, ImageRecord, ImageRecordChanges, image_record_to_dto
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
class ImageServiceABC(ABC):
"""High-level service for image management."""
@ -150,42 +141,11 @@ class ImageServiceABC(ABC):
pass
class ImageServiceDependencies:
"""Service dependencies for the ImageService."""
image_records: ImageRecordStorageBase
image_files: ImageFileStorageBase
board_image_records: BoardImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
names: NameServiceBase
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
board_image_record_storage: BoardImageRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.image_records = image_record_storage
self.image_files = image_file_storage
self.board_image_records = board_image_record_storage
self.urls = url
self.logger = logger
self.names = names
self.graph_execution_manager = graph_execution_manager
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
__invoker: Invoker
def __init__(self, services: ImageServiceDependencies):
super().__init__()
self._services = services
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
def create(
self,
@ -205,24 +165,13 @@ class ImageService(ImageServiceABC):
if image_category not in ImageCategory:
raise InvalidImageCategoryException
image_name = self._services.names.create_image_name()
# TODO: Do we want to store the graph in the image at all? I don't think so...
# graph = None
# if session_id is not None:
# session_raw = self._services.graph_execution_manager.get_raw(session_id)
# if session_raw is not None:
# try:
# graph = get_metadata_graph_from_raw_session(session_raw)
# except Exception as e:
# self._services.logger.warn(f"Failed to parse session graph: {e}")
# graph = None
image_name = self.__invoker.services.names.create_image_name()
(width, height) = image.size
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
self._services.image_records.save(
self.__invoker.services.image_records.save(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
@ -237,20 +186,22 @@ class ImageService(ImageServiceABC):
session_id=session_id,
)
if board_id is not None:
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow
)
image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
self.__invoker.services.logger.error("Failed to save image record")
raise
except ImageFileSaveException:
self._services.logger.error("Failed to save image file")
self.__invoker.services.logger.error("Failed to save image file")
raise
except Exception as e:
self._services.logger.error(f"Problem saving image record and file: {str(e)}")
self.__invoker.services.logger.error(f"Problem saving image record and file: {str(e)}")
raise e
def update(
@ -259,101 +210,101 @@ class ImageService(ImageServiceABC):
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.image_records.update(image_name, changes)
self.__invoker.services.image_records.update(image_name, changes)
image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
self.__invoker.services.logger.error("Failed to update image record")
raise
except Exception as e:
self._services.logger.error("Problem updating image record")
self.__invoker.services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_name: str) -> PILImageType:
try:
return self._services.image_files.get(image_name)
return self.__invoker.services.image_files.get(image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
self.__invoker.services.logger.error("Failed to get image file")
raise
except Exception as e:
self._services.logger.error("Problem getting image file")
self.__invoker.services.logger.error("Problem getting image file")
raise e
def get_record(self, image_name: str) -> ImageRecord:
try:
return self._services.image_records.get(image_name)
return self.__invoker.services.image_records.get(image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
self.__invoker.services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image record")
self.__invoker.services.logger.error("Problem getting image record")
raise e
def get_dto(self, image_name: str) -> ImageDTO:
try:
image_record = self._services.image_records.get(image_name)
image_record = self.__invoker.services.image_records.get(image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True),
self._services.board_image_records.get_board_for_image(image_name),
self.__invoker.services.urls.get_image_url(image_name),
self.__invoker.services.urls.get_image_url(image_name, True),
self.__invoker.services.board_image_records.get_board_for_image(image_name),
)
return image_dto
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
self.__invoker.services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
self.__invoker.services.logger.error("Problem getting image DTO")
raise e
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
try:
image_record = self._services.image_records.get(image_name)
metadata = self._services.image_records.get_metadata(image_name)
image_record = self.__invoker.services.image_records.get(image_name)
metadata = self.__invoker.services.image_records.get_metadata(image_name)
if not image_record.session_id:
return ImageMetadata(metadata=metadata)
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
session_raw = self.__invoker.services.graph_execution_manager.get_raw(image_record.session_id)
graph = None
if session_raw:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
self.__invoker.services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
return ImageMetadata(graph=graph, metadata=metadata)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
self.__invoker.services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
self.__invoker.services.logger.error("Problem getting image DTO")
raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.image_files.get_path(image_name, thumbnail)
return self.__invoker.services.image_files.get_path(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
self.__invoker.services.logger.error("Problem getting image path")
raise e
def validate_path(self, path: str) -> bool:
try:
return self._services.image_files.validate_path(path)
return self.__invoker.services.image_files.validate_path(path)
except Exception as e:
self._services.logger.error("Problem validating image path")
self.__invoker.services.logger.error("Problem validating image path")
raise e
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.urls.get_image_url(image_name, thumbnail)
return self.__invoker.services.urls.get_image_url(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
self.__invoker.services.logger.error("Problem getting image path")
raise e
def get_many(
@ -366,7 +317,7 @@ class ImageService(ImageServiceABC):
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self._services.image_records.get_many(
results = self.__invoker.services.image_records.get_many(
offset,
limit,
image_origin,
@ -379,9 +330,9 @@ class ImageService(ImageServiceABC):
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
self._services.board_image_records.get_board_for_image(r.image_name),
self.__invoker.services.urls.get_image_url(r.image_name),
self.__invoker.services.urls.get_image_url(r.image_name, True),
self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
),
results.items,
)
@ -394,56 +345,56 @@ class ImageService(ImageServiceABC):
total=results.total,
)
except Exception as e:
self._services.logger.error("Problem getting paginated image DTOs")
self.__invoker.services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_name: str):
try:
self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name)
self.__invoker.services.image_files.delete(image_name)
self.__invoker.services.image_records.delete(image_name)
self._on_deleted(image_name)
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image record")
self.__invoker.services.logger.error("Failed to delete image record")
raise
except ImageFileDeleteException:
self._services.logger.error("Failed to delete image file")
self.__invoker.services.logger.error("Failed to delete image file")
raise
except Exception as e:
self._services.logger.error("Problem deleting image record and file")
self.__invoker.services.logger.error("Problem deleting image record and file")
raise e
def delete_images_on_board(self, board_id: str):
try:
image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id)
image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
for image_name in image_names:
self._services.image_files.delete(image_name)
self._services.image_records.delete_many(image_names)
self.__invoker.services.image_files.delete(image_name)
self.__invoker.services.image_records.delete_many(image_names)
for image_name in image_names:
self._on_deleted(image_name)
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records")
self.__invoker.services.logger.error("Failed to delete image records")
raise
except ImageFileDeleteException:
self._services.logger.error("Failed to delete image files")
self.__invoker.services.logger.error("Failed to delete image files")
raise
except Exception as e:
self._services.logger.error("Problem deleting image records and files")
self.__invoker.services.logger.error("Problem deleting image records and files")
raise e
def delete_intermediates(self) -> int:
try:
image_names = self._services.image_records.delete_intermediates()
image_names = self.__invoker.services.image_records.delete_intermediates()
count = len(image_names)
for image_name in image_names:
self._services.image_files.delete(image_name)
self.__invoker.services.image_files.delete(image_name)
self._on_deleted(image_name)
return count
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records")
self.__invoker.services.logger.error("Failed to delete image records")
raise
except ImageFileDeleteException:
self._services.logger.error("Failed to delete image files")
self.__invoker.services.logger.error("Failed to delete image files")
raise
except Exception as e:
self._services.logger.error("Problem deleting image records and files")
self.__invoker.services.logger.error("Problem deleting image records and files")
raise e

View File

@ -6,11 +6,15 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from logging import Logger
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.board_record_storage import BoardRecordStorageBase
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.image_file_storage import ImageFileStorageBase
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
from invokeai.app.services.images import ImageServiceABC
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invocation_queue import InvocationQueueABC
@ -19,8 +23,10 @@ if TYPE_CHECKING:
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.urls import UrlServiceBase
class InvocationServices:
@ -28,12 +34,16 @@ class InvocationServices:
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
board_images: "BoardImagesServiceABC"
board_image_record_storage: "BoardImageRecordStorageBase"
boards: "BoardServiceABC"
board_records: "BoardRecordStorageBase"
configuration: "InvokeAIAppConfig"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
graph_library: "ItemStorageABC[LibraryGraph]"
images: "ImageServiceABC"
image_records: "ImageRecordStorageBase"
image_files: "ImageFileStorageBase"
latents: "LatentsStorageBase"
logger: "Logger"
model_manager: "ModelManagerServiceBase"
@ -43,16 +53,22 @@ class InvocationServices:
session_queue: "SessionQueueBase"
session_processor: "SessionProcessorBase"
invocation_cache: "InvocationCacheBase"
names: "NameServiceBase"
urls: "UrlServiceBase"
def __init__(
self,
board_images: "BoardImagesServiceABC",
board_image_records: "BoardImageRecordStorageBase",
boards: "BoardServiceABC",
board_records: "BoardRecordStorageBase",
configuration: "InvokeAIAppConfig",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
graph_library: "ItemStorageABC[LibraryGraph]",
images: "ImageServiceABC",
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
latents: "LatentsStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
@ -62,14 +78,20 @@ class InvocationServices:
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase",
names: "NameServiceBase",
urls: "UrlServiceBase",
):
self.board_images = board_images
self.board_image_records = board_image_records
self.boards = boards
self.board_records = board_records
self.configuration = configuration
self.events = events
self.graph_execution_manager = graph_execution_manager
self.graph_library = graph_library
self.images = images
self.image_files = image_files
self.image_records = image_records
self.latents = latents
self.logger = logger
self.model_manager = model_manager
@ -79,3 +101,5 @@ class InvocationServices:
self.session_queue = session_queue
self.session_processor = session_processor
self.invocation_cache = invocation_cache
self.names = names
self.urls = urls

View File

@ -38,12 +38,11 @@ import psutil
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_management.model_cache import CacheStats
from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
from .model_manager_service import ModelManagerService
from .model_manager_service import ModelManagerServiceBase
# size of GIG in bytes
GIG = 1073741824
@ -72,7 +71,6 @@ class NodeLog:
class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
# {graph_id => NodeLog}
_stats: Dict[str, NodeLog]
_cache_stats: Dict[str, CacheStats]
@ -80,10 +78,9 @@ class InvocationStatsServiceBase(ABC):
ram_changed: float
@abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
def __init__(self):
"""
Initialize the InvocationStatsService and reset counters to zero
:param graph_execution_manager: Graph execution manager for this session
"""
pass
@ -158,14 +155,18 @@ class InvocationStatsService(InvocationStatsServiceBase):
"""Accumulate performance information about a running graph. Collects time spent in each node,
as well as the maximum and current VRAM utilisation for CUDA systems"""
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
self.graph_execution_manager = graph_execution_manager
_invoker: Invoker
def __init__(self):
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
self._cache_stats: Dict[str, CacheStats] = {}
self.ram_used: float = 0.0
self.ram_changed: float = 0.0
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
class StatsContext:
"""Context manager for collecting statistics."""
@ -174,13 +175,13 @@ class InvocationStatsService(InvocationStatsServiceBase):
graph_id: str
start_time: float
ram_used: int
model_manager: ModelManagerService
model_manager: ModelManagerServiceBase
def __init__(
self,
invocation: BaseInvocation,
graph_id: str,
model_manager: ModelManagerService,
model_manager: ModelManagerServiceBase,
collector: "InvocationStatsServiceBase",
):
"""Initialize statistics for this run."""
@ -217,12 +218,11 @@ class InvocationStatsService(InvocationStatsServiceBase):
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
model_manager: ModelManagerService,
) -> StatsContext:
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
self._cache_stats[graph_execution_state_id] = CacheStats()
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
return self.StatsContext(invocation, graph_execution_state_id, self._invoker.services.model_manager, self)
def reset_all_stats(self):
"""Zero all statistics"""
@ -261,7 +261,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
errored = set()
for graph_id, node_log in self._stats.items():
try:
current_graph_state = self.graph_execution_manager.get(graph_id)
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
except Exception:
errored.add(graph_id)
continue

View File

@ -8,7 +8,6 @@ import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import InvocationContext
from ..models.exceptions import CanceledException
from .invocation_queue import InvocationQueueItem
from .invocation_stats import InvocationStatsServiceBase
from .invoker import InvocationProcessorABC, Invoker
@ -37,7 +36,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event):
try:
self.__threadLimit.acquire()
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
queue_item: Optional[InvocationQueueItem] = None
while not stop_event.is_set():
@ -97,8 +95,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke
try:
graph_id = graph_execution_state.id
model_manager = self.__invoker.services.model_manager
with statistics.collect_stats(invocation, graph_id, model_manager):
with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
# use the internal invoke_internal(), which wraps the node's invoke() method,
# which handles a few things:
# - nodes that require a value, but get it only from a connection
@ -133,13 +130,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
source_node_id=source_node_id,
result=outputs.dict(),
)
statistics.log_stats()
self.__invoker.services.performance_statistics.log_stats()
except KeyboardInterrupt:
pass
except CanceledException:
statistics.reset_stats(graph_execution_state.id)
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
pass
except Exception as e:
@ -164,7 +161,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__,
error=error,
)
statistics.reset_stats(graph_execution_state.id)
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
pass
# Check queue to see if this is canceled, and skip if so

View File

@ -29,6 +29,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
calc_session_count,
prepare_values_to_insert,
)
from invokeai.app.services.shared.db import SqliteDatabase
from invokeai.app.services.shared.models import CursorPaginatedResults
@ -45,13 +46,11 @@ class SqliteSessionQueue(SessionQueueBase):
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self.__conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self.__conn.row_factory = sqlite3.Row
self.__lock = db.lock
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
self.__lock = lock
self._create_tables()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:

View File

View File

@ -0,0 +1,46 @@
import sqlite3
import threading
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
class SqliteDatabase:
conn: sqlite3.Connection
lock: threading.Lock
_logger: Logger
_config: InvokeAIAppConfig
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
self._logger = logger
self._config = config
if self._config.use_memory_db:
location = ":memory:"
logger.info("Using in-memory database")
else:
db_path = self._config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
location = str(db_path)
self._logger.info(f"Using database at {location}")
self.conn = sqlite3.connect(location, check_same_thread=False)
self.lock = threading.Lock()
self.conn.row_factory = sqlite3.Row
if self._config.log_sql:
self.conn.set_trace_callback(self._logger.debug)
self.conn.execute("PRAGMA foreign_keys = ON;")
def clean(self) -> None:
try:
self.lock.acquire()
self.conn.execute("VACUUM;")
self.conn.commit()
self._logger.info("Cleaned database")
except Exception as e:
self._logger.error(f"Error cleaning database: {e}")
raise e
finally:
self.lock.release()

View File

@ -4,6 +4,8 @@ from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, parse_raw_as
from invokeai.app.services.shared.db import SqliteDatabase
from .item_storage import ItemStorageABC, PaginatedResults
T = TypeVar("T", bound=BaseModel)
@ -18,13 +20,13 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
_id_field: str
_lock: threading.Lock
def __init__(self, conn: sqlite3.Connection, table_name: str, lock: threading.Lock, id_field: str = "id"):
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = lock
self._conn = conn
self._cursor = self._conn.cursor()
self._create_table()

View File

@ -1,3 +0,0 @@
import threading
lock = threading.Lock()