diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 9db35fb5c3..aa17bf08d7 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -1,19 +1,19 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -import sqlite3 from logging import Logger from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage -from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies +from invokeai.app.services.board_images import BoardImagesService from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage -from invokeai.app.services.boards import BoardService, BoardServiceDependencies +from invokeai.app.services.boards import BoardService from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.image_record_storage import SqliteImageRecordStorage -from invokeai.app.services.images import ImageService, ImageServiceDependencies +from invokeai.app.services.images import ImageService from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue +from invokeai.app.services.shared.db import SqliteDatabase from invokeai.app.services.urls import LocalUrlService from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -29,7 +29,6 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto from ..services.model_manager_service import ModelManagerService from ..services.processor import DefaultInvocationProcessor from ..services.sqlite import SqliteItemStorage -from ..services.thread import lock from .events import FastAPIEventService @@ -63,100 +62,64 @@ class ApiDependencies: logger.info(f"Root directory = {str(config.root_path)}") logger.debug(f"Internet connectivity is {config.internet_available}") - events = FastAPIEventService(event_handler_id) - output_folder = config.output_path - # TODO: build a file/path manager? - if config.use_memory_db: - db_location = ":memory:" - else: - db_path = config.db_path - db_path.parent.mkdir(parents=True, exist_ok=True) - db_location = str(db_path) + db = SqliteDatabase(config, logger) - logger.info(f"Using database at {db_location}") - db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution + configuration = config + logger = logger - if config.log_sql: - db_conn.set_trace_callback(print) - db_conn.execute("PRAGMA foreign_keys = ON;") - - graph_execution_manager = SqliteItemStorage[GraphExecutionState]( - conn=db_conn, table_name="graph_executions", lock=lock - ) - - urls = LocalUrlService() - image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock) - image_file_storage = DiskImageFileStorage(f"{output_folder}/images") - names = SimpleNameService() + board_image_records = SqliteBoardImageRecordStorage(db=db) + board_images = BoardImagesService() + board_records = SqliteBoardRecordStorage(db=db) + boards = BoardService() + events = FastAPIEventService(event_handler_id) + graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions") + graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs") + image_files = DiskImageFileStorage(f"{output_folder}/images") + image_records = SqliteImageRecordStorage(db=db) + images = ImageService() + invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) - - board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock) - board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock) - - boards = BoardService( - services=BoardServiceDependencies( - board_image_record_storage=board_image_record_storage, - board_record_storage=board_record_storage, - image_record_storage=image_record_storage, - url=urls, - logger=logger, - ) - ) - - board_images = BoardImagesService( - services=BoardImagesServiceDependencies( - board_image_record_storage=board_image_record_storage, - board_record_storage=board_record_storage, - image_record_storage=image_record_storage, - url=urls, - logger=logger, - ) - ) - - images = ImageService( - services=ImageServiceDependencies( - board_image_record_storage=board_image_record_storage, - image_record_storage=image_record_storage, - image_file_storage=image_file_storage, - url=urls, - logger=logger, - names=names, - graph_execution_manager=graph_execution_manager, - ) - ) + model_manager = ModelManagerService(config, logger) + names = SimpleNameService() + performance_statistics = InvocationStatsService() + processor = DefaultInvocationProcessor() + queue = MemoryInvocationQueue() + session_processor = DefaultSessionProcessor() + session_queue = SqliteSessionQueue(db=db) + urls = LocalUrlService() services = InvocationServices( - model_manager=ModelManagerService(config, logger), - events=events, - latents=latents, - images=images, - boards=boards, + board_image_records=board_image_records, board_images=board_images, - queue=MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"), + board_records=board_records, + boards=boards, + configuration=configuration, + events=events, graph_execution_manager=graph_execution_manager, - processor=DefaultInvocationProcessor(), - configuration=config, - performance_statistics=InvocationStatsService(graph_execution_manager), + graph_library=graph_library, + image_files=image_files, + image_records=image_records, + images=images, + invocation_cache=invocation_cache, + latents=latents, logger=logger, - session_queue=SqliteSessionQueue(conn=db_conn, lock=lock), - session_processor=DefaultSessionProcessor(), - invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size), + model_manager=model_manager, + names=names, + performance_statistics=performance_statistics, + processor=processor, + queue=queue, + session_processor=session_processor, + session_queue=session_queue, + urls=urls, ) create_system_graphs(services.graph_library) ApiDependencies.invoker = Invoker(services) - try: - lock.acquire() - db_conn.execute("VACUUM;") - db_conn.commit() - logger.info("Cleaned database") - finally: - lock.release() + db.clean() @staticmethod def shutdown(): diff --git a/invokeai/app/services/board_image_record_storage.py b/invokeai/app/services/board_image_record_storage.py index c4d06ec131..e8ec803992 100644 --- a/invokeai/app/services/board_image_record_storage.py +++ b/invokeai/app/services/board_image_record_storage.py @@ -5,6 +5,7 @@ from typing import Optional, cast from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.models.image_record import ImageRecord, deserialize_image_record +from invokeai.app.services.shared.db import SqliteDatabase class BoardImageRecordStorageBase(ABC): @@ -57,13 +58,11 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): _cursor: sqlite3.Cursor _lock: threading.Lock - def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None: + def __init__(self, db: SqliteDatabase) -> None: super().__init__() - self._conn = conn - # Enable row factory to get rows as dictionaries (must be done before making the cursor!) - self._conn.row_factory = sqlite3.Row + self._lock = db.lock + self._conn = db.conn self._cursor = self._conn.cursor() - self._lock = lock try: self._lock.acquire() diff --git a/invokeai/app/services/board_images.py b/invokeai/app/services/board_images.py index 788722ec37..1cbc026dc9 100644 --- a/invokeai/app/services/board_images.py +++ b/invokeai/app/services/board_images.py @@ -1,12 +1,9 @@ from abc import ABC, abstractmethod -from logging import Logger from typing import Optional -from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase -from invokeai.app.services.board_record_storage import BoardRecord, BoardRecordStorageBase -from invokeai.app.services.image_record_storage import ImageRecordStorageBase +from invokeai.app.services.board_record_storage import BoardRecord +from invokeai.app.services.invoker import Invoker from invokeai.app.services.models.board_record import BoardDTO -from invokeai.app.services.urls import UrlServiceBase class BoardImagesServiceABC(ABC): @@ -46,60 +43,36 @@ class BoardImagesServiceABC(ABC): pass -class BoardImagesServiceDependencies: - """Service dependencies for the BoardImagesService.""" - - board_image_records: BoardImageRecordStorageBase - board_records: BoardRecordStorageBase - image_records: ImageRecordStorageBase - urls: UrlServiceBase - logger: Logger - - def __init__( - self, - board_image_record_storage: BoardImageRecordStorageBase, - image_record_storage: ImageRecordStorageBase, - board_record_storage: BoardRecordStorageBase, - url: UrlServiceBase, - logger: Logger, - ): - self.board_image_records = board_image_record_storage - self.image_records = image_record_storage - self.board_records = board_record_storage - self.urls = url - self.logger = logger - - class BoardImagesService(BoardImagesServiceABC): - _services: BoardImagesServiceDependencies + __invoker: Invoker - def __init__(self, services: BoardImagesServiceDependencies): - self._services = services + def start(self, invoker: Invoker) -> None: + self.__invoker = invoker def add_image_to_board( self, board_id: str, image_name: str, ) -> None: - self._services.board_image_records.add_image_to_board(board_id, image_name) + self.__invoker.services.board_image_records.add_image_to_board(board_id, image_name) def remove_image_from_board( self, image_name: str, ) -> None: - self._services.board_image_records.remove_image_from_board(image_name) + self.__invoker.services.board_image_records.remove_image_from_board(image_name) def get_all_board_image_names_for_board( self, board_id: str, ) -> list[str]: - return self._services.board_image_records.get_all_board_image_names_for_board(board_id) + return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id) def get_board_for_image( self, image_name: str, ) -> Optional[str]: - board_id = self._services.board_image_records.get_board_for_image(image_name) + board_id = self.__invoker.services.board_image_records.get_board_for_image(image_name) return board_id diff --git a/invokeai/app/services/board_record_storage.py b/invokeai/app/services/board_record_storage.py index c12a9c8eea..25d79a4214 100644 --- a/invokeai/app/services/board_record_storage.py +++ b/invokeai/app/services/board_record_storage.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Extra, Field from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record +from invokeai.app.services.shared.db import SqliteDatabase from invokeai.app.util.misc import uuid_string @@ -91,13 +92,11 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): _cursor: sqlite3.Cursor _lock: threading.Lock - def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None: + def __init__(self, db: SqliteDatabase) -> None: super().__init__() - self._conn = conn - # Enable row factory to get rows as dictionaries (must be done before making the cursor!) - self._conn.row_factory = sqlite3.Row + self._lock = db.lock + self._conn = db.conn self._cursor = self._conn.cursor() - self._lock = lock try: self._lock.acquire() diff --git a/invokeai/app/services/boards.py b/invokeai/app/services/boards.py index e7a516da65..36f9a3cf32 100644 --- a/invokeai/app/services/boards.py +++ b/invokeai/app/services/boards.py @@ -1,12 +1,10 @@ from abc import ABC, abstractmethod -from logging import Logger -from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.board_images import board_record_to_dto -from invokeai.app.services.board_record_storage import BoardChanges, BoardRecordStorageBase -from invokeai.app.services.image_record_storage import ImageRecordStorageBase, OffsetPaginatedResults +from invokeai.app.services.board_record_storage import BoardChanges +from invokeai.app.services.image_record_storage import OffsetPaginatedResults +from invokeai.app.services.invoker import Invoker from invokeai.app.services.models.board_record import BoardDTO -from invokeai.app.services.urls import UrlServiceBase class BoardServiceABC(ABC): @@ -62,51 +60,27 @@ class BoardServiceABC(ABC): pass -class BoardServiceDependencies: - """Service dependencies for the BoardService.""" - - board_image_records: BoardImageRecordStorageBase - board_records: BoardRecordStorageBase - image_records: ImageRecordStorageBase - urls: UrlServiceBase - logger: Logger - - def __init__( - self, - board_image_record_storage: BoardImageRecordStorageBase, - image_record_storage: ImageRecordStorageBase, - board_record_storage: BoardRecordStorageBase, - url: UrlServiceBase, - logger: Logger, - ): - self.board_image_records = board_image_record_storage - self.image_records = image_record_storage - self.board_records = board_record_storage - self.urls = url - self.logger = logger - - class BoardService(BoardServiceABC): - _services: BoardServiceDependencies + __invoker: Invoker - def __init__(self, services: BoardServiceDependencies): - self._services = services + def start(self, invoker: Invoker) -> None: + self.__invoker = invoker def create( self, board_name: str, ) -> BoardDTO: - board_record = self._services.board_records.save(board_name) + board_record = self.__invoker.services.board_records.save(board_name) return board_record_to_dto(board_record, None, 0) def get_dto(self, board_id: str) -> BoardDTO: - board_record = self._services.board_records.get(board_id) - cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id) + board_record = self.__invoker.services.board_records.get(board_id) + cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id) if cover_image: cover_image_name = cover_image.image_name else: cover_image_name = None - image_count = self._services.board_image_records.get_image_count_for_board(board_id) + image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id) return board_record_to_dto(board_record, cover_image_name, image_count) def update( @@ -114,45 +88,45 @@ class BoardService(BoardServiceABC): board_id: str, changes: BoardChanges, ) -> BoardDTO: - board_record = self._services.board_records.update(board_id, changes) - cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id) + board_record = self.__invoker.services.board_records.update(board_id, changes) + cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id) if cover_image: cover_image_name = cover_image.image_name else: cover_image_name = None - image_count = self._services.board_image_records.get_image_count_for_board(board_id) + image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id) return board_record_to_dto(board_record, cover_image_name, image_count) def delete(self, board_id: str) -> None: - self._services.board_records.delete(board_id) + self.__invoker.services.board_records.delete(board_id) def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]: - board_records = self._services.board_records.get_many(offset, limit) + board_records = self.__invoker.services.board_records.get_many(offset, limit) board_dtos = [] for r in board_records.items: - cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id) + cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id) if cover_image: cover_image_name = cover_image.image_name else: cover_image_name = None - image_count = self._services.board_image_records.get_image_count_for_board(r.board_id) + image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id) board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)) def get_all(self) -> list[BoardDTO]: - board_records = self._services.board_records.get_all() + board_records = self.__invoker.services.board_records.get_all() board_dtos = [] for r in board_records: - cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id) + cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id) if cover_image: cover_image_name = cover_image.image_name else: cover_image_name = None - image_count = self._services.board_image_records.get_image_count_for_board(r.board_id) + image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id) board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) return board_dtos diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 21afcaf0bf..77f3f6216d 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -10,6 +10,7 @@ from pydantic.generics import GenericModel from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.services.models.image_record import ImageRecord, ImageRecordChanges, deserialize_image_record +from invokeai.app.services.shared.db import SqliteDatabase T = TypeVar("T", bound=BaseModel) @@ -152,13 +153,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): _cursor: sqlite3.Cursor _lock: threading.Lock - def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None: + def __init__(self, db: SqliteDatabase) -> None: super().__init__() - self._conn = conn - # Enable row factory to get rows as dictionaries (must be done before making the cursor!) - self._conn.row_factory = sqlite3.Row + self._lock = db.lock + self._conn = db.conn self._cursor = self._conn.cursor() - self._lock = lock try: self._lock.acquire() @@ -204,6 +203,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """ ) + if "workflow" not in columns: + self._cursor.execute( + """--sql + ALTER TABLE images + ADD COLUMN workflow_id TEXT; + -- TODO: This requires a migration: + -- FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE SET NULL; + """ + ) + # Create the `images` table indices. self._cursor.execute( """--sql diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 08d5093a70..97fdb89118 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -from logging import Logger -from typing import TYPE_CHECKING, Callable, Optional +from typing import Callable, Optional from PIL.Image import Image as PILImageType @@ -11,29 +10,21 @@ from invokeai.app.models.image import ( InvalidOriginException, ResourceOrigin, ) -from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.image_file_storage import ( ImageFileDeleteException, ImageFileNotFoundException, ImageFileSaveException, - ImageFileStorageBase, ) from invokeai.app.services.image_record_storage import ( ImageRecordDeleteException, ImageRecordNotFoundException, ImageRecordSaveException, - ImageRecordStorageBase, OffsetPaginatedResults, ) -from invokeai.app.services.item_storage import ItemStorageABC +from invokeai.app.services.invoker import Invoker from invokeai.app.services.models.image_record import ImageDTO, ImageRecord, ImageRecordChanges, image_record_to_dto -from invokeai.app.services.resource_name import NameServiceBase -from invokeai.app.services.urls import UrlServiceBase from invokeai.app.util.metadata import get_metadata_graph_from_raw_session -if TYPE_CHECKING: - from invokeai.app.services.graph import GraphExecutionState - class ImageServiceABC(ABC): """High-level service for image management.""" @@ -150,42 +141,11 @@ class ImageServiceABC(ABC): pass -class ImageServiceDependencies: - """Service dependencies for the ImageService.""" - - image_records: ImageRecordStorageBase - image_files: ImageFileStorageBase - board_image_records: BoardImageRecordStorageBase - urls: UrlServiceBase - logger: Logger - names: NameServiceBase - graph_execution_manager: ItemStorageABC["GraphExecutionState"] - - def __init__( - self, - image_record_storage: ImageRecordStorageBase, - image_file_storage: ImageFileStorageBase, - board_image_record_storage: BoardImageRecordStorageBase, - url: UrlServiceBase, - logger: Logger, - names: NameServiceBase, - graph_execution_manager: ItemStorageABC["GraphExecutionState"], - ): - self.image_records = image_record_storage - self.image_files = image_file_storage - self.board_image_records = board_image_record_storage - self.urls = url - self.logger = logger - self.names = names - self.graph_execution_manager = graph_execution_manager - - class ImageService(ImageServiceABC): - _services: ImageServiceDependencies + __invoker: Invoker - def __init__(self, services: ImageServiceDependencies): - super().__init__() - self._services = services + def start(self, invoker: Invoker) -> None: + self.__invoker = invoker def create( self, @@ -205,24 +165,13 @@ class ImageService(ImageServiceABC): if image_category not in ImageCategory: raise InvalidImageCategoryException - image_name = self._services.names.create_image_name() - - # TODO: Do we want to store the graph in the image at all? I don't think so... - # graph = None - # if session_id is not None: - # session_raw = self._services.graph_execution_manager.get_raw(session_id) - # if session_raw is not None: - # try: - # graph = get_metadata_graph_from_raw_session(session_raw) - # except Exception as e: - # self._services.logger.warn(f"Failed to parse session graph: {e}") - # graph = None + image_name = self.__invoker.services.names.create_image_name() (width, height) = image.size try: # TODO: Consider using a transaction here to ensure consistency between storage and database - self._services.image_records.save( + self.__invoker.services.image_records.save( # Non-nullable fields image_name=image_name, image_origin=image_origin, @@ -237,20 +186,22 @@ class ImageService(ImageServiceABC): session_id=session_id, ) if board_id is not None: - self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name) - self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow) + self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name) + self.__invoker.services.image_files.save( + image_name=image_name, image=image, metadata=metadata, workflow=workflow + ) image_dto = self.get_dto(image_name) self._on_changed(image_dto) return image_dto except ImageRecordSaveException: - self._services.logger.error("Failed to save image record") + self.__invoker.services.logger.error("Failed to save image record") raise except ImageFileSaveException: - self._services.logger.error("Failed to save image file") + self.__invoker.services.logger.error("Failed to save image file") raise except Exception as e: - self._services.logger.error(f"Problem saving image record and file: {str(e)}") + self.__invoker.services.logger.error(f"Problem saving image record and file: {str(e)}") raise e def update( @@ -259,101 +210,101 @@ class ImageService(ImageServiceABC): changes: ImageRecordChanges, ) -> ImageDTO: try: - self._services.image_records.update(image_name, changes) + self.__invoker.services.image_records.update(image_name, changes) image_dto = self.get_dto(image_name) self._on_changed(image_dto) return image_dto except ImageRecordSaveException: - self._services.logger.error("Failed to update image record") + self.__invoker.services.logger.error("Failed to update image record") raise except Exception as e: - self._services.logger.error("Problem updating image record") + self.__invoker.services.logger.error("Problem updating image record") raise e def get_pil_image(self, image_name: str) -> PILImageType: try: - return self._services.image_files.get(image_name) + return self.__invoker.services.image_files.get(image_name) except ImageFileNotFoundException: - self._services.logger.error("Failed to get image file") + self.__invoker.services.logger.error("Failed to get image file") raise except Exception as e: - self._services.logger.error("Problem getting image file") + self.__invoker.services.logger.error("Problem getting image file") raise e def get_record(self, image_name: str) -> ImageRecord: try: - return self._services.image_records.get(image_name) + return self.__invoker.services.image_records.get(image_name) except ImageRecordNotFoundException: - self._services.logger.error("Image record not found") + self.__invoker.services.logger.error("Image record not found") raise except Exception as e: - self._services.logger.error("Problem getting image record") + self.__invoker.services.logger.error("Problem getting image record") raise e def get_dto(self, image_name: str) -> ImageDTO: try: - image_record = self._services.image_records.get(image_name) + image_record = self.__invoker.services.image_records.get(image_name) image_dto = image_record_to_dto( image_record, - self._services.urls.get_image_url(image_name), - self._services.urls.get_image_url(image_name, True), - self._services.board_image_records.get_board_for_image(image_name), + self.__invoker.services.urls.get_image_url(image_name), + self.__invoker.services.urls.get_image_url(image_name, True), + self.__invoker.services.board_image_records.get_board_for_image(image_name), ) return image_dto except ImageRecordNotFoundException: - self._services.logger.error("Image record not found") + self.__invoker.services.logger.error("Image record not found") raise except Exception as e: - self._services.logger.error("Problem getting image DTO") + self.__invoker.services.logger.error("Problem getting image DTO") raise e def get_metadata(self, image_name: str) -> Optional[ImageMetadata]: try: - image_record = self._services.image_records.get(image_name) - metadata = self._services.image_records.get_metadata(image_name) + image_record = self.__invoker.services.image_records.get(image_name) + metadata = self.__invoker.services.image_records.get_metadata(image_name) if not image_record.session_id: return ImageMetadata(metadata=metadata) - session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id) + session_raw = self.__invoker.services.graph_execution_manager.get_raw(image_record.session_id) graph = None if session_raw: try: graph = get_metadata_graph_from_raw_session(session_raw) except Exception as e: - self._services.logger.warn(f"Failed to parse session graph: {e}") + self.__invoker.services.logger.warn(f"Failed to parse session graph: {e}") graph = None return ImageMetadata(graph=graph, metadata=metadata) except ImageRecordNotFoundException: - self._services.logger.error("Image record not found") + self.__invoker.services.logger.error("Image record not found") raise except Exception as e: - self._services.logger.error("Problem getting image DTO") + self.__invoker.services.logger.error("Problem getting image DTO") raise e def get_path(self, image_name: str, thumbnail: bool = False) -> str: try: - return self._services.image_files.get_path(image_name, thumbnail) + return self.__invoker.services.image_files.get_path(image_name, thumbnail) except Exception as e: - self._services.logger.error("Problem getting image path") + self.__invoker.services.logger.error("Problem getting image path") raise e def validate_path(self, path: str) -> bool: try: - return self._services.image_files.validate_path(path) + return self.__invoker.services.image_files.validate_path(path) except Exception as e: - self._services.logger.error("Problem validating image path") + self.__invoker.services.logger.error("Problem validating image path") raise e def get_url(self, image_name: str, thumbnail: bool = False) -> str: try: - return self._services.urls.get_image_url(image_name, thumbnail) + return self.__invoker.services.urls.get_image_url(image_name, thumbnail) except Exception as e: - self._services.logger.error("Problem getting image path") + self.__invoker.services.logger.error("Problem getting image path") raise e def get_many( @@ -366,7 +317,7 @@ class ImageService(ImageServiceABC): board_id: Optional[str] = None, ) -> OffsetPaginatedResults[ImageDTO]: try: - results = self._services.image_records.get_many( + results = self.__invoker.services.image_records.get_many( offset, limit, image_origin, @@ -379,9 +330,9 @@ class ImageService(ImageServiceABC): map( lambda r: image_record_to_dto( r, - self._services.urls.get_image_url(r.image_name), - self._services.urls.get_image_url(r.image_name, True), - self._services.board_image_records.get_board_for_image(r.image_name), + self.__invoker.services.urls.get_image_url(r.image_name), + self.__invoker.services.urls.get_image_url(r.image_name, True), + self.__invoker.services.board_image_records.get_board_for_image(r.image_name), ), results.items, ) @@ -394,56 +345,56 @@ class ImageService(ImageServiceABC): total=results.total, ) except Exception as e: - self._services.logger.error("Problem getting paginated image DTOs") + self.__invoker.services.logger.error("Problem getting paginated image DTOs") raise e def delete(self, image_name: str): try: - self._services.image_files.delete(image_name) - self._services.image_records.delete(image_name) + self.__invoker.services.image_files.delete(image_name) + self.__invoker.services.image_records.delete(image_name) self._on_deleted(image_name) except ImageRecordDeleteException: - self._services.logger.error("Failed to delete image record") + self.__invoker.services.logger.error("Failed to delete image record") raise except ImageFileDeleteException: - self._services.logger.error("Failed to delete image file") + self.__invoker.services.logger.error("Failed to delete image file") raise except Exception as e: - self._services.logger.error("Problem deleting image record and file") + self.__invoker.services.logger.error("Problem deleting image record and file") raise e def delete_images_on_board(self, board_id: str): try: - image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id) + image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id) for image_name in image_names: - self._services.image_files.delete(image_name) - self._services.image_records.delete_many(image_names) + self.__invoker.services.image_files.delete(image_name) + self.__invoker.services.image_records.delete_many(image_names) for image_name in image_names: self._on_deleted(image_name) except ImageRecordDeleteException: - self._services.logger.error("Failed to delete image records") + self.__invoker.services.logger.error("Failed to delete image records") raise except ImageFileDeleteException: - self._services.logger.error("Failed to delete image files") + self.__invoker.services.logger.error("Failed to delete image files") raise except Exception as e: - self._services.logger.error("Problem deleting image records and files") + self.__invoker.services.logger.error("Problem deleting image records and files") raise e def delete_intermediates(self) -> int: try: - image_names = self._services.image_records.delete_intermediates() + image_names = self.__invoker.services.image_records.delete_intermediates() count = len(image_names) for image_name in image_names: - self._services.image_files.delete(image_name) + self.__invoker.services.image_files.delete(image_name) self._on_deleted(image_name) return count except ImageRecordDeleteException: - self._services.logger.error("Failed to delete image records") + self.__invoker.services.logger.error("Failed to delete image records") raise except ImageFileDeleteException: - self._services.logger.error("Failed to delete image files") + self.__invoker.services.logger.error("Failed to delete image files") raise except Exception as e: - self._services.logger.error("Problem deleting image records and files") + self.__invoker.services.logger.error("Problem deleting image records and files") raise e diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index e496ff80f2..09a5df0cd9 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -6,11 +6,15 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from logging import Logger + from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.board_images import BoardImagesServiceABC + from invokeai.app.services.board_record_storage import BoardRecordStorageBase from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.events import EventServiceBase from invokeai.app.services.graph import GraphExecutionState, LibraryGraph + from invokeai.app.services.image_file_storage import ImageFileStorageBase + from invokeai.app.services.image_record_storage import ImageRecordStorageBase from invokeai.app.services.images import ImageServiceABC from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase from invokeai.app.services.invocation_queue import InvocationQueueABC @@ -19,8 +23,10 @@ if TYPE_CHECKING: from invokeai.app.services.item_storage import ItemStorageABC from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.model_manager_service import ModelManagerServiceBase + from invokeai.app.services.resource_name import NameServiceBase from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase + from invokeai.app.services.urls import UrlServiceBase class InvocationServices: @@ -28,12 +34,16 @@ class InvocationServices: # TODO: Just forward-declared everything due to circular dependencies. Fix structure. board_images: "BoardImagesServiceABC" + board_image_record_storage: "BoardImageRecordStorageBase" boards: "BoardServiceABC" + board_records: "BoardRecordStorageBase" configuration: "InvokeAIAppConfig" events: "EventServiceBase" graph_execution_manager: "ItemStorageABC[GraphExecutionState]" graph_library: "ItemStorageABC[LibraryGraph]" images: "ImageServiceABC" + image_records: "ImageRecordStorageBase" + image_files: "ImageFileStorageBase" latents: "LatentsStorageBase" logger: "Logger" model_manager: "ModelManagerServiceBase" @@ -43,16 +53,22 @@ class InvocationServices: session_queue: "SessionQueueBase" session_processor: "SessionProcessorBase" invocation_cache: "InvocationCacheBase" + names: "NameServiceBase" + urls: "UrlServiceBase" def __init__( self, board_images: "BoardImagesServiceABC", + board_image_records: "BoardImageRecordStorageBase", boards: "BoardServiceABC", + board_records: "BoardRecordStorageBase", configuration: "InvokeAIAppConfig", events: "EventServiceBase", graph_execution_manager: "ItemStorageABC[GraphExecutionState]", graph_library: "ItemStorageABC[LibraryGraph]", images: "ImageServiceABC", + image_files: "ImageFileStorageBase", + image_records: "ImageRecordStorageBase", latents: "LatentsStorageBase", logger: "Logger", model_manager: "ModelManagerServiceBase", @@ -62,14 +78,20 @@ class InvocationServices: session_queue: "SessionQueueBase", session_processor: "SessionProcessorBase", invocation_cache: "InvocationCacheBase", + names: "NameServiceBase", + urls: "UrlServiceBase", ): self.board_images = board_images + self.board_image_records = board_image_records self.boards = boards + self.board_records = board_records self.configuration = configuration self.events = events self.graph_execution_manager = graph_execution_manager self.graph_library = graph_library self.images = images + self.image_files = image_files + self.image_records = image_records self.latents = latents self.logger = logger self.model_manager = model_manager @@ -79,3 +101,5 @@ class InvocationServices: self.session_queue = session_queue self.session_processor = session_processor self.invocation_cache = invocation_cache + self.names = names + self.urls = urls diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index 33932f73aa..6799031eff 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -38,12 +38,11 @@ import psutil import torch import invokeai.backend.util.logging as logger +from invokeai.app.services.invoker import Invoker from invokeai.backend.model_management.model_cache import CacheStats from ..invocations.baseinvocation import BaseInvocation -from .graph import GraphExecutionState -from .item_storage import ItemStorageABC -from .model_manager_service import ModelManagerService +from .model_manager_service import ModelManagerServiceBase # size of GIG in bytes GIG = 1073741824 @@ -72,7 +71,6 @@ class NodeLog: class InvocationStatsServiceBase(ABC): "Abstract base class for recording node memory/time performance statistics" - graph_execution_manager: ItemStorageABC["GraphExecutionState"] # {graph_id => NodeLog} _stats: Dict[str, NodeLog] _cache_stats: Dict[str, CacheStats] @@ -80,10 +78,9 @@ class InvocationStatsServiceBase(ABC): ram_changed: float @abstractmethod - def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): + def __init__(self): """ Initialize the InvocationStatsService and reset counters to zero - :param graph_execution_manager: Graph execution manager for this session """ pass @@ -158,14 +155,18 @@ class InvocationStatsService(InvocationStatsServiceBase): """Accumulate performance information about a running graph. Collects time spent in each node, as well as the maximum and current VRAM utilisation for CUDA systems""" - def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): - self.graph_execution_manager = graph_execution_manager + _invoker: Invoker + + def __init__(self): # {graph_id => NodeLog} self._stats: Dict[str, NodeLog] = {} self._cache_stats: Dict[str, CacheStats] = {} self.ram_used: float = 0.0 self.ram_changed: float = 0.0 + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + class StatsContext: """Context manager for collecting statistics.""" @@ -174,13 +175,13 @@ class InvocationStatsService(InvocationStatsServiceBase): graph_id: str start_time: float ram_used: int - model_manager: ModelManagerService + model_manager: ModelManagerServiceBase def __init__( self, invocation: BaseInvocation, graph_id: str, - model_manager: ModelManagerService, + model_manager: ModelManagerServiceBase, collector: "InvocationStatsServiceBase", ): """Initialize statistics for this run.""" @@ -217,12 +218,11 @@ class InvocationStatsService(InvocationStatsServiceBase): self, invocation: BaseInvocation, graph_execution_state_id: str, - model_manager: ModelManagerService, ) -> StatsContext: if not self._stats.get(graph_execution_state_id): # first time we're seeing this self._stats[graph_execution_state_id] = NodeLog() self._cache_stats[graph_execution_state_id] = CacheStats() - return self.StatsContext(invocation, graph_execution_state_id, model_manager, self) + return self.StatsContext(invocation, graph_execution_state_id, self._invoker.services.model_manager, self) def reset_all_stats(self): """Zero all statistics""" @@ -261,7 +261,7 @@ class InvocationStatsService(InvocationStatsServiceBase): errored = set() for graph_id, node_log in self._stats.items(): try: - current_graph_state = self.graph_execution_manager.get(graph_id) + current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id) except Exception: errored.add(graph_id) continue diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index b4c894c52d..226920bdaf 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -8,7 +8,6 @@ import invokeai.backend.util.logging as logger from ..invocations.baseinvocation import InvocationContext from ..models.exceptions import CanceledException from .invocation_queue import InvocationQueueItem -from .invocation_stats import InvocationStatsServiceBase from .invoker import InvocationProcessorABC, Invoker @@ -37,7 +36,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC): def __process(self, stop_event: Event): try: self.__threadLimit.acquire() - statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics queue_item: Optional[InvocationQueueItem] = None while not stop_event.is_set(): @@ -97,8 +95,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): # Invoke try: graph_id = graph_execution_state.id - model_manager = self.__invoker.services.model_manager - with statistics.collect_stats(invocation, graph_id, model_manager): + with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id): # use the internal invoke_internal(), which wraps the node's invoke() method, # which handles a few things: # - nodes that require a value, but get it only from a connection @@ -133,13 +130,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC): source_node_id=source_node_id, result=outputs.dict(), ) - statistics.log_stats() + self.__invoker.services.performance_statistics.log_stats() except KeyboardInterrupt: pass except CanceledException: - statistics.reset_stats(graph_execution_state.id) + self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id) pass except Exception as e: @@ -164,7 +161,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error_type=e.__class__.__name__, error=error, ) - statistics.reset_stats(graph_execution_state.id) + self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id) pass # Check queue to see if this is canceled, and skip if so diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index f995576311..674593b550 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -29,6 +29,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( calc_session_count, prepare_values_to_insert, ) +from invokeai.app.services.shared.db import SqliteDatabase from invokeai.app.services.shared.models import CursorPaginatedResults @@ -45,13 +46,11 @@ class SqliteSessionQueue(SessionQueueBase): local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event) self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") - def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None: + def __init__(self, db: SqliteDatabase) -> None: super().__init__() - self.__conn = conn - # Enable row factory to get rows as dictionaries (must be done before making the cursor!) - self.__conn.row_factory = sqlite3.Row + self.__lock = db.lock + self.__conn = db.conn self.__cursor = self.__conn.cursor() - self.__lock = lock self._create_tables() def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool: diff --git a/invokeai/app/services/shared/__init__.py b/invokeai/app/services/shared/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/app/services/shared/db.py b/invokeai/app/services/shared/db.py new file mode 100644 index 0000000000..6b3b86f25f --- /dev/null +++ b/invokeai/app/services/shared/db.py @@ -0,0 +1,46 @@ +import sqlite3 +import threading +from logging import Logger + +from invokeai.app.services.config import InvokeAIAppConfig + + +class SqliteDatabase: + conn: sqlite3.Connection + lock: threading.Lock + _logger: Logger + _config: InvokeAIAppConfig + + def __init__(self, config: InvokeAIAppConfig, logger: Logger): + self._logger = logger + self._config = config + + if self._config.use_memory_db: + location = ":memory:" + logger.info("Using in-memory database") + else: + db_path = self._config.db_path + db_path.parent.mkdir(parents=True, exist_ok=True) + location = str(db_path) + self._logger.info(f"Using database at {location}") + + self.conn = sqlite3.connect(location, check_same_thread=False) + self.lock = threading.Lock() + self.conn.row_factory = sqlite3.Row + + if self._config.log_sql: + self.conn.set_trace_callback(self._logger.debug) + + self.conn.execute("PRAGMA foreign_keys = ON;") + + def clean(self) -> None: + try: + self.lock.acquire() + self.conn.execute("VACUUM;") + self.conn.commit() + self._logger.info("Cleaned database") + except Exception as e: + self._logger.error(f"Error cleaning database: {e}") + raise e + finally: + self.lock.release() diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index 63f3356b3c..989fa5132e 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -4,6 +4,8 @@ from typing import Generic, Optional, TypeVar, get_args from pydantic import BaseModel, parse_raw_as +from invokeai.app.services.shared.db import SqliteDatabase + from .item_storage import ItemStorageABC, PaginatedResults T = TypeVar("T", bound=BaseModel) @@ -18,13 +20,13 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): _id_field: str _lock: threading.Lock - def __init__(self, conn: sqlite3.Connection, table_name: str, lock: threading.Lock, id_field: str = "id"): + def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"): super().__init__() + self._lock = db.lock + self._conn = db.conn self._table_name = table_name self._id_field = id_field # TODO: validate that T has this field - self._lock = lock - self._conn = conn self._cursor = self._conn.cursor() self._create_table() diff --git a/invokeai/app/services/thread.py b/invokeai/app/services/thread.py deleted file mode 100644 index 3fd88295b1..0000000000 --- a/invokeai/app/services/thread.py +++ /dev/null @@ -1,3 +0,0 @@ -import threading - -lock = threading.Lock()