From ee6fc4ab1d6276200a1a0e3dacda4a5af558d1ab Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 2 Feb 2024 08:57:32 +1100 Subject: [PATCH] chore(item_storage): excise SqliteItemStorage --- .../item_storage/item_storage_sqlite.py | 81 ------------------- tests/aa_nodes/test_graph_execution_state.py | 9 +-- tests/aa_nodes/test_invoker.py | 11 +-- tests/aa_nodes/test_sqlite.py | 59 -------------- 4 files changed, 5 insertions(+), 155 deletions(-) delete mode 100644 invokeai/app/services/item_storage/item_storage_sqlite.py delete mode 100644 tests/aa_nodes/test_sqlite.py diff --git a/invokeai/app/services/item_storage/item_storage_sqlite.py b/invokeai/app/services/item_storage/item_storage_sqlite.py deleted file mode 100644 index 114b1d9274..0000000000 --- a/invokeai/app/services/item_storage/item_storage_sqlite.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Generic, Optional, TypeVar, get_args - -from pydantic import BaseModel, TypeAdapter - -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase - -from .item_storage_base import ItemStorageABC - -T = TypeVar("T", bound=BaseModel) - - -class SqliteItemStorage(ItemStorageABC, Generic[T]): - def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"): - super().__init__() - - self._lock = db.lock - self._conn = db.conn - self._table_name = table_name - self._id_field = id_field # TODO: validate that T has this field - self._cursor = self._conn.cursor() - self._validator: Optional[TypeAdapter[T]] = None - - self._create_table() - - def _create_table(self): - try: - self._lock.acquire() - self._cursor.execute( - f"""CREATE TABLE IF NOT EXISTS {self._table_name} ( - item TEXT, - id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);""" - ) - self._cursor.execute( - f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);""" - ) - finally: - self._lock.release() - - def _parse_item(self, item: str) -> T: - if self._validator is None: - """ - We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so - we can create it when it is first needed instead. - __orig_class__ is technically an implementation detail of the typing module, not a supported API - """ - self._validator = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined] - return self._validator.validate_json(item) - - def set(self, item: T): - try: - self._lock.acquire() - self._cursor.execute( - f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""", - (item.model_dump_json(warnings=False, exclude_none=True),), - ) - self._conn.commit() - finally: - self._lock.release() - self._on_changed(item) - - def get(self, id: str) -> Optional[T]: - try: - self._lock.acquire() - self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)) - result = self._cursor.fetchone() - finally: - self._lock.release() - - if not result: - return None - - return self._parse_item(result[0]) - - def delete(self, id: str): - try: - self._lock.acquire() - self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)) - self._conn.commit() - finally: - self._lock.release() - self._on_deleted(id) diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index bb31161426..fab1fa4598 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -2,6 +2,8 @@ import logging import pytest +from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory + # This import must happen before other invoke imports or test in other files(!!) break from .test_nodes import ( # isort: split PromptCollectionTestInvocation, @@ -19,7 +21,6 @@ from invokeai.app.services.invocation_processor.invocation_processor_default imp from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService -from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID from invokeai.app.services.shared.graph import ( CollectInvocation, @@ -27,8 +28,6 @@ from invokeai.app.services.shared.graph import ( GraphExecutionState, IterateInvocation, ) -from invokeai.backend.util.logging import InvokeAILogger -from tests.fixtures.sqlite_database import create_mock_sqlite_database from .test_invoker import create_edge @@ -48,10 +47,8 @@ def simple_graph(): @pytest.fixture def mock_services() -> InvocationServices: configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) - logger = InvokeAILogger.get_logger() - db = create_mock_sqlite_database(configuration, logger) # NOTE: none of these are actually called by the test invocations - graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions") + graph_execution_manager = ItemStorageMemory[GraphExecutionState]() return InvocationServices( board_image_records=None, # type: ignore board_images=None, # type: ignore diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index d4959282a1..2ae4eab58a 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -3,8 +3,7 @@ import logging import pytest from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.backend.util.logging import InvokeAILogger -from tests.fixtures.sqlite_database import create_mock_sqlite_database +from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory # This import must happen before other invoke imports or test in other files(!!) break from .test_nodes import ( # isort: split @@ -22,7 +21,6 @@ from invokeai.app.services.invocation_queue.invocation_queue_memory import Memor from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.invoker import Invoker -from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation @@ -53,11 +51,6 @@ def graph_with_subgraph(): @pytest.fixture def mock_services() -> InvocationServices: configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) - logger = InvokeAILogger.get_logger() - db = create_mock_sqlite_database(configuration, logger) - - # NOTE: none of these are actually called by the test invocations - graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions") return InvocationServices( board_image_records=None, # type: ignore board_images=None, # type: ignore @@ -65,7 +58,7 @@ def mock_services() -> InvocationServices: boards=None, # type: ignore configuration=configuration, events=TestEventService(), - graph_execution_manager=graph_execution_manager, + graph_execution_manager=ItemStorageMemory[GraphExecutionState](), image_files=None, # type: ignore image_records=None, # type: ignore images=None, # type: ignore diff --git a/tests/aa_nodes/test_sqlite.py b/tests/aa_nodes/test_sqlite.py deleted file mode 100644 index e61657e3cd..0000000000 --- a/tests/aa_nodes/test_sqlite.py +++ /dev/null @@ -1,59 +0,0 @@ -import pytest -from pydantic import BaseModel, Field - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage -from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase -from invokeai.backend.util.logging import InvokeAILogger - - -class TestModel(BaseModel): - id: str = Field(description="ID") - name: str = Field(description="Name") - __test__ = False # not a pytest test case - - -@pytest.fixture -def db() -> SqliteItemStorage[TestModel]: - config = InvokeAIAppConfig(use_memory_db=True) - logger = InvokeAILogger.get_logger() - db_path = None if config.use_memory_db else config.db_path - db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql) - sqlite_item_storage = SqliteItemStorage[TestModel](db=db, table_name="test", id_field="id") - return sqlite_item_storage - - -def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]): - db.set(TestModel(id="1", name="Test")) - assert db.get("1") == TestModel(id="1", name="Test") - - -def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]): - db.set(TestModel(id="1", name="Test")) - db.delete("1") - assert db.get("1") is None - - -def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]): - called = False - - def on_changed(item: TestModel): - nonlocal called - called = True - - db.on_changed(on_changed) - db.set(TestModel(id="1", name="Test")) - assert called - - -def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]): - called = False - - def on_deleted(item_id: str): - nonlocal called - called = True - - db.on_deleted(on_deleted) - db.set(TestModel(id="1", name="Test")) - db.delete("1") - assert called