mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore(item_storage): excise SqliteItemStorage
This commit is contained in:
parent
9f793bdae8
commit
ee6fc4ab1d
@ -1,81 +0,0 @@
|
||||
from typing import Generic, Optional, TypeVar, get_args
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .item_storage_base import ItemStorageABC
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._cursor = self._conn.cursor()
|
||||
self._validator: Optional[TypeAdapter[T]] = None
|
||||
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
item TEXT,
|
||||
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
if self._validator is None:
|
||||
"""
|
||||
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
||||
we can create it when it is first needed instead.
|
||||
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
||||
"""
|
||||
self._validator = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
||||
return self._validator.validate_json(item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.model_dump_json(warnings=False, exclude_none=True),),
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
|
||||
def get(self, id: str) -> Optional[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return self._parse_item(result[0])
|
||||
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
@ -2,6 +2,8 @@ import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
|
||||
# This import must happen before other invoke imports or test in other files(!!) break
|
||||
from .test_nodes import ( # isort: split
|
||||
PromptCollectionTestInvocation,
|
||||
@ -19,7 +21,6 @@ from invokeai.app.services.invocation_processor.invocation_processor_default imp
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.shared.graph import (
|
||||
CollectInvocation,
|
||||
@ -27,8 +28,6 @@ from invokeai.app.services.shared.graph import (
|
||||
GraphExecutionState,
|
||||
IterateInvocation,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
from .test_invoker import create_edge
|
||||
|
||||
@ -48,10 +47,8 @@ def simple_graph():
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(configuration, logger)
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
|
||||
return InvocationServices(
|
||||
board_image_records=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
|
@ -3,8 +3,7 @@ import logging
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
|
||||
# This import must happen before other invoke imports or test in other files(!!) break
|
||||
from .test_nodes import ( # isort: split
|
||||
@ -22,7 +21,6 @@ from invokeai.app.services.invocation_queue.invocation_queue_memory import Memor
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
||||
|
||||
@ -53,11 +51,6 @@ def graph_with_subgraph():
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(configuration, logger)
|
||||
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
return InvocationServices(
|
||||
board_image_records=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
@ -65,7 +58,7 @@ def mock_services() -> InvocationServices:
|
||||
boards=None, # type: ignore
|
||||
configuration=configuration,
|
||||
events=TestEventService(),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
graph_execution_manager=ItemStorageMemory[GraphExecutionState](),
|
||||
image_files=None, # type: ignore
|
||||
image_records=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
|
@ -1,59 +0,0 @@
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
class TestModel(BaseModel):
|
||||
id: str = Field(description="ID")
|
||||
name: str = Field(description="Name")
|
||||
__test__ = False # not a pytest test case
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db() -> SqliteItemStorage[TestModel]:
|
||||
config = InvokeAIAppConfig(use_memory_db=True)
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db_path = None if config.use_memory_db else config.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
||||
sqlite_item_storage = SqliteItemStorage[TestModel](db=db, table_name="test", id_field="id")
|
||||
return sqlite_item_storage
|
||||
|
||||
|
||||
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
assert db.get("1") == TestModel(id="1", name="Test")
|
||||
|
||||
|
||||
def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.delete("1")
|
||||
assert db.get("1") is None
|
||||
|
||||
|
||||
def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]):
|
||||
called = False
|
||||
|
||||
def on_changed(item: TestModel):
|
||||
nonlocal called
|
||||
called = True
|
||||
|
||||
db.on_changed(on_changed)
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
assert called
|
||||
|
||||
|
||||
def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]):
|
||||
called = False
|
||||
|
||||
def on_deleted(item_id: str):
|
||||
nonlocal called
|
||||
called = True
|
||||
|
||||
db.on_deleted(on_deleted)
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.delete("1")
|
||||
assert called
|
Loading…
Reference in New Issue
Block a user