mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(backend): organise service dependencies
**Service Dependencies** Services that depend on other services now access those services via the `Invoker` object. This object is provided to the service as a kwarg to its `start()` method. Until now, most services did not utilize this feature, and several services required their dependencies to be initialized and passed in on init. Additionally, _all_ services are now registered as invocation services - including the low-level services. This obviates issues with inter-dependent services we would otherwise experience as we add workflow storage. **Database Access** Previously, we were passing in a separate sqlite connection and corresponding lock as args to services in their init. A good amount of posturing was done in each service that uses the db. These objects, along with the sqlite startup and cleanup logic, is now abstracted into a simple `SqliteDatabase` class. This creates the shared connection and lock objects, enables foreign keys, and provides a `clean()` method to do startup db maintenance. This is not a service as it's only used by sqlite services.
This commit is contained in:
parent
10fac5c085
commit
2a35d93a4d
@ -1,19 +1,19 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
||||||
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
|
from invokeai.app.services.board_images import BoardImagesService
|
||||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
from invokeai.app.services.boards import BoardService
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService
|
||||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||||
|
from invokeai.app.services.shared.db import SqliteDatabase
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
@ -29,7 +29,6 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto
|
|||||||
from ..services.model_manager_service import ModelManagerService
|
from ..services.model_manager_service import ModelManagerService
|
||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.thread import lock
|
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -63,100 +62,64 @@ class ApiDependencies:
|
|||||||
logger.info(f"Root directory = {str(config.root_path)}")
|
logger.info(f"Root directory = {str(config.root_path)}")
|
||||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
|
||||||
|
|
||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
db = SqliteDatabase(config, logger)
|
||||||
if config.use_memory_db:
|
|
||||||
db_location = ":memory:"
|
|
||||||
else:
|
|
||||||
db_path = config.db_path
|
|
||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
db_location = str(db_path)
|
|
||||||
|
|
||||||
logger.info(f"Using database at {db_location}")
|
configuration = config
|
||||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
logger = logger
|
||||||
|
|
||||||
if config.log_sql:
|
board_image_records = SqliteBoardImageRecordStorage(db=db)
|
||||||
db_conn.set_trace_callback(print)
|
board_images = BoardImagesService()
|
||||||
db_conn.execute("PRAGMA foreign_keys = ON;")
|
board_records = SqliteBoardRecordStorage(db=db)
|
||||||
|
boards = BoardService()
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
events = FastAPIEventService(event_handler_id)
|
||||||
conn=db_conn, table_name="graph_executions", lock=lock
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||||
)
|
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
|
||||||
|
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
urls = LocalUrlService()
|
image_records = SqliteImageRecordStorage(db=db)
|
||||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
|
images = ImageService()
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||||
names = SimpleNameService()
|
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
|
model_manager = ModelManagerService(config, logger)
|
||||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
|
names = SimpleNameService()
|
||||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
|
performance_statistics = InvocationStatsService()
|
||||||
|
processor = DefaultInvocationProcessor()
|
||||||
boards = BoardService(
|
queue = MemoryInvocationQueue()
|
||||||
services=BoardServiceDependencies(
|
session_processor = DefaultSessionProcessor()
|
||||||
board_image_record_storage=board_image_record_storage,
|
session_queue = SqliteSessionQueue(db=db)
|
||||||
board_record_storage=board_record_storage,
|
urls = LocalUrlService()
|
||||||
image_record_storage=image_record_storage,
|
|
||||||
url=urls,
|
|
||||||
logger=logger,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
board_images = BoardImagesService(
|
|
||||||
services=BoardImagesServiceDependencies(
|
|
||||||
board_image_record_storage=board_image_record_storage,
|
|
||||||
board_record_storage=board_record_storage,
|
|
||||||
image_record_storage=image_record_storage,
|
|
||||||
url=urls,
|
|
||||||
logger=logger,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
images = ImageService(
|
|
||||||
services=ImageServiceDependencies(
|
|
||||||
board_image_record_storage=board_image_record_storage,
|
|
||||||
image_record_storage=image_record_storage,
|
|
||||||
image_file_storage=image_file_storage,
|
|
||||||
url=urls,
|
|
||||||
logger=logger,
|
|
||||||
names=names,
|
|
||||||
graph_execution_manager=graph_execution_manager,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=ModelManagerService(config, logger),
|
board_image_records=board_image_records,
|
||||||
events=events,
|
|
||||||
latents=latents,
|
|
||||||
images=images,
|
|
||||||
boards=boards,
|
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
board_records=board_records,
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"),
|
boards=boards,
|
||||||
|
configuration=configuration,
|
||||||
|
events=events,
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
graph_library=graph_library,
|
||||||
configuration=config,
|
image_files=image_files,
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
image_records=image_records,
|
||||||
|
images=images,
|
||||||
|
invocation_cache=invocation_cache,
|
||||||
|
latents=latents,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
|
model_manager=model_manager,
|
||||||
session_processor=DefaultSessionProcessor(),
|
names=names,
|
||||||
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
performance_statistics=performance_statistics,
|
||||||
|
processor=processor,
|
||||||
|
queue=queue,
|
||||||
|
session_processor=session_processor,
|
||||||
|
session_queue=session_queue,
|
||||||
|
urls=urls,
|
||||||
)
|
)
|
||||||
|
|
||||||
create_system_graphs(services.graph_library)
|
create_system_graphs(services.graph_library)
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
|
||||||
try:
|
db.clean()
|
||||||
lock.acquire()
|
|
||||||
db_conn.execute("VACUUM;")
|
|
||||||
db_conn.commit()
|
|
||||||
logger.info("Cleaned database")
|
|
||||||
finally:
|
|
||||||
lock.release()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shutdown():
|
def shutdown():
|
||||||
|
@ -5,6 +5,7 @@ from typing import Optional, cast
|
|||||||
|
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
from invokeai.app.services.models.image_record import ImageRecord, deserialize_image_record
|
from invokeai.app.services.models.image_record import ImageRecord, deserialize_image_record
|
||||||
|
from invokeai.app.services.shared.db import SqliteDatabase
|
||||||
|
|
||||||
|
|
||||||
class BoardImageRecordStorageBase(ABC):
|
class BoardImageRecordStorageBase(ABC):
|
||||||
@ -57,13 +58,11 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._conn = conn
|
self._lock = db.lock
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
self._conn = db.conn
|
||||||
self._conn.row_factory = sqlite3.Row
|
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = lock
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from logging import Logger
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
from invokeai.app.services.board_record_storage import BoardRecord
|
||||||
from invokeai.app.services.board_record_storage import BoardRecord, BoardRecordStorageBase
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
|
|
||||||
from invokeai.app.services.models.board_record import BoardDTO
|
from invokeai.app.services.models.board_record import BoardDTO
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
|
||||||
|
|
||||||
|
|
||||||
class BoardImagesServiceABC(ABC):
|
class BoardImagesServiceABC(ABC):
|
||||||
@ -46,60 +43,36 @@ class BoardImagesServiceABC(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BoardImagesServiceDependencies:
|
|
||||||
"""Service dependencies for the BoardImagesService."""
|
|
||||||
|
|
||||||
board_image_records: BoardImageRecordStorageBase
|
|
||||||
board_records: BoardRecordStorageBase
|
|
||||||
image_records: ImageRecordStorageBase
|
|
||||||
urls: UrlServiceBase
|
|
||||||
logger: Logger
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
board_image_record_storage: BoardImageRecordStorageBase,
|
|
||||||
image_record_storage: ImageRecordStorageBase,
|
|
||||||
board_record_storage: BoardRecordStorageBase,
|
|
||||||
url: UrlServiceBase,
|
|
||||||
logger: Logger,
|
|
||||||
):
|
|
||||||
self.board_image_records = board_image_record_storage
|
|
||||||
self.image_records = image_record_storage
|
|
||||||
self.board_records = board_record_storage
|
|
||||||
self.urls = url
|
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
|
|
||||||
class BoardImagesService(BoardImagesServiceABC):
|
class BoardImagesService(BoardImagesServiceABC):
|
||||||
_services: BoardImagesServiceDependencies
|
__invoker: Invoker
|
||||||
|
|
||||||
def __init__(self, services: BoardImagesServiceDependencies):
|
def start(self, invoker: Invoker) -> None:
|
||||||
self._services = services
|
self.__invoker = invoker
|
||||||
|
|
||||||
def add_image_to_board(
|
def add_image_to_board(
|
||||||
self,
|
self,
|
||||||
board_id: str,
|
board_id: str,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._services.board_image_records.add_image_to_board(board_id, image_name)
|
self.__invoker.services.board_image_records.add_image_to_board(board_id, image_name)
|
||||||
|
|
||||||
def remove_image_from_board(
|
def remove_image_from_board(
|
||||||
self,
|
self,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._services.board_image_records.remove_image_from_board(image_name)
|
self.__invoker.services.board_image_records.remove_image_from_board(image_name)
|
||||||
|
|
||||||
def get_all_board_image_names_for_board(
|
def get_all_board_image_names_for_board(
|
||||||
self,
|
self,
|
||||||
board_id: str,
|
board_id: str,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
return self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||||
|
|
||||||
def get_board_for_image(
|
def get_board_for_image(
|
||||||
self,
|
self,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
board_id = self._services.board_image_records.get_board_for_image(image_name)
|
board_id = self.__invoker.services.board_image_records.get_board_for_image(image_name)
|
||||||
return board_id
|
return board_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel, Extra, Field
|
|||||||
|
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
|
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
|
||||||
|
from invokeai.app.services.shared.db import SqliteDatabase
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
|
||||||
@ -91,13 +92,11 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._conn = conn
|
self._lock = db.lock
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
self._conn = db.conn
|
||||||
self._conn.row_factory = sqlite3.Row
|
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = lock
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from logging import Logger
|
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
|
||||||
from invokeai.app.services.board_images import board_record_to_dto
|
from invokeai.app.services.board_images import board_record_to_dto
|
||||||
from invokeai.app.services.board_record_storage import BoardChanges, BoardRecordStorageBase
|
from invokeai.app.services.board_record_storage import BoardChanges
|
||||||
from invokeai.app.services.image_record_storage import ImageRecordStorageBase, OffsetPaginatedResults
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.models.board_record import BoardDTO
|
from invokeai.app.services.models.board_record import BoardDTO
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
|
||||||
|
|
||||||
|
|
||||||
class BoardServiceABC(ABC):
|
class BoardServiceABC(ABC):
|
||||||
@ -62,51 +60,27 @@ class BoardServiceABC(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BoardServiceDependencies:
|
|
||||||
"""Service dependencies for the BoardService."""
|
|
||||||
|
|
||||||
board_image_records: BoardImageRecordStorageBase
|
|
||||||
board_records: BoardRecordStorageBase
|
|
||||||
image_records: ImageRecordStorageBase
|
|
||||||
urls: UrlServiceBase
|
|
||||||
logger: Logger
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
board_image_record_storage: BoardImageRecordStorageBase,
|
|
||||||
image_record_storage: ImageRecordStorageBase,
|
|
||||||
board_record_storage: BoardRecordStorageBase,
|
|
||||||
url: UrlServiceBase,
|
|
||||||
logger: Logger,
|
|
||||||
):
|
|
||||||
self.board_image_records = board_image_record_storage
|
|
||||||
self.image_records = image_record_storage
|
|
||||||
self.board_records = board_record_storage
|
|
||||||
self.urls = url
|
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
|
|
||||||
class BoardService(BoardServiceABC):
|
class BoardService(BoardServiceABC):
|
||||||
_services: BoardServiceDependencies
|
__invoker: Invoker
|
||||||
|
|
||||||
def __init__(self, services: BoardServiceDependencies):
|
def start(self, invoker: Invoker) -> None:
|
||||||
self._services = services
|
self.__invoker = invoker
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
board_name: str,
|
board_name: str,
|
||||||
) -> BoardDTO:
|
) -> BoardDTO:
|
||||||
board_record = self._services.board_records.save(board_name)
|
board_record = self.__invoker.services.board_records.save(board_name)
|
||||||
return board_record_to_dto(board_record, None, 0)
|
return board_record_to_dto(board_record, None, 0)
|
||||||
|
|
||||||
def get_dto(self, board_id: str) -> BoardDTO:
|
def get_dto(self, board_id: str) -> BoardDTO:
|
||||||
board_record = self._services.board_records.get(board_id)
|
board_record = self.__invoker.services.board_records.get(board_id)
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||||
if cover_image:
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
|
||||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
@ -114,45 +88,45 @@ class BoardService(BoardServiceABC):
|
|||||||
board_id: str,
|
board_id: str,
|
||||||
changes: BoardChanges,
|
changes: BoardChanges,
|
||||||
) -> BoardDTO:
|
) -> BoardDTO:
|
||||||
board_record = self._services.board_records.update(board_id, changes)
|
board_record = self.__invoker.services.board_records.update(board_id, changes)
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||||
if cover_image:
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
|
||||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||||
|
|
||||||
def delete(self, board_id: str) -> None:
|
def delete(self, board_id: str) -> None:
|
||||||
self._services.board_records.delete(board_id)
|
self.__invoker.services.board_records.delete(board_id)
|
||||||
|
|
||||||
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
||||||
board_records = self._services.board_records.get_many(offset, limit)
|
board_records = self.__invoker.services.board_records.get_many(offset, limit)
|
||||||
board_dtos = []
|
board_dtos = []
|
||||||
for r in board_records.items:
|
for r in board_records.items:
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||||
if cover_image:
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
|
||||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
||||||
|
|
||||||
def get_all(self) -> list[BoardDTO]:
|
def get_all(self) -> list[BoardDTO]:
|
||||||
board_records = self._services.board_records.get_all()
|
board_records = self.__invoker.services.board_records.get_all()
|
||||||
board_dtos = []
|
board_dtos = []
|
||||||
for r in board_records:
|
for r in board_records:
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||||
if cover_image:
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
|
||||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
return board_dtos
|
return board_dtos
|
||||||
|
@ -10,6 +10,7 @@ from pydantic.generics import GenericModel
|
|||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.services.models.image_record import ImageRecord, ImageRecordChanges, deserialize_image_record
|
from invokeai.app.services.models.image_record import ImageRecord, ImageRecordChanges, deserialize_image_record
|
||||||
|
from invokeai.app.services.shared.db import SqliteDatabase
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
@ -152,13 +153,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._conn = conn
|
self._lock = db.lock
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
self._conn = db.conn
|
||||||
self._conn.row_factory = sqlite3.Row
|
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = lock
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
@ -204,6 +203,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "workflow" not in columns:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
ALTER TABLE images
|
||||||
|
ADD COLUMN workflow_id TEXT;
|
||||||
|
-- TODO: This requires a migration:
|
||||||
|
-- FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE SET NULL;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# Create the `images` table indices.
|
# Create the `images` table indices.
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from logging import Logger
|
from typing import Callable, Optional
|
||||||
from typing import TYPE_CHECKING, Callable, Optional
|
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
@ -11,29 +10,21 @@ from invokeai.app.models.image import (
|
|||||||
InvalidOriginException,
|
InvalidOriginException,
|
||||||
ResourceOrigin,
|
ResourceOrigin,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
|
||||||
from invokeai.app.services.image_file_storage import (
|
from invokeai.app.services.image_file_storage import (
|
||||||
ImageFileDeleteException,
|
ImageFileDeleteException,
|
||||||
ImageFileNotFoundException,
|
ImageFileNotFoundException,
|
||||||
ImageFileSaveException,
|
ImageFileSaveException,
|
||||||
ImageFileStorageBase,
|
|
||||||
)
|
)
|
||||||
from invokeai.app.services.image_record_storage import (
|
from invokeai.app.services.image_record_storage import (
|
||||||
ImageRecordDeleteException,
|
ImageRecordDeleteException,
|
||||||
ImageRecordNotFoundException,
|
ImageRecordNotFoundException,
|
||||||
ImageRecordSaveException,
|
ImageRecordSaveException,
|
||||||
ImageRecordStorageBase,
|
|
||||||
OffsetPaginatedResults,
|
OffsetPaginatedResults,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.item_storage import ItemStorageABC
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.models.image_record import ImageDTO, ImageRecord, ImageRecordChanges, image_record_to_dto
|
from invokeai.app.services.models.image_record import ImageDTO, ImageRecord, ImageRecordChanges, image_record_to_dto
|
||||||
from invokeai.app.services.resource_name import NameServiceBase
|
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
|
||||||
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.services.graph import GraphExecutionState
|
|
||||||
|
|
||||||
|
|
||||||
class ImageServiceABC(ABC):
|
class ImageServiceABC(ABC):
|
||||||
"""High-level service for image management."""
|
"""High-level service for image management."""
|
||||||
@ -150,42 +141,11 @@ class ImageServiceABC(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ImageServiceDependencies:
|
|
||||||
"""Service dependencies for the ImageService."""
|
|
||||||
|
|
||||||
image_records: ImageRecordStorageBase
|
|
||||||
image_files: ImageFileStorageBase
|
|
||||||
board_image_records: BoardImageRecordStorageBase
|
|
||||||
urls: UrlServiceBase
|
|
||||||
logger: Logger
|
|
||||||
names: NameServiceBase
|
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
image_record_storage: ImageRecordStorageBase,
|
|
||||||
image_file_storage: ImageFileStorageBase,
|
|
||||||
board_image_record_storage: BoardImageRecordStorageBase,
|
|
||||||
url: UrlServiceBase,
|
|
||||||
logger: Logger,
|
|
||||||
names: NameServiceBase,
|
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
|
||||||
):
|
|
||||||
self.image_records = image_record_storage
|
|
||||||
self.image_files = image_file_storage
|
|
||||||
self.board_image_records = board_image_record_storage
|
|
||||||
self.urls = url
|
|
||||||
self.logger = logger
|
|
||||||
self.names = names
|
|
||||||
self.graph_execution_manager = graph_execution_manager
|
|
||||||
|
|
||||||
|
|
||||||
class ImageService(ImageServiceABC):
|
class ImageService(ImageServiceABC):
|
||||||
_services: ImageServiceDependencies
|
__invoker: Invoker
|
||||||
|
|
||||||
def __init__(self, services: ImageServiceDependencies):
|
def start(self, invoker: Invoker) -> None:
|
||||||
super().__init__()
|
self.__invoker = invoker
|
||||||
self._services = services
|
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
@ -205,24 +165,13 @@ class ImageService(ImageServiceABC):
|
|||||||
if image_category not in ImageCategory:
|
if image_category not in ImageCategory:
|
||||||
raise InvalidImageCategoryException
|
raise InvalidImageCategoryException
|
||||||
|
|
||||||
image_name = self._services.names.create_image_name()
|
image_name = self.__invoker.services.names.create_image_name()
|
||||||
|
|
||||||
# TODO: Do we want to store the graph in the image at all? I don't think so...
|
|
||||||
# graph = None
|
|
||||||
# if session_id is not None:
|
|
||||||
# session_raw = self._services.graph_execution_manager.get_raw(session_id)
|
|
||||||
# if session_raw is not None:
|
|
||||||
# try:
|
|
||||||
# graph = get_metadata_graph_from_raw_session(session_raw)
|
|
||||||
# except Exception as e:
|
|
||||||
# self._services.logger.warn(f"Failed to parse session graph: {e}")
|
|
||||||
# graph = None
|
|
||||||
|
|
||||||
(width, height) = image.size
|
(width, height) = image.size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||||
self._services.image_records.save(
|
self.__invoker.services.image_records.save(
|
||||||
# Non-nullable fields
|
# Non-nullable fields
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_origin=image_origin,
|
image_origin=image_origin,
|
||||||
@ -237,20 +186,22 @@ class ImageService(ImageServiceABC):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
if board_id is not None:
|
if board_id is not None:
|
||||||
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||||
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
|
self.__invoker.services.image_files.save(
|
||||||
|
image_name=image_name, image=image, metadata=metadata, workflow=workflow
|
||||||
|
)
|
||||||
image_dto = self.get_dto(image_name)
|
image_dto = self.get_dto(image_name)
|
||||||
|
|
||||||
self._on_changed(image_dto)
|
self._on_changed(image_dto)
|
||||||
return image_dto
|
return image_dto
|
||||||
except ImageRecordSaveException:
|
except ImageRecordSaveException:
|
||||||
self._services.logger.error("Failed to save image record")
|
self.__invoker.services.logger.error("Failed to save image record")
|
||||||
raise
|
raise
|
||||||
except ImageFileSaveException:
|
except ImageFileSaveException:
|
||||||
self._services.logger.error("Failed to save image file")
|
self.__invoker.services.logger.error("Failed to save image file")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error(f"Problem saving image record and file: {str(e)}")
|
self.__invoker.services.logger.error(f"Problem saving image record and file: {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
@ -259,101 +210,101 @@ class ImageService(ImageServiceABC):
|
|||||||
changes: ImageRecordChanges,
|
changes: ImageRecordChanges,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
try:
|
try:
|
||||||
self._services.image_records.update(image_name, changes)
|
self.__invoker.services.image_records.update(image_name, changes)
|
||||||
image_dto = self.get_dto(image_name)
|
image_dto = self.get_dto(image_name)
|
||||||
self._on_changed(image_dto)
|
self._on_changed(image_dto)
|
||||||
return image_dto
|
return image_dto
|
||||||
except ImageRecordSaveException:
|
except ImageRecordSaveException:
|
||||||
self._services.logger.error("Failed to update image record")
|
self.__invoker.services.logger.error("Failed to update image record")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem updating image record")
|
self.__invoker.services.logger.error("Problem updating image record")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
return self._services.image_files.get(image_name)
|
return self.__invoker.services.image_files.get(image_name)
|
||||||
except ImageFileNotFoundException:
|
except ImageFileNotFoundException:
|
||||||
self._services.logger.error("Failed to get image file")
|
self.__invoker.services.logger.error("Failed to get image file")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image file")
|
self.__invoker.services.logger.error("Problem getting image file")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_record(self, image_name: str) -> ImageRecord:
|
def get_record(self, image_name: str) -> ImageRecord:
|
||||||
try:
|
try:
|
||||||
return self._services.image_records.get(image_name)
|
return self.__invoker.services.image_records.get(image_name)
|
||||||
except ImageRecordNotFoundException:
|
except ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Image record not found")
|
self.__invoker.services.logger.error("Image record not found")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image record")
|
self.__invoker.services.logger.error("Problem getting image record")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_dto(self, image_name: str) -> ImageDTO:
|
def get_dto(self, image_name: str) -> ImageDTO:
|
||||||
try:
|
try:
|
||||||
image_record = self._services.image_records.get(image_name)
|
image_record = self.__invoker.services.image_records.get(image_name)
|
||||||
|
|
||||||
image_dto = image_record_to_dto(
|
image_dto = image_record_to_dto(
|
||||||
image_record,
|
image_record,
|
||||||
self._services.urls.get_image_url(image_name),
|
self.__invoker.services.urls.get_image_url(image_name),
|
||||||
self._services.urls.get_image_url(image_name, True),
|
self.__invoker.services.urls.get_image_url(image_name, True),
|
||||||
self._services.board_image_records.get_board_for_image(image_name),
|
self.__invoker.services.board_image_records.get_board_for_image(image_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
except ImageRecordNotFoundException:
|
except ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Image record not found")
|
self.__invoker.services.logger.error("Image record not found")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image DTO")
|
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
||||||
try:
|
try:
|
||||||
image_record = self._services.image_records.get(image_name)
|
image_record = self.__invoker.services.image_records.get(image_name)
|
||||||
metadata = self._services.image_records.get_metadata(image_name)
|
metadata = self.__invoker.services.image_records.get_metadata(image_name)
|
||||||
|
|
||||||
if not image_record.session_id:
|
if not image_record.session_id:
|
||||||
return ImageMetadata(metadata=metadata)
|
return ImageMetadata(metadata=metadata)
|
||||||
|
|
||||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
session_raw = self.__invoker.services.graph_execution_manager.get_raw(image_record.session_id)
|
||||||
graph = None
|
graph = None
|
||||||
|
|
||||||
if session_raw:
|
if session_raw:
|
||||||
try:
|
try:
|
||||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
self.__invoker.services.logger.warn(f"Failed to parse session graph: {e}")
|
||||||
graph = None
|
graph = None
|
||||||
|
|
||||||
return ImageMetadata(graph=graph, metadata=metadata)
|
return ImageMetadata(graph=graph, metadata=metadata)
|
||||||
except ImageRecordNotFoundException:
|
except ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Image record not found")
|
self.__invoker.services.logger.error("Image record not found")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image DTO")
|
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
try:
|
try:
|
||||||
return self._services.image_files.get_path(image_name, thumbnail)
|
return self.__invoker.services.image_files.get_path(image_name, thumbnail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image path")
|
self.__invoker.services.logger.error("Problem getting image path")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def validate_path(self, path: str) -> bool:
|
def validate_path(self, path: str) -> bool:
|
||||||
try:
|
try:
|
||||||
return self._services.image_files.validate_path(path)
|
return self.__invoker.services.image_files.validate_path(path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem validating image path")
|
self.__invoker.services.logger.error("Problem validating image path")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
try:
|
try:
|
||||||
return self._services.urls.get_image_url(image_name, thumbnail)
|
return self.__invoker.services.urls.get_image_url(image_name, thumbnail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image path")
|
self.__invoker.services.logger.error("Problem getting image path")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_many(
|
def get_many(
|
||||||
@ -366,7 +317,7 @@ class ImageService(ImageServiceABC):
|
|||||||
board_id: Optional[str] = None,
|
board_id: Optional[str] = None,
|
||||||
) -> OffsetPaginatedResults[ImageDTO]:
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
try:
|
try:
|
||||||
results = self._services.image_records.get_many(
|
results = self.__invoker.services.image_records.get_many(
|
||||||
offset,
|
offset,
|
||||||
limit,
|
limit,
|
||||||
image_origin,
|
image_origin,
|
||||||
@ -379,9 +330,9 @@ class ImageService(ImageServiceABC):
|
|||||||
map(
|
map(
|
||||||
lambda r: image_record_to_dto(
|
lambda r: image_record_to_dto(
|
||||||
r,
|
r,
|
||||||
self._services.urls.get_image_url(r.image_name),
|
self.__invoker.services.urls.get_image_url(r.image_name),
|
||||||
self._services.urls.get_image_url(r.image_name, True),
|
self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||||
self._services.board_image_records.get_board_for_image(r.image_name),
|
self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||||
),
|
),
|
||||||
results.items,
|
results.items,
|
||||||
)
|
)
|
||||||
@ -394,56 +345,56 @@ class ImageService(ImageServiceABC):
|
|||||||
total=results.total,
|
total=results.total,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting paginated image DTOs")
|
self.__invoker.services.logger.error("Problem getting paginated image DTOs")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def delete(self, image_name: str):
|
def delete(self, image_name: str):
|
||||||
try:
|
try:
|
||||||
self._services.image_files.delete(image_name)
|
self.__invoker.services.image_files.delete(image_name)
|
||||||
self._services.image_records.delete(image_name)
|
self.__invoker.services.image_records.delete(image_name)
|
||||||
self._on_deleted(image_name)
|
self._on_deleted(image_name)
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error("Failed to delete image record")
|
self.__invoker.services.logger.error("Failed to delete image record")
|
||||||
raise
|
raise
|
||||||
except ImageFileDeleteException:
|
except ImageFileDeleteException:
|
||||||
self._services.logger.error("Failed to delete image file")
|
self.__invoker.services.logger.error("Failed to delete image file")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem deleting image record and file")
|
self.__invoker.services.logger.error("Problem deleting image record and file")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def delete_images_on_board(self, board_id: str):
|
def delete_images_on_board(self, board_id: str):
|
||||||
try:
|
try:
|
||||||
image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||||
for image_name in image_names:
|
for image_name in image_names:
|
||||||
self._services.image_files.delete(image_name)
|
self.__invoker.services.image_files.delete(image_name)
|
||||||
self._services.image_records.delete_many(image_names)
|
self.__invoker.services.image_records.delete_many(image_names)
|
||||||
for image_name in image_names:
|
for image_name in image_names:
|
||||||
self._on_deleted(image_name)
|
self._on_deleted(image_name)
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error("Failed to delete image records")
|
self.__invoker.services.logger.error("Failed to delete image records")
|
||||||
raise
|
raise
|
||||||
except ImageFileDeleteException:
|
except ImageFileDeleteException:
|
||||||
self._services.logger.error("Failed to delete image files")
|
self.__invoker.services.logger.error("Failed to delete image files")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem deleting image records and files")
|
self.__invoker.services.logger.error("Problem deleting image records and files")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def delete_intermediates(self) -> int:
|
def delete_intermediates(self) -> int:
|
||||||
try:
|
try:
|
||||||
image_names = self._services.image_records.delete_intermediates()
|
image_names = self.__invoker.services.image_records.delete_intermediates()
|
||||||
count = len(image_names)
|
count = len(image_names)
|
||||||
for image_name in image_names:
|
for image_name in image_names:
|
||||||
self._services.image_files.delete(image_name)
|
self.__invoker.services.image_files.delete(image_name)
|
||||||
self._on_deleted(image_name)
|
self._on_deleted(image_name)
|
||||||
return count
|
return count
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error("Failed to delete image records")
|
self.__invoker.services.logger.error("Failed to delete image records")
|
||||||
raise
|
raise
|
||||||
except ImageFileDeleteException:
|
except ImageFileDeleteException:
|
||||||
self._services.logger.error("Failed to delete image files")
|
self.__invoker.services.logger.error("Failed to delete image files")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem deleting image records and files")
|
self.__invoker.services.logger.error("Problem deleting image records and files")
|
||||||
raise e
|
raise e
|
||||||
|
@ -6,11 +6,15 @@ from typing import TYPE_CHECKING
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
|
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||||
|
from invokeai.app.services.board_record_storage import BoardRecordStorageBase
|
||||||
from invokeai.app.services.boards import BoardServiceABC
|
from invokeai.app.services.boards import BoardServiceABC
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||||
|
from invokeai.app.services.image_file_storage import ImageFileStorageBase
|
||||||
|
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
|
||||||
from invokeai.app.services.images import ImageServiceABC
|
from invokeai.app.services.images import ImageServiceABC
|
||||||
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||||
@ -19,8 +23,10 @@ if TYPE_CHECKING:
|
|||||||
from invokeai.app.services.item_storage import ItemStorageABC
|
from invokeai.app.services.item_storage import ItemStorageABC
|
||||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||||
|
from invokeai.app.services.resource_name import NameServiceBase
|
||||||
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
||||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||||
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
|
|
||||||
|
|
||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
@ -28,12 +34,16 @@ class InvocationServices:
|
|||||||
|
|
||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
board_images: "BoardImagesServiceABC"
|
board_images: "BoardImagesServiceABC"
|
||||||
|
board_image_record_storage: "BoardImageRecordStorageBase"
|
||||||
boards: "BoardServiceABC"
|
boards: "BoardServiceABC"
|
||||||
|
board_records: "BoardRecordStorageBase"
|
||||||
configuration: "InvokeAIAppConfig"
|
configuration: "InvokeAIAppConfig"
|
||||||
events: "EventServiceBase"
|
events: "EventServiceBase"
|
||||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
||||||
graph_library: "ItemStorageABC[LibraryGraph]"
|
graph_library: "ItemStorageABC[LibraryGraph]"
|
||||||
images: "ImageServiceABC"
|
images: "ImageServiceABC"
|
||||||
|
image_records: "ImageRecordStorageBase"
|
||||||
|
image_files: "ImageFileStorageBase"
|
||||||
latents: "LatentsStorageBase"
|
latents: "LatentsStorageBase"
|
||||||
logger: "Logger"
|
logger: "Logger"
|
||||||
model_manager: "ModelManagerServiceBase"
|
model_manager: "ModelManagerServiceBase"
|
||||||
@ -43,16 +53,22 @@ class InvocationServices:
|
|||||||
session_queue: "SessionQueueBase"
|
session_queue: "SessionQueueBase"
|
||||||
session_processor: "SessionProcessorBase"
|
session_processor: "SessionProcessorBase"
|
||||||
invocation_cache: "InvocationCacheBase"
|
invocation_cache: "InvocationCacheBase"
|
||||||
|
names: "NameServiceBase"
|
||||||
|
urls: "UrlServiceBase"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
|
board_image_records: "BoardImageRecordStorageBase",
|
||||||
boards: "BoardServiceABC",
|
boards: "BoardServiceABC",
|
||||||
|
board_records: "BoardRecordStorageBase",
|
||||||
configuration: "InvokeAIAppConfig",
|
configuration: "InvokeAIAppConfig",
|
||||||
events: "EventServiceBase",
|
events: "EventServiceBase",
|
||||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
||||||
graph_library: "ItemStorageABC[LibraryGraph]",
|
graph_library: "ItemStorageABC[LibraryGraph]",
|
||||||
images: "ImageServiceABC",
|
images: "ImageServiceABC",
|
||||||
|
image_files: "ImageFileStorageBase",
|
||||||
|
image_records: "ImageRecordStorageBase",
|
||||||
latents: "LatentsStorageBase",
|
latents: "LatentsStorageBase",
|
||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
@ -62,14 +78,20 @@ class InvocationServices:
|
|||||||
session_queue: "SessionQueueBase",
|
session_queue: "SessionQueueBase",
|
||||||
session_processor: "SessionProcessorBase",
|
session_processor: "SessionProcessorBase",
|
||||||
invocation_cache: "InvocationCacheBase",
|
invocation_cache: "InvocationCacheBase",
|
||||||
|
names: "NameServiceBase",
|
||||||
|
urls: "UrlServiceBase",
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
|
self.board_image_records = board_image_records
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
|
self.board_records = board_records
|
||||||
self.configuration = configuration
|
self.configuration = configuration
|
||||||
self.events = events
|
self.events = events
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
self.graph_library = graph_library
|
self.graph_library = graph_library
|
||||||
self.images = images
|
self.images = images
|
||||||
|
self.image_files = image_files
|
||||||
|
self.image_records = image_records
|
||||||
self.latents = latents
|
self.latents = latents
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
@ -79,3 +101,5 @@ class InvocationServices:
|
|||||||
self.session_queue = session_queue
|
self.session_queue = session_queue
|
||||||
self.session_processor = session_processor
|
self.session_processor = session_processor
|
||||||
self.invocation_cache = invocation_cache
|
self.invocation_cache = invocation_cache
|
||||||
|
self.names = names
|
||||||
|
self.urls = urls
|
||||||
|
@ -38,12 +38,11 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
from ..invocations.baseinvocation import BaseInvocation
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
from .graph import GraphExecutionState
|
from .model_manager_service import ModelManagerServiceBase
|
||||||
from .item_storage import ItemStorageABC
|
|
||||||
from .model_manager_service import ModelManagerService
|
|
||||||
|
|
||||||
# size of GIG in bytes
|
# size of GIG in bytes
|
||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
@ -72,7 +71,6 @@ class NodeLog:
|
|||||||
class InvocationStatsServiceBase(ABC):
|
class InvocationStatsServiceBase(ABC):
|
||||||
"Abstract base class for recording node memory/time performance statistics"
|
"Abstract base class for recording node memory/time performance statistics"
|
||||||
|
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
|
||||||
# {graph_id => NodeLog}
|
# {graph_id => NodeLog}
|
||||||
_stats: Dict[str, NodeLog]
|
_stats: Dict[str, NodeLog]
|
||||||
_cache_stats: Dict[str, CacheStats]
|
_cache_stats: Dict[str, CacheStats]
|
||||||
@ -80,10 +78,9 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
ram_changed: float
|
ram_changed: float
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
def __init__(self):
|
||||||
"""
|
"""
|
||||||
Initialize the InvocationStatsService and reset counters to zero
|
Initialize the InvocationStatsService and reset counters to zero
|
||||||
:param graph_execution_manager: Graph execution manager for this session
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -158,14 +155,18 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
||||||
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
||||||
|
|
||||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
_invoker: Invoker
|
||||||
self.graph_execution_manager = graph_execution_manager
|
|
||||||
|
def __init__(self):
|
||||||
# {graph_id => NodeLog}
|
# {graph_id => NodeLog}
|
||||||
self._stats: Dict[str, NodeLog] = {}
|
self._stats: Dict[str, NodeLog] = {}
|
||||||
self._cache_stats: Dict[str, CacheStats] = {}
|
self._cache_stats: Dict[str, CacheStats] = {}
|
||||||
self.ram_used: float = 0.0
|
self.ram_used: float = 0.0
|
||||||
self.ram_changed: float = 0.0
|
self.ram_changed: float = 0.0
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self._invoker = invoker
|
||||||
|
|
||||||
class StatsContext:
|
class StatsContext:
|
||||||
"""Context manager for collecting statistics."""
|
"""Context manager for collecting statistics."""
|
||||||
|
|
||||||
@ -174,13 +175,13 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
graph_id: str
|
graph_id: str
|
||||||
start_time: float
|
start_time: float
|
||||||
ram_used: int
|
ram_used: int
|
||||||
model_manager: ModelManagerService
|
model_manager: ModelManagerServiceBase
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
invocation: BaseInvocation,
|
invocation: BaseInvocation,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
model_manager: ModelManagerService,
|
model_manager: ModelManagerServiceBase,
|
||||||
collector: "InvocationStatsServiceBase",
|
collector: "InvocationStatsServiceBase",
|
||||||
):
|
):
|
||||||
"""Initialize statistics for this run."""
|
"""Initialize statistics for this run."""
|
||||||
@ -217,12 +218,11 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
self,
|
self,
|
||||||
invocation: BaseInvocation,
|
invocation: BaseInvocation,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
model_manager: ModelManagerService,
|
|
||||||
) -> StatsContext:
|
) -> StatsContext:
|
||||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||||
self._stats[graph_execution_state_id] = NodeLog()
|
self._stats[graph_execution_state_id] = NodeLog()
|
||||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||||
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
return self.StatsContext(invocation, graph_execution_state_id, self._invoker.services.model_manager, self)
|
||||||
|
|
||||||
def reset_all_stats(self):
|
def reset_all_stats(self):
|
||||||
"""Zero all statistics"""
|
"""Zero all statistics"""
|
||||||
@ -261,7 +261,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
errored = set()
|
errored = set()
|
||||||
for graph_id, node_log in self._stats.items():
|
for graph_id, node_log in self._stats.items():
|
||||||
try:
|
try:
|
||||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
errored.add(graph_id)
|
errored.add(graph_id)
|
||||||
continue
|
continue
|
||||||
|
@ -8,7 +8,6 @@ import invokeai.backend.util.logging as logger
|
|||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from ..models.exceptions import CanceledException
|
from ..models.exceptions import CanceledException
|
||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_queue import InvocationQueueItem
|
||||||
from .invocation_stats import InvocationStatsServiceBase
|
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
|
|
||||||
|
|
||||||
@ -37,7 +36,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
def __process(self, stop_event: Event):
|
def __process(self, stop_event: Event):
|
||||||
try:
|
try:
|
||||||
self.__threadLimit.acquire()
|
self.__threadLimit.acquire()
|
||||||
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
|
||||||
queue_item: Optional[InvocationQueueItem] = None
|
queue_item: Optional[InvocationQueueItem] = None
|
||||||
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
@ -97,8 +95,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
# Invoke
|
# Invoke
|
||||||
try:
|
try:
|
||||||
graph_id = graph_execution_state.id
|
graph_id = graph_execution_state.id
|
||||||
model_manager = self.__invoker.services.model_manager
|
with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
|
||||||
with statistics.collect_stats(invocation, graph_id, model_manager):
|
|
||||||
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
||||||
# which handles a few things:
|
# which handles a few things:
|
||||||
# - nodes that require a value, but get it only from a connection
|
# - nodes that require a value, but get it only from a connection
|
||||||
@ -133,13 +130,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
result=outputs.dict(),
|
result=outputs.dict(),
|
||||||
)
|
)
|
||||||
statistics.log_stats()
|
self.__invoker.services.performance_statistics.log_stats()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except CanceledException:
|
except CanceledException:
|
||||||
statistics.reset_stats(graph_execution_state.id)
|
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -164,7 +161,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
statistics.reset_stats(graph_execution_state.id)
|
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
|
@ -29,6 +29,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
calc_session_count,
|
calc_session_count,
|
||||||
prepare_values_to_insert,
|
prepare_values_to_insert,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.shared.db import SqliteDatabase
|
||||||
from invokeai.app.services.shared.models import CursorPaginatedResults
|
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||||
|
|
||||||
|
|
||||||
@ -45,13 +46,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
||||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||||
|
|
||||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.__conn = conn
|
self.__lock = db.lock
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
self.__conn = db.conn
|
||||||
self.__conn.row_factory = sqlite3.Row
|
|
||||||
self.__cursor = self.__conn.cursor()
|
self.__cursor = self.__conn.cursor()
|
||||||
self.__lock = lock
|
|
||||||
self._create_tables()
|
self._create_tables()
|
||||||
|
|
||||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||||
|
0
invokeai/app/services/shared/__init__.py
Normal file
0
invokeai/app/services/shared/__init__.py
Normal file
46
invokeai/app/services/shared/db.py
Normal file
46
invokeai/app/services/shared/db.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteDatabase:
|
||||||
|
conn: sqlite3.Connection
|
||||||
|
lock: threading.Lock
|
||||||
|
_logger: Logger
|
||||||
|
_config: InvokeAIAppConfig
|
||||||
|
|
||||||
|
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
||||||
|
self._logger = logger
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
if self._config.use_memory_db:
|
||||||
|
location = ":memory:"
|
||||||
|
logger.info("Using in-memory database")
|
||||||
|
else:
|
||||||
|
db_path = self._config.db_path
|
||||||
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
location = str(db_path)
|
||||||
|
self._logger.info(f"Using database at {location}")
|
||||||
|
|
||||||
|
self.conn = sqlite3.connect(location, check_same_thread=False)
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
if self._config.log_sql:
|
||||||
|
self.conn.set_trace_callback(self._logger.debug)
|
||||||
|
|
||||||
|
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
|
||||||
|
def clean(self) -> None:
|
||||||
|
try:
|
||||||
|
self.lock.acquire()
|
||||||
|
self.conn.execute("VACUUM;")
|
||||||
|
self.conn.commit()
|
||||||
|
self._logger.info("Cleaned database")
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(f"Error cleaning database: {e}")
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self.lock.release()
|
@ -4,6 +4,8 @@ from typing import Generic, Optional, TypeVar, get_args
|
|||||||
|
|
||||||
from pydantic import BaseModel, parse_raw_as
|
from pydantic import BaseModel, parse_raw_as
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.db import SqliteDatabase
|
||||||
|
|
||||||
from .item_storage import ItemStorageABC, PaginatedResults
|
from .item_storage import ItemStorageABC, PaginatedResults
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
@ -18,13 +20,13 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
_id_field: str
|
_id_field: str
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, conn: sqlite3.Connection, table_name: str, lock: threading.Lock, id_field: str = "id"):
|
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self._lock = db.lock
|
||||||
|
self._conn = db.conn
|
||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
self._id_field = id_field # TODO: validate that T has this field
|
||||||
self._lock = lock
|
|
||||||
self._conn = conn
|
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
|
@ -1,3 +0,0 @@
|
|||||||
import threading
|
|
||||||
|
|
||||||
lock = threading.Lock()
|
|
Loading…
Reference in New Issue
Block a user