mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(app): move sqlite-specific objects to sqlite file
This commit is contained in:
parent
a30d143c5a
commit
25107e427c
@ -1,103 +1,11 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from attr import dataclass
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
# This query is missing a GROUP BY clause, which is required for the query to be valid.
|
||||
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
|
||||
SELECT b.board_id,
|
||||
b.board_name,
|
||||
b.created_at,
|
||||
b.updated_at,
|
||||
b.archived,
|
||||
COUNT(
|
||||
CASE
|
||||
WHEN i.image_category in ('general')
|
||||
AND i.is_intermediate = 0 THEN 1
|
||||
END
|
||||
) AS image_count,
|
||||
COUNT(
|
||||
CASE
|
||||
WHEN i.image_category in ('control', 'mask', 'user', 'other')
|
||||
AND i.is_intermediate = 0 THEN 1
|
||||
END
|
||||
) AS asset_count,
|
||||
(
|
||||
SELECT bi.image_name
|
||||
FROM board_images bi
|
||||
JOIN images i ON bi.image_name = i.image_name
|
||||
WHERE bi.board_id = b.board_id
|
||||
AND i.is_intermediate = 0
|
||||
ORDER BY i.created_at DESC
|
||||
LIMIT 1
|
||||
) AS cover_image_name
|
||||
FROM boards b
|
||||
LEFT JOIN board_images bi ON b.board_id = bi.board_id
|
||||
LEFT JOIN images i ON bi.image_name = i.image_name
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PaginatedBoardRecordsQueries:
|
||||
main_query: str
|
||||
total_count_query: str
|
||||
|
||||
|
||||
def get_paginated_list_board_records_query(include_archived: bool) -> PaginatedBoardRecordsQueries:
|
||||
"""Gets a query to retrieve a paginated list of board records."""
|
||||
|
||||
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
|
||||
|
||||
# The GROUP BY must be added _after_ the WHERE clause!
|
||||
main_query = f"""
|
||||
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
|
||||
{archived_condition}
|
||||
GROUP BY b.board_id,
|
||||
b.board_name,
|
||||
b.created_at,
|
||||
b.updated_at
|
||||
ORDER BY b.created_at DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
total_count_query = f"""
|
||||
SELECT COUNT(*)
|
||||
FROM boards b
|
||||
{archived_condition};
|
||||
"""
|
||||
|
||||
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
|
||||
|
||||
|
||||
def get_list_all_board_records_query(include_archived: bool) -> str:
|
||||
"""Gets a query to retrieve all board records."""
|
||||
|
||||
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
|
||||
|
||||
# The GROUP BY must be added _after_ the WHERE clause!
|
||||
return f"""
|
||||
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
|
||||
{archived_condition}
|
||||
GROUP BY b.board_id,
|
||||
b.board_name,
|
||||
b.created_at,
|
||||
b.updated_at
|
||||
ORDER BY b.created_at DESC;
|
||||
"""
|
||||
|
||||
|
||||
def get_board_record_query() -> str:
|
||||
"""Gets a query to retrieve a board record."""
|
||||
|
||||
return f"""
|
||||
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
|
||||
WHERE b.board_id = ?;
|
||||
"""
|
||||
|
||||
|
||||
class BoardRecord(BaseModelExcludeNull):
|
||||
"""Deserialized board record."""
|
||||
|
@ -1,5 +1,6 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Union, cast
|
||||
|
||||
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
|
||||
@ -11,14 +12,102 @@ from invokeai.app.services.board_records.board_records_common import (
|
||||
BoardRecordSaveException,
|
||||
UncategorizedImageCounts,
|
||||
deserialize_board_record,
|
||||
get_board_record_query,
|
||||
get_list_all_board_records_query,
|
||||
get_paginated_list_board_records_query,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
# This query is missing a GROUP BY clause, which is required for the query to be valid.
|
||||
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
|
||||
SELECT b.board_id,
|
||||
b.board_name,
|
||||
b.created_at,
|
||||
b.updated_at,
|
||||
b.archived,
|
||||
COUNT(
|
||||
CASE
|
||||
WHEN i.image_category in ('general')
|
||||
AND i.is_intermediate = 0 THEN 1
|
||||
END
|
||||
) AS image_count,
|
||||
COUNT(
|
||||
CASE
|
||||
WHEN i.image_category in ('control', 'mask', 'user', 'other')
|
||||
AND i.is_intermediate = 0 THEN 1
|
||||
END
|
||||
) AS asset_count,
|
||||
(
|
||||
SELECT bi.image_name
|
||||
FROM board_images bi
|
||||
JOIN images i ON bi.image_name = i.image_name
|
||||
WHERE bi.board_id = b.board_id
|
||||
AND i.is_intermediate = 0
|
||||
ORDER BY i.created_at DESC
|
||||
LIMIT 1
|
||||
) AS cover_image_name
|
||||
FROM boards b
|
||||
LEFT JOIN board_images bi ON b.board_id = bi.board_id
|
||||
LEFT JOIN images i ON bi.image_name = i.image_name
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PaginatedBoardRecordsQueries:
|
||||
main_query: str
|
||||
total_count_query: str
|
||||
|
||||
|
||||
def get_paginated_list_board_records_queries(include_archived: bool) -> PaginatedBoardRecordsQueries:
|
||||
"""Gets a query to retrieve a paginated list of board records."""
|
||||
|
||||
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
|
||||
|
||||
# The GROUP BY must be added _after_ the WHERE clause!
|
||||
main_query = f"""
|
||||
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
|
||||
{archived_condition}
|
||||
GROUP BY b.board_id,
|
||||
b.board_name,
|
||||
b.created_at,
|
||||
b.updated_at
|
||||
ORDER BY b.created_at DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
total_count_query = f"""
|
||||
SELECT COUNT(*)
|
||||
FROM boards b
|
||||
{archived_condition};
|
||||
"""
|
||||
|
||||
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
|
||||
|
||||
|
||||
def get_list_all_board_records_query(include_archived: bool) -> str:
|
||||
"""Gets a query to retrieve all board records."""
|
||||
|
||||
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
|
||||
|
||||
# The GROUP BY must be added _after_ the WHERE clause!
|
||||
return f"""
|
||||
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
|
||||
{archived_condition}
|
||||
GROUP BY b.board_id,
|
||||
b.board_name,
|
||||
b.created_at,
|
||||
b.updated_at
|
||||
ORDER BY b.created_at DESC;
|
||||
"""
|
||||
|
||||
|
||||
def get_board_record_query() -> str:
|
||||
"""Gets a query to retrieve a board record."""
|
||||
|
||||
return f"""
|
||||
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
|
||||
WHERE b.board_id = ?;
|
||||
"""
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
@ -149,7 +238,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
queries = get_paginated_list_board_records_query(include_archived=include_archived)
|
||||
queries = get_paginated_list_board_records_queries(include_archived=include_archived)
|
||||
|
||||
self._cursor.execute(
|
||||
queries.main_query,
|
||||
|
Loading…
Reference in New Issue
Block a user