feat(app): refactor board record to include image & asset counts and cover image

This _substantially_ reduces the number of queries required to list all boards. A single query now gets one, all, or a page of boards, including counts and cover image name.

- Add helpers to build the queries, which share a common base with some joins.
- Update `BoardRecord` to include the counts.
- Update `BoardDTO`, which is now identical to `BoardRecord`. I opted to not remove `BoardDTO` because it is used in many places.
- Update boards high-level service and board records services accordingly.
This commit is contained in:
psychedelicious 2024-07-12 14:13:26 +10:00
parent a95aa6cc16
commit c05f97d8ca
4 changed files with 125 additions and 121 deletions

View File

@ -1,11 +1,103 @@
from datetime import datetime
from typing import Optional, Union
from typing import Any, Optional, Union
from attr import dataclass
from pydantic import BaseModel, Field
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
# This query is missing a GROUP BY clause, which is required for the query to be valid.
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
SELECT b.board_id,
b.board_name,
b.created_at,
b.updated_at,
b.archived,
COUNT(
CASE
WHEN i.image_category in ('general')
AND i.is_intermediate = 0 THEN 1
END
) AS image_count,
COUNT(
CASE
WHEN i.image_category in ('control', 'mask', 'user', 'other')
AND i.is_intermediate = 0 THEN 1
END
) AS asset_count,
(
SELECT bi.image_name
FROM board_images bi
JOIN images i ON bi.image_name = i.image_name
WHERE bi.board_id = b.board_id
AND i.is_intermediate = 0
ORDER BY i.created_at DESC
LIMIT 1
) AS cover_image_name
FROM boards b
LEFT JOIN board_images bi ON b.board_id = bi.board_id
LEFT JOIN images i ON bi.image_name = i.image_name
"""
@dataclass
class PaginatedBoardRecordsQueries:
main_query: str
total_count_query: str
def get_paginated_list_board_records_query(include_archived: bool) -> PaginatedBoardRecordsQueries:
"""Gets a query to retrieve a paginated list of board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
main_query = f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC
LIMIT ? OFFSET ?;
"""
total_count_query = f"""
SELECT COUNT(*)
FROM boards b
{archived_condition};
"""
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
def get_list_all_board_records_query(include_archived: bool) -> str:
"""Gets a query to retrieve all board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
return f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC;
"""
def get_board_record_query() -> str:
"""Gets a query to retrieve a board record."""
return f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
WHERE b.board_id = ?;
"""
class BoardRecord(BaseModelExcludeNull):
"""Deserialized board record."""
@ -26,21 +118,25 @@ class BoardRecord(BaseModelExcludeNull):
"""Whether or not the board is archived."""
is_private: Optional[bool] = Field(default=None, description="Whether the board is private.")
"""Whether the board is private."""
image_count: int = Field(description="The number of images in the board.")
asset_count: int = Field(description="The number of assets in the board.")
def deserialize_board_record(board_dict: dict) -> BoardRecord:
def deserialize_board_record(board_dict: dict[str, Any]) -> BoardRecord:
"""Deserializes a board record."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
board_id = board_dict.get("board_id", "unknown")
board_name = board_dict.get("board_name", "unknown")
cover_image_name = board_dict.get("cover_image_name", "unknown")
cover_image_name = board_dict.get("cover_image_name", None)
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
archived = board_dict.get("archived", False)
is_private = board_dict.get("is_private", False)
image_count = board_dict.get("image_count", 0)
asset_count = board_dict.get("asset_count", 0)
return BoardRecord(
board_id=board_id,
@ -51,6 +147,8 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
deleted_at=deleted_at,
archived=archived,
is_private=is_private,
image_count=image_count,
asset_count=asset_count,
)
@ -63,21 +161,21 @@ class BoardChanges(BaseModel, extra="forbid"):
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""
def __init__(self, message="Board record not found"):
def __init__(self, message: str = "Board record not found"):
super().__init__(message)
class BoardRecordSaveException(Exception):
"""Raised when an board record cannot be saved."""
def __init__(self, message="Board record not saved"):
def __init__(self, message: str = "Board record not saved"):
super().__init__(message)
class BoardRecordDeleteException(Exception):
"""Raised when an board record cannot be deleted."""
def __init__(self, message="Board record not deleted"):
def __init__(self, message: str = "Board record not deleted"):
super().__init__(message)

View File

@ -11,6 +11,9 @@ from invokeai.app.services.board_records.board_records_common import (
BoardRecordSaveException,
UncategorizedImageCounts,
deserialize_board_record,
get_board_record_query,
get_list_all_board_records_query,
get_paginated_list_board_records_query,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@ -77,11 +80,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
get_board_record_query(),
(board_id,),
)
@ -93,7 +92,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self._lock.release()
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
return deserialize_board_record(dict(result))
def update(
self,
@ -150,45 +149,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
try:
self._lock.acquire()
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
"""
queries = get_paginated_list_board_records_query(include_archived=include_archived)
# Determine archived filter condition
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
self._cursor.execute(
queries.main_query,
(limit, offset),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
self._cursor.execute(count_query)
self._cursor.execute(queries.total_count_query)
count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
@ -202,26 +173,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
try:
self._lock.acquire()
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
"""
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
self._cursor.execute(final_query)
self._cursor.execute(get_list_all_board_records_query(include_archived=include_archived))
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
return boards
except sqlite3.Error as e:

View File

@ -1,23 +1,8 @@
from typing import Optional
from pydantic import Field
from invokeai.app.services.board_records.board_records_common import BoardRecord
# TODO(psyche): BoardDTO is now identical to BoardRecord. We should consider removing it.
class BoardDTO(BoardRecord):
"""Deserialized board record with cover image URL and image count."""
"""Deserialized board record."""
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
"""The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board."""
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.model_dump(exclude={"cover_image_name"}),
cover_image_name=cover_image_name,
image_count=image_count,
)
pass

View File

@ -1,6 +1,6 @@
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.boards.boards_common import BoardDTO, board_record_to_dto
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
@ -16,17 +16,11 @@ class BoardService(BoardServiceABC):
board_name: str,
) -> BoardDTO:
board_record = self.__invoker.services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0)
return BoardDTO.model_validate(board_record.model_dump())
def get_dto(self, board_id: str) -> BoardDTO:
board_record = self.__invoker.services.board_records.get(board_id)
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
return board_record_to_dto(board_record, cover_image_name, image_count)
return BoardDTO.model_validate(board_record.model_dump())
def update(
self,
@ -34,14 +28,7 @@ class BoardService(BoardServiceABC):
changes: BoardChanges,
) -> BoardDTO:
board_record = self.__invoker.services.board_records.update(board_id, changes)
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
return board_record_to_dto(board_record, cover_image_name, image_count)
return BoardDTO.model_validate(board_record.model_dump())
def delete(self, board_id: str) -> None:
self.__invoker.services.board_records.delete(board_id)
@ -50,30 +37,10 @@ class BoardService(BoardServiceABC):
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
board_dtos = []
for r in board_records.items:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records.items]
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(include_archived)
board_dtos = []
for r in board_records:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records]
return board_dtos