diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 8889c70674..60f8c1b09d 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -12,7 +12,7 @@ from invokeai.app.services.board_images import ( from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.image_record_storage import SqliteImageRecordStorage -from invokeai.app.services.images import ImageService +from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService @@ -106,13 +106,16 @@ class ApiDependencies: ) images = ImageService( - image_record_storage=image_record_storage, - image_file_storage=image_file_storage, - metadata=metadata, - url=urls, - logger=logger, - names=names, - graph_execution_manager=graph_execution_manager, + services=ImageServiceDependencies( + board_image_record_storage=board_image_record_storage, + image_record_storage=image_record_storage, + image_file_storage=image_file_storage, + metadata=metadata, + url=urls, + logger=logger, + names=names, + graph_execution_manager=graph_execution_manager, + ) ) services = InvocationServices( diff --git a/invokeai/app/services/board_image_record_storage.py b/invokeai/app/services/board_image_record_storage.py index 851e8502e1..2f1603be82 100644 --- a/invokeai/app/services/board_image_record_storage.py +++ b/invokeai/app/services/board_image_record_storage.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod import sqlite3 import threading -from typing import cast +from typing import Union, cast from invokeai.app.services.board_record_storage import BoardRecord from invokeai.app.services.image_record_storage import OffsetPaginatedResults @@ -41,11 +41,11 @@ class BoardImageRecordStorageBase(ABC): pass @abstractmethod - def get_boards_for_image( + def get_board_for_image( self, - board_id: str, - ) -> OffsetPaginatedResults[BoardRecord]: - """Gets boards for an image.""" + image_name: str, + ) -> Union[str, None]: + """Gets an image's board id, if it has one.""" pass @abstractmethod @@ -134,7 +134,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): board_id: str, image_name: str, ) -> None: - """Adds an image to a board.""" try: self._lock.acquire() self._cursor.execute( @@ -156,7 +155,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): board_id: str, image_name: str, ) -> None: - """Removes an image from a board.""" try: self._lock.acquire() self._cursor.execute( @@ -179,7 +177,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): offset: int = 0, limit: int = 10, ) -> OffsetPaginatedResults[ImageRecord]: - """Gets images for a board.""" try: self._lock.acquire() self._cursor.execute( @@ -211,46 +208,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): items=images, offset=offset, limit=limit, total=count ) - def get_boards_for_image( + def get_board_for_image( self, - board_id: str, - offset: int = 0, - limit: int = 10, - ) -> OffsetPaginatedResults[BoardRecord]: - """Gets boards for an image.""" + image_name: str, + ) -> Union[str, None]: try: self._lock.acquire() self._cursor.execute( """--sql - SELECT boards.* + SELECT board_id FROM board_images - INNER JOIN boards ON board_images.board_id = boards.board_id - WHERE board_images.image_name = ? - ORDER BY board_images.updated_at DESC; + WHERE image_name = ?; """, - (board_id,), + (image_name,), ) - result = cast(list[sqlite3.Row], self._cursor.fetchall()) - boards = list(map(lambda r: BoardRecord(**r), result)) - - self._cursor.execute( - """--sql - SELECT COUNT(*) FROM boards WHERE 1=1; - """ - ) - count = cast(int, self._cursor.fetchone()[0]) - + result = self._cursor.fetchone() + if result is None: + return None + return cast(str, result[0]) except sqlite3.Error as e: self._conn.rollback() raise e finally: self._lock.release() - return OffsetPaginatedResults( - items=boards, offset=offset, limit=limit, total=count - ) def get_image_count_for_board(self, board_id: str) -> int: - """Gets the number of images for a board.""" try: self._lock.acquire() self._cursor.execute( diff --git a/invokeai/app/services/board_images.py b/invokeai/app/services/board_images.py index df2af4bbcf..cf16993a7a 100644 --- a/invokeai/app/services/board_images.py +++ b/invokeai/app/services/board_images.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from logging import Logger +from typing import Union from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.board_record_storage import ( BoardRecord, @@ -45,11 +46,11 @@ class BoardImagesServiceABC(ABC): pass @abstractmethod - def get_boards_for_image( + def get_board_for_image( self, image_name: str, - ) -> OffsetPaginatedResults[BoardDTO]: - """Gets boards for an image.""" + ) -> Union[str, None]: + """Gets an image's board id, if it has one.""" pass @@ -110,6 +111,7 @@ class BoardImagesService(BoardImagesServiceABC): r, self._services.urls.get_image_url(r.image_name), self._services.urls.get_image_url(r.image_name, True), + board_id, ), image_records.items, ) @@ -121,38 +123,12 @@ class BoardImagesService(BoardImagesServiceABC): total=image_records.total, ) - def get_boards_for_image( + def get_board_for_image( self, image_name: str, - ) -> OffsetPaginatedResults[BoardDTO]: - board_records = self._services.board_image_records.get_boards_for_image( - image_name - ) - board_dtos = [] - - for r in board_records.items: - cover_image_url = ( - self._services.urls.get_image_url(r.cover_image_name, True) - if r.cover_image_name - else None - ) - image_count = self._services.board_image_records.get_image_count_for_board( - r.board_id - ) - board_dtos.append( - board_record_to_dto( - r, - cover_image_url, - image_count, - ) - ) - - return OffsetPaginatedResults[BoardDTO]( - items=board_dtos, - offset=board_records.offset, - limit=board_records.limit, - total=board_records.total, - ) + ) -> Union[str, None]: + board_id = self._services.board_image_records.get_board_for_image(image_name) + return board_id def board_record_to_dto( diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index aa27e38d17..5959116161 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -10,6 +10,7 @@ from invokeai.app.models.image import ( InvalidOriginException, ) from invokeai.app.models.metadata import ImageMetadata +from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.image_record_storage import ( ImageRecordDeleteException, ImageRecordNotFoundException, @@ -114,8 +115,9 @@ class ImageServiceABC(ABC): class ImageServiceDependencies: """Service dependencies for the ImageService.""" - records: ImageRecordStorageBase - files: ImageFileStorageBase + image_records: ImageRecordStorageBase + image_files: ImageFileStorageBase + board_image_records: BoardImageRecordStorageBase metadata: MetadataServiceBase urls: UrlServiceBase logger: Logger @@ -126,14 +128,16 @@ class ImageServiceDependencies: self, image_record_storage: ImageRecordStorageBase, image_file_storage: ImageFileStorageBase, + board_image_record_storage: BoardImageRecordStorageBase, metadata: MetadataServiceBase, url: UrlServiceBase, logger: Logger, names: NameServiceBase, graph_execution_manager: ItemStorageABC["GraphExecutionState"], ): - self.records = image_record_storage - self.files = image_file_storage + self.image_records = image_record_storage + self.image_files = image_file_storage + self.board_image_records = board_image_record_storage self.metadata = metadata self.urls = url self.logger = logger @@ -144,25 +148,8 @@ class ImageServiceDependencies: class ImageService(ImageServiceABC): _services: ImageServiceDependencies - def __init__( - self, - image_record_storage: ImageRecordStorageBase, - image_file_storage: ImageFileStorageBase, - metadata: MetadataServiceBase, - url: UrlServiceBase, - logger: Logger, - names: NameServiceBase, - graph_execution_manager: ItemStorageABC["GraphExecutionState"], - ): - self._services = ImageServiceDependencies( - image_record_storage=image_record_storage, - image_file_storage=image_file_storage, - metadata=metadata, - url=url, - logger=logger, - names=names, - graph_execution_manager=graph_execution_manager, - ) + def __init__(self, services: ImageServiceDependencies): + self._services = services def create( self, @@ -187,7 +174,7 @@ class ImageService(ImageServiceABC): try: # TODO: Consider using a transaction here to ensure consistency between storage and database - created_at = self._services.records.save( + self._services.image_records.save( # Non-nullable fields image_name=image_name, image_origin=image_origin, @@ -202,35 +189,15 @@ class ImageService(ImageServiceABC): metadata=metadata, ) - self._services.files.save( + self._services.image_files.save( image_name=image_name, image=image, metadata=metadata, ) - image_url = self._services.urls.get_image_url(image_name) - thumbnail_url = self._services.urls.get_image_url(image_name, True) + image_dto = self.get_dto(image_name) - return ImageDTO( - # Non-nullable fields - image_name=image_name, - image_origin=image_origin, - image_category=image_category, - width=width, - height=height, - # Nullable fields - node_id=node_id, - session_id=session_id, - metadata=metadata, - # Meta fields - created_at=created_at, - updated_at=created_at, # this is always the same as the created_at at this time - deleted_at=None, - is_intermediate=is_intermediate, - # Extra non-nullable fields for DTO - image_url=image_url, - thumbnail_url=thumbnail_url, - ) + return image_dto except ImageRecordSaveException: self._services.logger.error("Failed to save image record") raise @@ -247,7 +214,7 @@ class ImageService(ImageServiceABC): changes: ImageRecordChanges, ) -> ImageDTO: try: - self._services.records.update(image_name, changes) + self._services.image_records.update(image_name, changes) return self.get_dto(image_name) except ImageRecordSaveException: self._services.logger.error("Failed to update image record") @@ -258,7 +225,7 @@ class ImageService(ImageServiceABC): def get_pil_image(self, image_name: str) -> PILImageType: try: - return self._services.files.get(image_name) + return self._services.image_files.get(image_name) except ImageFileNotFoundException: self._services.logger.error("Failed to get image file") raise @@ -268,7 +235,7 @@ class ImageService(ImageServiceABC): def get_record(self, image_name: str) -> ImageRecord: try: - return self._services.records.get(image_name) + return self._services.image_records.get(image_name) except ImageRecordNotFoundException: self._services.logger.error("Image record not found") raise @@ -278,12 +245,13 @@ class ImageService(ImageServiceABC): def get_dto(self, image_name: str) -> ImageDTO: try: - image_record = self._services.records.get(image_name) + image_record = self._services.image_records.get(image_name) image_dto = image_record_to_dto( image_record, self._services.urls.get_image_url(image_name), self._services.urls.get_image_url(image_name, True), + self._services.board_image_records.get_board_for_image(image_name), ) return image_dto @@ -296,14 +264,14 @@ class ImageService(ImageServiceABC): def get_path(self, image_name: str, thumbnail: bool = False) -> str: try: - return self._services.files.get_path(image_name, thumbnail) + return self._services.image_files.get_path(image_name, thumbnail) except Exception as e: self._services.logger.error("Problem getting image path") raise e def validate_path(self, path: str) -> bool: try: - return self._services.files.validate_path(path) + return self._services.image_files.validate_path(path) except Exception as e: self._services.logger.error("Problem validating image path") raise e @@ -324,7 +292,7 @@ class ImageService(ImageServiceABC): is_intermediate: Optional[bool] = None, ) -> OffsetPaginatedResults[ImageDTO]: try: - results = self._services.records.get_many( + results = self._services.image_records.get_many( offset, limit, image_origin, @@ -338,6 +306,9 @@ class ImageService(ImageServiceABC): r, self._services.urls.get_image_url(r.image_name), self._services.urls.get_image_url(r.image_name, True), + self._services.board_image_records.get_board_for_image( + r.image_name + ), ), results.items, ) @@ -355,8 +326,8 @@ class ImageService(ImageServiceABC): def delete(self, image_name: str): try: - self._services.files.delete(image_name) - self._services.records.delete(image_name) + self._services.image_files.delete(image_name) + self._services.image_records.delete(image_name) except ImageRecordDeleteException: self._services.logger.error(f"Failed to delete image record") raise diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index d971d65916..cc02016cf9 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -86,19 +86,24 @@ class ImageUrlsDTO(BaseModel): class ImageDTO(ImageRecord, ImageUrlsDTO): - """Deserialized image record, enriched for the frontend with URLs.""" + """Deserialized image record, enriched for the frontend.""" + board_id: Union[str, None] = Field( + description="The id of the board the image belongs to, if one exists." + ) + """The id of the board the image belongs to, if one exists.""" pass def image_record_to_dto( - image_record: ImageRecord, image_url: str, thumbnail_url: str + image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None] ) -> ImageDTO: """Converts an image record to an image DTO.""" return ImageDTO( **image_record.dict(), image_url=image_url, thumbnail_url=thumbnail_url, + board_id=board_id, )