InvokeAI/invokeai/app/services/sqlite.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

124 lines
3.9 KiB
Python
Raw Normal View History

import sqlite3
from threading import Lock
from typing import Generic, TypeVar, Union, get_args
2023-03-03 06:02:00 +00:00
from pydantic import BaseModel, parse_raw_as
2023-03-03 06:02:00 +00:00
from .item_storage import ItemStorageABC, PaginatedResults
2023-03-03 06:02:00 +00:00
T = TypeVar("T", bound=BaseModel)
sqlite_memory = ":memory:"
class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str
_table_name: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_id_field: str
_lock: Lock
2023-03-03 06:02:00 +00:00
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
super().__init__()
self._filename = filename
self._table_name = table_name
2023-03-03 06:02:00 +00:00
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
2023-03-03 06:02:00 +00:00
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._create_table()
def _create_table(self):
with self._lock:
2023-03-03 06:02:00 +00:00
self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT,
2023-03-03 06:02:00 +00:00
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
)
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
)
self._conn.commit()
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item)
def set(self, item: T):
with self._lock:
2023-03-03 06:02:00 +00:00
self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),),
)
2023-04-03 02:02:58 +00:00
self._conn.commit()
self._on_changed(item)
def get(self, id: str) -> Union[T, None]:
with self._lock:
2023-03-03 06:02:00 +00:00
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone()
if not result:
return None
return self._parse_item(result[0])
def delete(self, id: str):
with self._lock:
2023-03-03 06:02:00 +00:00
self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
2023-04-03 02:02:58 +00:00
self._conn.commit()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
with self._lock:
2023-03-03 06:02:00 +00:00
self._cursor.execute(
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
(per_page, page * per_page),
)
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
2023-03-03 06:02:00 +00:00
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
pageCount = int(count / per_page) + 1
return PaginatedResults[T](
2023-03-03 06:02:00 +00:00
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)
2023-03-03 06:02:00 +00:00
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
with self._lock:
2023-03-03 06:02:00 +00:00
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
(f"%{query}%", per_page, page * per_page),
)
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
2023-03-03 06:02:00 +00:00
self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
pageCount = int(count / per_page) + 1
return PaginatedResults[T](
2023-03-03 06:02:00 +00:00
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)