tidy(app): move sqlite-specific objects to sqlite file

This commit is contained in:
psychedelicious 2024-07-15 17:09:55 +10:00
parent a30d143c5a
commit 25107e427c
2 changed files with 93 additions and 96 deletions

View File

@ -1,103 +1,11 @@
from datetime import datetime from datetime import datetime
from typing import Any, Optional, Union from typing import Any, Optional, Union
from attr import dataclass
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
# This query is missing a GROUP BY clause, which is required for the query to be valid.
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
SELECT b.board_id,
b.board_name,
b.created_at,
b.updated_at,
b.archived,
COUNT(
CASE
WHEN i.image_category in ('general')
AND i.is_intermediate = 0 THEN 1
END
) AS image_count,
COUNT(
CASE
WHEN i.image_category in ('control', 'mask', 'user', 'other')
AND i.is_intermediate = 0 THEN 1
END
) AS asset_count,
(
SELECT bi.image_name
FROM board_images bi
JOIN images i ON bi.image_name = i.image_name
WHERE bi.board_id = b.board_id
AND i.is_intermediate = 0
ORDER BY i.created_at DESC
LIMIT 1
) AS cover_image_name
FROM boards b
LEFT JOIN board_images bi ON b.board_id = bi.board_id
LEFT JOIN images i ON bi.image_name = i.image_name
"""
@dataclass
class PaginatedBoardRecordsQueries:
main_query: str
total_count_query: str
def get_paginated_list_board_records_query(include_archived: bool) -> PaginatedBoardRecordsQueries:
"""Gets a query to retrieve a paginated list of board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
main_query = f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC
LIMIT ? OFFSET ?;
"""
total_count_query = f"""
SELECT COUNT(*)
FROM boards b
{archived_condition};
"""
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
def get_list_all_board_records_query(include_archived: bool) -> str:
"""Gets a query to retrieve all board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
return f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC;
"""
def get_board_record_query() -> str:
"""Gets a query to retrieve a board record."""
return f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
WHERE b.board_id = ?;
"""
class BoardRecord(BaseModelExcludeNull): class BoardRecord(BaseModelExcludeNull):
"""Deserialized board record.""" """Deserialized board record."""

View File

@ -1,5 +1,6 @@
import sqlite3 import sqlite3
import threading import threading
from dataclasses import dataclass
from typing import Union, cast from typing import Union, cast
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
@ -11,14 +12,102 @@ from invokeai.app.services.board_records.board_records_common import (
BoardRecordSaveException, BoardRecordSaveException,
UncategorizedImageCounts, UncategorizedImageCounts,
deserialize_board_record, deserialize_board_record,
get_board_record_query,
get_list_all_board_records_query,
get_paginated_list_board_records_query,
) )
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
# This query is missing a GROUP BY clause, which is required for the query to be valid.
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
SELECT b.board_id,
b.board_name,
b.created_at,
b.updated_at,
b.archived,
COUNT(
CASE
WHEN i.image_category in ('general')
AND i.is_intermediate = 0 THEN 1
END
) AS image_count,
COUNT(
CASE
WHEN i.image_category in ('control', 'mask', 'user', 'other')
AND i.is_intermediate = 0 THEN 1
END
) AS asset_count,
(
SELECT bi.image_name
FROM board_images bi
JOIN images i ON bi.image_name = i.image_name
WHERE bi.board_id = b.board_id
AND i.is_intermediate = 0
ORDER BY i.created_at DESC
LIMIT 1
) AS cover_image_name
FROM boards b
LEFT JOIN board_images bi ON b.board_id = bi.board_id
LEFT JOIN images i ON bi.image_name = i.image_name
"""
@dataclass
class PaginatedBoardRecordsQueries:
main_query: str
total_count_query: str
def get_paginated_list_board_records_queries(include_archived: bool) -> PaginatedBoardRecordsQueries:
"""Gets a query to retrieve a paginated list of board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
main_query = f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC
LIMIT ? OFFSET ?;
"""
total_count_query = f"""
SELECT COUNT(*)
FROM boards b
{archived_condition};
"""
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
def get_list_all_board_records_query(include_archived: bool) -> str:
"""Gets a query to retrieve all board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
return f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC;
"""
def get_board_record_query() -> str:
"""Gets a query to retrieve a board record."""
return f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
WHERE b.board_id = ?;
"""
class SqliteBoardRecordStorage(BoardRecordStorageBase): class SqliteBoardRecordStorage(BoardRecordStorageBase):
_conn: sqlite3.Connection _conn: sqlite3.Connection
@ -149,7 +238,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
try: try:
self._lock.acquire() self._lock.acquire()
queries = get_paginated_list_board_records_query(include_archived=include_archived) queries = get_paginated_list_board_records_queries(include_archived=include_archived)
self._cursor.execute( self._cursor.execute(
queries.main_query, queries.main_query,