feat(nodes): add boards and board_images services

This commit is contained in:
psychedelicious 2023-06-15 00:07:20 +10:00
parent 3833304f57
commit 72e9ced889
12 changed files with 993 additions and 368 deletions

View File

@ -2,7 +2,15 @@
from logging import Logger from logging import Logger
import os import os
from invokeai.app.services import boards from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
from invokeai.app.services.board_images import (
BoardImagesService,
BoardImagesServiceDependencies,
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.metadata import CoreMetadataService
@ -59,7 +67,7 @@ class ApiDependencies:
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = config.db_path db_location = config.db_path
db_location.parent.mkdir(parents=True,exist_ok=True) db_location.parent.mkdir(parents=True, exist_ok=True)
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"
@ -73,7 +81,29 @@ class ApiDependencies:
latents = ForwardCacheLatentsStorage( latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents") DiskLatentsStorage(f"{output_folder}/latents")
) )
boards = SqliteBoardStorage(db_location)
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
boards = BoardService(
services=BoardServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
board_images = BoardImagesService(
services=BoardImagesServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
images = ImageService( images = ImageService(
image_record_storage=image_record_storage, image_record_storage=image_record_storage,
@ -90,6 +120,8 @@ class ApiDependencies:
events=events, events=events,
latents=latents, latents=latents,
images=images, images=images,
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph]( graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs" filename=db_location, table_name="graphs"
@ -99,7 +131,6 @@ class ApiDependencies:
restoration=RestorationServices(config, logger), restoration=RestorationServices(config, logger),
configuration=config, configuration=config,
logger=logger, logger=logger,
boards=boards
) )
create_system_graphs(services.graph_library) create_system_graphs(services.graph_library)

View File

@ -1,91 +1,86 @@
from fastapi import Body, HTTPException, Path, Query # from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter # from fastapi.routing import APIRouter
from invokeai.app.services.boards import BoardRecord, BoardRecordChanges # from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults # from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from ..dependencies import ApiDependencies # from ..dependencies import ApiDependencies
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"]) # boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
@boards_router.post( # @boards_router.post(
"/", # "/",
operation_id="create_board", # operation_id="create_board",
responses={ # responses={
201: {"description": "The board was created successfully"}, # 201: {"description": "The board was created successfully"},
}, # },
status_code=201, # status_code=201,
) # )
async def create_board( # async def create_board(
board_name: str = Body(description="The name of the board to create"), # board_name: str = Body(description="The name of the board to create"),
): # ):
"""Creates a board""" # """Creates a board"""
try: # try:
result = ApiDependencies.invoker.services.boards.save(board_name=board_name) # result = ApiDependencies.invoker.services.boards.save(board_name=board_name)
return result # return result
except Exception as e: # except Exception as e:
raise HTTPException(status_code=500, detail="Failed to create board") # raise HTTPException(status_code=500, detail="Failed to create board")
@boards_router.delete("/{board_id}", operation_id="delete_board") # @boards_router.delete("/{board_id}", operation_id="delete_board")
async def delete_board( # async def delete_board(
board_id: str = Path(description="The id of board to delete"), # board_id: str = Path(description="The id of board to delete"),
) -> None: # ) -> None:
"""Deletes a board""" # """Deletes a board"""
try: # try:
ApiDependencies.invoker.services.boards.delete(board_id=board_id) # ApiDependencies.invoker.services.boards.delete(board_id=board_id)
except Exception as e: # except Exception as e:
# TODO: Does this need any exception handling at all? # # TODO: Does this need any exception handling at all?
pass # pass
@boards_router.get( # @boards_router.get(
"/", # "/",
operation_id="list_boards", # operation_id="list_boards",
response_model=OffsetPaginatedResults[BoardRecord], # response_model=OffsetPaginatedResults[BoardRecord],
) # )
async def list_boards( # async def list_boards(
offset: int = Query(default=0, description="The page offset"), # offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of boards per page"), # limit: int = Query(default=10, description="The number of boards per page"),
) -> OffsetPaginatedResults[BoardRecord]: # ) -> OffsetPaginatedResults[BoardRecord]:
"""Gets a list of boards""" # """Gets a list of boards"""
results = ApiDependencies.invoker.services.boards.get_many( # results = ApiDependencies.invoker.services.boards.get_many(
offset, # offset,
limit, # limit,
) # )
boards = list( # boards = list(
map( # map(
lambda r: board_record_to_dto( # lambda r: board_record_to_dto(
r, # r,
generate_cover_photo_url(r.id) # generate_cover_photo_url(r.id)
), # ),
results.boards, # results.boards,
) # )
) # )
return boards # return boards
class BoardDTO(BaseModel):
"""A DTO for an image"""
id: str
name: str
cover_image_url: str
def board_record_to_dto( # def board_record_to_dto(
board_record: BoardRecord, cover_image_url: str # board_record: BoardRecord, cover_image_url: str
) -> BoardDTO: # ) -> BoardDTO:
"""Converts an image record to an image DTO.""" # """Converts an image record to an image DTO."""
return BoardDTO( # return BoardDTO(
**board_record.dict(), # **board_record.dict(),
cover_image_url=cover_image_url, # cover_image_url=cover_image_url,
) # )
def generate_cover_photo_url(board_id: str) -> str | None: # def generate_cover_photo_url(board_id: str) -> str | None:
cover_photo = ApiDependencies.invoker.services.images._services.records.get_board_cover_photo(board_id) # cover_photo = ApiDependencies.invoker.services.images._services.records.get_board_cover_photo(board_id)
if cover_photo is not None: # if cover_photo is not None:
url = ApiDependencies.invoker.services.images._services.urls.get_image_url(cover_photo.image_origin, cover_photo.image_name) # url = ApiDependencies.invoker.services.images._services.urls.get_image_url(cover_photo.image_origin, cover_photo.image_name)
return url # return url

View File

@ -221,9 +221,6 @@ async def list_images_with_metadata(
is_intermediate: Optional[bool] = Query( is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images" default=None, description="Whether to list intermediate images"
), ),
board_id: Optional[str] = Query(
default=None, description="The board of images to include"
),
offset: int = Query(default=0, description="The page offset"), offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"), limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
@ -235,7 +232,6 @@ async def list_images_with_metadata(
image_origin, image_origin,
categories, categories,
is_intermediate, is_intermediate,
board_id
) )
return image_dtos return image_dtos

View File

@ -78,7 +78,7 @@ app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api") app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api") # app.include_router(boards.boards_router, prefix="/api")
# Build a custom OpenAPI to include all outputs # Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow? # TODO: can outputs be included on metadata of invocation schemas somehow?

View File

@ -0,0 +1,253 @@
from abc import ABC, abstractmethod
import sqlite3
import threading
from typing import cast
from invokeai.app.services.board_record_storage import BoardRecord
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageRecord,
deserialize_image_record,
)
class BoardImageRecordStorageBase(ABC):
"""Abstract base class for board-image relationship record storage."""
@abstractmethod
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
pass
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
pass
@abstractmethod
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images for a board."""
pass
@abstractmethod
def get_boards_for_image(
self,
board_id: str,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets images for a board."""
pass
@abstractmethod
def get_image_count_for_board(
self,
board_id: str,
) -> int:
"""Gets the number of images for a board."""
pass
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `board_images` junction table."""
# Create the `board_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
PRIMARY KEY (board_id, image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
)
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?);
""",
(board_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM board_images
WHERE board_id = ? AND image_name = ?;
""",
(board_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_images_for_board(
self,
board_id: str,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images for a board."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
)
def get_boards_for_image(
self,
board_id: str,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets boards for an image."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT boards.*
FROM board_images
INNER JOIN boards ON board_images.board_id = boards.board_id
WHERE board_images.image_name = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: BoardRecord(**r), result))
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM boards WHERE 1=1;
"""
)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=boards, offset=offset, limit=limit, total=count
)
def get_image_count_for_board(self, board_id: str) -> int:
"""Gets the number of images for a board."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM board_images WHERE board_id = ?;
""",
(board_id,),
)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return count

View File

@ -0,0 +1,166 @@
from abc import ABC, abstractmethod
from logging import Logger
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import (
BoardDTO,
BoardRecord,
BoardRecordStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
from invokeai.app.services.urls import UrlServiceBase
class BoardImagesServiceABC(ABC):
"""High-level service for board-image relationship management."""
@abstractmethod
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
pass
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
pass
@abstractmethod
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets images for a board."""
pass
@abstractmethod
def get_boards_for_image(
self,
image_name: str,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets boards for an image."""
pass
class BoardImagesServiceDependencies:
"""Service dependencies for the BoardImagesService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardImagesService(BoardImagesServiceABC):
_services: BoardImagesServiceDependencies
def __init__(self, services: BoardImagesServiceDependencies):
self._services = services
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.add_image_to_board(board_id, image_name)
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.remove_image_from_board(board_id, image_name)
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageDTO]:
image_records = self._services.board_image_records.get_images_for_board(
board_id
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
),
image_records.items,
)
)
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=image_records.offset,
limit=image_records.limit,
total=image_records.total,
)
def get_boards_for_image(
self,
image_name: str,
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self._services.board_image_records.get_boards_for_image(
image_name
)
board_dtos = []
for r in board_records.items:
cover_image_url = (
self._services.urls.get_image_url(r.cover_image_name, True)
if r.cover_image_name
else None
)
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(
board_record_to_dto(
r,
cover_image_url,
image_count,
)
)
return OffsetPaginatedResults[BoardDTO](
items=board_dtos,
offset=board_records.offset,
limit=board_records.limit,
total=board_records.total,
)
def board_record_to_dto(
board_record: BoardRecord, cover_image_url: str | None, image_count: int
) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.dict(),
cover_image_url=cover_image_url,
image_count=image_count,
)

View File

@ -0,0 +1,331 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional, cast
import sqlite3
import threading
from typing import Optional, Union
import uuid
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from pydantic import BaseModel, Field, Extra
class BoardRecord(BaseModel):
"""Deserialized board record."""
board_id: str = Field(description="The unique ID of the board.")
"""The unique ID of the board."""
board_name: str = Field(description="The name of the board.")
"""The name of the board."""
created_at: Union[datetime, str] = Field(
description="The created timestamp of the board."
)
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(
description="The updated timestamp of the board."
)
"""The updated timestamp of the image."""
cover_image_name: Optional[str] = Field(
description="The name of the cover image of the board."
)
"""The name of the cover image of the board."""
class BoardDTO(BoardRecord):
"""Deserialized board record with cover image URL and image count."""
cover_image_url: Optional[str] = Field(
description="The URL of the thumbnail of the board's cover image."
)
"""The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board."""
class BoardChanges(BaseModel, extra=Extra.forbid):
board_name: Optional[str] = Field(description="The board's new name.")
cover_image_name: Optional[str] = Field(
description="The name of the board's new cover image."
)
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""
def __init__(self, message="Board record not found"):
super().__init__(message)
class BoardRecordSaveException(Exception):
"""Raised when an board record cannot be saved."""
def __init__(self, message="Board record not saved"):
super().__init__(message)
class BoardRecordDeleteException(Exception):
"""Raised when an board record cannot be deleted."""
def __init__(self, message="Board record not deleted"):
super().__init__(message)
class BoardRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the board record store."""
@abstractmethod
def delete(self, board_id: str) -> None:
"""Deletes a board record."""
pass
@abstractmethod
def save(
self,
board_name: str,
) -> BoardRecord:
"""Saves a board record."""
pass
@abstractmethod
def get(
self,
board_id: str,
) -> BoardRecord:
"""Gets a board record."""
pass
@abstractmethod
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
"""Updates a board record."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
pass
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `boards` table and `board_images` junction table."""
# Create the `boards` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards(created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
)
def delete(self, board_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
board_name: str,
) -> BoardRecord:
try:
board_id = str(uuid.uuid4())
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = self._cursor.fetchone()
return BoardRecord(**result)
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
def get(
self,
board_id: str,
) -> BoardRecord:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
def update(
self,
board_id: str,
changes: BoardChanges,
) -> None:
try:
self._lock.acquire()
# Change the name of a board
if changes.board_name is not None:
self._cursor.execute(
f"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
self._cursor.execute(
f"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
ORDER BY updated_at DESC
LIMIT ? OFFSET ?;
""",
(limit, offset),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [BoardRecord(**dict(row)) for row in result]
# Get the total number of boards
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM boards
WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](
items=boards, offset=offset, limit=limit, total=count
)
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@ -1,253 +1,153 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime
from typing import Generic, Optional, TypeVar, cast
import sqlite3
import threading
from typing import Optional, Union
import uuid
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from pydantic import BaseModel, Field, Extra from logging import Logger
from pydantic.generics import GenericModel from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_images import board_record_to_dto
T = TypeVar("T", bound=BaseModel) from invokeai.app.services.board_record_storage import (
BoardDTO,
class BoardRecord(BaseModel): BoardRecord,
"""Deserialized board record.""" BoardChanges,
BoardRecordStorageBase,
id: str = Field(description="The unique ID of the board.") )
name: str = Field(description="The name of the board.") from invokeai.app.services.image_record_storage import (
"""The name of the board.""" ImageRecordStorageBase,
created_at: Union[datetime, str] = Field( OffsetPaginatedResults,
description="The created timestamp of the board." )
) from invokeai.app.services.models.image_record import ImageDTO
"""The created timestamp of the image.""" from invokeai.app.services.urls import UrlServiceBase
updated_at: Union[datetime, str] = Field(
description="The updated timestamp of the board."
)
class BoardRecordInList(BaseModel):
"""Deserialized board record in a list."""
id: str = Field(description="The unique ID of the board.")
name: str = Field(description="The name of the board.")
most_recent_image_url: Optional[str] = Field(
description="The URL of the most recent image in the board."
)
"""The name of the board."""
created_at: Union[datetime, str] = Field(
description="The created timestamp of the board."
)
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(
description="The updated timestamp of the board."
)
class BoardRecordChanges(BaseModel, extra=Extra.forbid):
name: Optional[str] = Field(
description="The board's new name."
)
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""
def __init__(self, message="Board record not found"):
super().__init__(message)
class BoardRecordSaveException(Exception): class BoardServiceABC(ABC):
"""Raised when an board record cannot be saved.""" """High-level service for board management."""
def __init__(self, message="Board record not saved"):
super().__init__(message)
class BoardRecordDeleteException(Exception):
"""Raised when an board record cannot be deleted."""
def __init__(self, message="Board record not deleted"):
super().__init__(message)
class BoardStorageBase(ABC):
"""Low-level service responsible for interfacing with the board record store."""
@abstractmethod @abstractmethod
def delete(self, board_id: str) -> None: def create(
"""Deletes a board record.""" self,
board_name: str,
) -> BoardDTO:
"""Creates a board."""
pass pass
@abstractmethod @abstractmethod
def save( def get_dto(
self, self,
board_name: str, board_id: str,
): ) -> BoardDTO:
"""Saves a board record.""" """Gets a board."""
pass pass
def get_cover_photo(self, board_id: str) -> Optional[str]: @abstractmethod
"""Gets the cover photo for a board.""" def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardDTO:
"""Updates a board."""
pass pass
@abstractmethod
def delete(
self,
board_id: str,
) -> None:
"""Deletes a board."""
pass
@abstractmethod
def get_many( def get_many(
self, self,
offset: int, offset: int = 0,
limit: int, limit: int = 10,
): ) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many board records.""" """Gets many boards."""
pass pass
class SqliteBoardStorage(BoardStorageBase): class BoardServiceDependencies:
_filename: str """Service dependencies for the BoardService."""
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None: board_image_records: BoardImageRecordStorageBase
super().__init__() board_records: BoardRecordStorageBase
self._filename = filename image_records: ImageRecordStorageBase
self._conn = sqlite3.connect(filename, check_same_thread=False) urls: UrlServiceBase
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) logger: Logger
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try: def __init__(
self._lock.acquire() self,
# Enable foreign keys board_image_record_storage: BoardImageRecordStorageBase,
self._conn.execute("PRAGMA foreign_keys = ON;") image_record_storage: ImageRecordStorageBase,
self._create_tables() board_record_storage: BoardRecordStorageBase,
self._conn.commit() url: UrlServiceBase,
finally: logger: Logger,
self._lock.release() ):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
def _create_tables(self) -> None:
"""Creates the `board` table."""
# Create the `images` table. class BoardService(BoardServiceABC):
self._cursor.execute( _services: BoardServiceDependencies
"""--sql
CREATE TABLE IF NOT EXISTS boards ( def __init__(self, services: BoardServiceDependencies):
id TEXT NOT NULL PRIMARY KEY, self._services = services
name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), def create(
-- Updated via trigger self,
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) board_name: str,
); ) -> BoardDTO:
""" board_record = self._services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0)
def get_dto(self, board_id: str) -> BoardDTO:
board_record = self._services.board_records.get(board_id)
cover_image_url = (
self._services.urls.get_image_url(board_record.cover_image_name, True)
if board_record.cover_image_name
else None
) )
image_count = self._services.board_image_records.get_image_count_for_board(
self._cursor.execute( board_id
"""--sql
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards(created_at);
"""
) )
return board_record_to_dto(board_record, cover_image_url, image_count)
# Add trigger for `updated_at`. def update(
self._cursor.execute( self,
"""--sql board_id: str,
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at changes: BoardChanges,
AFTER UPDATE ) -> BoardDTO:
ON boards FOR EACH ROW board_record = self._services.board_records.update(board_id, changes)
BEGIN cover_image_url = (
UPDATE boards SET updated_at = current_timestamp self._services.urls.get_image_url(board_record.cover_image_name, True)
WHERE board_name = old.board_name; if board_record.cover_image_name
END; else None
"""
) )
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_url, image_count)
def delete(self, board_id: str) -> None: def delete(self, board_id: str) -> None:
try: self._services.board_records.delete(board_id)
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM boards
WHERE id = ?;
""",
(board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
board_name: str,
):
try:
board_id = str(uuid.uuid4())
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (id, name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE id = ?;
""",
(board_id,),
)
result = self._cursor.fetchone()
return result
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
def get_many( def get_many(
self, self, offset: int = 0, limit: int = 10
offset: int, ) -> OffsetPaginatedResults[BoardDTO]:
limit: int, board_records = self._services.board_records.get_many(offset, limit)
) -> OffsetPaginatedResults[BoardRecord]: board_dtos = []
try: for r in board_records.items:
cover_image_url = (
self._lock.acquire() self._services.urls.get_image_url(r.cover_image_name, True)
if r.cover_image_name
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n""" else None
images_query = f"""SELECT * FROM images WHERE 1=1\n""" )
image_count = self._services.board_image_records.get_image_count_for_board(
query_conditions = "" r.board_id
query_params = [] )
board_dtos.append(board_record_to_dto(r, cover_image_url, image_count))
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
return OffsetPaginatedResults[BoardDTO](
# Final images query with pagination items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
images_params.append(limit)
images_params.append(offset)
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [BoardRecord(**dict(row)) for row in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=boards, offset=offset, limit=limit, total=count
) )

View File

@ -82,7 +82,6 @@ class ImageRecordStorageBase(ABC):
image_origin: Optional[ResourceOrigin] = None, image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None, categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]: ) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records.""" """Gets a page of image records."""
pass pass
@ -94,11 +93,6 @@ class ImageRecordStorageBase(ABC):
"""Deletes an image record.""" """Deletes an image record."""
pass pass
@abstractmethod
def get_board_cover_photo(self, board_id: str) -> Optional[ImageRecord]:
"""Gets the cover photo for a board."""
pass
@abstractmethod @abstractmethod
def save( def save(
self, self,
@ -197,7 +191,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
AFTER UPDATE AFTER UPDATE
ON images FOR EACH ROW ON images FOR EACH ROW
BEGIN BEGIN
UPDATE images SET updated_at = current_timestamp UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name; WHERE image_name = old.image_name;
END; END;
""" """
@ -268,14 +262,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
) )
# Change the image's `is_intermediate`` flag # Change the image's `is_intermediate`` flag
if changes.board_id is not None: if changes.is_intermediate is not None:
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
UPDATE images UPDATE images
SET board_id = ? SET board_id = ?
WHERE image_name = ?; WHERE image_name = ?;
""", """,
(changes.board_id, image_name), (changes.is_intermediate, image_name),
) )
self._conn.commit() self._conn.commit()
@ -285,32 +279,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally: finally:
self._lock.release() self._lock.release()
def get_board_cover_photo(self, board_id: str) -> ImageRecord | None:
try:
self._lock.acquire()
self._cursor.execute(
"""
SELECT *
FROM images
WHERE board_id = ?
ORDER BY created_at DESC
LIMIT 1
""",
(board_id),
)
self._conn.commit()
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
if not result:
raise ImageRecordNotFoundException
return deserialize_image_record(dict(result))
def get_many( def get_many(
self, self,
offset: int = 0, offset: int = 0,
@ -318,7 +286,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_origin: Optional[ResourceOrigin] = None, image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None, categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]: ) -> OffsetPaginatedResults[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
@ -350,10 +317,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
query_conditions += f"""AND is_intermediate = ?\n""" query_conditions += f"""AND is_intermediate = ?\n"""
query_params.append(is_intermediate) query_params.append(is_intermediate)
if board_id is not None:
query_conditions += f"""AND board_id = ?\n"""
query_params.append(board_id)
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n""" query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
# Final images query with pagination # Final images query with pagination
@ -371,7 +334,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
count_query += query_conditions + ";" count_query += query_conditions + ";"
count_params = query_params.copy() count_params = query_params.copy()
self._cursor.execute(count_query, count_params) self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0] count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise e raise e

View File

@ -49,7 +49,7 @@ class ImageServiceABC(ABC):
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
intermediate: bool = False, is_intermediate: bool = False,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass
@ -79,7 +79,7 @@ class ImageServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_path(self, image_name: str) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path.""" """Gets an image's path."""
pass pass
@ -322,7 +322,6 @@ class ImageService(ImageServiceABC):
image_origin: Optional[ResourceOrigin] = None, image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None, categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
try: try:
results = self._services.records.get_many( results = self._services.records.get_many(
@ -331,7 +330,6 @@ class ImageService(ImageServiceABC):
image_origin, image_origin,
categories, categories,
is_intermediate, is_intermediate,
board_id
) )
image_dtos = list( image_dtos = list(

View File

@ -4,7 +4,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from logging import Logger from logging import Logger
from invokeai.app.services.images import ImageService from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC
from invokeai.backend import ModelManager from invokeai.backend import ModelManager
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.latent_storage import LatentsStorageBase
@ -14,7 +16,6 @@ if TYPE_CHECKING:
from invokeai.app.services.config import InvokeAISettings from invokeai.app.services.config import InvokeAISettings
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC from invokeai.app.services.invoker import InvocationProcessorABC
from invokeai.app.services.boards import BoardStorageBase
class InvocationServices: class InvocationServices:
@ -27,10 +28,9 @@ class InvocationServices:
model_manager: "ModelManager" model_manager: "ModelManager"
restoration: "RestorationServices" restoration: "RestorationServices"
configuration: "InvokeAISettings" configuration: "InvokeAISettings"
images: "ImageService" images: "ImageServiceABC"
boards: "BoardStorageBase" boards: "BoardServiceABC"
board_images: "BoardImagesServiceABC"
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: "ItemStorageABC"["LibraryGraph"] graph_library: "ItemStorageABC"["LibraryGraph"]
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
@ -41,20 +41,23 @@ class InvocationServices:
events: "EventServiceBase", events: "EventServiceBase",
logger: "Logger", logger: "Logger",
latents: "LatentsStorageBase", latents: "LatentsStorageBase",
images: "ImageService", images: "ImageServiceABC",
boards: "BoardServiceABC",
board_images: "BoardImagesServiceABC",
queue: "InvocationQueueABC", queue: "InvocationQueueABC",
graph_library: "ItemStorageABC"["LibraryGraph"], graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
restoration: "RestorationServices", restoration: "RestorationServices",
configuration: "InvokeAISettings", configuration: "InvokeAISettings",
boards: "BoardStorageBase",
): ):
self.model_manager = model_manager self.model_manager = model_manager
self.events = events self.events = events
self.logger = logger self.logger = logger
self.latents = latents self.latents = latents
self.images = images self.images = images
self.boards = boards
self.board_images = board_images
self.queue = queue self.queue = queue
self.graph_library = graph_library self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager

View File

@ -48,11 +48,6 @@ class ImageRecord(BaseModel):
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.", description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
) )
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.""" """A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
board_id: Optional[str] = Field(
default=None,
description="The board ID that this image belongs to.",
)
"""The board ID that this image belongs to."""
class ImageRecordChanges(BaseModel, extra=Extra.forbid): class ImageRecordChanges(BaseModel, extra=Extra.forbid):
@ -77,10 +72,6 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
default=None, description="The image's new `is_intermediate` flag." default=None, description="The image's new `is_intermediate` flag."
) )
"""The image's new `is_intermediate` flag.""" """The image's new `is_intermediate` flag."""
board_id: Optional[StrictStr] = Field(
default=None, description="The image's new board ID."
)
"""The image's new board ID."""
class ImageUrlsDTO(BaseModel): class ImageUrlsDTO(BaseModel):
@ -131,7 +122,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at = image_dict.get("updated_at", get_iso_timestamp()) updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False) is_intermediate = image_dict.get("is_intermediate", False)
board_id = image_dict.get("board_id", None)
raw_metadata = image_dict.get("metadata") raw_metadata = image_dict.get("metadata")
@ -153,5 +143,4 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at=updated_at, updated_at=updated_at,
deleted_at=deleted_at, deleted_at=deleted_at,
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
board_id=board_id,
) )