feat(nodes): support custom exception in ephemeral disk storage

This commit is contained in:
psychedelicious 2024-02-07 22:54:52 +11:00
parent 54a67459bf
commit 8b6e322697

View File

@ -1,12 +1,12 @@
import typing import typing
from pathlib import Path from pathlib import Path
from typing import Optional, TypeVar from typing import Optional, Type, TypeVar
import torch import torch
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
from invokeai.app.services.item_storage.item_storage_common import LoadFunc, SaveFunc from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError, LoadFunc, SaveFunc
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
T = TypeVar("T") T = TypeVar("T")
@ -18,6 +18,7 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]):
:param output_folder: The folder where the items will be stored :param output_folder: The folder where the items will be stored
:param save: The function to use to save the items to disk [torch.save] :param save: The function to use to save the items to disk [torch.save]
:param load: The function to use to load the items from disk [torch.load] :param load: The function to use to load the items from disk [torch.load]
:param load_exc: The exception that is raised when an item is not found [FileNotFoundError]
""" """
def __init__( def __init__(
@ -25,12 +26,14 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]):
output_folder: Path, output_folder: Path,
save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType] save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType]
load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType] load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType]
load_exc: Type[Exception] = FileNotFoundError,
): ):
super().__init__() super().__init__()
self._output_folder = output_folder self._output_folder = output_folder
self._output_folder.mkdir(parents=True, exist_ok=True) self._output_folder.mkdir(parents=True, exist_ok=True)
self._save = save self._save = save
self._load = load self._load = load
self._load_exc = load_exc
self.__item_class_name: Optional[str] = None self.__item_class_name: Optional[str] = None
@property @property
@ -46,7 +49,10 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]):
def get(self, item_id: str) -> T: def get(self, item_id: str) -> T:
file_path = self._get_path(item_id) file_path = self._get_path(item_id)
return self._load(file_path) try:
return self._load(file_path)
except self._load_exc as e:
raise ItemNotFoundError(item_id) from e
def set(self, item: T) -> str: def set(self, item: T) -> str:
self._output_folder.mkdir(parents=True, exist_ok=True) self._output_folder.mkdir(parents=True, exist_ok=True)