feat(nodes): support custom exception in ephemeral disk storage

This commit is contained in:
psychedelicious 2024-02-07 22:54:52 +11:00 committed by Brandon Rising
parent 723009e163
commit 614f0e8086

View File

@ -1,12 +1,12 @@
import typing
from pathlib import Path
from typing import Optional, TypeVar
from typing import Optional, Type, TypeVar
import torch
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
from invokeai.app.services.item_storage.item_storage_common import LoadFunc, SaveFunc
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError, LoadFunc, SaveFunc
from invokeai.app.util.misc import uuid_string
T = TypeVar("T")
@ -18,6 +18,7 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]):
:param output_folder: The folder where the items will be stored
:param save: The function to use to save the items to disk [torch.save]
:param load: The function to use to load the items from disk [torch.load]
:param load_exc: The exception that is raised when an item is not found [FileNotFoundError]
"""
def __init__(
@ -25,12 +26,14 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]):
output_folder: Path,
save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType]
load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType]
load_exc: Type[Exception] = FileNotFoundError,
):
super().__init__()
self._output_folder = output_folder
self._output_folder.mkdir(parents=True, exist_ok=True)
self._save = save
self._load = load
self._load_exc = load_exc
self.__item_class_name: Optional[str] = None
@property
@ -46,7 +49,10 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]):
def get(self, item_id: str) -> T:
file_path = self._get_path(item_id)
try:
return self._load(file_path)
except self._load_exc as e:
raise ItemNotFoundError(item_id) from e
def set(self, item: T) -> str:
self._output_folder.mkdir(parents=True, exist_ok=True)