From c05f97d8cac3df938685bac1b45fe33486b7e775 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 12 Jul 2024 14:13:26 +1000 Subject: [PATCH] feat(app): refactor board record to include image & asset counts and cover image This _substantially_ reduces the number of queries required to list all boards. A single query now gets one, all, or a page of boards, including counts and cover image name. - Add helpers to build the queries, which share a common base with some joins. - Update `BoardRecord` to include the counts. - Update `BoardDTO`, which is now identical to `BoardRecord`. I opted to not remove `BoardDTO` because it is used in many places. - Update boards high-level service and board records services accordingly. --- .../board_records/board_records_common.py | 110 +++++++++++++++++- .../board_records/board_records_sqlite.py | 70 ++--------- invokeai/app/services/boards/boards_common.py | 21 +--- .../app/services/boards/boards_default.py | 45 +------ 4 files changed, 125 insertions(+), 121 deletions(-) diff --git a/invokeai/app/services/board_records/board_records_common.py b/invokeai/app/services/board_records/board_records_common.py index 1c25aab565..80a4c3313a 100644 --- a/invokeai/app/services/board_records/board_records_common.py +++ b/invokeai/app/services/board_records/board_records_common.py @@ -1,11 +1,103 @@ from datetime import datetime -from typing import Optional, Union +from typing import Any, Optional, Union +from attr import dataclass from pydantic import BaseModel, Field from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.model_exclude_null import BaseModelExcludeNull +# This query is missing a GROUP BY clause, which is required for the query to be valid. +BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """ + SELECT b.board_id, + b.board_name, + b.created_at, + b.updated_at, + b.archived, + COUNT( + CASE + WHEN i.image_category in ('general') + AND i.is_intermediate = 0 THEN 1 + END + ) AS image_count, + COUNT( + CASE + WHEN i.image_category in ('control', 'mask', 'user', 'other') + AND i.is_intermediate = 0 THEN 1 + END + ) AS asset_count, + ( + SELECT bi.image_name + FROM board_images bi + JOIN images i ON bi.image_name = i.image_name + WHERE bi.board_id = b.board_id + AND i.is_intermediate = 0 + ORDER BY i.created_at DESC + LIMIT 1 + ) AS cover_image_name + FROM boards b + LEFT JOIN board_images bi ON b.board_id = bi.board_id + LEFT JOIN images i ON bi.image_name = i.image_name + """ + + +@dataclass +class PaginatedBoardRecordsQueries: + main_query: str + total_count_query: str + + +def get_paginated_list_board_records_query(include_archived: bool) -> PaginatedBoardRecordsQueries: + """Gets a query to retrieve a paginated list of board records.""" + + archived_condition = "WHERE b.archived = 0" if not include_archived else "" + + # The GROUP BY must be added _after_ the WHERE clause! + main_query = f""" + {BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY} + {archived_condition} + GROUP BY b.board_id, + b.board_name, + b.created_at, + b.updated_at + ORDER BY b.created_at DESC + LIMIT ? OFFSET ?; + """ + + total_count_query = f""" + SELECT COUNT(*) + FROM boards b + {archived_condition}; + """ + + return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query) + + +def get_list_all_board_records_query(include_archived: bool) -> str: + """Gets a query to retrieve all board records.""" + + archived_condition = "WHERE b.archived = 0" if not include_archived else "" + + # The GROUP BY must be added _after_ the WHERE clause! + return f""" + {BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY} + {archived_condition} + GROUP BY b.board_id, + b.board_name, + b.created_at, + b.updated_at + ORDER BY b.created_at DESC; + """ + + +def get_board_record_query() -> str: + """Gets a query to retrieve a board record.""" + + return f""" + {BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY} + WHERE b.board_id = ?; + """ + class BoardRecord(BaseModelExcludeNull): """Deserialized board record.""" @@ -26,21 +118,25 @@ class BoardRecord(BaseModelExcludeNull): """Whether or not the board is archived.""" is_private: Optional[bool] = Field(default=None, description="Whether the board is private.") """Whether the board is private.""" + image_count: int = Field(description="The number of images in the board.") + asset_count: int = Field(description="The number of assets in the board.") -def deserialize_board_record(board_dict: dict) -> BoardRecord: +def deserialize_board_record(board_dict: dict[str, Any]) -> BoardRecord: """Deserializes a board record.""" # Retrieve all the values, setting "reasonable" defaults if they are not present. board_id = board_dict.get("board_id", "unknown") board_name = board_dict.get("board_name", "unknown") - cover_image_name = board_dict.get("cover_image_name", "unknown") + cover_image_name = board_dict.get("cover_image_name", None) created_at = board_dict.get("created_at", get_iso_timestamp()) updated_at = board_dict.get("updated_at", get_iso_timestamp()) deleted_at = board_dict.get("deleted_at", get_iso_timestamp()) archived = board_dict.get("archived", False) is_private = board_dict.get("is_private", False) + image_count = board_dict.get("image_count", 0) + asset_count = board_dict.get("asset_count", 0) return BoardRecord( board_id=board_id, @@ -51,6 +147,8 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord: deleted_at=deleted_at, archived=archived, is_private=is_private, + image_count=image_count, + asset_count=asset_count, ) @@ -63,21 +161,21 @@ class BoardChanges(BaseModel, extra="forbid"): class BoardRecordNotFoundException(Exception): """Raised when an board record is not found.""" - def __init__(self, message="Board record not found"): + def __init__(self, message: str = "Board record not found"): super().__init__(message) class BoardRecordSaveException(Exception): """Raised when an board record cannot be saved.""" - def __init__(self, message="Board record not saved"): + def __init__(self, message: str = "Board record not saved"): super().__init__(message) class BoardRecordDeleteException(Exception): """Raised when an board record cannot be deleted.""" - def __init__(self, message="Board record not deleted"): + def __init__(self, message: str = "Board record not deleted"): super().__init__(message) diff --git a/invokeai/app/services/board_records/board_records_sqlite.py b/invokeai/app/services/board_records/board_records_sqlite.py index c5167824cd..27b47ea57d 100644 --- a/invokeai/app/services/board_records/board_records_sqlite.py +++ b/invokeai/app/services/board_records/board_records_sqlite.py @@ -11,6 +11,9 @@ from invokeai.app.services.board_records.board_records_common import ( BoardRecordSaveException, UncategorizedImageCounts, deserialize_board_record, + get_board_record_query, + get_list_all_board_records_query, + get_paginated_list_board_records_query, ) from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase @@ -77,11 +80,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): try: self._lock.acquire() self._cursor.execute( - """--sql - SELECT * - FROM boards - WHERE board_id = ?; - """, + get_board_record_query(), (board_id,), ) @@ -93,7 +92,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): self._lock.release() if result is None: raise BoardRecordNotFoundException - return BoardRecord(**dict(result)) + return deserialize_board_record(dict(result)) def update( self, @@ -150,45 +149,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): try: self._lock.acquire() - # Build base query - base_query = """ - SELECT * - FROM boards - {archived_filter} - ORDER BY created_at DESC - LIMIT ? OFFSET ?; - """ + queries = get_paginated_list_board_records_query(include_archived=include_archived) - # Determine archived filter condition - if include_archived: - archived_filter = "" - else: - archived_filter = "WHERE archived = 0" - - final_query = base_query.format(archived_filter=archived_filter) - - # Execute query to fetch boards - self._cursor.execute(final_query, (limit, offset)) + self._cursor.execute( + queries.main_query, + (limit, offset), + ) result = cast(list[sqlite3.Row], self._cursor.fetchall()) boards = [deserialize_board_record(dict(r)) for r in result] - # Determine count query - if include_archived: - count_query = """ - SELECT COUNT(*) - FROM boards; - """ - else: - count_query = """ - SELECT COUNT(*) - FROM boards - WHERE archived = 0; - """ - - # Execute count query - self._cursor.execute(count_query) - + self._cursor.execute(queries.total_count_query) count = cast(int, self._cursor.fetchone()[0]) return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count) @@ -202,26 +173,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): def get_all(self, include_archived: bool = False) -> list[BoardRecord]: try: self._lock.acquire() - - base_query = """ - SELECT * - FROM boards - {archived_filter} - ORDER BY created_at DESC - """ - - if include_archived: - archived_filter = "" - else: - archived_filter = "WHERE archived = 0" - - final_query = base_query.format(archived_filter=archived_filter) - - self._cursor.execute(final_query) - + self._cursor.execute(get_list_all_board_records_query(include_archived=include_archived)) result = cast(list[sqlite3.Row], self._cursor.fetchall()) boards = [deserialize_board_record(dict(r)) for r in result] - return boards except sqlite3.Error as e: diff --git a/invokeai/app/services/boards/boards_common.py b/invokeai/app/services/boards/boards_common.py index 15d0b3c37f..1e9337a3ed 100644 --- a/invokeai/app/services/boards/boards_common.py +++ b/invokeai/app/services/boards/boards_common.py @@ -1,23 +1,8 @@ -from typing import Optional - -from pydantic import Field - from invokeai.app.services.board_records.board_records_common import BoardRecord +# TODO(psyche): BoardDTO is now identical to BoardRecord. We should consider removing it. class BoardDTO(BoardRecord): - """Deserialized board record with cover image URL and image count.""" + """Deserialized board record.""" - cover_image_name: Optional[str] = Field(description="The name of the board's cover image.") - """The URL of the thumbnail of the most recent image in the board.""" - image_count: int = Field(description="The number of images in the board.") - """The number of images in the board.""" - - -def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO: - """Converts a board record to a board DTO.""" - return BoardDTO( - **board_record.model_dump(exclude={"cover_image_name"}), - cover_image_name=cover_image_name, - image_count=image_count, - ) + pass diff --git a/invokeai/app/services/boards/boards_default.py b/invokeai/app/services/boards/boards_default.py index 97fd3059a9..abf38e8ea7 100644 --- a/invokeai/app/services/boards/boards_default.py +++ b/invokeai/app/services/boards/boards_default.py @@ -1,6 +1,6 @@ from invokeai.app.services.board_records.board_records_common import BoardChanges from invokeai.app.services.boards.boards_base import BoardServiceABC -from invokeai.app.services.boards.boards_common import BoardDTO, board_record_to_dto +from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.pagination import OffsetPaginatedResults @@ -16,17 +16,11 @@ class BoardService(BoardServiceABC): board_name: str, ) -> BoardDTO: board_record = self.__invoker.services.board_records.save(board_name) - return board_record_to_dto(board_record, None, 0) + return BoardDTO.model_validate(board_record.model_dump()) def get_dto(self, board_id: str) -> BoardDTO: board_record = self.__invoker.services.board_records.get(board_id) - cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id) - if cover_image: - cover_image_name = cover_image.image_name - else: - cover_image_name = None - image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id) - return board_record_to_dto(board_record, cover_image_name, image_count) + return BoardDTO.model_validate(board_record.model_dump()) def update( self, @@ -34,14 +28,7 @@ class BoardService(BoardServiceABC): changes: BoardChanges, ) -> BoardDTO: board_record = self.__invoker.services.board_records.update(board_id, changes) - cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id) - if cover_image: - cover_image_name = cover_image.image_name - else: - cover_image_name = None - - image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id) - return board_record_to_dto(board_record, cover_image_name, image_count) + return BoardDTO.model_validate(board_record.model_dump()) def delete(self, board_id: str) -> None: self.__invoker.services.board_records.delete(board_id) @@ -50,30 +37,10 @@ class BoardService(BoardServiceABC): self, offset: int = 0, limit: int = 10, include_archived: bool = False ) -> OffsetPaginatedResults[BoardDTO]: board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived) - board_dtos = [] - for r in board_records.items: - cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id) - if cover_image: - cover_image_name = cover_image.image_name - else: - cover_image_name = None - - image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id) - board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) - + board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records.items] return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)) def get_all(self, include_archived: bool = False) -> list[BoardDTO]: board_records = self.__invoker.services.board_records.get_all(include_archived) - board_dtos = [] - for r in board_records: - cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id) - if cover_image: - cover_image_name = cover_image.image_name - else: - cover_image_name = None - - image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id) - board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) - + board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records] return board_dtos