mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
88c08bbfc7
- `ItemStorageMemory.get` now throws an `ItemNotFoundError` when the requested `item_id` is not found. - Update docstrings in ABC and tests. The new memory item storage implementation implemented the `get` method incorrectly, by returning `None` if the item didn't exist. The ABC typed `get` as returning `T`, while the SQLite implementation typed `get` as returning `Optional[T]`. The SQLite implementation was referenced when writing the memory implementation. This mismatched typing is a violation of the Liskov substitution principle, because the signature of the implementation of `get` in the implementation is wider than the abstract class's definition. Using `pyright` in strict mode catches this. In `invocation_stats_default`, this introduced an error. The `_prune_stats` method calls `get`, expecting the method to throw if the item is not found. If the graph is no longer stored in the bounded item storage, we will call `is_complete()` on `None`, causing the error. Note: This error condition never arose the SQLite implementation because it parsed the item with pydantic before returning it, which would throw if the item was not found. It implicitly threw, while the memory implementation did not.
112 lines
3.6 KiB
Python
112 lines
3.6 KiB
Python
import re
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
|
|
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
|
|
|
|
|
class MockItemModel(BaseModel):
|
|
id: str
|
|
value: int
|
|
|
|
|
|
@pytest.fixture
|
|
def item_storage_memory():
|
|
return ItemStorageMemory[MockItemModel]()
|
|
|
|
|
|
def test_item_storage_memory_initializes():
|
|
item_storage_memory = ItemStorageMemory()
|
|
assert item_storage_memory._items == {}
|
|
assert item_storage_memory._id_field == "id"
|
|
assert item_storage_memory._max_items == 10
|
|
|
|
item_storage_memory = ItemStorageMemory(id_field="bananas", max_items=20)
|
|
assert item_storage_memory._id_field == "bananas"
|
|
assert item_storage_memory._max_items == 20
|
|
|
|
with pytest.raises(ValueError, match=re.escape("max_items must be at least 1")):
|
|
item_storage_memory = ItemStorageMemory(max_items=0)
|
|
with pytest.raises(ValueError, match=re.escape("id_field must not be empty")):
|
|
item_storage_memory = ItemStorageMemory(id_field="")
|
|
|
|
|
|
def test_item_storage_memory_sets(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
|
item_1 = MockItemModel(id="1", value=1)
|
|
item_storage_memory.set(item_1)
|
|
assert item_storage_memory._items == {"1": item_1}
|
|
|
|
item_2 = MockItemModel(id="2", value=2)
|
|
item_storage_memory.set(item_2)
|
|
assert item_storage_memory._items == {"1": item_1, "2": item_2}
|
|
|
|
# Updating value of existing item
|
|
item_2_updated = MockItemModel(id="2", value=9001)
|
|
item_storage_memory.set(item_2_updated)
|
|
assert item_storage_memory._items == {"1": item_1, "2": item_2_updated}
|
|
|
|
|
|
def test_item_storage_memory_gets(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
|
item_1 = MockItemModel(id="1", value=1)
|
|
item_storage_memory.set(item_1)
|
|
item = item_storage_memory.get("1")
|
|
assert item == item_1
|
|
|
|
item_2 = MockItemModel(id="2", value=2)
|
|
item_storage_memory.set(item_2)
|
|
item = item_storage_memory.get("2")
|
|
assert item == item_2
|
|
|
|
with pytest.raises(ItemNotFoundError, match=re.escape("Item with id 3 not found")):
|
|
item_storage_memory.get("3")
|
|
|
|
|
|
def test_item_storage_memory_deletes(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
|
item_1 = MockItemModel(id="1", value=1)
|
|
item_2 = MockItemModel(id="2", value=2)
|
|
item_storage_memory.set(item_1)
|
|
item_storage_memory.set(item_2)
|
|
|
|
item_storage_memory.delete("2")
|
|
assert item_storage_memory._items == {"1": item_1}
|
|
|
|
|
|
def test_item_storage_memory_respects_max():
|
|
item_storage_memory = ItemStorageMemory(max_items=3)
|
|
for i in range(10):
|
|
item_storage_memory.set(MockItemModel(id=str(i), value=i))
|
|
assert item_storage_memory._items == {
|
|
"7": MockItemModel(id="7", value=7),
|
|
"8": MockItemModel(id="8", value=8),
|
|
"9": MockItemModel(id="9", value=9),
|
|
}
|
|
|
|
|
|
def test_item_storage_memory_calls_set_callback(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
|
called_item = None
|
|
item = MockItemModel(id="1", value=1)
|
|
|
|
def on_changed(item: MockItemModel):
|
|
nonlocal called_item
|
|
called_item = item
|
|
|
|
item_storage_memory.on_changed(on_changed)
|
|
item_storage_memory.set(item)
|
|
assert called_item == item
|
|
|
|
|
|
def test_item_storage_memory_calls_delete_callback(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
|
called_item_id = None
|
|
item = MockItemModel(id="1", value=1)
|
|
|
|
def on_deleted(item_id: str):
|
|
nonlocal called_item_id
|
|
called_item_id = item_id
|
|
|
|
item_storage_memory.on_deleted(on_deleted)
|
|
item_storage_memory.set(item)
|
|
item_storage_memory.delete("1")
|
|
assert called_item_id == "1"
|