From 8b6e32269702a837d4f785cf56ae545f69df43ef Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 22:54:52 +1100 Subject: [PATCH] feat(nodes): support custom exception in ephemeral disk storage --- .../item_storage/item_storage_ephemeral_disk.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py index 4dc67129da..97c767c87d 100644 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -1,12 +1,12 @@ import typing from pathlib import Path -from typing import Optional, TypeVar +from typing import Optional, Type, TypeVar import torch from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC -from invokeai.app.services.item_storage.item_storage_common import LoadFunc, SaveFunc +from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError, LoadFunc, SaveFunc from invokeai.app.util.misc import uuid_string T = TypeVar("T") @@ -18,6 +18,7 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]): :param output_folder: The folder where the items will be stored :param save: The function to use to save the items to disk [torch.save] :param load: The function to use to load the items from disk [torch.load] + :param load_exc: The exception that is raised when an item is not found [FileNotFoundError] """ def __init__( @@ -25,12 +26,14 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]): output_folder: Path, save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType] load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType] + load_exc: Type[Exception] = FileNotFoundError, ): super().__init__() self._output_folder = output_folder self._output_folder.mkdir(parents=True, exist_ok=True) self._save = save self._load = load + self._load_exc = load_exc self.__item_class_name: Optional[str] = None @property @@ -46,7 +49,10 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]): def get(self, item_id: str) -> T: file_path = self._get_path(item_id) - return self._load(file_path) + try: + return self._load(file_path) + except self._load_exc as e: + raise ItemNotFoundError(item_id) from e def set(self, item: T) -> str: self._output_folder.mkdir(parents=True, exist_ok=True)