mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(item_storage): implement item_storage_memory with LRU eviction strategy
Implemented with OrderedDict.
This commit is contained in:
parent
a0eecaecd0
commit
9f793bdae8
@ -1,3 +1,4 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from typing import Generic, Optional, TypeVar
|
from typing import Generic, Optional, TypeVar
|
||||||
|
|
||||||
@ -11,8 +12,7 @@ T = TypeVar("T", bound=BaseModel)
|
|||||||
class ItemStorageMemory(ItemStorageABC, Generic[T]):
|
class ItemStorageMemory(ItemStorageABC, Generic[T]):
|
||||||
"""
|
"""
|
||||||
Provides a simple in-memory storage for items, with a maximum number of items to store.
|
Provides a simple in-memory storage for items, with a maximum number of items to store.
|
||||||
An item is deleted when the maximum number of items is reached and a new item is added.
|
The storage uses the LRU strategy to evict items from storage when the max has been reached.
|
||||||
There is no guarantee about which item will be deleted.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, id_field: str = "id", max_items: int = 10) -> None:
|
def __init__(self, id_field: str = "id", max_items: int = 10) -> None:
|
||||||
@ -22,30 +22,29 @@ class ItemStorageMemory(ItemStorageABC, Generic[T]):
|
|||||||
if not id_field:
|
if not id_field:
|
||||||
raise ValueError("id_field must not be empty")
|
raise ValueError("id_field must not be empty")
|
||||||
self._id_field = id_field
|
self._id_field = id_field
|
||||||
self._items: dict[str, T] = {}
|
self._items: OrderedDict[str, T] = OrderedDict()
|
||||||
self._item_ids: set[str] = set()
|
|
||||||
self._max_items = max_items
|
self._max_items = max_items
|
||||||
|
|
||||||
def get(self, item_id: str) -> Optional[T]:
|
def get(self, item_id: str) -> Optional[T]:
|
||||||
|
# If the item exists, move it to the end of the OrderedDict.
|
||||||
|
item = self._items.pop(item_id, None)
|
||||||
|
if item is not None:
|
||||||
|
self._items[item_id] = item
|
||||||
return self._items.get(item_id)
|
return self._items.get(item_id)
|
||||||
|
|
||||||
def set(self, item: T) -> None:
|
def set(self, item: T) -> None:
|
||||||
item_id = getattr(item, self._id_field)
|
item_id = getattr(item, self._id_field)
|
||||||
assert isinstance(item_id, str)
|
if item_id in self._items:
|
||||||
if item_id in self._items or len(self._items) < self._max_items:
|
# If item already exists, remove it and add it to the end
|
||||||
# If the item is already stored, or we have room for more items, we can just add it.
|
self._items.pop(item_id)
|
||||||
|
elif len(self._items) >= self._max_items:
|
||||||
|
# If cache is full, evict the least recently used item
|
||||||
|
self._items.popitem(last=False)
|
||||||
self._items[item_id] = item
|
self._items[item_id] = item
|
||||||
self._item_ids.add(item_id)
|
|
||||||
else:
|
|
||||||
# Otherwise, we need to make room for it first.
|
|
||||||
self._items.pop(self._item_ids.pop())
|
|
||||||
self._items[item_id] = item
|
|
||||||
self._item_ids.add(item_id)
|
|
||||||
self._on_changed(item)
|
self._on_changed(item)
|
||||||
|
|
||||||
def delete(self, item_id: str) -> None:
|
def delete(self, item_id: str) -> None:
|
||||||
# Both of these are no-ops if the item doesn't exist.
|
# This is a no-op if the item doesn't exist.
|
||||||
with suppress(KeyError):
|
with suppress(KeyError):
|
||||||
del self._items[item_id]
|
del self._items[item_id]
|
||||||
self._item_ids.remove(item_id)
|
|
||||||
self._on_deleted(item_id)
|
self._on_deleted(item_id)
|
||||||
|
@ -19,7 +19,6 @@ def item_storage_memory():
|
|||||||
def test_item_storage_memory_initializes():
|
def test_item_storage_memory_initializes():
|
||||||
item_storage_memory = ItemStorageMemory()
|
item_storage_memory = ItemStorageMemory()
|
||||||
assert item_storage_memory._items == {}
|
assert item_storage_memory._items == {}
|
||||||
assert item_storage_memory._item_ids == set()
|
|
||||||
assert item_storage_memory._id_field == "id"
|
assert item_storage_memory._id_field == "id"
|
||||||
assert item_storage_memory._max_items == 10
|
assert item_storage_memory._max_items == 10
|
||||||
|
|
||||||
@ -37,18 +36,15 @@ def test_item_storage_memory_sets(item_storage_memory: ItemStorageMemory[MockIte
|
|||||||
item_1 = MockItemModel(id="1", value=1)
|
item_1 = MockItemModel(id="1", value=1)
|
||||||
item_storage_memory.set(item_1)
|
item_storage_memory.set(item_1)
|
||||||
assert item_storage_memory._items == {"1": item_1}
|
assert item_storage_memory._items == {"1": item_1}
|
||||||
assert item_storage_memory._item_ids == {"1"}
|
|
||||||
|
|
||||||
item_2 = MockItemModel(id="2", value=2)
|
item_2 = MockItemModel(id="2", value=2)
|
||||||
item_storage_memory.set(item_2)
|
item_storage_memory.set(item_2)
|
||||||
assert item_storage_memory._items == {"1": item_1, "2": item_2}
|
assert item_storage_memory._items == {"1": item_1, "2": item_2}
|
||||||
assert item_storage_memory._item_ids == {"1", "2"}
|
|
||||||
|
|
||||||
# Updating value of existing item
|
# Updating value of existing item
|
||||||
item_2_updated = MockItemModel(id="2", value=9001)
|
item_2_updated = MockItemModel(id="2", value=9001)
|
||||||
item_storage_memory.set(item_2_updated)
|
item_storage_memory.set(item_2_updated)
|
||||||
assert item_storage_memory._items == {"1": item_1, "2": item_2_updated}
|
assert item_storage_memory._items == {"1": item_1, "2": item_2_updated}
|
||||||
assert item_storage_memory._item_ids == {"1", "2"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_item_storage_memory_gets(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
def test_item_storage_memory_gets(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
||||||
@ -74,14 +70,17 @@ def test_item_storage_memory_deletes(item_storage_memory: ItemStorageMemory[Mock
|
|||||||
|
|
||||||
item_storage_memory.delete("2")
|
item_storage_memory.delete("2")
|
||||||
assert item_storage_memory._items == {"1": item_1}
|
assert item_storage_memory._items == {"1": item_1}
|
||||||
assert item_storage_memory._item_ids == {"1"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_item_storage_memory_respects_max():
|
def test_item_storage_memory_respects_max():
|
||||||
item_storage_memory = ItemStorageMemory(max_items=3)
|
item_storage_memory = ItemStorageMemory(max_items=3)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
item_storage_memory.set(MockItemModel(id=str(i), value=i))
|
item_storage_memory.set(MockItemModel(id=str(i), value=i))
|
||||||
assert len(item_storage_memory._items) == 3
|
assert item_storage_memory._items == {
|
||||||
|
"7": MockItemModel(id="7", value=7),
|
||||||
|
"8": MockItemModel(id="8", value=8),
|
||||||
|
"9": MockItemModel(id="9", value=9),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_item_storage_memory_calls_set_callback(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
def test_item_storage_memory_calls_set_callback(item_storage_memory: ItemStorageMemory[MockItemModel]):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user