mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore(item_storage): excise SqliteItemStorage
This commit is contained in:
parent
9f793bdae8
commit
ee6fc4ab1d
@ -1,81 +0,0 @@
|
|||||||
from typing import Generic, Optional, TypeVar, get_args
|
|
||||||
|
|
||||||
from pydantic import BaseModel, TypeAdapter
|
|
||||||
|
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
|
||||||
|
|
||||||
from .item_storage_base import ItemStorageABC
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|
||||||
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self._lock = db.lock
|
|
||||||
self._conn = db.conn
|
|
||||||
self._table_name = table_name
|
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
|
||||||
self._cursor = self._conn.cursor()
|
|
||||||
self._validator: Optional[TypeAdapter[T]] = None
|
|
||||||
|
|
||||||
self._create_table()
|
|
||||||
|
|
||||||
def _create_table(self):
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
|
||||||
item TEXT,
|
|
||||||
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
|
|
||||||
)
|
|
||||||
self._cursor.execute(
|
|
||||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
|
||||||
if self._validator is None:
|
|
||||||
"""
|
|
||||||
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
|
||||||
we can create it when it is first needed instead.
|
|
||||||
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
|
||||||
"""
|
|
||||||
self._validator = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
|
||||||
return self._validator.validate_json(item)
|
|
||||||
|
|
||||||
def set(self, item: T):
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
|
||||||
(item.model_dump_json(warnings=False, exclude_none=True),),
|
|
||||||
)
|
|
||||||
self._conn.commit()
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
self._on_changed(item)
|
|
||||||
|
|
||||||
def get(self, id: str) -> Optional[T]:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
|
||||||
result = self._cursor.fetchone()
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return self._parse_item(result[0])
|
|
||||||
|
|
||||||
def delete(self, id: str):
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
|
||||||
self._conn.commit()
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
self._on_deleted(id)
|
|
@ -2,6 +2,8 @@ import logging
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||||
|
|
||||||
# This import must happen before other invoke imports or test in other files(!!) break
|
# This import must happen before other invoke imports or test in other files(!!) break
|
||||||
from .test_nodes import ( # isort: split
|
from .test_nodes import ( # isort: split
|
||||||
PromptCollectionTestInvocation,
|
PromptCollectionTestInvocation,
|
||||||
@ -19,7 +21,6 @@ from invokeai.app.services.invocation_processor.invocation_processor_default imp
|
|||||||
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||||
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
|
||||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||||
from invokeai.app.services.shared.graph import (
|
from invokeai.app.services.shared.graph import (
|
||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
@ -27,8 +28,6 @@ from invokeai.app.services.shared.graph import (
|
|||||||
GraphExecutionState,
|
GraphExecutionState,
|
||||||
IterateInvocation,
|
IterateInvocation,
|
||||||
)
|
)
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
|
||||||
|
|
||||||
from .test_invoker import create_edge
|
from .test_invoker import create_edge
|
||||||
|
|
||||||
@ -48,10 +47,8 @@ def simple_graph():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||||
logger = InvokeAILogger.get_logger()
|
|
||||||
db = create_mock_sqlite_database(configuration, logger)
|
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
board_image_records=None, # type: ignore
|
board_image_records=None, # type: ignore
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
|
@ -3,8 +3,7 @@ import logging
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
|
||||||
|
|
||||||
# This import must happen before other invoke imports or test in other files(!!) break
|
# This import must happen before other invoke imports or test in other files(!!) break
|
||||||
from .test_nodes import ( # isort: split
|
from .test_nodes import ( # isort: split
|
||||||
@ -22,7 +21,6 @@ from invokeai.app.services.invocation_queue.invocation_queue_memory import Memor
|
|||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
|
||||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
||||||
|
|
||||||
@ -53,11 +51,6 @@ def graph_with_subgraph():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||||
logger = InvokeAILogger.get_logger()
|
|
||||||
db = create_mock_sqlite_database(configuration, logger)
|
|
||||||
|
|
||||||
# NOTE: none of these are actually called by the test invocations
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
board_image_records=None, # type: ignore
|
board_image_records=None, # type: ignore
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
@ -65,7 +58,7 @@ def mock_services() -> InvocationServices:
|
|||||||
boards=None, # type: ignore
|
boards=None, # type: ignore
|
||||||
configuration=configuration,
|
configuration=configuration,
|
||||||
events=TestEventService(),
|
events=TestEventService(),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=ItemStorageMemory[GraphExecutionState](),
|
||||||
image_files=None, # type: ignore
|
image_files=None, # type: ignore
|
||||||
image_records=None, # type: ignore
|
image_records=None, # type: ignore
|
||||||
images=None, # type: ignore
|
images=None, # type: ignore
|
||||||
|
@ -1,59 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
|
||||||
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
|
|
||||||
class TestModel(BaseModel):
|
|
||||||
id: str = Field(description="ID")
|
|
||||||
name: str = Field(description="Name")
|
|
||||||
__test__ = False # not a pytest test case
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def db() -> SqliteItemStorage[TestModel]:
|
|
||||||
config = InvokeAIAppConfig(use_memory_db=True)
|
|
||||||
logger = InvokeAILogger.get_logger()
|
|
||||||
db_path = None if config.use_memory_db else config.db_path
|
|
||||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
|
||||||
sqlite_item_storage = SqliteItemStorage[TestModel](db=db, table_name="test", id_field="id")
|
|
||||||
return sqlite_item_storage
|
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
|
||||||
assert db.get("1") == TestModel(id="1", name="Test")
|
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]):
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
|
||||||
db.delete("1")
|
|
||||||
assert db.get("1") is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]):
|
|
||||||
called = False
|
|
||||||
|
|
||||||
def on_changed(item: TestModel):
|
|
||||||
nonlocal called
|
|
||||||
called = True
|
|
||||||
|
|
||||||
db.on_changed(on_changed)
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
|
||||||
assert called
|
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]):
|
|
||||||
called = False
|
|
||||||
|
|
||||||
def on_deleted(item_id: str):
|
|
||||||
nonlocal called
|
|
||||||
called = True
|
|
||||||
|
|
||||||
db.on_deleted(on_deleted)
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
|
||||||
db.delete("1")
|
|
||||||
assert called
|
|
Loading…
x
Reference in New Issue
Block a user