feat(backend): organise service dependencies

**Service Dependencies**

Services that depend on other services now access those services via the `Invoker` object. This object is provided to the service as a kwarg to its `start()` method.

Until now, most services did not utilize this feature, and several services required their dependencies to be initialized and passed in on init.

Additionally, _all_ services are now registered as invocation services - including the low-level services. This obviates issues with inter-dependent services we would otherwise experience as we add workflow storage.

**Database Access**

Previously, we were passing in a separate sqlite connection and corresponding lock as args to services in their init. A good amount of posturing was done in each service that uses the db.

These objects, along with the sqlite startup and cleanup logic, is now abstracted into a simple `SqliteDatabase` class. This creates the shared connection and lock objects, enables foreign keys, and provides a `clean()` method to do startup db maintenance.

This is not a service as it's only used by sqlite services.
This commit is contained in:
psychedelicious 2023-09-24 15:12:51 +10:00 committed by Kent Keirsey
parent 10fac5c085
commit 2a35d93a4d
15 changed files with 255 additions and 322 deletions

View File

@ -1,19 +1,19 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import sqlite3
from logging import Logger from logging import Logger
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies from invokeai.app.services.board_images import BoardImagesService
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.boards import BoardService
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.shared.db import SqliteDatabase
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__ from invokeai.version.invokeai_version import __version__
@ -29,7 +29,6 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto
from ..services.model_manager_service import ModelManagerService from ..services.model_manager_service import ModelManagerService
from ..services.processor import DefaultInvocationProcessor from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage from ..services.sqlite import SqliteItemStorage
from ..services.thread import lock
from .events import FastAPIEventService from .events import FastAPIEventService
@ -63,100 +62,64 @@ class ApiDependencies:
logger.info(f"Root directory = {str(config.root_path)}") logger.info(f"Root directory = {str(config.root_path)}")
logger.debug(f"Internet connectivity is {config.internet_available}") logger.debug(f"Internet connectivity is {config.internet_available}")
events = FastAPIEventService(event_handler_id)
output_folder = config.output_path output_folder = config.output_path
# TODO: build a file/path manager? db = SqliteDatabase(config, logger)
if config.use_memory_db:
db_location = ":memory:"
else:
db_path = config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path)
logger.info(f"Using database at {db_location}") configuration = config
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution logger = logger
if config.log_sql: board_image_records = SqliteBoardImageRecordStorage(db=db)
db_conn.set_trace_callback(print) board_images = BoardImagesService()
db_conn.execute("PRAGMA foreign_keys = ON;") board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( events = FastAPIEventService(event_handler_id)
conn=db_conn, table_name="graph_executions", lock=lock graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
) graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
image_files = DiskImageFileStorage(f"{output_folder}/images")
urls = LocalUrlService() image_records = SqliteImageRecordStorage(db=db)
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock) images = ImageService()
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
names = SimpleNameService()
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger)
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock) names = SimpleNameService()
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock) performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
boards = BoardService( queue = MemoryInvocationQueue()
services=BoardServiceDependencies( session_processor = DefaultSessionProcessor()
board_image_record_storage=board_image_record_storage, session_queue = SqliteSessionQueue(db=db)
board_record_storage=board_record_storage, urls = LocalUrlService()
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
board_images = BoardImagesService(
services=BoardImagesServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
images = ImageService(
services=ImageServiceDependencies(
board_image_record_storage=board_image_record_storage,
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
)
services = InvocationServices( services = InvocationServices(
model_manager=ModelManagerService(config, logger), board_image_records=board_image_records,
events=events,
latents=latents,
images=images,
boards=boards,
board_images=board_images, board_images=board_images,
queue=MemoryInvocationQueue(), board_records=board_records,
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"), boards=boards,
configuration=configuration,
events=events,
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), graph_library=graph_library,
configuration=config, image_files=image_files,
performance_statistics=InvocationStatsService(graph_execution_manager), image_records=image_records,
images=images,
invocation_cache=invocation_cache,
latents=latents,
logger=logger, logger=logger,
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock), model_manager=model_manager,
session_processor=DefaultSessionProcessor(), names=names,
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size), performance_statistics=performance_statistics,
processor=processor,
queue=queue,
session_processor=session_processor,
session_queue=session_queue,
urls=urls,
) )
create_system_graphs(services.graph_library) create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services) ApiDependencies.invoker = Invoker(services)
try: db.clean()
lock.acquire()
db_conn.execute("VACUUM;")
db_conn.commit()
logger.info("Cleaned database")
finally:
lock.release()
@staticmethod @staticmethod
def shutdown(): def shutdown():

View File

@ -5,6 +5,7 @@ from typing import Optional, cast
from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import ImageRecord, deserialize_image_record from invokeai.app.services.models.image_record import ImageRecord, deserialize_image_record
from invokeai.app.services.shared.db import SqliteDatabase
class BoardImageRecordStorageBase(ABC): class BoardImageRecordStorageBase(ABC):
@ -57,13 +58,11 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None: def __init__(self, db: SqliteDatabase) -> None:
super().__init__() super().__init__()
self._conn = conn self._lock = db.lock
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) self._conn = db.conn
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._lock = lock
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -1,12 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger
from typing import Optional from typing import Optional
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.board_record_storage import BoardRecord
from invokeai.app.services.board_record_storage import BoardRecord, BoardRecordStorageBase from invokeai.app.services.invoker import Invoker
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
from invokeai.app.services.models.board_record import BoardDTO from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.urls import UrlServiceBase
class BoardImagesServiceABC(ABC): class BoardImagesServiceABC(ABC):
@ -46,60 +43,36 @@ class BoardImagesServiceABC(ABC):
pass pass
class BoardImagesServiceDependencies:
"""Service dependencies for the BoardImagesService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardImagesService(BoardImagesServiceABC): class BoardImagesService(BoardImagesServiceABC):
_services: BoardImagesServiceDependencies __invoker: Invoker
def __init__(self, services: BoardImagesServiceDependencies): def start(self, invoker: Invoker) -> None:
self._services = services self.__invoker = invoker
def add_image_to_board( def add_image_to_board(
self, self,
board_id: str, board_id: str,
image_name: str, image_name: str,
) -> None: ) -> None:
self._services.board_image_records.add_image_to_board(board_id, image_name) self.__invoker.services.board_image_records.add_image_to_board(board_id, image_name)
def remove_image_from_board( def remove_image_from_board(
self, self,
image_name: str, image_name: str,
) -> None: ) -> None:
self._services.board_image_records.remove_image_from_board(image_name) self.__invoker.services.board_image_records.remove_image_from_board(image_name)
def get_all_board_image_names_for_board( def get_all_board_image_names_for_board(
self, self,
board_id: str, board_id: str,
) -> list[str]: ) -> list[str]:
return self._services.board_image_records.get_all_board_image_names_for_board(board_id) return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
def get_board_for_image( def get_board_for_image(
self, self,
image_name: str, image_name: str,
) -> Optional[str]: ) -> Optional[str]:
board_id = self._services.board_image_records.get_board_for_image(image_name) board_id = self.__invoker.services.board_image_records.get_board_for_image(image_name)
return board_id return board_id

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel, Extra, Field
from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
from invokeai.app.services.shared.db import SqliteDatabase
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
@ -91,13 +92,11 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None: def __init__(self, db: SqliteDatabase) -> None:
super().__init__() super().__init__()
self._conn = conn self._lock = db.lock
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) self._conn = db.conn
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._lock = lock
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -1,12 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_images import board_record_to_dto from invokeai.app.services.board_images import board_record_to_dto
from invokeai.app.services.board_record_storage import BoardChanges, BoardRecordStorageBase from invokeai.app.services.board_record_storage import BoardChanges
from invokeai.app.services.image_record_storage import ImageRecordStorageBase, OffsetPaginatedResults from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.models.board_record import BoardDTO from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.urls import UrlServiceBase
class BoardServiceABC(ABC): class BoardServiceABC(ABC):
@ -62,51 +60,27 @@ class BoardServiceABC(ABC):
pass pass
class BoardServiceDependencies:
"""Service dependencies for the BoardService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardService(BoardServiceABC): class BoardService(BoardServiceABC):
_services: BoardServiceDependencies __invoker: Invoker
def __init__(self, services: BoardServiceDependencies): def start(self, invoker: Invoker) -> None:
self._services = services self.__invoker = invoker
def create( def create(
self, self,
board_name: str, board_name: str,
) -> BoardDTO: ) -> BoardDTO:
board_record = self._services.board_records.save(board_name) board_record = self.__invoker.services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0) return board_record_to_dto(board_record, None, 0)
def get_dto(self, board_id: str) -> BoardDTO: def get_dto(self, board_id: str) -> BoardDTO:
board_record = self._services.board_records.get(board_id) board_record = self.__invoker.services.board_records.get(board_id)
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id) cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
if cover_image: if cover_image:
cover_image_name = cover_image.image_name cover_image_name = cover_image.image_name
else: else:
cover_image_name = None cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(board_id) image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
return board_record_to_dto(board_record, cover_image_name, image_count) return board_record_to_dto(board_record, cover_image_name, image_count)
def update( def update(
@ -114,45 +88,45 @@ class BoardService(BoardServiceABC):
board_id: str, board_id: str,
changes: BoardChanges, changes: BoardChanges,
) -> BoardDTO: ) -> BoardDTO:
board_record = self._services.board_records.update(board_id, changes) board_record = self.__invoker.services.board_records.update(board_id, changes)
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id) cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
if cover_image: if cover_image:
cover_image_name = cover_image.image_name cover_image_name = cover_image.image_name
else: else:
cover_image_name = None cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(board_id) image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
return board_record_to_dto(board_record, cover_image_name, image_count) return board_record_to_dto(board_record, cover_image_name, image_count)
def delete(self, board_id: str) -> None: def delete(self, board_id: str) -> None:
self._services.board_records.delete(board_id) self.__invoker.services.board_records.delete(board_id)
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]: def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
board_records = self._services.board_records.get_many(offset, limit) board_records = self.__invoker.services.board_records.get_many(offset, limit)
board_dtos = [] board_dtos = []
for r in board_records.items: for r in board_records.items:
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id) cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
if cover_image: if cover_image:
cover_image_name = cover_image.image_name cover_image_name = cover_image.image_name
else: else:
cover_image_name = None cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id) image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)) return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(self) -> list[BoardDTO]: def get_all(self) -> list[BoardDTO]:
board_records = self._services.board_records.get_all() board_records = self.__invoker.services.board_records.get_all()
board_dtos = [] board_dtos = []
for r in board_records: for r in board_records:
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id) cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
if cover_image: if cover_image:
cover_image_name = cover_image.image_name cover_image_name = cover_image.image_name
else: else:
cover_image_name = None cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id) image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return board_dtos return board_dtos

View File

@ -10,6 +10,7 @@ from pydantic.generics import GenericModel
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.services.models.image_record import ImageRecord, ImageRecordChanges, deserialize_image_record from invokeai.app.services.models.image_record import ImageRecord, ImageRecordChanges, deserialize_image_record
from invokeai.app.services.shared.db import SqliteDatabase
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
@ -152,13 +153,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None: def __init__(self, db: SqliteDatabase) -> None:
super().__init__() super().__init__()
self._conn = conn self._lock = db.lock
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) self._conn = db.conn
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._lock = lock
try: try:
self._lock.acquire() self._lock.acquire()
@ -204,6 +203,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
) )
if "workflow" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images
ADD COLUMN workflow_id TEXT;
-- TODO: This requires a migration:
-- FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE SET NULL;
"""
)
# Create the `images` table indices. # Create the `images` table indices.
self._cursor.execute( self._cursor.execute(
"""--sql """--sql

View File

@ -1,6 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from typing import Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
@ -11,29 +10,21 @@ from invokeai.app.models.image import (
InvalidOriginException, InvalidOriginException,
ResourceOrigin, ResourceOrigin,
) )
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.image_file_storage import ( from invokeai.app.services.image_file_storage import (
ImageFileDeleteException, ImageFileDeleteException,
ImageFileNotFoundException, ImageFileNotFoundException,
ImageFileSaveException, ImageFileSaveException,
ImageFileStorageBase,
) )
from invokeai.app.services.image_record_storage import ( from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException, ImageRecordDeleteException,
ImageRecordNotFoundException, ImageRecordNotFoundException,
ImageRecordSaveException, ImageRecordSaveException,
ImageRecordStorageBase,
OffsetPaginatedResults, OffsetPaginatedResults,
) )
from invokeai.app.services.item_storage import ItemStorageABC from invokeai.app.services.invoker import Invoker
from invokeai.app.services.models.image_record import ImageDTO, ImageRecord, ImageRecordChanges, image_record_to_dto from invokeai.app.services.models.image_record import ImageDTO, ImageRecord, ImageRecordChanges, image_record_to_dto
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
class ImageServiceABC(ABC): class ImageServiceABC(ABC):
"""High-level service for image management.""" """High-level service for image management."""
@ -150,42 +141,11 @@ class ImageServiceABC(ABC):
pass pass
class ImageServiceDependencies:
"""Service dependencies for the ImageService."""
image_records: ImageRecordStorageBase
image_files: ImageFileStorageBase
board_image_records: BoardImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
names: NameServiceBase
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
board_image_record_storage: BoardImageRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.image_records = image_record_storage
self.image_files = image_file_storage
self.board_image_records = board_image_record_storage
self.urls = url
self.logger = logger
self.names = names
self.graph_execution_manager = graph_execution_manager
class ImageService(ImageServiceABC): class ImageService(ImageServiceABC):
_services: ImageServiceDependencies __invoker: Invoker
def __init__(self, services: ImageServiceDependencies): def start(self, invoker: Invoker) -> None:
super().__init__() self.__invoker = invoker
self._services = services
def create( def create(
self, self,
@ -205,24 +165,13 @@ class ImageService(ImageServiceABC):
if image_category not in ImageCategory: if image_category not in ImageCategory:
raise InvalidImageCategoryException raise InvalidImageCategoryException
image_name = self._services.names.create_image_name() image_name = self.__invoker.services.names.create_image_name()
# TODO: Do we want to store the graph in the image at all? I don't think so...
# graph = None
# if session_id is not None:
# session_raw = self._services.graph_execution_manager.get_raw(session_id)
# if session_raw is not None:
# try:
# graph = get_metadata_graph_from_raw_session(session_raw)
# except Exception as e:
# self._services.logger.warn(f"Failed to parse session graph: {e}")
# graph = None
(width, height) = image.size (width, height) = image.size
try: try:
# TODO: Consider using a transaction here to ensure consistency between storage and database # TODO: Consider using a transaction here to ensure consistency between storage and database
self._services.image_records.save( self.__invoker.services.image_records.save(
# Non-nullable fields # Non-nullable fields
image_name=image_name, image_name=image_name,
image_origin=image_origin, image_origin=image_origin,
@ -237,20 +186,22 @@ class ImageService(ImageServiceABC):
session_id=session_id, session_id=session_id,
) )
if board_id is not None: if board_id is not None:
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name) self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow) self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow
)
image_dto = self.get_dto(image_name) image_dto = self.get_dto(image_name)
self._on_changed(image_dto) self._on_changed(image_dto)
return image_dto return image_dto
except ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to save image record") self.__invoker.services.logger.error("Failed to save image record")
raise raise
except ImageFileSaveException: except ImageFileSaveException:
self._services.logger.error("Failed to save image file") self.__invoker.services.logger.error("Failed to save image file")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error(f"Problem saving image record and file: {str(e)}") self.__invoker.services.logger.error(f"Problem saving image record and file: {str(e)}")
raise e raise e
def update( def update(
@ -259,101 +210,101 @@ class ImageService(ImageServiceABC):
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> ImageDTO: ) -> ImageDTO:
try: try:
self._services.image_records.update(image_name, changes) self.__invoker.services.image_records.update(image_name, changes)
image_dto = self.get_dto(image_name) image_dto = self.get_dto(image_name)
self._on_changed(image_dto) self._on_changed(image_dto)
return image_dto return image_dto
except ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to update image record") self.__invoker.services.logger.error("Failed to update image record")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem updating image record") self.__invoker.services.logger.error("Problem updating image record")
raise e raise e
def get_pil_image(self, image_name: str) -> PILImageType: def get_pil_image(self, image_name: str) -> PILImageType:
try: try:
return self._services.image_files.get(image_name) return self.__invoker.services.image_files.get(image_name)
except ImageFileNotFoundException: except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file") self.__invoker.services.logger.error("Failed to get image file")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image file") self.__invoker.services.logger.error("Problem getting image file")
raise e raise e
def get_record(self, image_name: str) -> ImageRecord: def get_record(self, image_name: str) -> ImageRecord:
try: try:
return self._services.image_records.get(image_name) return self.__invoker.services.image_records.get(image_name)
except ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self.__invoker.services.logger.error("Image record not found")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image record") self.__invoker.services.logger.error("Problem getting image record")
raise e raise e
def get_dto(self, image_name: str) -> ImageDTO: def get_dto(self, image_name: str) -> ImageDTO:
try: try:
image_record = self._services.image_records.get(image_name) image_record = self.__invoker.services.image_records.get(image_name)
image_dto = image_record_to_dto( image_dto = image_record_to_dto(
image_record, image_record,
self._services.urls.get_image_url(image_name), self.__invoker.services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True), self.__invoker.services.urls.get_image_url(image_name, True),
self._services.board_image_records.get_board_for_image(image_name), self.__invoker.services.board_image_records.get_board_for_image(image_name),
) )
return image_dto return image_dto
except ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self.__invoker.services.logger.error("Image record not found")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image DTO") self.__invoker.services.logger.error("Problem getting image DTO")
raise e raise e
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]: def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
try: try:
image_record = self._services.image_records.get(image_name) image_record = self.__invoker.services.image_records.get(image_name)
metadata = self._services.image_records.get_metadata(image_name) metadata = self.__invoker.services.image_records.get_metadata(image_name)
if not image_record.session_id: if not image_record.session_id:
return ImageMetadata(metadata=metadata) return ImageMetadata(metadata=metadata)
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id) session_raw = self.__invoker.services.graph_execution_manager.get_raw(image_record.session_id)
graph = None graph = None
if session_raw: if session_raw:
try: try:
graph = get_metadata_graph_from_raw_session(session_raw) graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e: except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}") self.__invoker.services.logger.warn(f"Failed to parse session graph: {e}")
graph = None graph = None
return ImageMetadata(graph=graph, metadata=metadata) return ImageMetadata(graph=graph, metadata=metadata)
except ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self.__invoker.services.logger.error("Image record not found")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image DTO") self.__invoker.services.logger.error("Problem getting image DTO")
raise e raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try: try:
return self._services.image_files.get_path(image_name, thumbnail) return self.__invoker.services.image_files.get_path(image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self.__invoker.services.logger.error("Problem getting image path")
raise e raise e
def validate_path(self, path: str) -> bool: def validate_path(self, path: str) -> bool:
try: try:
return self._services.image_files.validate_path(path) return self.__invoker.services.image_files.validate_path(path)
except Exception as e: except Exception as e:
self._services.logger.error("Problem validating image path") self.__invoker.services.logger.error("Problem validating image path")
raise e raise e
def get_url(self, image_name: str, thumbnail: bool = False) -> str: def get_url(self, image_name: str, thumbnail: bool = False) -> str:
try: try:
return self._services.urls.get_image_url(image_name, thumbnail) return self.__invoker.services.urls.get_image_url(image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self.__invoker.services.logger.error("Problem getting image path")
raise e raise e
def get_many( def get_many(
@ -366,7 +317,7 @@ class ImageService(ImageServiceABC):
board_id: Optional[str] = None, board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
try: try:
results = self._services.image_records.get_many( results = self.__invoker.services.image_records.get_many(
offset, offset,
limit, limit,
image_origin, image_origin,
@ -379,9 +330,9 @@ class ImageService(ImageServiceABC):
map( map(
lambda r: image_record_to_dto( lambda r: image_record_to_dto(
r, r,
self._services.urls.get_image_url(r.image_name), self.__invoker.services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True), self.__invoker.services.urls.get_image_url(r.image_name, True),
self._services.board_image_records.get_board_for_image(r.image_name), self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
), ),
results.items, results.items,
) )
@ -394,56 +345,56 @@ class ImageService(ImageServiceABC):
total=results.total, total=results.total,
) )
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting paginated image DTOs") self.__invoker.services.logger.error("Problem getting paginated image DTOs")
raise e raise e
def delete(self, image_name: str): def delete(self, image_name: str):
try: try:
self._services.image_files.delete(image_name) self.__invoker.services.image_files.delete(image_name)
self._services.image_records.delete(image_name) self.__invoker.services.image_records.delete(image_name)
self._on_deleted(image_name) self._on_deleted(image_name)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image record") self.__invoker.services.logger.error("Failed to delete image record")
raise raise
except ImageFileDeleteException: except ImageFileDeleteException:
self._services.logger.error("Failed to delete image file") self.__invoker.services.logger.error("Failed to delete image file")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem deleting image record and file") self.__invoker.services.logger.error("Problem deleting image record and file")
raise e raise e
def delete_images_on_board(self, board_id: str): def delete_images_on_board(self, board_id: str):
try: try:
image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id) image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
for image_name in image_names: for image_name in image_names:
self._services.image_files.delete(image_name) self.__invoker.services.image_files.delete(image_name)
self._services.image_records.delete_many(image_names) self.__invoker.services.image_records.delete_many(image_names)
for image_name in image_names: for image_name in image_names:
self._on_deleted(image_name) self._on_deleted(image_name)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records") self.__invoker.services.logger.error("Failed to delete image records")
raise raise
except ImageFileDeleteException: except ImageFileDeleteException:
self._services.logger.error("Failed to delete image files") self.__invoker.services.logger.error("Failed to delete image files")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem deleting image records and files") self.__invoker.services.logger.error("Problem deleting image records and files")
raise e raise e
def delete_intermediates(self) -> int: def delete_intermediates(self) -> int:
try: try:
image_names = self._services.image_records.delete_intermediates() image_names = self.__invoker.services.image_records.delete_intermediates()
count = len(image_names) count = len(image_names)
for image_name in image_names: for image_name in image_names:
self._services.image_files.delete(image_name) self.__invoker.services.image_files.delete(image_name)
self._on_deleted(image_name) self._on_deleted(image_name)
return count return count
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records") self.__invoker.services.logger.error("Failed to delete image records")
raise raise
except ImageFileDeleteException: except ImageFileDeleteException:
self._services.logger.error("Failed to delete image files") self.__invoker.services.logger.error("Failed to delete image files")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem deleting image records and files") self.__invoker.services.logger.error("Problem deleting image records and files")
raise e raise e

View File

@ -6,11 +6,15 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from logging import Logger from logging import Logger
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.board_record_storage import BoardRecordStorageBase
from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.image_file_storage import ImageFileStorageBase
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
from invokeai.app.services.images import ImageServiceABC from invokeai.app.services.images import ImageServiceABC
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invocation_queue import InvocationQueueABC from invokeai.app.services.invocation_queue import InvocationQueueABC
@ -19,8 +23,10 @@ if TYPE_CHECKING:
from invokeai.app.services.item_storage import ItemStorageABC from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.model_manager_service import ModelManagerServiceBase from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.urls import UrlServiceBase
class InvocationServices: class InvocationServices:
@ -28,12 +34,16 @@ class InvocationServices:
# TODO: Just forward-declared everything due to circular dependencies. Fix structure. # TODO: Just forward-declared everything due to circular dependencies. Fix structure.
board_images: "BoardImagesServiceABC" board_images: "BoardImagesServiceABC"
board_image_record_storage: "BoardImageRecordStorageBase"
boards: "BoardServiceABC" boards: "BoardServiceABC"
board_records: "BoardRecordStorageBase"
configuration: "InvokeAIAppConfig" configuration: "InvokeAIAppConfig"
events: "EventServiceBase" events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC[GraphExecutionState]" graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
graph_library: "ItemStorageABC[LibraryGraph]" graph_library: "ItemStorageABC[LibraryGraph]"
images: "ImageServiceABC" images: "ImageServiceABC"
image_records: "ImageRecordStorageBase"
image_files: "ImageFileStorageBase"
latents: "LatentsStorageBase" latents: "LatentsStorageBase"
logger: "Logger" logger: "Logger"
model_manager: "ModelManagerServiceBase" model_manager: "ModelManagerServiceBase"
@ -43,16 +53,22 @@ class InvocationServices:
session_queue: "SessionQueueBase" session_queue: "SessionQueueBase"
session_processor: "SessionProcessorBase" session_processor: "SessionProcessorBase"
invocation_cache: "InvocationCacheBase" invocation_cache: "InvocationCacheBase"
names: "NameServiceBase"
urls: "UrlServiceBase"
def __init__( def __init__(
self, self,
board_images: "BoardImagesServiceABC", board_images: "BoardImagesServiceABC",
board_image_records: "BoardImageRecordStorageBase",
boards: "BoardServiceABC", boards: "BoardServiceABC",
board_records: "BoardRecordStorageBase",
configuration: "InvokeAIAppConfig", configuration: "InvokeAIAppConfig",
events: "EventServiceBase", events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC[GraphExecutionState]", graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
graph_library: "ItemStorageABC[LibraryGraph]", graph_library: "ItemStorageABC[LibraryGraph]",
images: "ImageServiceABC", images: "ImageServiceABC",
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
latents: "LatentsStorageBase", latents: "LatentsStorageBase",
logger: "Logger", logger: "Logger",
model_manager: "ModelManagerServiceBase", model_manager: "ModelManagerServiceBase",
@ -62,14 +78,20 @@ class InvocationServices:
session_queue: "SessionQueueBase", session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase", session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase", invocation_cache: "InvocationCacheBase",
names: "NameServiceBase",
urls: "UrlServiceBase",
): ):
self.board_images = board_images self.board_images = board_images
self.board_image_records = board_image_records
self.boards = boards self.boards = boards
self.board_records = board_records
self.configuration = configuration self.configuration = configuration
self.events = events self.events = events
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager
self.graph_library = graph_library self.graph_library = graph_library
self.images = images self.images = images
self.image_files = image_files
self.image_records = image_records
self.latents = latents self.latents = latents
self.logger = logger self.logger = logger
self.model_manager = model_manager self.model_manager = model_manager
@ -79,3 +101,5 @@ class InvocationServices:
self.session_queue = session_queue self.session_queue = session_queue
self.session_processor = session_processor self.session_processor = session_processor
self.invocation_cache = invocation_cache self.invocation_cache = invocation_cache
self.names = names
self.urls = urls

View File

@ -38,12 +38,11 @@ import psutil
import torch import torch
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_management.model_cache import CacheStats from invokeai.backend.model_management.model_cache import CacheStats
from ..invocations.baseinvocation import BaseInvocation from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState from .model_manager_service import ModelManagerServiceBase
from .item_storage import ItemStorageABC
from .model_manager_service import ModelManagerService
# size of GIG in bytes # size of GIG in bytes
GIG = 1073741824 GIG = 1073741824
@ -72,7 +71,6 @@ class NodeLog:
class InvocationStatsServiceBase(ABC): class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics" "Abstract base class for recording node memory/time performance statistics"
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
# {graph_id => NodeLog} # {graph_id => NodeLog}
_stats: Dict[str, NodeLog] _stats: Dict[str, NodeLog]
_cache_stats: Dict[str, CacheStats] _cache_stats: Dict[str, CacheStats]
@ -80,10 +78,9 @@ class InvocationStatsServiceBase(ABC):
ram_changed: float ram_changed: float
@abstractmethod @abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): def __init__(self):
""" """
Initialize the InvocationStatsService and reset counters to zero Initialize the InvocationStatsService and reset counters to zero
:param graph_execution_manager: Graph execution manager for this session
""" """
pass pass
@ -158,14 +155,18 @@ class InvocationStatsService(InvocationStatsServiceBase):
"""Accumulate performance information about a running graph. Collects time spent in each node, """Accumulate performance information about a running graph. Collects time spent in each node,
as well as the maximum and current VRAM utilisation for CUDA systems""" as well as the maximum and current VRAM utilisation for CUDA systems"""
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): _invoker: Invoker
self.graph_execution_manager = graph_execution_manager
def __init__(self):
# {graph_id => NodeLog} # {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {} self._stats: Dict[str, NodeLog] = {}
self._cache_stats: Dict[str, CacheStats] = {} self._cache_stats: Dict[str, CacheStats] = {}
self.ram_used: float = 0.0 self.ram_used: float = 0.0
self.ram_changed: float = 0.0 self.ram_changed: float = 0.0
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
class StatsContext: class StatsContext:
"""Context manager for collecting statistics.""" """Context manager for collecting statistics."""
@ -174,13 +175,13 @@ class InvocationStatsService(InvocationStatsServiceBase):
graph_id: str graph_id: str
start_time: float start_time: float
ram_used: int ram_used: int
model_manager: ModelManagerService model_manager: ModelManagerServiceBase
def __init__( def __init__(
self, self,
invocation: BaseInvocation, invocation: BaseInvocation,
graph_id: str, graph_id: str,
model_manager: ModelManagerService, model_manager: ModelManagerServiceBase,
collector: "InvocationStatsServiceBase", collector: "InvocationStatsServiceBase",
): ):
"""Initialize statistics for this run.""" """Initialize statistics for this run."""
@ -217,12 +218,11 @@ class InvocationStatsService(InvocationStatsServiceBase):
self, self,
invocation: BaseInvocation, invocation: BaseInvocation,
graph_execution_state_id: str, graph_execution_state_id: str,
model_manager: ModelManagerService,
) -> StatsContext: ) -> StatsContext:
if not self._stats.get(graph_execution_state_id): # first time we're seeing this if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog() self._stats[graph_execution_state_id] = NodeLog()
self._cache_stats[graph_execution_state_id] = CacheStats() self._cache_stats[graph_execution_state_id] = CacheStats()
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self) return self.StatsContext(invocation, graph_execution_state_id, self._invoker.services.model_manager, self)
def reset_all_stats(self): def reset_all_stats(self):
"""Zero all statistics""" """Zero all statistics"""
@ -261,7 +261,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
errored = set() errored = set()
for graph_id, node_log in self._stats.items(): for graph_id, node_log in self._stats.items():
try: try:
current_graph_state = self.graph_execution_manager.get(graph_id) current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
except Exception: except Exception:
errored.add(graph_id) errored.add(graph_id)
continue continue

View File

@ -8,7 +8,6 @@ import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from ..models.exceptions import CanceledException from ..models.exceptions import CanceledException
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invocation_stats import InvocationStatsServiceBase
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
@ -37,7 +36,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event): def __process(self, stop_event: Event):
try: try:
self.__threadLimit.acquire() self.__threadLimit.acquire()
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
queue_item: Optional[InvocationQueueItem] = None queue_item: Optional[InvocationQueueItem] = None
while not stop_event.is_set(): while not stop_event.is_set():
@ -97,8 +95,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke # Invoke
try: try:
graph_id = graph_execution_state.id graph_id = graph_execution_state.id
model_manager = self.__invoker.services.model_manager with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
with statistics.collect_stats(invocation, graph_id, model_manager):
# use the internal invoke_internal(), which wraps the node's invoke() method, # use the internal invoke_internal(), which wraps the node's invoke() method,
# which handles a few things: # which handles a few things:
# - nodes that require a value, but get it only from a connection # - nodes that require a value, but get it only from a connection
@ -133,13 +130,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
source_node_id=source_node_id, source_node_id=source_node_id,
result=outputs.dict(), result=outputs.dict(),
) )
statistics.log_stats() self.__invoker.services.performance_statistics.log_stats()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
except CanceledException: except CanceledException:
statistics.reset_stats(graph_execution_state.id) self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
pass pass
except Exception as e: except Exception as e:
@ -164,7 +161,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__, error_type=e.__class__.__name__,
error=error, error=error,
) )
statistics.reset_stats(graph_execution_state.id) self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
pass pass
# Check queue to see if this is canceled, and skip if so # Check queue to see if this is canceled, and skip if so

View File

@ -29,6 +29,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
calc_session_count, calc_session_count,
prepare_values_to_insert, prepare_values_to_insert,
) )
from invokeai.app.services.shared.db import SqliteDatabase
from invokeai.app.services.shared.models import CursorPaginatedResults from invokeai.app.services.shared.models import CursorPaginatedResults
@ -45,13 +46,11 @@ class SqliteSessionQueue(SessionQueueBase):
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event) local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None: def __init__(self, db: SqliteDatabase) -> None:
super().__init__() super().__init__()
self.__conn = conn self.__lock = db.lock
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) self.__conn = db.conn
self.__conn.row_factory = sqlite3.Row
self.__cursor = self.__conn.cursor() self.__cursor = self.__conn.cursor()
self.__lock = lock
self._create_tables() self._create_tables()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool: def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:

View File

View File

@ -0,0 +1,46 @@
import sqlite3
import threading
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
class SqliteDatabase:
conn: sqlite3.Connection
lock: threading.Lock
_logger: Logger
_config: InvokeAIAppConfig
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
self._logger = logger
self._config = config
if self._config.use_memory_db:
location = ":memory:"
logger.info("Using in-memory database")
else:
db_path = self._config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
location = str(db_path)
self._logger.info(f"Using database at {location}")
self.conn = sqlite3.connect(location, check_same_thread=False)
self.lock = threading.Lock()
self.conn.row_factory = sqlite3.Row
if self._config.log_sql:
self.conn.set_trace_callback(self._logger.debug)
self.conn.execute("PRAGMA foreign_keys = ON;")
def clean(self) -> None:
try:
self.lock.acquire()
self.conn.execute("VACUUM;")
self.conn.commit()
self._logger.info("Cleaned database")
except Exception as e:
self._logger.error(f"Error cleaning database: {e}")
raise e
finally:
self.lock.release()

View File

@ -4,6 +4,8 @@ from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, parse_raw_as from pydantic import BaseModel, parse_raw_as
from invokeai.app.services.shared.db import SqliteDatabase
from .item_storage import ItemStorageABC, PaginatedResults from .item_storage import ItemStorageABC, PaginatedResults
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
@ -18,13 +20,13 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
_id_field: str _id_field: str
_lock: threading.Lock _lock: threading.Lock
def __init__(self, conn: sqlite3.Connection, table_name: str, lock: threading.Lock, id_field: str = "id"): def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
super().__init__() super().__init__()
self._lock = db.lock
self._conn = db.conn
self._table_name = table_name self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field self._id_field = id_field # TODO: validate that T has this field
self._lock = lock
self._conn = conn
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._create_table() self._create_table()

View File

@ -1,3 +0,0 @@
import threading
lock = threading.Lock()