feat(app): refactor board record to include image & asset counts and cover image

This _substantially_ reduces the number of queries required to list all boards. A single query now gets one, all, or a page of boards, including counts and cover image name.

- Add helpers to build the queries, which share a common base with some joins.
- Update `BoardRecord` to include the counts.
- Update `BoardDTO`, which is now identical to `BoardRecord`. I opted to not remove `BoardDTO` because it is used in many places.
- Update boards high-level service and board records services accordingly.
This commit is contained in:
psychedelicious 2024-07-12 14:13:26 +10:00
parent a95aa6cc16
commit c05f97d8ca
4 changed files with 125 additions and 121 deletions

View File

@ -1,11 +1,103 @@
from datetime import datetime from datetime import datetime
from typing import Optional, Union from typing import Any, Optional, Union
from attr import dataclass
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
# This query is missing a GROUP BY clause, which is required for the query to be valid.
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
SELECT b.board_id,
b.board_name,
b.created_at,
b.updated_at,
b.archived,
COUNT(
CASE
WHEN i.image_category in ('general')
AND i.is_intermediate = 0 THEN 1
END
) AS image_count,
COUNT(
CASE
WHEN i.image_category in ('control', 'mask', 'user', 'other')
AND i.is_intermediate = 0 THEN 1
END
) AS asset_count,
(
SELECT bi.image_name
FROM board_images bi
JOIN images i ON bi.image_name = i.image_name
WHERE bi.board_id = b.board_id
AND i.is_intermediate = 0
ORDER BY i.created_at DESC
LIMIT 1
) AS cover_image_name
FROM boards b
LEFT JOIN board_images bi ON b.board_id = bi.board_id
LEFT JOIN images i ON bi.image_name = i.image_name
"""
@dataclass
class PaginatedBoardRecordsQueries:
main_query: str
total_count_query: str
def get_paginated_list_board_records_query(include_archived: bool) -> PaginatedBoardRecordsQueries:
"""Gets a query to retrieve a paginated list of board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
main_query = f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC
LIMIT ? OFFSET ?;
"""
total_count_query = f"""
SELECT COUNT(*)
FROM boards b
{archived_condition};
"""
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
def get_list_all_board_records_query(include_archived: bool) -> str:
"""Gets a query to retrieve all board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
return f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC;
"""
def get_board_record_query() -> str:
"""Gets a query to retrieve a board record."""
return f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
WHERE b.board_id = ?;
"""
class BoardRecord(BaseModelExcludeNull): class BoardRecord(BaseModelExcludeNull):
"""Deserialized board record.""" """Deserialized board record."""
@ -26,21 +118,25 @@ class BoardRecord(BaseModelExcludeNull):
"""Whether or not the board is archived.""" """Whether or not the board is archived."""
is_private: Optional[bool] = Field(default=None, description="Whether the board is private.") is_private: Optional[bool] = Field(default=None, description="Whether the board is private.")
"""Whether the board is private.""" """Whether the board is private."""
image_count: int = Field(description="The number of images in the board.")
asset_count: int = Field(description="The number of assets in the board.")
def deserialize_board_record(board_dict: dict) -> BoardRecord: def deserialize_board_record(board_dict: dict[str, Any]) -> BoardRecord:
"""Deserializes a board record.""" """Deserializes a board record."""
# Retrieve all the values, setting "reasonable" defaults if they are not present. # Retrieve all the values, setting "reasonable" defaults if they are not present.
board_id = board_dict.get("board_id", "unknown") board_id = board_dict.get("board_id", "unknown")
board_name = board_dict.get("board_name", "unknown") board_name = board_dict.get("board_name", "unknown")
cover_image_name = board_dict.get("cover_image_name", "unknown") cover_image_name = board_dict.get("cover_image_name", None)
created_at = board_dict.get("created_at", get_iso_timestamp()) created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp()) updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp()) deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
archived = board_dict.get("archived", False) archived = board_dict.get("archived", False)
is_private = board_dict.get("is_private", False) is_private = board_dict.get("is_private", False)
image_count = board_dict.get("image_count", 0)
asset_count = board_dict.get("asset_count", 0)
return BoardRecord( return BoardRecord(
board_id=board_id, board_id=board_id,
@ -51,6 +147,8 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
deleted_at=deleted_at, deleted_at=deleted_at,
archived=archived, archived=archived,
is_private=is_private, is_private=is_private,
image_count=image_count,
asset_count=asset_count,
) )
@ -63,21 +161,21 @@ class BoardChanges(BaseModel, extra="forbid"):
class BoardRecordNotFoundException(Exception): class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found.""" """Raised when an board record is not found."""
def __init__(self, message="Board record not found"): def __init__(self, message: str = "Board record not found"):
super().__init__(message) super().__init__(message)
class BoardRecordSaveException(Exception): class BoardRecordSaveException(Exception):
"""Raised when an board record cannot be saved.""" """Raised when an board record cannot be saved."""
def __init__(self, message="Board record not saved"): def __init__(self, message: str = "Board record not saved"):
super().__init__(message) super().__init__(message)
class BoardRecordDeleteException(Exception): class BoardRecordDeleteException(Exception):
"""Raised when an board record cannot be deleted.""" """Raised when an board record cannot be deleted."""
def __init__(self, message="Board record not deleted"): def __init__(self, message: str = "Board record not deleted"):
super().__init__(message) super().__init__(message)

View File

@ -11,6 +11,9 @@ from invokeai.app.services.board_records.board_records_common import (
BoardRecordSaveException, BoardRecordSaveException,
UncategorizedImageCounts, UncategorizedImageCounts,
deserialize_board_record, deserialize_board_record,
get_board_record_query,
get_list_all_board_records_query,
get_paginated_list_board_records_query,
) )
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@ -77,11 +80,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
"""--sql get_board_record_query(),
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,), (board_id,),
) )
@ -93,7 +92,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self._lock.release() self._lock.release()
if result is None: if result is None:
raise BoardRecordNotFoundException raise BoardRecordNotFoundException
return BoardRecord(**dict(result)) return deserialize_board_record(dict(result))
def update( def update(
self, self,
@ -150,45 +149,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
try: try:
self._lock.acquire() self._lock.acquire()
# Build base query queries = get_paginated_list_board_records_query(include_archived=include_archived)
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition self._cursor.execute(
if include_archived: queries.main_query,
archived_filter = "" (limit, offset),
else: )
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result] boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query self._cursor.execute(queries.total_count_query)
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
self._cursor.execute(count_query)
count = cast(int, self._cursor.fetchone()[0]) count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count) return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
@ -202,26 +173,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
def get_all(self, include_archived: bool = False) -> list[BoardRecord]: def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute(get_list_all_board_records_query(include_archived=include_archived))
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
"""
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
self._cursor.execute(final_query)
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result] boards = [deserialize_board_record(dict(r)) for r in result]
return boards return boards
except sqlite3.Error as e: except sqlite3.Error as e:

View File

@ -1,23 +1,8 @@
from typing import Optional
from pydantic import Field
from invokeai.app.services.board_records.board_records_common import BoardRecord from invokeai.app.services.board_records.board_records_common import BoardRecord
# TODO(psyche): BoardDTO is now identical to BoardRecord. We should consider removing it.
class BoardDTO(BoardRecord): class BoardDTO(BoardRecord):
"""Deserialized board record with cover image URL and image count.""" """Deserialized board record."""
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.") pass
"""The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board."""
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.model_dump(exclude={"cover_image_name"}),
cover_image_name=cover_image_name,
image_count=image_count,
)

View File

@ -1,6 +1,6 @@
from invokeai.app.services.board_records.board_records_common import BoardChanges from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.boards.boards_base import BoardServiceABC from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.boards.boards_common import BoardDTO, board_record_to_dto from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
@ -16,17 +16,11 @@ class BoardService(BoardServiceABC):
board_name: str, board_name: str,
) -> BoardDTO: ) -> BoardDTO:
board_record = self.__invoker.services.board_records.save(board_name) board_record = self.__invoker.services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0) return BoardDTO.model_validate(board_record.model_dump())
def get_dto(self, board_id: str) -> BoardDTO: def get_dto(self, board_id: str) -> BoardDTO:
board_record = self.__invoker.services.board_records.get(board_id) board_record = self.__invoker.services.board_records.get(board_id)
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id) return BoardDTO.model_validate(board_record.model_dump())
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
return board_record_to_dto(board_record, cover_image_name, image_count)
def update( def update(
self, self,
@ -34,14 +28,7 @@ class BoardService(BoardServiceABC):
changes: BoardChanges, changes: BoardChanges,
) -> BoardDTO: ) -> BoardDTO:
board_record = self.__invoker.services.board_records.update(board_id, changes) board_record = self.__invoker.services.board_records.update(board_id, changes)
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id) return BoardDTO.model_validate(board_record.model_dump())
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
return board_record_to_dto(board_record, cover_image_name, image_count)
def delete(self, board_id: str) -> None: def delete(self, board_id: str) -> None:
self.__invoker.services.board_records.delete(board_id) self.__invoker.services.board_records.delete(board_id)
@ -50,30 +37,10 @@ class BoardService(BoardServiceABC):
self, offset: int = 0, limit: int = 10, include_archived: bool = False self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardDTO]: ) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived) board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
board_dtos = [] board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records.items]
for r in board_records.items:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)) return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(self, include_archived: bool = False) -> list[BoardDTO]: def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(include_archived) board_records = self.__invoker.services.board_records.get_all(include_archived)
board_dtos = [] board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records]
for r in board_records:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return board_dtos return board_dtos