mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(item_storage): implement item_storage_memory max_size
Implemented with unordered dict and set.
This commit is contained in:
parent
d532073f5b
commit
a0eecaecd0
@ -1,3 +1,4 @@
|
||||
from contextlib import suppress
|
||||
from typing import Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -8,21 +9,43 @@ T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class ItemStorageMemory(ItemStorageABC, Generic[T]):
|
||||
def __init__(self, id_field: str = "id") -> None:
|
||||
"""
|
||||
Provides a simple in-memory storage for items, with a maximum number of items to store.
|
||||
An item is deleted when the maximum number of items is reached and a new item is added.
|
||||
There is no guarantee about which item will be deleted.
|
||||
"""
|
||||
|
||||
def __init__(self, id_field: str = "id", max_items: int = 10) -> None:
|
||||
super().__init__()
|
||||
if max_items < 1:
|
||||
raise ValueError("max_items must be at least 1")
|
||||
if not id_field:
|
||||
raise ValueError("id_field must not be empty")
|
||||
self._id_field = id_field
|
||||
self._items: dict[str, T] = {}
|
||||
self._item_ids: set[str] = set()
|
||||
self._max_items = max_items
|
||||
|
||||
def get(self, item_id: str) -> Optional[T]:
|
||||
return self._items.get(item_id)
|
||||
|
||||
def set(self, item: T) -> None:
|
||||
self._items[getattr(item, self._id_field)] = item
|
||||
item_id = getattr(item, self._id_field)
|
||||
assert isinstance(item_id, str)
|
||||
if item_id in self._items or len(self._items) < self._max_items:
|
||||
# If the item is already stored, or we have room for more items, we can just add it.
|
||||
self._items[item_id] = item
|
||||
self._item_ids.add(item_id)
|
||||
else:
|
||||
# Otherwise, we need to make room for it first.
|
||||
self._items.pop(self._item_ids.pop())
|
||||
self._items[item_id] = item
|
||||
self._item_ids.add(item_id)
|
||||
self._on_changed(item)
|
||||
|
||||
def delete(self, item_id: str) -> None:
|
||||
try:
|
||||
# Both of these are no-ops if the item doesn't exist.
|
||||
with suppress(KeyError):
|
||||
del self._items[item_id]
|
||||
self._item_ids.remove(item_id)
|
||||
self._on_deleted(item_id)
|
||||
except KeyError:
|
||||
pass
|
||||
|
111
tests/test_item_storage_memory.py
Normal file
111
tests/test_item_storage_memory.py
Normal file
@ -0,0 +1,111 @@
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
|
||||
|
||||
class MockItemModel(BaseModel):
|
||||
id: str
|
||||
value: int
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def item_storage_memory():
|
||||
return ItemStorageMemory[MockItemModel]()
|
||||
|
||||
|
||||
def test_item_storage_memory_initializes():
|
||||
item_storage_memory = ItemStorageMemory()
|
||||
assert item_storage_memory._items == {}
|
||||
assert item_storage_memory._item_ids == set()
|
||||
assert item_storage_memory._id_field == "id"
|
||||
assert item_storage_memory._max_items == 10
|
||||
|
||||
item_storage_memory = ItemStorageMemory(id_field="bananas", max_items=20)
|
||||
assert item_storage_memory._id_field == "bananas"
|
||||
assert item_storage_memory._max_items == 20
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape("max_items must be at least 1")):
|
||||
item_storage_memory = ItemStorageMemory(max_items=0)
|
||||
with pytest.raises(ValueError, match=re.escape("id_field must not be empty")):
|
||||
item_storage_memory = ItemStorageMemory(id_field="")
|
||||
|
||||
|
||||
def test_item_storage_memory_sets(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
||||
item_1 = MockItemModel(id="1", value=1)
|
||||
item_storage_memory.set(item_1)
|
||||
assert item_storage_memory._items == {"1": item_1}
|
||||
assert item_storage_memory._item_ids == {"1"}
|
||||
|
||||
item_2 = MockItemModel(id="2", value=2)
|
||||
item_storage_memory.set(item_2)
|
||||
assert item_storage_memory._items == {"1": item_1, "2": item_2}
|
||||
assert item_storage_memory._item_ids == {"1", "2"}
|
||||
|
||||
# Updating value of existing item
|
||||
item_2_updated = MockItemModel(id="2", value=9001)
|
||||
item_storage_memory.set(item_2_updated)
|
||||
assert item_storage_memory._items == {"1": item_1, "2": item_2_updated}
|
||||
assert item_storage_memory._item_ids == {"1", "2"}
|
||||
|
||||
|
||||
def test_item_storage_memory_gets(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
||||
item_1 = MockItemModel(id="1", value=1)
|
||||
item_storage_memory.set(item_1)
|
||||
item = item_storage_memory.get("1")
|
||||
assert item == item_1
|
||||
|
||||
item_2 = MockItemModel(id="2", value=2)
|
||||
item_storage_memory.set(item_2)
|
||||
item = item_storage_memory.get("2")
|
||||
assert item == item_2
|
||||
|
||||
item = item_storage_memory.get("3")
|
||||
assert item is None
|
||||
|
||||
|
||||
def test_item_storage_memory_deletes(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
||||
item_1 = MockItemModel(id="1", value=1)
|
||||
item_2 = MockItemModel(id="2", value=2)
|
||||
item_storage_memory.set(item_1)
|
||||
item_storage_memory.set(item_2)
|
||||
|
||||
item_storage_memory.delete("2")
|
||||
assert item_storage_memory._items == {"1": item_1}
|
||||
assert item_storage_memory._item_ids == {"1"}
|
||||
|
||||
|
||||
def test_item_storage_memory_respects_max():
|
||||
item_storage_memory = ItemStorageMemory(max_items=3)
|
||||
for i in range(10):
|
||||
item_storage_memory.set(MockItemModel(id=str(i), value=i))
|
||||
assert len(item_storage_memory._items) == 3
|
||||
|
||||
|
||||
def test_item_storage_memory_calls_set_callback(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
||||
called_item = None
|
||||
item = MockItemModel(id="1", value=1)
|
||||
|
||||
def on_changed(item: MockItemModel):
|
||||
nonlocal called_item
|
||||
called_item = item
|
||||
|
||||
item_storage_memory.on_changed(on_changed)
|
||||
item_storage_memory.set(item)
|
||||
assert called_item == item
|
||||
|
||||
|
||||
def test_item_storage_memory_calls_delete_callback(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
||||
called_item_id = None
|
||||
item = MockItemModel(id="1", value=1)
|
||||
|
||||
def on_deleted(item_id: str):
|
||||
nonlocal called_item_id
|
||||
called_item_id = item_id
|
||||
|
||||
item_storage_memory.on_deleted(on_deleted)
|
||||
item_storage_memory.set(item)
|
||||
item_storage_memory.delete("1")
|
||||
assert called_item_id == "1"
|
Loading…
Reference in New Issue
Block a user