tidy(app): move sqlite-specific objects to sqlite file

This commit is contained in:
psychedelicious 2024-07-15 17:09:55 +10:00
parent a30d143c5a
commit 25107e427c
2 changed files with 93 additions and 96 deletions

View File

@ -1,103 +1,11 @@
from datetime import datetime
from typing import Any, Optional, Union
from attr import dataclass
from pydantic import BaseModel, Field
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
# This query is missing a GROUP BY clause, which is required for the query to be valid.
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
SELECT b.board_id,
b.board_name,
b.created_at,
b.updated_at,
b.archived,
COUNT(
CASE
WHEN i.image_category in ('general')
AND i.is_intermediate = 0 THEN 1
END
) AS image_count,
COUNT(
CASE
WHEN i.image_category in ('control', 'mask', 'user', 'other')
AND i.is_intermediate = 0 THEN 1
END
) AS asset_count,
(
SELECT bi.image_name
FROM board_images bi
JOIN images i ON bi.image_name = i.image_name
WHERE bi.board_id = b.board_id
AND i.is_intermediate = 0
ORDER BY i.created_at DESC
LIMIT 1
) AS cover_image_name
FROM boards b
LEFT JOIN board_images bi ON b.board_id = bi.board_id
LEFT JOIN images i ON bi.image_name = i.image_name
"""
@dataclass
class PaginatedBoardRecordsQueries:
main_query: str
total_count_query: str
def get_paginated_list_board_records_query(include_archived: bool) -> PaginatedBoardRecordsQueries:
"""Gets a query to retrieve a paginated list of board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
main_query = f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC
LIMIT ? OFFSET ?;
"""
total_count_query = f"""
SELECT COUNT(*)
FROM boards b
{archived_condition};
"""
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
def get_list_all_board_records_query(include_archived: bool) -> str:
"""Gets a query to retrieve all board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
return f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC;
"""
def get_board_record_query() -> str:
"""Gets a query to retrieve a board record."""
return f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
WHERE b.board_id = ?;
"""
class BoardRecord(BaseModelExcludeNull):
"""Deserialized board record."""

View File

@ -1,5 +1,6 @@
import sqlite3
import threading
from dataclasses import dataclass
from typing import Union, cast
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
@ -11,14 +12,102 @@ from invokeai.app.services.board_records.board_records_common import (
BoardRecordSaveException,
UncategorizedImageCounts,
deserialize_board_record,
get_board_record_query,
get_list_all_board_records_query,
get_paginated_list_board_records_query,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string
# This query is missing a GROUP BY clause, which is required for the query to be valid.
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
SELECT b.board_id,
b.board_name,
b.created_at,
b.updated_at,
b.archived,
COUNT(
CASE
WHEN i.image_category in ('general')
AND i.is_intermediate = 0 THEN 1
END
) AS image_count,
COUNT(
CASE
WHEN i.image_category in ('control', 'mask', 'user', 'other')
AND i.is_intermediate = 0 THEN 1
END
) AS asset_count,
(
SELECT bi.image_name
FROM board_images bi
JOIN images i ON bi.image_name = i.image_name
WHERE bi.board_id = b.board_id
AND i.is_intermediate = 0
ORDER BY i.created_at DESC
LIMIT 1
) AS cover_image_name
FROM boards b
LEFT JOIN board_images bi ON b.board_id = bi.board_id
LEFT JOIN images i ON bi.image_name = i.image_name
"""
@dataclass
class PaginatedBoardRecordsQueries:
main_query: str
total_count_query: str
def get_paginated_list_board_records_queries(include_archived: bool) -> PaginatedBoardRecordsQueries:
"""Gets a query to retrieve a paginated list of board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
main_query = f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC
LIMIT ? OFFSET ?;
"""
total_count_query = f"""
SELECT COUNT(*)
FROM boards b
{archived_condition};
"""
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
def get_list_all_board_records_query(include_archived: bool) -> str:
"""Gets a query to retrieve all board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
return f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC;
"""
def get_board_record_query() -> str:
"""Gets a query to retrieve a board record."""
return f"""
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
WHERE b.board_id = ?;
"""
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_conn: sqlite3.Connection
@ -149,7 +238,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
try:
self._lock.acquire()
queries = get_paginated_list_board_records_query(include_archived=include_archived)
queries = get_paginated_list_board_records_queries(include_archived=include_archived)
self._cursor.execute(
queries.main_query,