mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
120 lines
4.0 KiB
Python
120 lines
4.0 KiB
Python
|
import sqlite3
|
||
|
from threading import Lock
|
||
|
from typing import Generic, TypeVar, Union, get_args
|
||
|
from pydantic import BaseModel, parse_raw_as
|
||
|
from .item_storage import ItemStorageABC, PaginatedResults
|
||
|
|
||
|
T = TypeVar('T', bound=BaseModel)
|
||
|
|
||
|
sqlite_memory = ':memory:'
|
||
|
|
||
|
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||
|
_filename: str
|
||
|
_table_name: str
|
||
|
_conn: sqlite3.Connection
|
||
|
_cursor: sqlite3.Cursor
|
||
|
_id_field: str
|
||
|
_lock: Lock
|
||
|
|
||
|
def __init__(self, filename: str, table_name: str, id_field: str = 'id'):
|
||
|
super().__init__()
|
||
|
|
||
|
self._filename = filename
|
||
|
self._table_name = table_name
|
||
|
self._id_field = id_field # TODO: validate that T has this field
|
||
|
self._lock = Lock()
|
||
|
|
||
|
self._conn = sqlite3.connect(self._filename, check_same_thread=False) # TODO: figure out a better threading solution
|
||
|
self._cursor = self._conn.cursor()
|
||
|
|
||
|
self._create_table()
|
||
|
|
||
|
def _create_table(self):
|
||
|
try:
|
||
|
self._lock.acquire()
|
||
|
self._cursor.execute(f'''CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||
|
item TEXT,
|
||
|
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);''')
|
||
|
self._cursor.execute(f'''CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);''')
|
||
|
finally:
|
||
|
self._lock.release()
|
||
|
|
||
|
def _parse_item(self, item: str) -> T:
|
||
|
item_type = get_args(self.__orig_class__)[0]
|
||
|
return parse_raw_as(item_type, item)
|
||
|
|
||
|
def set(self, item: T):
|
||
|
try:
|
||
|
self._lock.acquire()
|
||
|
self._cursor.execute(f'''INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);''', (item.json(),))
|
||
|
finally:
|
||
|
self._lock.release()
|
||
|
self._on_changed(item)
|
||
|
|
||
|
def get(self, id: str) -> Union[T, None]:
|
||
|
try:
|
||
|
self._lock.acquire()
|
||
|
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE id = ?;''', (str(id),))
|
||
|
result = self._cursor.fetchone()
|
||
|
finally:
|
||
|
self._lock.release()
|
||
|
|
||
|
if not result:
|
||
|
return None
|
||
|
|
||
|
return self._parse_item(result[0])
|
||
|
|
||
|
def delete(self, id: str):
|
||
|
try:
|
||
|
self._lock.acquire()
|
||
|
self._cursor.execute(f'''DELETE FROM {self._table_name} WHERE id = ?;''', (str(id),))
|
||
|
finally:
|
||
|
self._lock.release()
|
||
|
self._on_deleted(id)
|
||
|
|
||
|
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||
|
try:
|
||
|
self._lock.acquire()
|
||
|
self._cursor.execute(f'''SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;''', (per_page, page * per_page))
|
||
|
result = self._cursor.fetchall()
|
||
|
|
||
|
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||
|
|
||
|
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name};''')
|
||
|
count = self._cursor.fetchone()[0]
|
||
|
finally:
|
||
|
self._lock.release()
|
||
|
|
||
|
pageCount = int(count / per_page) + 1
|
||
|
|
||
|
return PaginatedResults[T](
|
||
|
items = items,
|
||
|
page = page,
|
||
|
pages = pageCount,
|
||
|
per_page = per_page,
|
||
|
total = count
|
||
|
)
|
||
|
|
||
|
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||
|
try:
|
||
|
self._lock.acquire()
|
||
|
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;''', (f'%{query}%', per_page, page * per_page))
|
||
|
result = self._cursor.fetchall()
|
||
|
|
||
|
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||
|
|
||
|
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;''', (f'%{query}%',))
|
||
|
count = self._cursor.fetchone()[0]
|
||
|
finally:
|
||
|
self._lock.release()
|
||
|
|
||
|
pageCount = int(count / per_page) + 1
|
||
|
|
||
|
return PaginatedResults[T](
|
||
|
items = items,
|
||
|
page = page,
|
||
|
pages = pageCount,
|
||
|
per_page = per_page,
|
||
|
total = count
|
||
|
)
|