fix(item-storage-memory): throw when requested item does not exist

- `ItemStorageMemory.get` now throws an `ItemNotFoundError` when the requested `item_id` is not found.
- Update docstrings in ABC and tests.

The new memory item storage implementation implemented the `get` method incorrectly, by returning `None` if the item didn't exist.

The ABC typed `get` as returning `T`, while the SQLite implementation typed `get` as returning `Optional[T]`. The SQLite implementation was referenced when writing the memory implementation.

This mismatched typing is a violation of the Liskov substitution principle, because the signature of the implementation of `get` in the implementation is wider than the abstract class's definition. Using `pyright` in strict mode catches this.

In `invocation_stats_default`, this introduced an error. The `_prune_stats` method calls `get`, expecting the method to throw if the item is not found. If the graph is no longer stored in the bounded item storage, we will call `is_complete()` on `None`, causing the error.

Note: This error condition never arose the SQLite implementation because it parsed the item with pydantic before returning it, which would throw if the item was not found. It implicitly threw, while the memory implementation did not.
This commit is contained in:
psychedelicious
2024-02-03 18:30:41 +11:00
committed by Kent Keirsey
parent c2af124622
commit 88c08bbfc7
5 changed files with 29 additions and 11 deletions

View File

@ -9,6 +9,7 @@ import torch
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
from invokeai.backend.model_management.model_cache import CacheStats from invokeai.backend.model_management.model_cache import CacheStats
from .invocation_stats_base import InvocationStatsServiceBase from .invocation_stats_base import InvocationStatsServiceBase
@ -82,7 +83,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
for graph_execution_state_id in self._stats: for graph_execution_state_id in self._stats:
try: try:
graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id) graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id)
except Exception: except ItemNotFoundError:
# TODO(ryand): What would cause this? Should this exception just be allowed to propagate? # TODO(ryand): What would cause this? Should this exception just be allowed to propagate?
logger.warning(f"Failed to get graph state for {graph_execution_state_id}.") logger.warning(f"Failed to get graph state for {graph_execution_state_id}.")
continue continue

View File

@ -20,17 +20,26 @@ class ItemStorageABC(ABC, Generic[T]):
@abstractmethod @abstractmethod
def get(self, item_id: str) -> T: def get(self, item_id: str) -> T:
"""Gets the item, parsing it into a Pydantic model""" """
Gets the item.
:param item_id: the id of the item to get
:raises ItemNotFoundError: if the item is not found
"""
pass pass
@abstractmethod @abstractmethod
def set(self, item: T) -> None: def set(self, item: T) -> None:
"""Sets the item""" """
Sets the item. The id will be extracted based on id_field.
:param item: the item to set
"""
pass pass
@abstractmethod @abstractmethod
def delete(self, item_id: str) -> None: def delete(self, item_id: str) -> None:
"""Deletes the item""" """
Deletes the item, if it exists.
"""
pass pass
def on_changed(self, on_changed: Callable[[T], None]) -> None: def on_changed(self, on_changed: Callable[[T], None]) -> None:

View File

@ -0,0 +1,5 @@
class ItemNotFoundError(KeyError):
"""Raised when an item is not found in storage"""
def __init__(self, item_id: str) -> None:
super().__init__(f"Item with id {item_id} not found")

View File

@ -1,10 +1,11 @@
from collections import OrderedDict from collections import OrderedDict
from contextlib import suppress from contextlib import suppress
from typing import Generic, Optional, TypeVar from typing import Generic, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
@ -25,12 +26,13 @@ class ItemStorageMemory(ItemStorageABC, Generic[T]):
self._items: OrderedDict[str, T] = OrderedDict() self._items: OrderedDict[str, T] = OrderedDict()
self._max_items = max_items self._max_items = max_items
def get(self, item_id: str) -> Optional[T]: def get(self, item_id: str) -> T:
# If the item exists, move it to the end of the OrderedDict. # If the item exists, move it to the end of the OrderedDict.
item = self._items.pop(item_id, None) item = self._items.pop(item_id, None)
if item is not None: if item is None:
self._items[item_id] = item raise ItemNotFoundError(item_id)
return self._items.get(item_id) self._items[item_id] = item
return item
def set(self, item: T) -> None: def set(self, item: T) -> None:
item_id = getattr(item, self._id_field) item_id = getattr(item, self._id_field)

View File

@ -3,6 +3,7 @@ import re
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
@ -58,8 +59,8 @@ def test_item_storage_memory_gets(item_storage_memory: ItemStorageMemory[MockIte
item = item_storage_memory.get("2") item = item_storage_memory.get("2")
assert item == item_2 assert item == item_2
item = item_storage_memory.get("3") with pytest.raises(ItemNotFoundError, match=re.escape("Item with id 3 not found")):
assert item is None item_storage_memory.get("3")
def test_item_storage_memory_deletes(item_storage_memory: ItemStorageMemory[MockItemModel]): def test_item_storage_memory_deletes(item_storage_memory: ItemStorageMemory[MockItemModel]):