InvokeAI/ldm/invoke/app/services/sqlite.py
Kyle Schouviller 34e3aa1f88 parent 9eed1919c2
author Kyle Schouviller <kyle0654@hotmail.com> 1669872800 -0800
committer Kyle Schouviller <kyle0654@hotmail.com> 1676240900 -0800

Adding base node architecture

Fix type annotation errors

Runs and generates, but breaks in saving session

Fix default model value setting. Fix deprecation warning.

Fixed node api

Adding markdown docs

Simplifying Generate construction in apps

[nodes] A few minor changes (#2510)

* Pin api-related requirements

* Remove confusing extra CORS origins list

* Adds response models for HTTP 200

[nodes] Adding graph_execution_state to soon replace session. Adding tests with pytest.

Minor typing fixes

[nodes] Fix some small output query hookups

[node] Fixing some additional typing issues

[nodes] Move and expand graph code. Add base item storage and sqlite implementation.

Update startup to match new code

[nodes] Add callbacks to item storage

[nodes] Adding an InvocationContext object to use for invocations to provide easier extensibility

[nodes] New execution model that handles iteration

[nodes] Fixing the CLI

[nodes] Adding a note to the CLI

[nodes] Split processing thread into separate service

[node] Add error message on node processing failure

Removing old files and duplicated packages

Adding python-multipart
2023-02-24 18:57:02 -08:00

120 lines
4.0 KiB
Python

import sqlite3
from threading import Lock
from typing import Generic, TypeVar, Union, get_args
from pydantic import BaseModel, parse_raw_as
from .item_storage import ItemStorageABC, PaginatedResults
T = TypeVar('T', bound=BaseModel)
sqlite_memory = ':memory:'
class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str
_table_name: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_id_field: str
_lock: Lock
def __init__(self, filename: str, table_name: str, id_field: str = 'id'):
super().__init__()
self._filename = filename
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
self._conn = sqlite3.connect(self._filename, check_same_thread=False) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(f'''CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT,
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);''')
self._cursor.execute(f'''CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);''')
finally:
self._lock.release()
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item)
def set(self, item: T):
try:
self._lock.acquire()
self._cursor.execute(f'''INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);''', (item.json(),))
finally:
self._lock.release()
self._on_changed(item)
def get(self, id: str) -> Union[T, None]:
try:
self._lock.acquire()
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE id = ?;''', (str(id),))
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return self._parse_item(result[0])
def delete(self, id: str):
try:
self._lock.acquire()
self._cursor.execute(f'''DELETE FROM {self._table_name} WHERE id = ?;''', (str(id),))
finally:
self._lock.release()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(f'''SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;''', (per_page, page * per_page))
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name};''')
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[T](
items = items,
page = page,
pages = pageCount,
per_page = per_page,
total = count
)
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;''', (f'%{query}%', per_page, page * per_page))
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;''', (f'%{query}%',))
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[T](
items = items,
page = page,
pages = pageCount,
per_page = per_page,
total = count
)