InvokeAI/tests/nodes/test_sqlite.py
2023-03-03 00:02:15 -05:00

113 lines
4.4 KiB
Python

from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from pydantic import BaseModel, Field
class TestModel(BaseModel):
id: str = Field(description = "ID")
name: str = Field(description = "Name")
def test_sqlite_service_can_create_and_get():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
assert db.get('1') == TestModel(id = '1', name = 'Test')
def test_sqlite_service_can_list():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.list()
assert results.page == 0
assert results.pages == 1
assert results.per_page == 10
assert results.total == 3
assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')]
def test_sqlite_service_can_delete():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.delete('1')
assert db.get('1') is None
def test_sqlite_service_calls_set_callback():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
called = False
def on_changed(item: TestModel):
nonlocal called
called = True
db.on_changed(on_changed)
db.set(TestModel(id = '1', name = 'Test'))
assert called
def test_sqlite_service_calls_delete_callback():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
called = False
def on_deleted(item_id: str):
nonlocal called
called = True
db.on_deleted(on_deleted)
db.set(TestModel(id = '1', name = 'Test'))
db.delete('1')
assert called
def test_sqlite_service_can_list_with_pagination():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.list(page = 0, per_page = 2)
assert results.page == 0
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')]
def test_sqlite_service_can_list_with_pagination_and_offset():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.list(page = 1, per_page = 2)
assert results.page == 1
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id = '3', name = 'Test')]
def test_sqlite_service_can_search():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.search(query = 'Test')
assert results.page == 0
assert results.pages == 1
assert results.per_page == 10
assert results.total == 3
assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')]
def test_sqlite_service_can_search_with_pagination():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.search(query = 'Test', page = 0, per_page = 2)
assert results.page == 0
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')]
def test_sqlite_service_can_search_with_pagination_and_offset():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.search(query = 'Test', page = 1, per_page = 2)
assert results.page == 1
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id = '3', name = 'Test')]