mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
402cf9b0ee
Refactor services folder/module structure. **Motivation** While working on our services I've repeatedly encountered circular imports and a general lack of clarity regarding where to put things. The structure introduced goes a long way towards resolving those issues, setting us up for a clean structure going forward. **Services** Services are now in their own folder with a few files: - `services/{service_name}/__init__.py`: init as needed, mostly empty now - `services/{service_name}/{service_name}_base.py`: the base class for the service - `services/{service_name}/{service_name}_{impl_type}.py`: the default concrete implementation of the service - typically one of `sqlite`, `default`, or `memory` - `services/{service_name}/{service_name}_common.py`: any common items - models, exceptions, utilities, etc Though it's a bit verbose to have the service name both as the folder name and the prefix for files, I found it is _extremely_ confusing to have all of the base classes just be named `base.py`. So, at the cost of some verbosity when importing things, I've included the service name in the filename. There are some minor logic changes. For example, in `InvocationProcessor`, instead of assigning the model manager service to a variable to be used later in the file, the service is used directly via the `Invoker`. **Shared** Things that are used across disparate services are in `services/shared/`: - `default_graphs.py`: previously in `services/` - `graphs.py`: previously in `services/` - `paginatation`: generic pagination models used in a few services - `sqlite`: the `SqliteDatabase` class, other sqlite-specific things
135 lines
4.4 KiB
Python
135 lines
4.4 KiB
Python
import pytest
|
|
from pydantic import BaseModel, Field
|
|
|
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
|
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
|
from invokeai.backend.util.logging import InvokeAILogger
|
|
|
|
|
|
class TestModel(BaseModel):
|
|
id: str = Field(description="ID")
|
|
name: str = Field(description="Name")
|
|
|
|
|
|
@pytest.fixture
|
|
def db() -> SqliteItemStorage[TestModel]:
|
|
sqlite_db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
|
|
return SqliteItemStorage[TestModel](db=sqlite_db, table_name="test", id_field="id")
|
|
|
|
|
|
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
assert db.get("1") == TestModel(id="1", name="Test")
|
|
|
|
|
|
def test_sqlite_service_can_list(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.set(TestModel(id="2", name="Test"))
|
|
db.set(TestModel(id="3", name="Test"))
|
|
results = db.list()
|
|
assert results.page == 0
|
|
assert results.pages == 1
|
|
assert results.per_page == 10
|
|
assert results.total == 3
|
|
assert results.items == [
|
|
TestModel(id="1", name="Test"),
|
|
TestModel(id="2", name="Test"),
|
|
TestModel(id="3", name="Test"),
|
|
]
|
|
|
|
|
|
def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.delete("1")
|
|
assert db.get("1") is None
|
|
|
|
|
|
def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]):
|
|
called = False
|
|
|
|
def on_changed(item: TestModel):
|
|
nonlocal called
|
|
called = True
|
|
|
|
db.on_changed(on_changed)
|
|
db.set(TestModel(id="1", name="Test"))
|
|
assert called
|
|
|
|
|
|
def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]):
|
|
called = False
|
|
|
|
def on_deleted(item_id: str):
|
|
nonlocal called
|
|
called = True
|
|
|
|
db.on_deleted(on_deleted)
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.delete("1")
|
|
assert called
|
|
|
|
|
|
def test_sqlite_service_can_list_with_pagination(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.set(TestModel(id="2", name="Test"))
|
|
db.set(TestModel(id="3", name="Test"))
|
|
results = db.list(page=0, per_page=2)
|
|
assert results.page == 0
|
|
assert results.pages == 2
|
|
assert results.per_page == 2
|
|
assert results.total == 3
|
|
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
|
|
|
|
|
def test_sqlite_service_can_list_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.set(TestModel(id="2", name="Test"))
|
|
db.set(TestModel(id="3", name="Test"))
|
|
results = db.list(page=1, per_page=2)
|
|
assert results.page == 1
|
|
assert results.pages == 2
|
|
assert results.per_page == 2
|
|
assert results.total == 3
|
|
assert results.items == [TestModel(id="3", name="Test")]
|
|
|
|
|
|
def test_sqlite_service_can_search(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.set(TestModel(id="2", name="Test"))
|
|
db.set(TestModel(id="3", name="Test"))
|
|
results = db.search(query="Test")
|
|
assert results.page == 0
|
|
assert results.pages == 1
|
|
assert results.per_page == 10
|
|
assert results.total == 3
|
|
assert results.items == [
|
|
TestModel(id="1", name="Test"),
|
|
TestModel(id="2", name="Test"),
|
|
TestModel(id="3", name="Test"),
|
|
]
|
|
|
|
|
|
def test_sqlite_service_can_search_with_pagination(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.set(TestModel(id="2", name="Test"))
|
|
db.set(TestModel(id="3", name="Test"))
|
|
results = db.search(query="Test", page=0, per_page=2)
|
|
assert results.page == 0
|
|
assert results.pages == 2
|
|
assert results.per_page == 2
|
|
assert results.total == 3
|
|
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
|
|
|
|
|
def test_sqlite_service_can_search_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.set(TestModel(id="2", name="Test"))
|
|
db.set(TestModel(id="3", name="Test"))
|
|
results = db.search(query="Test", page=1, per_page=2)
|
|
assert results.page == 1
|
|
assert results.pages == 2
|
|
assert results.per_page == 2
|
|
assert results.total == 3
|
|
assert results.items == [TestModel(id="3", name="Test")]
|