feat(nodes): add boards and board_images services

This commit is contained in:
psychedelicious 2023-06-15 00:07:20 +10:00
parent 3833304f57
commit 72e9ced889
12 changed files with 993 additions and 368 deletions

View File

@ -2,7 +2,15 @@
from logging import Logger
import os
from invokeai.app.services import boards
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
from invokeai.app.services.board_images import (
BoardImagesService,
BoardImagesServiceDependencies,
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
@ -59,7 +67,7 @@ class ApiDependencies:
# TODO: build a file/path manager?
db_location = config.db_path
db_location.parent.mkdir(parents=True,exist_ok=True)
db_location.parent.mkdir(parents=True, exist_ok=True)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
@ -73,7 +81,29 @@ class ApiDependencies:
latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
boards = SqliteBoardStorage(db_location)
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
boards = BoardService(
services=BoardServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
board_images = BoardImagesService(
services=BoardImagesServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
images = ImageService(
image_record_storage=image_record_storage,
@ -90,6 +120,8 @@ class ApiDependencies:
events=events,
latents=latents,
images=images,
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
@ -99,7 +131,6 @@ class ApiDependencies:
restoration=RestorationServices(config, logger),
configuration=config,
logger=logger,
boards=boards
)
create_system_graphs(services.graph_library)

View File

@ -1,91 +1,86 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.boards import BoardRecord, BoardRecordChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
# from fastapi import Body, HTTPException, Path, Query
# from fastapi.routing import APIRouter
# from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
# from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from ..dependencies import ApiDependencies
# from ..dependencies import ApiDependencies
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
# boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
@boards_router.post(
"/",
operation_id="create_board",
responses={
201: {"description": "The board was created successfully"},
},
status_code=201,
)
async def create_board(
board_name: str = Body(description="The name of the board to create"),
):
"""Creates a board"""
try:
result = ApiDependencies.invoker.services.boards.save(board_name=board_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to create board")
# @boards_router.post(
# "/",
# operation_id="create_board",
# responses={
# 201: {"description": "The board was created successfully"},
# },
# status_code=201,
# )
# async def create_board(
# board_name: str = Body(description="The name of the board to create"),
# ):
# """Creates a board"""
# try:
# result = ApiDependencies.invoker.services.boards.save(board_name=board_name)
# return result
# except Exception as e:
# raise HTTPException(status_code=500, detail="Failed to create board")
@boards_router.delete("/{board_id}", operation_id="delete_board")
async def delete_board(
board_id: str = Path(description="The id of board to delete"),
) -> None:
"""Deletes a board"""
# @boards_router.delete("/{board_id}", operation_id="delete_board")
# async def delete_board(
# board_id: str = Path(description="The id of board to delete"),
# ) -> None:
# """Deletes a board"""
try:
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
# try:
# ApiDependencies.invoker.services.boards.delete(board_id=board_id)
# except Exception as e:
# # TODO: Does this need any exception handling at all?
# pass
@boards_router.get(
"/",
operation_id="list_boards",
response_model=OffsetPaginatedResults[BoardRecord],
)
async def list_boards(
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of boards per page"),
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets a list of boards"""
# @boards_router.get(
# "/",
# operation_id="list_boards",
# response_model=OffsetPaginatedResults[BoardRecord],
# )
# async def list_boards(
# offset: int = Query(default=0, description="The page offset"),
# limit: int = Query(default=10, description="The number of boards per page"),
# ) -> OffsetPaginatedResults[BoardRecord]:
# """Gets a list of boards"""
results = ApiDependencies.invoker.services.boards.get_many(
offset,
limit,
)
# results = ApiDependencies.invoker.services.boards.get_many(
# offset,
# limit,
# )
boards = list(
map(
lambda r: board_record_to_dto(
r,
generate_cover_photo_url(r.id)
),
results.boards,
)
)
# boards = list(
# map(
# lambda r: board_record_to_dto(
# r,
# generate_cover_photo_url(r.id)
# ),
# results.boards,
# )
# )
return boards
# return boards
class BoardDTO(BaseModel):
"""A DTO for an image"""
id: str
name: str
cover_image_url: str
def board_record_to_dto(
board_record: BoardRecord, cover_image_url: str
) -> BoardDTO:
"""Converts an image record to an image DTO."""
return BoardDTO(
**board_record.dict(),
cover_image_url=cover_image_url,
)
# def board_record_to_dto(
# board_record: BoardRecord, cover_image_url: str
# ) -> BoardDTO:
# """Converts an image record to an image DTO."""
# return BoardDTO(
# **board_record.dict(),
# cover_image_url=cover_image_url,
# )
def generate_cover_photo_url(board_id: str) -> str | None:
cover_photo = ApiDependencies.invoker.services.images._services.records.get_board_cover_photo(board_id)
if cover_photo is not None:
url = ApiDependencies.invoker.services.images._services.urls.get_image_url(cover_photo.image_origin, cover_photo.image_name)
return url
# def generate_cover_photo_url(board_id: str) -> str | None:
# cover_photo = ApiDependencies.invoker.services.images._services.records.get_board_cover_photo(board_id)
# if cover_photo is not None:
# url = ApiDependencies.invoker.services.images._services.urls.get_image_url(cover_photo.image_origin, cover_photo.image_name)
# return url

View File

@ -221,9 +221,6 @@ async def list_images_with_metadata(
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images"
),
board_id: Optional[str] = Query(
default=None, description="The board of images to include"
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]:
@ -235,7 +232,6 @@ async def list_images_with_metadata(
image_origin,
categories,
is_intermediate,
board_id
)
return image_dtos

View File

@ -78,7 +78,7 @@ app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
# app.include_router(boards.boards_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?

View File

@ -0,0 +1,253 @@
from abc import ABC, abstractmethod
import sqlite3
import threading
from typing import cast
from invokeai.app.services.board_record_storage import BoardRecord
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageRecord,
deserialize_image_record,
)
class BoardImageRecordStorageBase(ABC):
"""Abstract base class for board-image relationship record storage."""
@abstractmethod
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
pass
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
pass
@abstractmethod
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images for a board."""
pass
@abstractmethod
def get_boards_for_image(
self,
board_id: str,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets images for a board."""
pass
@abstractmethod
def get_image_count_for_board(
self,
board_id: str,
) -> int:
"""Gets the number of images for a board."""
pass
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `board_images` junction table."""
# Create the `board_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
PRIMARY KEY (board_id, image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
)
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?);
""",
(board_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM board_images
WHERE board_id = ? AND image_name = ?;
""",
(board_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_images_for_board(
self,
board_id: str,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images for a board."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
)
def get_boards_for_image(
self,
board_id: str,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets boards for an image."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT boards.*
FROM board_images
INNER JOIN boards ON board_images.board_id = boards.board_id
WHERE board_images.image_name = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: BoardRecord(**r), result))
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM boards WHERE 1=1;
"""
)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=boards, offset=offset, limit=limit, total=count
)
def get_image_count_for_board(self, board_id: str) -> int:
"""Gets the number of images for a board."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM board_images WHERE board_id = ?;
""",
(board_id,),
)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return count

View File

@ -0,0 +1,166 @@
from abc import ABC, abstractmethod
from logging import Logger
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import (
BoardDTO,
BoardRecord,
BoardRecordStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
from invokeai.app.services.urls import UrlServiceBase
class BoardImagesServiceABC(ABC):
"""High-level service for board-image relationship management."""
@abstractmethod
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
pass
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
pass
@abstractmethod
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets images for a board."""
pass
@abstractmethod
def get_boards_for_image(
self,
image_name: str,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets boards for an image."""
pass
class BoardImagesServiceDependencies:
"""Service dependencies for the BoardImagesService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardImagesService(BoardImagesServiceABC):
_services: BoardImagesServiceDependencies
def __init__(self, services: BoardImagesServiceDependencies):
self._services = services
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.add_image_to_board(board_id, image_name)
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.remove_image_from_board(board_id, image_name)
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageDTO]:
image_records = self._services.board_image_records.get_images_for_board(
board_id
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
),
image_records.items,
)
)
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=image_records.offset,
limit=image_records.limit,
total=image_records.total,
)
def get_boards_for_image(
self,
image_name: str,
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self._services.board_image_records.get_boards_for_image(
image_name
)
board_dtos = []
for r in board_records.items:
cover_image_url = (
self._services.urls.get_image_url(r.cover_image_name, True)
if r.cover_image_name
else None
)
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(
board_record_to_dto(
r,
cover_image_url,
image_count,
)
)
return OffsetPaginatedResults[BoardDTO](
items=board_dtos,
offset=board_records.offset,
limit=board_records.limit,
total=board_records.total,
)
def board_record_to_dto(
board_record: BoardRecord, cover_image_url: str | None, image_count: int
) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.dict(),
cover_image_url=cover_image_url,
image_count=image_count,
)

View File

@ -0,0 +1,331 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional, cast
import sqlite3
import threading
from typing import Optional, Union
import uuid
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from pydantic import BaseModel, Field, Extra
class BoardRecord(BaseModel):
"""Deserialized board record."""
board_id: str = Field(description="The unique ID of the board.")
"""The unique ID of the board."""
board_name: str = Field(description="The name of the board.")
"""The name of the board."""
created_at: Union[datetime, str] = Field(
description="The created timestamp of the board."
)
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(
description="The updated timestamp of the board."
)
"""The updated timestamp of the image."""
cover_image_name: Optional[str] = Field(
description="The name of the cover image of the board."
)
"""The name of the cover image of the board."""
class BoardDTO(BoardRecord):
"""Deserialized board record with cover image URL and image count."""
cover_image_url: Optional[str] = Field(
description="The URL of the thumbnail of the board's cover image."
)
"""The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board."""
class BoardChanges(BaseModel, extra=Extra.forbid):
board_name: Optional[str] = Field(description="The board's new name.")
cover_image_name: Optional[str] = Field(
description="The name of the board's new cover image."
)
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""
def __init__(self, message="Board record not found"):
super().__init__(message)
class BoardRecordSaveException(Exception):
"""Raised when an board record cannot be saved."""
def __init__(self, message="Board record not saved"):
super().__init__(message)
class BoardRecordDeleteException(Exception):
"""Raised when an board record cannot be deleted."""
def __init__(self, message="Board record not deleted"):
super().__init__(message)
class BoardRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the board record store."""
@abstractmethod
def delete(self, board_id: str) -> None:
"""Deletes a board record."""
pass
@abstractmethod
def save(
self,
board_name: str,
) -> BoardRecord:
"""Saves a board record."""
pass
@abstractmethod
def get(
self,
board_id: str,
) -> BoardRecord:
"""Gets a board record."""
pass
@abstractmethod
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
"""Updates a board record."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
pass
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `boards` table and `board_images` junction table."""
# Create the `boards` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards(created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
)
def delete(self, board_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
board_name: str,
) -> BoardRecord:
try:
board_id = str(uuid.uuid4())
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = self._cursor.fetchone()
return BoardRecord(**result)
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
def get(
self,
board_id: str,
) -> BoardRecord:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
def update(
self,
board_id: str,
changes: BoardChanges,
) -> None:
try:
self._lock.acquire()
# Change the name of a board
if changes.board_name is not None:
self._cursor.execute(
f"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
self._cursor.execute(
f"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
ORDER BY updated_at DESC
LIMIT ? OFFSET ?;
""",
(limit, offset),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [BoardRecord(**dict(row)) for row in result]
# Get the total number of boards
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM boards
WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](
items=boards, offset=offset, limit=limit, total=count
)
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@ -1,253 +1,153 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Generic, Optional, TypeVar, cast
import sqlite3
import threading
from typing import Optional, Union
import uuid
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from pydantic import BaseModel, Field, Extra
from pydantic.generics import GenericModel
from logging import Logger
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_images import board_record_to_dto
T = TypeVar("T", bound=BaseModel)
class BoardRecord(BaseModel):
"""Deserialized board record."""
id: str = Field(description="The unique ID of the board.")
name: str = Field(description="The name of the board.")
"""The name of the board."""
created_at: Union[datetime, str] = Field(
description="The created timestamp of the board."
)
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(
description="The updated timestamp of the board."
)
class BoardRecordInList(BaseModel):
"""Deserialized board record in a list."""
id: str = Field(description="The unique ID of the board.")
name: str = Field(description="The name of the board.")
most_recent_image_url: Optional[str] = Field(
description="The URL of the most recent image in the board."
)
"""The name of the board."""
created_at: Union[datetime, str] = Field(
description="The created timestamp of the board."
)
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(
description="The updated timestamp of the board."
)
class BoardRecordChanges(BaseModel, extra=Extra.forbid):
name: Optional[str] = Field(
description="The board's new name."
)
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""
def __init__(self, message="Board record not found"):
super().__init__(message)
from invokeai.app.services.board_record_storage import (
BoardDTO,
BoardRecord,
BoardChanges,
BoardRecordStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.image_record import ImageDTO
from invokeai.app.services.urls import UrlServiceBase
class BoardRecordSaveException(Exception):
"""Raised when an board record cannot be saved."""
def __init__(self, message="Board record not saved"):
super().__init__(message)
class BoardRecordDeleteException(Exception):
"""Raised when an board record cannot be deleted."""
def __init__(self, message="Board record not deleted"):
super().__init__(message)
class BoardStorageBase(ABC):
"""Low-level service responsible for interfacing with the board record store."""
class BoardServiceABC(ABC):
"""High-level service for board management."""
@abstractmethod
def delete(self, board_id: str) -> None:
"""Deletes a board record."""
def create(
self,
board_name: str,
) -> BoardDTO:
"""Creates a board."""
pass
@abstractmethod
def save(
def get_dto(
self,
board_name: str,
):
"""Saves a board record."""
board_id: str,
) -> BoardDTO:
"""Gets a board."""
pass
def get_cover_photo(self, board_id: str) -> Optional[str]:
"""Gets the cover photo for a board."""
@abstractmethod
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardDTO:
"""Updates a board."""
pass
@abstractmethod
def delete(
self,
board_id: str,
) -> None:
"""Deletes a board."""
pass
@abstractmethod
def get_many(
self,
offset: int,
limit: int,
):
"""Gets many board records."""
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards."""
pass
class SqliteBoardStorage(BoardStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
class BoardServiceDependencies:
"""Service dependencies for the BoardService."""
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
def _create_tables(self) -> None:
"""Creates the `board` table."""
# Create the `images` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS boards (
id TEXT NOT NULL PRIMARY KEY,
name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
class BoardService(BoardServiceABC):
_services: BoardServiceDependencies
def __init__(self, services: BoardServiceDependencies):
self._services = services
def create(
self,
board_name: str,
) -> BoardDTO:
board_record = self._services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0)
def get_dto(self, board_id: str) -> BoardDTO:
board_record = self._services.board_records.get(board_id)
cover_image_url = (
self._services.urls.get_image_url(board_record.cover_image_name, True)
if board_record.cover_image_name
else None
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards(created_at);
"""
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_url, image_count)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_name = old.board_name;
END;
"""
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardDTO:
board_record = self._services.board_records.update(board_id, changes)
cover_image_url = (
self._services.urls.get_image_url(board_record.cover_image_name, True)
if board_record.cover_image_name
else None
)
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_url, image_count)
def delete(self, board_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM boards
WHERE id = ?;
""",
(board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
board_name: str,
):
try:
board_id = str(uuid.uuid4())
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (id, name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE id = ?;
""",
(board_id,),
)
result = self._cursor.fetchone()
return result
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
self._services.board_records.delete(board_id)
def get_many(
self,
offset: int,
limit: int,
) -> OffsetPaginatedResults[BoardRecord]:
try:
self, offset: int = 0, limit: int = 10
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self._services.board_records.get_many(offset, limit)
board_dtos = []
for r in board_records.items:
cover_image_url = (
self._services.urls.get_image_url(r.cover_image_name, True)
if r.cover_image_name
else None
)
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_url, image_count))
self._lock.acquire()
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
query_conditions = ""
query_params = []
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
images_params.append(limit)
images_params.append(offset)
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [BoardRecord(**dict(row)) for row in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=boards, offset=offset, limit=limit, total=count
return OffsetPaginatedResults[BoardDTO](
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
)

View File

@ -82,7 +82,6 @@ class ImageRecordStorageBase(ABC):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
@ -94,11 +93,6 @@ class ImageRecordStorageBase(ABC):
"""Deletes an image record."""
pass
@abstractmethod
def get_board_cover_photo(self, board_id: str) -> Optional[ImageRecord]:
"""Gets the cover photo for a board."""
pass
@abstractmethod
def save(
self,
@ -197,7 +191,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = current_timestamp
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
@ -268,14 +262,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
)
# Change the image's `is_intermediate`` flag
if changes.board_id is not None:
if changes.is_intermediate is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET board_id = ?
WHERE image_name = ?;
""",
(changes.board_id, image_name),
(changes.is_intermediate, image_name),
)
self._conn.commit()
@ -285,32 +279,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally:
self._lock.release()
def get_board_cover_photo(self, board_id: str) -> ImageRecord | None:
try:
self._lock.acquire()
self._cursor.execute(
"""
SELECT *
FROM images
WHERE board_id = ?
ORDER BY created_at DESC
LIMIT 1
""",
(board_id),
)
self._conn.commit()
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
if not result:
raise ImageRecordNotFoundException
return deserialize_image_record(dict(result))
def get_many(
self,
offset: int = 0,
@ -318,7 +286,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
@ -350,10 +317,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
query_conditions += f"""AND is_intermediate = ?\n"""
query_params.append(is_intermediate)
if board_id is not None:
query_conditions += f"""AND board_id = ?\n"""
query_params.append(board_id)
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
# Final images query with pagination
@ -371,7 +334,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0]
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e

View File

@ -49,7 +49,7 @@ class ImageServiceABC(ABC):
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
intermediate: bool = False,
is_intermediate: bool = False,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@ -79,7 +79,7 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_path(self, image_name: str) -> str:
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path."""
pass
@ -322,7 +322,6 @@ class ImageService(ImageServiceABC):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self._services.records.get_many(
@ -331,7 +330,6 @@ class ImageService(ImageServiceABC):
image_origin,
categories,
is_intermediate,
board_id
)
image_dtos = list(

View File

@ -4,7 +4,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from logging import Logger
from invokeai.app.services.images import ImageService
from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC
from invokeai.backend import ModelManager
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
@ -14,7 +16,6 @@ if TYPE_CHECKING:
from invokeai.app.services.config import InvokeAISettings
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC
from invokeai.app.services.boards import BoardStorageBase
class InvocationServices:
@ -27,10 +28,9 @@ class InvocationServices:
model_manager: "ModelManager"
restoration: "RestorationServices"
configuration: "InvokeAISettings"
images: "ImageService"
boards: "BoardStorageBase"
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
images: "ImageServiceABC"
boards: "BoardServiceABC"
board_images: "BoardImagesServiceABC"
graph_library: "ItemStorageABC"["LibraryGraph"]
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
processor: "InvocationProcessorABC"
@ -41,20 +41,23 @@ class InvocationServices:
events: "EventServiceBase",
logger: "Logger",
latents: "LatentsStorageBase",
images: "ImageService",
images: "ImageServiceABC",
boards: "BoardServiceABC",
board_images: "BoardImagesServiceABC",
queue: "InvocationQueueABC",
graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: "RestorationServices",
configuration: "InvokeAISettings",
boards: "BoardStorageBase",
):
self.model_manager = model_manager
self.events = events
self.logger = logger
self.latents = latents
self.images = images
self.boards = boards
self.board_images = board_images
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager

View File

@ -48,11 +48,6 @@ class ImageRecord(BaseModel):
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
)
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
board_id: Optional[str] = Field(
default=None,
description="The board ID that this image belongs to.",
)
"""The board ID that this image belongs to."""
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
@ -77,10 +72,6 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
default=None, description="The image's new `is_intermediate` flag."
)
"""The image's new `is_intermediate` flag."""
board_id: Optional[StrictStr] = Field(
default=None, description="The image's new board ID."
)
"""The image's new board ID."""
class ImageUrlsDTO(BaseModel):
@ -131,7 +122,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False)
board_id = image_dict.get("board_id", None)
raw_metadata = image_dict.get("metadata")
@ -153,5 +143,4 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at=updated_at,
deleted_at=deleted_at,
is_intermediate=is_intermediate,
board_id=board_id,
)