diff --git a/invokeai/app/services/board_records/board_records_common.py b/invokeai/app/services/board_records/board_records_common.py index 80a4c3313a..3478746536 100644 --- a/invokeai/app/services/board_records/board_records_common.py +++ b/invokeai/app/services/board_records/board_records_common.py @@ -1,103 +1,11 @@ from datetime import datetime from typing import Any, Optional, Union -from attr import dataclass from pydantic import BaseModel, Field from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.model_exclude_null import BaseModelExcludeNull -# This query is missing a GROUP BY clause, which is required for the query to be valid. -BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """ - SELECT b.board_id, - b.board_name, - b.created_at, - b.updated_at, - b.archived, - COUNT( - CASE - WHEN i.image_category in ('general') - AND i.is_intermediate = 0 THEN 1 - END - ) AS image_count, - COUNT( - CASE - WHEN i.image_category in ('control', 'mask', 'user', 'other') - AND i.is_intermediate = 0 THEN 1 - END - ) AS asset_count, - ( - SELECT bi.image_name - FROM board_images bi - JOIN images i ON bi.image_name = i.image_name - WHERE bi.board_id = b.board_id - AND i.is_intermediate = 0 - ORDER BY i.created_at DESC - LIMIT 1 - ) AS cover_image_name - FROM boards b - LEFT JOIN board_images bi ON b.board_id = bi.board_id - LEFT JOIN images i ON bi.image_name = i.image_name - """ - - -@dataclass -class PaginatedBoardRecordsQueries: - main_query: str - total_count_query: str - - -def get_paginated_list_board_records_query(include_archived: bool) -> PaginatedBoardRecordsQueries: - """Gets a query to retrieve a paginated list of board records.""" - - archived_condition = "WHERE b.archived = 0" if not include_archived else "" - - # The GROUP BY must be added _after_ the WHERE clause! - main_query = f""" - {BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY} - {archived_condition} - GROUP BY b.board_id, - b.board_name, - b.created_at, - b.updated_at - ORDER BY b.created_at DESC - LIMIT ? OFFSET ?; - """ - - total_count_query = f""" - SELECT COUNT(*) - FROM boards b - {archived_condition}; - """ - - return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query) - - -def get_list_all_board_records_query(include_archived: bool) -> str: - """Gets a query to retrieve all board records.""" - - archived_condition = "WHERE b.archived = 0" if not include_archived else "" - - # The GROUP BY must be added _after_ the WHERE clause! - return f""" - {BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY} - {archived_condition} - GROUP BY b.board_id, - b.board_name, - b.created_at, - b.updated_at - ORDER BY b.created_at DESC; - """ - - -def get_board_record_query() -> str: - """Gets a query to retrieve a board record.""" - - return f""" - {BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY} - WHERE b.board_id = ?; - """ - class BoardRecord(BaseModelExcludeNull): """Deserialized board record.""" diff --git a/invokeai/app/services/board_records/board_records_sqlite.py b/invokeai/app/services/board_records/board_records_sqlite.py index 27b47ea57d..2a1fcd35d7 100644 --- a/invokeai/app/services/board_records/board_records_sqlite.py +++ b/invokeai/app/services/board_records/board_records_sqlite.py @@ -1,5 +1,6 @@ import sqlite3 import threading +from dataclasses import dataclass from typing import Union, cast from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase @@ -11,14 +12,102 @@ from invokeai.app.services.board_records.board_records_common import ( BoardRecordSaveException, UncategorizedImageCounts, deserialize_board_record, - get_board_record_query, - get_list_all_board_records_query, - get_paginated_list_board_records_query, ) from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.util.misc import uuid_string +# This query is missing a GROUP BY clause, which is required for the query to be valid. +BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """ + SELECT b.board_id, + b.board_name, + b.created_at, + b.updated_at, + b.archived, + COUNT( + CASE + WHEN i.image_category in ('general') + AND i.is_intermediate = 0 THEN 1 + END + ) AS image_count, + COUNT( + CASE + WHEN i.image_category in ('control', 'mask', 'user', 'other') + AND i.is_intermediate = 0 THEN 1 + END + ) AS asset_count, + ( + SELECT bi.image_name + FROM board_images bi + JOIN images i ON bi.image_name = i.image_name + WHERE bi.board_id = b.board_id + AND i.is_intermediate = 0 + ORDER BY i.created_at DESC + LIMIT 1 + ) AS cover_image_name + FROM boards b + LEFT JOIN board_images bi ON b.board_id = bi.board_id + LEFT JOIN images i ON bi.image_name = i.image_name + """ + + +@dataclass +class PaginatedBoardRecordsQueries: + main_query: str + total_count_query: str + + +def get_paginated_list_board_records_queries(include_archived: bool) -> PaginatedBoardRecordsQueries: + """Gets a query to retrieve a paginated list of board records.""" + + archived_condition = "WHERE b.archived = 0" if not include_archived else "" + + # The GROUP BY must be added _after_ the WHERE clause! + main_query = f""" + {BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY} + {archived_condition} + GROUP BY b.board_id, + b.board_name, + b.created_at, + b.updated_at + ORDER BY b.created_at DESC + LIMIT ? OFFSET ?; + """ + + total_count_query = f""" + SELECT COUNT(*) + FROM boards b + {archived_condition}; + """ + + return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query) + + +def get_list_all_board_records_query(include_archived: bool) -> str: + """Gets a query to retrieve all board records.""" + + archived_condition = "WHERE b.archived = 0" if not include_archived else "" + + # The GROUP BY must be added _after_ the WHERE clause! + return f""" + {BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY} + {archived_condition} + GROUP BY b.board_id, + b.board_name, + b.created_at, + b.updated_at + ORDER BY b.created_at DESC; + """ + + +def get_board_record_query() -> str: + """Gets a query to retrieve a board record.""" + + return f""" + {BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY} + WHERE b.board_id = ?; + """ + class SqliteBoardRecordStorage(BoardRecordStorageBase): _conn: sqlite3.Connection @@ -149,7 +238,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): try: self._lock.acquire() - queries = get_paginated_list_board_records_query(include_archived=include_archived) + queries = get_paginated_list_board_records_queries(include_archived=include_archived) self._cursor.execute( queries.main_query,