mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): support custom save and load functions in ItemStorageEphemeralDisk
This commit is contained in:
parent
7fe5283e74
commit
54a67459bf
@ -1,5 +1,15 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, TypeAlias, TypeVar
|
||||||
|
|
||||||
|
|
||||||
class ItemNotFoundError(KeyError):
|
class ItemNotFoundError(KeyError):
|
||||||
"""Raised when an item is not found in storage"""
|
"""Raised when an item is not found in storage"""
|
||||||
|
|
||||||
def __init__(self, item_id: str) -> None:
|
def __init__(self, item_id: str) -> None:
|
||||||
super().__init__(f"Item with id {item_id} not found")
|
super().__init__(f"Item with id {item_id} not found")
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
SaveFunc: TypeAlias = Callable[[T, Path], None]
|
||||||
|
LoadFunc: TypeAlias = Callable[[Path], T]
|
||||||
|
@ -6,18 +6,31 @@ import torch
|
|||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
||||||
|
from invokeai.app.services.item_storage.item_storage_common import LoadFunc, SaveFunc
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class ItemStorageEphemeralDisk(ItemStorageABC[T]):
|
class ItemStorageEphemeralDisk(ItemStorageABC[T]):
|
||||||
"""Provides arbitrary item storage with a disk-backed ephemeral storage. The storage is cleared at startup."""
|
"""Provides a disk-backed ephemeral storage. The storage is cleared at startup.
|
||||||
|
|
||||||
def __init__(self, output_folder: Path):
|
:param output_folder: The folder where the items will be stored
|
||||||
|
:param save: The function to use to save the items to disk [torch.save]
|
||||||
|
:param load: The function to use to load the items from disk [torch.load]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_folder: Path,
|
||||||
|
save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType]
|
||||||
|
load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType]
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._output_folder = output_folder
|
self._output_folder = output_folder
|
||||||
self._output_folder.mkdir(parents=True, exist_ok=True)
|
self._output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._save = save
|
||||||
|
self._load = load
|
||||||
self.__item_class_name: Optional[str] = None
|
self.__item_class_name: Optional[str] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -33,13 +46,13 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]):
|
|||||||
|
|
||||||
def get(self, item_id: str) -> T:
|
def get(self, item_id: str) -> T:
|
||||||
file_path = self._get_path(item_id)
|
file_path = self._get_path(item_id)
|
||||||
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
|
return self._load(file_path)
|
||||||
|
|
||||||
def set(self, item: T) -> str:
|
def set(self, item: T) -> str:
|
||||||
self._output_folder.mkdir(parents=True, exist_ok=True)
|
self._output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
item_id = self._new_item_id()
|
item_id = self._new_item_id()
|
||||||
file_path = self._get_path(item_id)
|
file_path = self._get_path(item_id)
|
||||||
torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType]
|
self._save(item, file_path)
|
||||||
return item_id
|
return item_id
|
||||||
|
|
||||||
def delete(self, item_id: str) -> None:
|
def delete(self, item_id: str) -> None:
|
||||||
@ -58,6 +71,9 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]):
|
|||||||
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have
|
||||||
|
# to manually clear them on startup anyways. This is a bit simpler and more reliable.
|
||||||
|
|
||||||
if not self._invoker:
|
if not self._invoker:
|
||||||
raise ValueError("Invoker is not set. Must call `start()` first.")
|
raise ValueError("Invoker is not set. Must call `start()` first.")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user