From 72e9ced88997bdfd84c06ea4127250284178b9bc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Jun 2023 00:07:20 +1000 Subject: [PATCH] feat(nodes): add boards and board_images services --- invokeai/app/api/dependencies.py | 39 +- invokeai/app/api/routers/boards.py | 145 ++++---- invokeai/app/api/routers/images.py | 4 - invokeai/app/api_app.py | 2 +- .../services/board_image_record_storage.py | 253 +++++++++++++ invokeai/app/services/board_images.py | 166 +++++++++ invokeai/app/services/board_record_storage.py | 331 +++++++++++++++++ invokeai/app/services/boards.py | 340 +++++++----------- invokeai/app/services/image_record_storage.py | 45 +-- invokeai/app/services/images.py | 6 +- invokeai/app/services/invocation_services.py | 19 +- invokeai/app/services/models/image_record.py | 11 - 12 files changed, 993 insertions(+), 368 deletions(-) create mode 100644 invokeai/app/services/board_image_record_storage.py create mode 100644 invokeai/app/services/board_images.py create mode 100644 invokeai/app/services/board_record_storage.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 8aa61d08aa..8889c70674 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -2,7 +2,15 @@ from logging import Logger import os -from invokeai.app.services import boards +from invokeai.app.services.board_image_record_storage import ( + SqliteBoardImageRecordStorage, +) +from invokeai.app.services.board_images import ( + BoardImagesService, + BoardImagesServiceDependencies, +) +from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage +from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.images import ImageService from invokeai.app.services.metadata import CoreMetadataService @@ -59,7 +67,7 @@ class ApiDependencies: # TODO: build a file/path manager? db_location = config.db_path - db_location.parent.mkdir(parents=True,exist_ok=True) + db_location.parent.mkdir(parents=True, exist_ok=True) graph_execution_manager = SqliteItemStorage[GraphExecutionState]( filename=db_location, table_name="graph_executions" @@ -73,7 +81,29 @@ class ApiDependencies: latents = ForwardCacheLatentsStorage( DiskLatentsStorage(f"{output_folder}/latents") ) - boards = SqliteBoardStorage(db_location) + + board_record_storage = SqliteBoardRecordStorage(db_location) + board_image_record_storage = SqliteBoardImageRecordStorage(db_location) + + boards = BoardService( + services=BoardServiceDependencies( + board_image_record_storage=board_image_record_storage, + board_record_storage=board_record_storage, + image_record_storage=image_record_storage, + url=urls, + logger=logger, + ) + ) + + board_images = BoardImagesService( + services=BoardImagesServiceDependencies( + board_image_record_storage=board_image_record_storage, + board_record_storage=board_record_storage, + image_record_storage=image_record_storage, + url=urls, + logger=logger, + ) + ) images = ImageService( image_record_storage=image_record_storage, @@ -90,6 +120,8 @@ class ApiDependencies: events=events, latents=latents, images=images, + boards=boards, + board_images=board_images, queue=MemoryInvocationQueue(), graph_library=SqliteItemStorage[LibraryGraph]( filename=db_location, table_name="graphs" @@ -99,7 +131,6 @@ class ApiDependencies: restoration=RestorationServices(config, logger), configuration=config, logger=logger, - boards=boards ) create_system_graphs(services.graph_library) diff --git a/invokeai/app/api/routers/boards.py b/invokeai/app/api/routers/boards.py index c8e877ca59..f3a76e08d3 100644 --- a/invokeai/app/api/routers/boards.py +++ b/invokeai/app/api/routers/boards.py @@ -1,91 +1,86 @@ -from fastapi import Body, HTTPException, Path, Query -from fastapi.routing import APIRouter -from invokeai.app.services.boards import BoardRecord, BoardRecordChanges -from invokeai.app.services.image_record_storage import OffsetPaginatedResults +# from fastapi import Body, HTTPException, Path, Query +# from fastapi.routing import APIRouter +# from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges +# from invokeai.app.services.image_record_storage import OffsetPaginatedResults -from ..dependencies import ApiDependencies +# from ..dependencies import ApiDependencies -boards_router = APIRouter(prefix="/v1/boards", tags=["boards"]) +# boards_router = APIRouter(prefix="/v1/boards", tags=["boards"]) -@boards_router.post( - "/", - operation_id="create_board", - responses={ - 201: {"description": "The board was created successfully"}, - }, - status_code=201, -) -async def create_board( - board_name: str = Body(description="The name of the board to create"), -): - """Creates a board""" - try: - result = ApiDependencies.invoker.services.boards.save(board_name=board_name) - return result - except Exception as e: - raise HTTPException(status_code=500, detail="Failed to create board") +# @boards_router.post( +# "/", +# operation_id="create_board", +# responses={ +# 201: {"description": "The board was created successfully"}, +# }, +# status_code=201, +# ) +# async def create_board( +# board_name: str = Body(description="The name of the board to create"), +# ): +# """Creates a board""" +# try: +# result = ApiDependencies.invoker.services.boards.save(board_name=board_name) +# return result +# except Exception as e: +# raise HTTPException(status_code=500, detail="Failed to create board") -@boards_router.delete("/{board_id}", operation_id="delete_board") -async def delete_board( - board_id: str = Path(description="The id of board to delete"), -) -> None: - """Deletes a board""" +# @boards_router.delete("/{board_id}", operation_id="delete_board") +# async def delete_board( +# board_id: str = Path(description="The id of board to delete"), +# ) -> None: +# """Deletes a board""" - try: - ApiDependencies.invoker.services.boards.delete(board_id=board_id) - except Exception as e: - # TODO: Does this need any exception handling at all? - pass +# try: +# ApiDependencies.invoker.services.boards.delete(board_id=board_id) +# except Exception as e: +# # TODO: Does this need any exception handling at all? +# pass -@boards_router.get( - "/", - operation_id="list_boards", - response_model=OffsetPaginatedResults[BoardRecord], -) -async def list_boards( - offset: int = Query(default=0, description="The page offset"), - limit: int = Query(default=10, description="The number of boards per page"), -) -> OffsetPaginatedResults[BoardRecord]: - """Gets a list of boards""" +# @boards_router.get( +# "/", +# operation_id="list_boards", +# response_model=OffsetPaginatedResults[BoardRecord], +# ) +# async def list_boards( +# offset: int = Query(default=0, description="The page offset"), +# limit: int = Query(default=10, description="The number of boards per page"), +# ) -> OffsetPaginatedResults[BoardRecord]: +# """Gets a list of boards""" - results = ApiDependencies.invoker.services.boards.get_many( - offset, - limit, - ) +# results = ApiDependencies.invoker.services.boards.get_many( +# offset, +# limit, +# ) - boards = list( - map( - lambda r: board_record_to_dto( - r, - generate_cover_photo_url(r.id) - ), - results.boards, - ) - ) +# boards = list( +# map( +# lambda r: board_record_to_dto( +# r, +# generate_cover_photo_url(r.id) +# ), +# results.boards, +# ) +# ) - return boards +# return boards -class BoardDTO(BaseModel): - """A DTO for an image""" - id: str - name: str - cover_image_url: str -def board_record_to_dto( - board_record: BoardRecord, cover_image_url: str -) -> BoardDTO: - """Converts an image record to an image DTO.""" - return BoardDTO( - **board_record.dict(), - cover_image_url=cover_image_url, - ) +# def board_record_to_dto( +# board_record: BoardRecord, cover_image_url: str +# ) -> BoardDTO: +# """Converts an image record to an image DTO.""" +# return BoardDTO( +# **board_record.dict(), +# cover_image_url=cover_image_url, +# ) -def generate_cover_photo_url(board_id: str) -> str | None: - cover_photo = ApiDependencies.invoker.services.images._services.records.get_board_cover_photo(board_id) - if cover_photo is not None: - url = ApiDependencies.invoker.services.images._services.urls.get_image_url(cover_photo.image_origin, cover_photo.image_name) - return url +# def generate_cover_photo_url(board_id: str) -> str | None: +# cover_photo = ApiDependencies.invoker.services.images._services.records.get_board_cover_photo(board_id) +# if cover_photo is not None: +# url = ApiDependencies.invoker.services.images._services.urls.get_image_url(cover_photo.image_origin, cover_photo.image_name) +# return url diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 24bb716635..11453d97f1 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -221,9 +221,6 @@ async def list_images_with_metadata( is_intermediate: Optional[bool] = Query( default=None, description="Whether to list intermediate images" ), - board_id: Optional[str] = Query( - default=None, description="The board of images to include" - ), offset: int = Query(default=0, description="The page offset"), limit: int = Query(default=10, description="The number of images per page"), ) -> OffsetPaginatedResults[ImageDTO]: @@ -235,7 +232,6 @@ async def list_images_with_metadata( image_origin, categories, is_intermediate, - board_id ) return image_dtos diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index d00d92f763..50228edf7e 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -78,7 +78,7 @@ app.include_router(models.models_router, prefix="/api") app.include_router(images.images_router, prefix="/api") -app.include_router(boards.boards_router, prefix="/api") +# app.include_router(boards.boards_router, prefix="/api") # Build a custom OpenAPI to include all outputs # TODO: can outputs be included on metadata of invocation schemas somehow? diff --git a/invokeai/app/services/board_image_record_storage.py b/invokeai/app/services/board_image_record_storage.py new file mode 100644 index 0000000000..b805087da8 --- /dev/null +++ b/invokeai/app/services/board_image_record_storage.py @@ -0,0 +1,253 @@ +from abc import ABC, abstractmethod +import sqlite3 +import threading +from typing import cast +from invokeai.app.services.board_record_storage import BoardRecord + +from invokeai.app.services.image_record_storage import OffsetPaginatedResults +from invokeai.app.services.models.image_record import ( + ImageRecord, + deserialize_image_record, +) + + +class BoardImageRecordStorageBase(ABC): + """Abstract base class for board-image relationship record storage.""" + + @abstractmethod + def add_image_to_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Adds an image to a board.""" + pass + + @abstractmethod + def remove_image_from_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Removes an image from a board.""" + pass + + @abstractmethod + def get_images_for_board( + self, + board_id: str, + ) -> OffsetPaginatedResults[ImageRecord]: + """Gets images for a board.""" + pass + + @abstractmethod + def get_boards_for_image( + self, + board_id: str, + ) -> OffsetPaginatedResults[BoardRecord]: + """Gets images for a board.""" + pass + + @abstractmethod + def get_image_count_for_board( + self, + board_id: str, + ) -> int: + """Gets the number of images for a board.""" + pass + + +class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): + _filename: str + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _lock: threading.Lock + + def __init__(self, filename: str) -> None: + super().__init__() + self._filename = filename + self._conn = sqlite3.connect(filename, check_same_thread=False) + # Enable row factory to get rows as dictionaries (must be done before making the cursor!) + self._conn.row_factory = sqlite3.Row + self._cursor = self._conn.cursor() + self._lock = threading.Lock() + + try: + self._lock.acquire() + # Enable foreign keys + self._conn.execute("PRAGMA foreign_keys = ON;") + self._create_tables() + self._conn.commit() + finally: + self._lock.release() + + def _create_tables(self) -> None: + """Creates the `board_images` junction table.""" + + # Create the `board_images` junction table. + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS board_images ( + board_id TEXT NOT NULL, + image_name TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + PRIMARY KEY (board_id, image_name), + FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE, + FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE + ); + """ + ) + + # Add trigger for `updated_at`. + self._cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at + AFTER UPDATE + ON board_images FOR EACH ROW + BEGIN + UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE board_id = old.board_id AND image_name = old.image_name; + END; + """ + ) + + def add_image_to_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Adds an image to a board.""" + try: + self._lock.acquire() + self._cursor.execute( + """--sql + INSERT INTO board_images (board_id, image_name) + VALUES (?, ?); + """, + (board_id, image_name), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + + def remove_image_from_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Removes an image from a board.""" + try: + self._lock.acquire() + self._cursor.execute( + """--sql + DELETE FROM board_images + WHERE board_id = ? AND image_name = ?; + """, + (board_id, image_name), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + + def get_images_for_board( + self, + board_id: str, + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[ImageRecord]: + """Gets images for a board.""" + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT images.* + FROM board_images + INNER JOIN images ON board_images.image_name = images.image_name + WHERE board_images.board_id = ? + ORDER BY board_images.updated_at DESC; + """, + (board_id,), + ) + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + images = list(map(lambda r: deserialize_image_record(dict(r)), result)) + + self._cursor.execute( + """--sql + SELECT COUNT(*) FROM images WHERE 1=1; + """ + ) + count = self._cursor.fetchone()[0] + + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + return OffsetPaginatedResults( + items=images, offset=offset, limit=limit, total=count + ) + + def get_boards_for_image( + self, + board_id: str, + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[BoardRecord]: + """Gets boards for an image.""" + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT boards.* + FROM board_images + INNER JOIN boards ON board_images.board_id = boards.board_id + WHERE board_images.image_name = ? + ORDER BY board_images.updated_at DESC; + """, + (board_id,), + ) + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + boards = list(map(lambda r: BoardRecord(**r), result)) + + self._cursor.execute( + """--sql + SELECT COUNT(*) FROM boards WHERE 1=1; + """ + ) + count = self._cursor.fetchone()[0] + + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + return OffsetPaginatedResults( + items=boards, offset=offset, limit=limit, total=count + ) + + def get_image_count_for_board(self, board_id: str) -> int: + """Gets the number of images for a board.""" + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT COUNT(*) FROM board_images WHERE board_id = ?; + """, + (board_id,), + ) + count = self._cursor.fetchone()[0] + + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + return count diff --git a/invokeai/app/services/board_images.py b/invokeai/app/services/board_images.py new file mode 100644 index 0000000000..dd2e104180 --- /dev/null +++ b/invokeai/app/services/board_images.py @@ -0,0 +1,166 @@ +from abc import ABC, abstractmethod +from logging import Logger +from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase +from invokeai.app.services.board_record_storage import ( + BoardDTO, + BoardRecord, + BoardRecordStorageBase, +) + +from invokeai.app.services.image_record_storage import ( + ImageRecordStorageBase, + OffsetPaginatedResults, +) +from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto +from invokeai.app.services.urls import UrlServiceBase + + +class BoardImagesServiceABC(ABC): + """High-level service for board-image relationship management.""" + + @abstractmethod + def add_image_to_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Adds an image to a board.""" + pass + + @abstractmethod + def remove_image_from_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Removes an image from a board.""" + pass + + @abstractmethod + def get_images_for_board( + self, + board_id: str, + ) -> OffsetPaginatedResults[ImageDTO]: + """Gets images for a board.""" + pass + + @abstractmethod + def get_boards_for_image( + self, + image_name: str, + ) -> OffsetPaginatedResults[BoardDTO]: + """Gets boards for an image.""" + pass + + +class BoardImagesServiceDependencies: + """Service dependencies for the BoardImagesService.""" + + board_image_records: BoardImageRecordStorageBase + board_records: BoardRecordStorageBase + image_records: ImageRecordStorageBase + urls: UrlServiceBase + logger: Logger + + def __init__( + self, + board_image_record_storage: BoardImageRecordStorageBase, + image_record_storage: ImageRecordStorageBase, + board_record_storage: BoardRecordStorageBase, + url: UrlServiceBase, + logger: Logger, + ): + self.board_image_records = board_image_record_storage + self.image_records = image_record_storage + self.board_records = board_record_storage + self.urls = url + self.logger = logger + + +class BoardImagesService(BoardImagesServiceABC): + _services: BoardImagesServiceDependencies + + def __init__(self, services: BoardImagesServiceDependencies): + self._services = services + + def add_image_to_board( + self, + board_id: str, + image_name: str, + ) -> None: + self._services.board_image_records.add_image_to_board(board_id, image_name) + + def remove_image_from_board( + self, + board_id: str, + image_name: str, + ) -> None: + self._services.board_image_records.remove_image_from_board(board_id, image_name) + + def get_images_for_board( + self, + board_id: str, + ) -> OffsetPaginatedResults[ImageDTO]: + image_records = self._services.board_image_records.get_images_for_board( + board_id + ) + image_dtos = list( + map( + lambda r: image_record_to_dto( + r, + self._services.urls.get_image_url(r.image_name), + self._services.urls.get_image_url(r.image_name, True), + ), + image_records.items, + ) + ) + return OffsetPaginatedResults[ImageDTO]( + items=image_dtos, + offset=image_records.offset, + limit=image_records.limit, + total=image_records.total, + ) + + def get_boards_for_image( + self, + image_name: str, + ) -> OffsetPaginatedResults[BoardDTO]: + board_records = self._services.board_image_records.get_boards_for_image( + image_name + ) + board_dtos = [] + + for r in board_records.items: + cover_image_url = ( + self._services.urls.get_image_url(r.cover_image_name, True) + if r.cover_image_name + else None + ) + image_count = self._services.board_image_records.get_image_count_for_board( + r.board_id + ) + board_dtos.append( + board_record_to_dto( + r, + cover_image_url, + image_count, + ) + ) + + return OffsetPaginatedResults[BoardDTO]( + items=board_dtos, + offset=board_records.offset, + limit=board_records.limit, + total=board_records.total, + ) + + +def board_record_to_dto( + board_record: BoardRecord, cover_image_url: str | None, image_count: int +) -> BoardDTO: + """Converts a board record to a board DTO.""" + return BoardDTO( + **board_record.dict(), + cover_image_url=cover_image_url, + image_count=image_count, + ) diff --git a/invokeai/app/services/board_record_storage.py b/invokeai/app/services/board_record_storage.py new file mode 100644 index 0000000000..a954fe7ac4 --- /dev/null +++ b/invokeai/app/services/board_record_storage.py @@ -0,0 +1,331 @@ +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Optional, cast +import sqlite3 +import threading +from typing import Optional, Union +import uuid +from invokeai.app.services.image_record_storage import OffsetPaginatedResults + +from pydantic import BaseModel, Field, Extra + + +class BoardRecord(BaseModel): + """Deserialized board record.""" + + board_id: str = Field(description="The unique ID of the board.") + """The unique ID of the board.""" + board_name: str = Field(description="The name of the board.") + """The name of the board.""" + created_at: Union[datetime, str] = Field( + description="The created timestamp of the board." + ) + """The created timestamp of the image.""" + updated_at: Union[datetime, str] = Field( + description="The updated timestamp of the board." + ) + """The updated timestamp of the image.""" + cover_image_name: Optional[str] = Field( + description="The name of the cover image of the board." + ) + """The name of the cover image of the board.""" + + +class BoardDTO(BoardRecord): + """Deserialized board record with cover image URL and image count.""" + + cover_image_url: Optional[str] = Field( + description="The URL of the thumbnail of the board's cover image." + ) + """The URL of the thumbnail of the most recent image in the board.""" + image_count: int = Field(description="The number of images in the board.") + """The number of images in the board.""" + + +class BoardChanges(BaseModel, extra=Extra.forbid): + board_name: Optional[str] = Field(description="The board's new name.") + cover_image_name: Optional[str] = Field( + description="The name of the board's new cover image." + ) + + +class BoardRecordNotFoundException(Exception): + """Raised when an board record is not found.""" + + def __init__(self, message="Board record not found"): + super().__init__(message) + + +class BoardRecordSaveException(Exception): + """Raised when an board record cannot be saved.""" + + def __init__(self, message="Board record not saved"): + super().__init__(message) + + +class BoardRecordDeleteException(Exception): + """Raised when an board record cannot be deleted.""" + + def __init__(self, message="Board record not deleted"): + super().__init__(message) + + +class BoardRecordStorageBase(ABC): + """Low-level service responsible for interfacing with the board record store.""" + + @abstractmethod + def delete(self, board_id: str) -> None: + """Deletes a board record.""" + pass + + @abstractmethod + def save( + self, + board_name: str, + ) -> BoardRecord: + """Saves a board record.""" + pass + + @abstractmethod + def get( + self, + board_id: str, + ) -> BoardRecord: + """Gets a board record.""" + pass + + @abstractmethod + def update( + self, + board_id: str, + changes: BoardChanges, + ) -> BoardRecord: + """Updates a board record.""" + pass + + @abstractmethod + def get_many( + self, + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[BoardRecord]: + """Gets many board records.""" + pass + + +class SqliteBoardRecordStorage(BoardRecordStorageBase): + _filename: str + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _lock: threading.Lock + + def __init__(self, filename: str) -> None: + super().__init__() + self._filename = filename + self._conn = sqlite3.connect(filename, check_same_thread=False) + # Enable row factory to get rows as dictionaries (must be done before making the cursor!) + self._conn.row_factory = sqlite3.Row + self._cursor = self._conn.cursor() + self._lock = threading.Lock() + + try: + self._lock.acquire() + # Enable foreign keys + self._conn.execute("PRAGMA foreign_keys = ON;") + self._create_tables() + self._conn.commit() + finally: + self._lock.release() + + def _create_tables(self) -> None: + """Creates the `boards` table and `board_images` junction table.""" + + # Create the `boards` table. + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS boards ( + board_id TEXT NOT NULL PRIMARY KEY, + board_name TEXT NOT NULL, + cover_image_name TEXT, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Soft delete, currently unused + deleted_at DATETIME + ); + """ + ) + + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards(created_at); + """ + ) + + # Add trigger for `updated_at`. + self._cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at + AFTER UPDATE + ON boards FOR EACH ROW + BEGIN + UPDATE boards SET updated_at = current_timestamp + WHERE board_id = old.board_id; + END; + """ + ) + + def delete(self, board_id: str) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + DELETE FROM boards + WHERE board_id = ?; + """, + (board_id), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BoardRecordDeleteException from e + finally: + self._lock.release() + + def save( + self, + board_name: str, + ) -> BoardRecord: + try: + board_id = str(uuid.uuid4()) + self._lock.acquire() + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO boards (board_id, board_name) + VALUES (?, ?); + """, + (board_id, board_name), + ) + self._conn.commit() + + self._cursor.execute( + """--sql + SELECT * + FROM boards + WHERE board_id = ?; + """, + (board_id,), + ) + + result = self._cursor.fetchone() + return BoardRecord(**result) + except sqlite3.Error as e: + self._conn.rollback() + raise BoardRecordSaveException from e + finally: + self._lock.release() + + def get( + self, + board_id: str, + ) -> BoardRecord: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT * + FROM boards + WHERE board_id = ?; + """, + (board_id,), + ) + + result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) + except sqlite3.Error as e: + self._conn.rollback() + raise BoardRecordNotFoundException from e + finally: + self._lock.release() + if result is None: + raise BoardRecordNotFoundException + return BoardRecord(**dict(result)) + + def update( + self, + board_id: str, + changes: BoardChanges, + ) -> None: + try: + self._lock.acquire() + + # Change the name of a board + if changes.board_name is not None: + self._cursor.execute( + f"""--sql + UPDATE boards + SET board_name = ? + WHERE board_id = ?; + """, + (changes.board_name, board_id), + ) + + # Change the cover image of a board + if changes.cover_image_name is not None: + self._cursor.execute( + f"""--sql + UPDATE boards + SET cover_image_name = ? + WHERE board_id = ?; + """, + (changes.cover_image_name, board_id), + ) + + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BoardRecordSaveException from e + finally: + self._lock.release() + + def get_many( + self, + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[BoardRecord]: + try: + self._lock.acquire() + + # Get all the boards + self._cursor.execute( + """--sql + SELECT * + FROM boards + ORDER BY updated_at DESC + LIMIT ? OFFSET ?; + """, + (limit, offset), + ) + + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + boards = [BoardRecord(**dict(row)) for row in result] + + # Get the total number of boards + self._cursor.execute( + """--sql + SELECT COUNT(*) + FROM boards + WHERE 1=1; + """ + ) + + count = cast(int, self._cursor.fetchone()[0]) + + return OffsetPaginatedResults[BoardRecord]( + items=boards, offset=offset, limit=limit, total=count + ) + + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() diff --git a/invokeai/app/services/boards.py b/invokeai/app/services/boards.py index 3cdadd6c22..07d64e655a 100644 --- a/invokeai/app/services/boards.py +++ b/invokeai/app/services/boards.py @@ -1,253 +1,153 @@ from abc import ABC, abstractmethod -from datetime import datetime -from typing import Generic, Optional, TypeVar, cast -import sqlite3 -import threading -from typing import Optional, Union -import uuid -from invokeai.app.services.image_record_storage import OffsetPaginatedResults -from pydantic import BaseModel, Field, Extra -from pydantic.generics import GenericModel +from logging import Logger +from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase +from invokeai.app.services.board_images import board_record_to_dto -T = TypeVar("T", bound=BaseModel) - -class BoardRecord(BaseModel): - """Deserialized board record.""" - - id: str = Field(description="The unique ID of the board.") - name: str = Field(description="The name of the board.") - """The name of the board.""" - created_at: Union[datetime, str] = Field( - description="The created timestamp of the board." - ) - """The created timestamp of the image.""" - updated_at: Union[datetime, str] = Field( - description="The updated timestamp of the board." - ) - -class BoardRecordInList(BaseModel): - """Deserialized board record in a list.""" - - id: str = Field(description="The unique ID of the board.") - name: str = Field(description="The name of the board.") - most_recent_image_url: Optional[str] = Field( - description="The URL of the most recent image in the board." - ) - """The name of the board.""" - created_at: Union[datetime, str] = Field( - description="The created timestamp of the board." - ) - """The created timestamp of the image.""" - updated_at: Union[datetime, str] = Field( - description="The updated timestamp of the board." - ) - -class BoardRecordChanges(BaseModel, extra=Extra.forbid): - name: Optional[str] = Field( - description="The board's new name." - ) - -class BoardRecordNotFoundException(Exception): - """Raised when an board record is not found.""" - - def __init__(self, message="Board record not found"): - super().__init__(message) +from invokeai.app.services.board_record_storage import ( + BoardDTO, + BoardRecord, + BoardChanges, + BoardRecordStorageBase, +) +from invokeai.app.services.image_record_storage import ( + ImageRecordStorageBase, + OffsetPaginatedResults, +) +from invokeai.app.services.models.image_record import ImageDTO +from invokeai.app.services.urls import UrlServiceBase -class BoardRecordSaveException(Exception): - """Raised when an board record cannot be saved.""" - - def __init__(self, message="Board record not saved"): - super().__init__(message) - - -class BoardRecordDeleteException(Exception): - """Raised when an board record cannot be deleted.""" - - def __init__(self, message="Board record not deleted"): - super().__init__(message) - -class BoardStorageBase(ABC): - """Low-level service responsible for interfacing with the board record store.""" +class BoardServiceABC(ABC): + """High-level service for board management.""" @abstractmethod - def delete(self, board_id: str) -> None: - """Deletes a board record.""" + def create( + self, + board_name: str, + ) -> BoardDTO: + """Creates a board.""" pass @abstractmethod - def save( + def get_dto( self, - board_name: str, - ): - """Saves a board record.""" + board_id: str, + ) -> BoardDTO: + """Gets a board.""" pass - def get_cover_photo(self, board_id: str) -> Optional[str]: - """Gets the cover photo for a board.""" + @abstractmethod + def update( + self, + board_id: str, + changes: BoardChanges, + ) -> BoardDTO: + """Updates a board.""" pass + @abstractmethod + def delete( + self, + board_id: str, + ) -> None: + """Deletes a board.""" + pass + + @abstractmethod def get_many( self, - offset: int, - limit: int, - ): - """Gets many board records.""" + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[BoardDTO]: + """Gets many boards.""" pass -class SqliteBoardStorage(BoardStorageBase): - _filename: str - _conn: sqlite3.Connection - _cursor: sqlite3.Cursor - _lock: threading.Lock +class BoardServiceDependencies: + """Service dependencies for the BoardService.""" - def __init__(self, filename: str) -> None: - super().__init__() - self._filename = filename - self._conn = sqlite3.connect(filename, check_same_thread=False) - # Enable row factory to get rows as dictionaries (must be done before making the cursor!) - self._conn.row_factory = sqlite3.Row - self._cursor = self._conn.cursor() - self._lock = threading.Lock() + board_image_records: BoardImageRecordStorageBase + board_records: BoardRecordStorageBase + image_records: ImageRecordStorageBase + urls: UrlServiceBase + logger: Logger - try: - self._lock.acquire() - # Enable foreign keys - self._conn.execute("PRAGMA foreign_keys = ON;") - self._create_tables() - self._conn.commit() - finally: - self._lock.release() + def __init__( + self, + board_image_record_storage: BoardImageRecordStorageBase, + image_record_storage: ImageRecordStorageBase, + board_record_storage: BoardRecordStorageBase, + url: UrlServiceBase, + logger: Logger, + ): + self.board_image_records = board_image_record_storage + self.image_records = image_record_storage + self.board_records = board_record_storage + self.urls = url + self.logger = logger - def _create_tables(self) -> None: - """Creates the `board` table.""" - # Create the `images` table. - self._cursor.execute( - """--sql - CREATE TABLE IF NOT EXISTS boards ( - id TEXT NOT NULL PRIMARY KEY, - name TEXT NOT NULL, - created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), - -- Updated via trigger - updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) - ); - """ +class BoardService(BoardServiceABC): + _services: BoardServiceDependencies + + def __init__(self, services: BoardServiceDependencies): + self._services = services + + def create( + self, + board_name: str, + ) -> BoardDTO: + board_record = self._services.board_records.save(board_name) + return board_record_to_dto(board_record, None, 0) + + def get_dto(self, board_id: str) -> BoardDTO: + board_record = self._services.board_records.get(board_id) + cover_image_url = ( + self._services.urls.get_image_url(board_record.cover_image_name, True) + if board_record.cover_image_name + else None ) - - self._cursor.execute( - """--sql - CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards(created_at); - """ + image_count = self._services.board_image_records.get_image_count_for_board( + board_id ) + return board_record_to_dto(board_record, cover_image_url, image_count) - # Add trigger for `updated_at`. - self._cursor.execute( - """--sql - CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at - AFTER UPDATE - ON boards FOR EACH ROW - BEGIN - UPDATE boards SET updated_at = current_timestamp - WHERE board_name = old.board_name; - END; - """ + def update( + self, + board_id: str, + changes: BoardChanges, + ) -> BoardDTO: + board_record = self._services.board_records.update(board_id, changes) + cover_image_url = ( + self._services.urls.get_image_url(board_record.cover_image_name, True) + if board_record.cover_image_name + else None ) - + image_count = self._services.board_image_records.get_image_count_for_board( + board_id + ) + return board_record_to_dto(board_record, cover_image_url, image_count) def delete(self, board_id: str) -> None: - try: - self._lock.acquire() - self._cursor.execute( - """--sql - DELETE FROM boards - WHERE id = ?; - """, - (board_id), - ) - self._conn.commit() - except sqlite3.Error as e: - self._conn.rollback() - raise BoardRecordDeleteException from e - finally: - self._lock.release() - - def save( - self, - board_name: str, - ): - try: - board_id = str(uuid.uuid4()) - self._lock.acquire() - self._cursor.execute( - """--sql - INSERT OR IGNORE INTO boards (id, name) - VALUES (?, ?); - """, - (board_id, board_name), - ) - self._conn.commit() - - self._cursor.execute( - """--sql - SELECT * - FROM boards - WHERE id = ?; - """, - (board_id,), - ) - - result = self._cursor.fetchone() - return result - except sqlite3.Error as e: - self._conn.rollback() - raise BoardRecordSaveException from e - finally: - self._lock.release() - + self._services.board_records.delete(board_id) def get_many( - self, - offset: int, - limit: int, - ) -> OffsetPaginatedResults[BoardRecord]: - try: + self, offset: int = 0, limit: int = 10 + ) -> OffsetPaginatedResults[BoardDTO]: + board_records = self._services.board_records.get_many(offset, limit) + board_dtos = [] + for r in board_records.items: + cover_image_url = ( + self._services.urls.get_image_url(r.cover_image_name, True) + if r.cover_image_name + else None + ) + image_count = self._services.board_image_records.get_image_count_for_board( + r.board_id + ) + board_dtos.append(board_record_to_dto(r, cover_image_url, image_count)) - self._lock.acquire() - - count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n""" - images_query = f"""SELECT * FROM images WHERE 1=1\n""" - - query_conditions = "" - query_params = [] - - query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n""" - - # Final images query with pagination - images_query += query_conditions + query_pagination + ";" - # Add all the parameters - images_params = query_params.copy() - images_params.append(limit) - images_params.append(offset) - # Build the list of images, deserializing each row - self._cursor.execute(images_query, images_params) - result = cast(list[sqlite3.Row], self._cursor.fetchall()) - boards = [BoardRecord(**dict(row)) for row in result] - - # Set up and execute the count query, without pagination - count_query += query_conditions + ";" - count_params = query_params.copy() - self._cursor.execute(count_query, count_params) - count = self._cursor.fetchone()[0] - - except sqlite3.Error as e: - self._conn.rollback() - raise BoardRecordSaveException from e - finally: - self._lock.release() - - return OffsetPaginatedResults( - items=boards, offset=offset, limit=limit, total=count - ) \ No newline at end of file + return OffsetPaginatedResults[BoardDTO]( + items=board_dtos, offset=offset, limit=limit, total=len(board_dtos) + ) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 96c6beea12..2ca9ad66ca 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -82,7 +82,6 @@ class ImageRecordStorageBase(ABC): image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - board_id: Optional[str] = None, ) -> OffsetPaginatedResults[ImageRecord]: """Gets a page of image records.""" pass @@ -94,11 +93,6 @@ class ImageRecordStorageBase(ABC): """Deletes an image record.""" pass - @abstractmethod - def get_board_cover_photo(self, board_id: str) -> Optional[ImageRecord]: - """Gets the cover photo for a board.""" - pass - @abstractmethod def save( self, @@ -197,7 +191,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): AFTER UPDATE ON images FOR EACH ROW BEGIN - UPDATE images SET updated_at = current_timestamp + UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') WHERE image_name = old.image_name; END; """ @@ -268,14 +262,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): ) # Change the image's `is_intermediate`` flag - if changes.board_id is not None: + if changes.is_intermediate is not None: self._cursor.execute( f"""--sql UPDATE images SET board_id = ? WHERE image_name = ?; """, - (changes.board_id, image_name), + (changes.is_intermediate, image_name), ) self._conn.commit() @@ -284,32 +278,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): raise ImageRecordSaveException from e finally: self._lock.release() - - def get_board_cover_photo(self, board_id: str) -> ImageRecord | None: - try: - self._lock.acquire() - self._cursor.execute( - """ - SELECT * - FROM images - WHERE board_id = ? - ORDER BY created_at DESC - LIMIT 1 - """, - (board_id), - ) - self._conn.commit() - result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) - except sqlite3.Error as e: - self._conn.rollback() - raise ImageRecordNotFoundException from e - finally: - self._lock.release() - - if not result: - raise ImageRecordNotFoundException - - return deserialize_image_record(dict(result)) def get_many( self, @@ -318,7 +286,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - board_id: Optional[str] = None, ) -> OffsetPaginatedResults[ImageRecord]: try: self._lock.acquire() @@ -350,10 +317,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): query_conditions += f"""AND is_intermediate = ?\n""" query_params.append(is_intermediate) - if board_id is not None: - query_conditions += f"""AND board_id = ?\n""" - query_params.append(board_id) - query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n""" # Final images query with pagination @@ -371,7 +334,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): count_query += query_conditions + ";" count_params = query_params.copy() self._cursor.execute(count_query, count_params) - count = self._cursor.fetchone()[0] + count = cast(int, self._cursor.fetchone()[0]) except sqlite3.Error as e: self._conn.rollback() raise e diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 173268563a..aa27e38d17 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -49,7 +49,7 @@ class ImageServiceABC(ABC): image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, - intermediate: bool = False, + is_intermediate: bool = False, ) -> ImageDTO: """Creates an image, storing the file and its metadata.""" pass @@ -79,7 +79,7 @@ class ImageServiceABC(ABC): pass @abstractmethod - def get_path(self, image_name: str) -> str: + def get_path(self, image_name: str, thumbnail: bool = False) -> str: """Gets an image's path.""" pass @@ -322,7 +322,6 @@ class ImageService(ImageServiceABC): image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - board_id: Optional[str] = None, ) -> OffsetPaginatedResults[ImageDTO]: try: results = self._services.records.get_many( @@ -331,7 +330,6 @@ class ImageService(ImageServiceABC): image_origin, categories, is_intermediate, - board_id ) image_dtos = list( diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index d69e0b294f..10d1d91920 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -4,7 +4,9 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from logging import Logger - from invokeai.app.services.images import ImageService + from invokeai.app.services.board_images import BoardImagesServiceABC + from invokeai.app.services.boards import BoardServiceABC + from invokeai.app.services.images import ImageServiceABC from invokeai.backend import ModelManager from invokeai.app.services.events import EventServiceBase from invokeai.app.services.latent_storage import LatentsStorageBase @@ -14,7 +16,6 @@ if TYPE_CHECKING: from invokeai.app.services.config import InvokeAISettings from invokeai.app.services.graph import GraphExecutionState, LibraryGraph from invokeai.app.services.invoker import InvocationProcessorABC - from invokeai.app.services.boards import BoardStorageBase class InvocationServices: @@ -27,10 +28,9 @@ class InvocationServices: model_manager: "ModelManager" restoration: "RestorationServices" configuration: "InvokeAISettings" - images: "ImageService" - boards: "BoardStorageBase" - - # NOTE: we must forward-declare any types that include invocations, since invocations can use services + images: "ImageServiceABC" + boards: "BoardServiceABC" + board_images: "BoardImagesServiceABC" graph_library: "ItemStorageABC"["LibraryGraph"] graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] processor: "InvocationProcessorABC" @@ -41,20 +41,23 @@ class InvocationServices: events: "EventServiceBase", logger: "Logger", latents: "LatentsStorageBase", - images: "ImageService", + images: "ImageServiceABC", + boards: "BoardServiceABC", + board_images: "BoardImagesServiceABC", queue: "InvocationQueueABC", graph_library: "ItemStorageABC"["LibraryGraph"], graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], processor: "InvocationProcessorABC", restoration: "RestorationServices", configuration: "InvokeAISettings", - boards: "BoardStorageBase", ): self.model_manager = model_manager self.events = events self.logger = logger self.latents = latents self.images = images + self.boards = boards + self.board_images = board_images self.queue = queue self.graph_library = graph_library self.graph_execution_manager = graph_execution_manager diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index 98f370f337..d971d65916 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -48,11 +48,6 @@ class ImageRecord(BaseModel): description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.", ) """A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.""" - board_id: Optional[str] = Field( - default=None, - description="The board ID that this image belongs to.", - ) - """The board ID that this image belongs to.""" class ImageRecordChanges(BaseModel, extra=Extra.forbid): @@ -77,10 +72,6 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid): default=None, description="The image's new `is_intermediate` flag." ) """The image's new `is_intermediate` flag.""" - board_id: Optional[StrictStr] = Field( - default=None, description="The image's new board ID." - ) - """The image's new board ID.""" class ImageUrlsDTO(BaseModel): @@ -131,7 +122,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: updated_at = image_dict.get("updated_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) is_intermediate = image_dict.get("is_intermediate", False) - board_id = image_dict.get("board_id", None) raw_metadata = image_dict.get("metadata") @@ -153,5 +143,4 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: updated_at=updated_at, deleted_at=deleted_at, is_intermediate=is_intermediate, - board_id=board_id, )