mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
131 lines
4.2 KiB
Python
131 lines
4.2 KiB
Python
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
|
from pydantic import BaseModel, Field
|
|
|
|
import pytest
|
|
import sqlite3
|
|
|
|
class TestModel(BaseModel):
|
|
id: str = Field(description="ID")
|
|
name: str = Field(description="Name")
|
|
|
|
|
|
@pytest.fixture
|
|
def db() -> SqliteItemStorage[TestModel]:
|
|
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
|
return SqliteItemStorage[TestModel](db_conn, "test", "id")
|
|
|
|
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
assert db.get("1") == TestModel(id="1", name="Test")
|
|
|
|
|
|
def test_sqlite_service_can_list(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.set(TestModel(id="2", name="Test"))
|
|
db.set(TestModel(id="3", name="Test"))
|
|
results = db.list()
|
|
assert results.page == 0
|
|
assert results.pages == 1
|
|
assert results.per_page == 10
|
|
assert results.total == 3
|
|
assert results.items == [
|
|
TestModel(id="1", name="Test"),
|
|
TestModel(id="2", name="Test"),
|
|
TestModel(id="3", name="Test"),
|
|
]
|
|
|
|
|
|
def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.delete("1")
|
|
assert db.get("1") is None
|
|
|
|
|
|
def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]):
|
|
called = False
|
|
|
|
def on_changed(item: TestModel):
|
|
nonlocal called
|
|
called = True
|
|
|
|
db.on_changed(on_changed)
|
|
db.set(TestModel(id="1", name="Test"))
|
|
assert called
|
|
|
|
|
|
def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]):
|
|
called = False
|
|
|
|
def on_deleted(item_id: str):
|
|
nonlocal called
|
|
called = True
|
|
|
|
db.on_deleted(on_deleted)
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.delete("1")
|
|
assert called
|
|
|
|
|
|
def test_sqlite_service_can_list_with_pagination(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.set(TestModel(id="2", name="Test"))
|
|
db.set(TestModel(id="3", name="Test"))
|
|
results = db.list(page=0, per_page=2)
|
|
assert results.page == 0
|
|
assert results.pages == 2
|
|
assert results.per_page == 2
|
|
assert results.total == 3
|
|
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
|
|
|
|
|
def test_sqlite_service_can_list_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.set(TestModel(id="2", name="Test"))
|
|
db.set(TestModel(id="3", name="Test"))
|
|
results = db.list(page=1, per_page=2)
|
|
assert results.page == 1
|
|
assert results.pages == 2
|
|
assert results.per_page == 2
|
|
assert results.total == 3
|
|
assert results.items == [TestModel(id="3", name="Test")]
|
|
|
|
|
|
def test_sqlite_service_can_search(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.set(TestModel(id="2", name="Test"))
|
|
db.set(TestModel(id="3", name="Test"))
|
|
results = db.search(query="Test")
|
|
assert results.page == 0
|
|
assert results.pages == 1
|
|
assert results.per_page == 10
|
|
assert results.total == 3
|
|
assert results.items == [
|
|
TestModel(id="1", name="Test"),
|
|
TestModel(id="2", name="Test"),
|
|
TestModel(id="3", name="Test"),
|
|
]
|
|
|
|
|
|
def test_sqlite_service_can_search_with_pagination(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.set(TestModel(id="2", name="Test"))
|
|
db.set(TestModel(id="3", name="Test"))
|
|
results = db.search(query="Test", page=0, per_page=2)
|
|
assert results.page == 0
|
|
assert results.pages == 2
|
|
assert results.per_page == 2
|
|
assert results.total == 3
|
|
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
|
|
|
|
|
def test_sqlite_service_can_search_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
|
|
db.set(TestModel(id="1", name="Test"))
|
|
db.set(TestModel(id="2", name="Test"))
|
|
db.set(TestModel(id="3", name="Test"))
|
|
results = db.search(query="Test", page=1, per_page=2)
|
|
assert results.page == 1
|
|
assert results.pages == 2
|
|
assert results.per_page == 2
|
|
assert results.total == 3
|
|
assert results.items == [TestModel(id="3", name="Test")]
|