mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(app): move sqlite-specific objects to sqlite file
This commit is contained in:
parent
a30d143c5a
commit
25107e427c
@ -1,103 +1,11 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from attr import dataclass
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
# This query is missing a GROUP BY clause, which is required for the query to be valid.
|
|
||||||
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
|
|
||||||
SELECT b.board_id,
|
|
||||||
b.board_name,
|
|
||||||
b.created_at,
|
|
||||||
b.updated_at,
|
|
||||||
b.archived,
|
|
||||||
COUNT(
|
|
||||||
CASE
|
|
||||||
WHEN i.image_category in ('general')
|
|
||||||
AND i.is_intermediate = 0 THEN 1
|
|
||||||
END
|
|
||||||
) AS image_count,
|
|
||||||
COUNT(
|
|
||||||
CASE
|
|
||||||
WHEN i.image_category in ('control', 'mask', 'user', 'other')
|
|
||||||
AND i.is_intermediate = 0 THEN 1
|
|
||||||
END
|
|
||||||
) AS asset_count,
|
|
||||||
(
|
|
||||||
SELECT bi.image_name
|
|
||||||
FROM board_images bi
|
|
||||||
JOIN images i ON bi.image_name = i.image_name
|
|
||||||
WHERE bi.board_id = b.board_id
|
|
||||||
AND i.is_intermediate = 0
|
|
||||||
ORDER BY i.created_at DESC
|
|
||||||
LIMIT 1
|
|
||||||
) AS cover_image_name
|
|
||||||
FROM boards b
|
|
||||||
LEFT JOIN board_images bi ON b.board_id = bi.board_id
|
|
||||||
LEFT JOIN images i ON bi.image_name = i.image_name
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PaginatedBoardRecordsQueries:
|
|
||||||
main_query: str
|
|
||||||
total_count_query: str
|
|
||||||
|
|
||||||
|
|
||||||
def get_paginated_list_board_records_query(include_archived: bool) -> PaginatedBoardRecordsQueries:
|
|
||||||
"""Gets a query to retrieve a paginated list of board records."""
|
|
||||||
|
|
||||||
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
|
|
||||||
|
|
||||||
# The GROUP BY must be added _after_ the WHERE clause!
|
|
||||||
main_query = f"""
|
|
||||||
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
|
|
||||||
{archived_condition}
|
|
||||||
GROUP BY b.board_id,
|
|
||||||
b.board_name,
|
|
||||||
b.created_at,
|
|
||||||
b.updated_at
|
|
||||||
ORDER BY b.created_at DESC
|
|
||||||
LIMIT ? OFFSET ?;
|
|
||||||
"""
|
|
||||||
|
|
||||||
total_count_query = f"""
|
|
||||||
SELECT COUNT(*)
|
|
||||||
FROM boards b
|
|
||||||
{archived_condition};
|
|
||||||
"""
|
|
||||||
|
|
||||||
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
|
|
||||||
|
|
||||||
|
|
||||||
def get_list_all_board_records_query(include_archived: bool) -> str:
|
|
||||||
"""Gets a query to retrieve all board records."""
|
|
||||||
|
|
||||||
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
|
|
||||||
|
|
||||||
# The GROUP BY must be added _after_ the WHERE clause!
|
|
||||||
return f"""
|
|
||||||
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
|
|
||||||
{archived_condition}
|
|
||||||
GROUP BY b.board_id,
|
|
||||||
b.board_name,
|
|
||||||
b.created_at,
|
|
||||||
b.updated_at
|
|
||||||
ORDER BY b.created_at DESC;
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_board_record_query() -> str:
|
|
||||||
"""Gets a query to retrieve a board record."""
|
|
||||||
|
|
||||||
return f"""
|
|
||||||
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
|
|
||||||
WHERE b.board_id = ?;
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class BoardRecord(BaseModelExcludeNull):
|
class BoardRecord(BaseModelExcludeNull):
|
||||||
"""Deserialized board record."""
|
"""Deserialized board record."""
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Union, cast
|
from typing import Union, cast
|
||||||
|
|
||||||
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
|
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
|
||||||
@ -11,14 +12,102 @@ from invokeai.app.services.board_records.board_records_common import (
|
|||||||
BoardRecordSaveException,
|
BoardRecordSaveException,
|
||||||
UncategorizedImageCounts,
|
UncategorizedImageCounts,
|
||||||
deserialize_board_record,
|
deserialize_board_record,
|
||||||
get_board_record_query,
|
|
||||||
get_list_all_board_records_query,
|
|
||||||
get_paginated_list_board_records_query,
|
|
||||||
)
|
)
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
# This query is missing a GROUP BY clause, which is required for the query to be valid.
|
||||||
|
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
|
||||||
|
SELECT b.board_id,
|
||||||
|
b.board_name,
|
||||||
|
b.created_at,
|
||||||
|
b.updated_at,
|
||||||
|
b.archived,
|
||||||
|
COUNT(
|
||||||
|
CASE
|
||||||
|
WHEN i.image_category in ('general')
|
||||||
|
AND i.is_intermediate = 0 THEN 1
|
||||||
|
END
|
||||||
|
) AS image_count,
|
||||||
|
COUNT(
|
||||||
|
CASE
|
||||||
|
WHEN i.image_category in ('control', 'mask', 'user', 'other')
|
||||||
|
AND i.is_intermediate = 0 THEN 1
|
||||||
|
END
|
||||||
|
) AS asset_count,
|
||||||
|
(
|
||||||
|
SELECT bi.image_name
|
||||||
|
FROM board_images bi
|
||||||
|
JOIN images i ON bi.image_name = i.image_name
|
||||||
|
WHERE bi.board_id = b.board_id
|
||||||
|
AND i.is_intermediate = 0
|
||||||
|
ORDER BY i.created_at DESC
|
||||||
|
LIMIT 1
|
||||||
|
) AS cover_image_name
|
||||||
|
FROM boards b
|
||||||
|
LEFT JOIN board_images bi ON b.board_id = bi.board_id
|
||||||
|
LEFT JOIN images i ON bi.image_name = i.image_name
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PaginatedBoardRecordsQueries:
|
||||||
|
main_query: str
|
||||||
|
total_count_query: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_paginated_list_board_records_queries(include_archived: bool) -> PaginatedBoardRecordsQueries:
|
||||||
|
"""Gets a query to retrieve a paginated list of board records."""
|
||||||
|
|
||||||
|
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
|
||||||
|
|
||||||
|
# The GROUP BY must be added _after_ the WHERE clause!
|
||||||
|
main_query = f"""
|
||||||
|
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
|
||||||
|
{archived_condition}
|
||||||
|
GROUP BY b.board_id,
|
||||||
|
b.board_name,
|
||||||
|
b.created_at,
|
||||||
|
b.updated_at
|
||||||
|
ORDER BY b.created_at DESC
|
||||||
|
LIMIT ? OFFSET ?;
|
||||||
|
"""
|
||||||
|
|
||||||
|
total_count_query = f"""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM boards b
|
||||||
|
{archived_condition};
|
||||||
|
"""
|
||||||
|
|
||||||
|
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
|
||||||
|
|
||||||
|
|
||||||
|
def get_list_all_board_records_query(include_archived: bool) -> str:
|
||||||
|
"""Gets a query to retrieve all board records."""
|
||||||
|
|
||||||
|
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
|
||||||
|
|
||||||
|
# The GROUP BY must be added _after_ the WHERE clause!
|
||||||
|
return f"""
|
||||||
|
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
|
||||||
|
{archived_condition}
|
||||||
|
GROUP BY b.board_id,
|
||||||
|
b.board_name,
|
||||||
|
b.created_at,
|
||||||
|
b.updated_at
|
||||||
|
ORDER BY b.created_at DESC;
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_board_record_query() -> str:
|
||||||
|
"""Gets a query to retrieve a board record."""
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
|
||||||
|
WHERE b.board_id = ?;
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
@ -149,7 +238,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
||||||
queries = get_paginated_list_board_records_query(include_archived=include_archived)
|
queries = get_paginated_list_board_records_queries(include_archived=include_archived)
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
queries.main_query,
|
queries.main_query,
|
||||||
|
Loading…
Reference in New Issue
Block a user