chore(item_storage): excise SqliteItemStorage

This commit is contained in:
psychedelicious 2024-02-02 08:57:32 +11:00
parent 9f793bdae8
commit ee6fc4ab1d
4 changed files with 5 additions and 155 deletions

View File

@ -1,81 +0,0 @@
from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, TypeAdapter
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .item_storage_base import ItemStorageABC
T = TypeVar("T", bound=BaseModel)
class SqliteItemStorage(ItemStorageABC, Generic[T]):
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._cursor = self._conn.cursor()
self._validator: Optional[TypeAdapter[T]] = None
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT,
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
)
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
)
finally:
self._lock.release()
def _parse_item(self, item: str) -> T:
if self._validator is None:
"""
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
we can create it when it is first needed instead.
__orig_class__ is technically an implementation detail of the typing module, not a supported API
"""
self._validator = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
return self._validator.validate_json(item)
def set(self, item: T):
try:
self._lock.acquire()
self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.model_dump_json(warnings=False, exclude_none=True),),
)
self._conn.commit()
finally:
self._lock.release()
self._on_changed(item)
def get(self, id: str) -> Optional[T]:
try:
self._lock.acquire()
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return self._parse_item(result[0])
def delete(self, id: str):
try:
self._lock.acquire()
self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),))
self._conn.commit()
finally:
self._lock.release()
self._on_deleted(id)

View File

@ -2,6 +2,8 @@ import logging
import pytest
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
PromptCollectionTestInvocation,
@ -19,7 +21,6 @@ from invokeai.app.services.invocation_processor.invocation_processor_default imp
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import (
CollectInvocation,
@ -27,8 +28,6 @@ from invokeai.app.services.shared.graph import (
GraphExecutionState,
IterateInvocation,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
from .test_invoker import create_edge
@ -48,10 +47,8 @@ def simple_graph():
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore

View File

@ -3,8 +3,7 @@ import logging
import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
@ -22,7 +21,6 @@ from invokeai.app.services.invocation_queue.invocation_queue_memory import Memor
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
@ -53,11 +51,6 @@ def graph_with_subgraph():
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
@ -65,7 +58,7 @@ def mock_services() -> InvocationServices:
boards=None, # type: ignore
configuration=configuration,
events=TestEventService(),
graph_execution_manager=graph_execution_manager,
graph_execution_manager=ItemStorageMemory[GraphExecutionState](),
image_files=None, # type: ignore
image_records=None, # type: ignore
images=None, # type: ignore

View File

@ -1,59 +0,0 @@
import pytest
from pydantic import BaseModel, Field
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
class TestModel(BaseModel):
id: str = Field(description="ID")
name: str = Field(description="Name")
__test__ = False # not a pytest test case
@pytest.fixture
def db() -> SqliteItemStorage[TestModel]:
config = InvokeAIAppConfig(use_memory_db=True)
logger = InvokeAILogger.get_logger()
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
sqlite_item_storage = SqliteItemStorage[TestModel](db=db, table_name="test", id_field="id")
return sqlite_item_storage
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
assert db.get("1") == TestModel(id="1", name="Test")
def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.delete("1")
assert db.get("1") is None
def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]):
called = False
def on_changed(item: TestModel):
nonlocal called
called = True
db.on_changed(on_changed)
db.set(TestModel(id="1", name="Test"))
assert called
def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]):
called = False
def on_deleted(item_id: str):
nonlocal called
called = True
db.on_deleted(on_deleted)
db.set(TestModel(id="1", name="Test"))
db.delete("1")
assert called