mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
revert(nodes): revert making tensors/conditioning use item storage
Turns out they are just different enough in purpose that the implementations would be rather unintuitive. I've made a separate ObjectSerializer service to handle tensors and conditioning. Refined the class a bit too.
This commit is contained in:
parent
8b6e322697
commit
06429028c8
@ -4,9 +4,9 @@ from logging import Logger
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk
|
|
||||||
from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache
|
|
||||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||||
|
from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk
|
||||||
|
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
@ -90,9 +90,9 @@ class ApiDependencies:
|
|||||||
image_records = SqliteImageRecordStorage(db=db)
|
image_records = SqliteImageRecordStorage(db=db)
|
||||||
images = ImageService()
|
images = ImageService()
|
||||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||||
tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors"))
|
tensors = ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[torch.Tensor](output_folder / "tensors"))
|
||||||
conditioning = ItemStorageForwardCache(
|
conditioning = ObjectSerializerForwardCache(
|
||||||
ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning")
|
ObjectSerializerEphemeralDisk[ConditioningFieldData](output_folder / "conditioning")
|
||||||
)
|
)
|
||||||
model_manager = ModelManagerService(config, logger)
|
model_manager = ModelManagerService(config, logger)
|
||||||
model_record_service = ModelRecordServiceSQL(db=db)
|
model_record_service = ModelRecordServiceSQL(db=db)
|
||||||
|
@ -304,11 +304,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
unet,
|
unet,
|
||||||
seed,
|
seed,
|
||||||
) -> ConditioningData:
|
) -> ConditioningData:
|
||||||
positive_cond_data = context.conditioning.get(self.positive_conditioning.conditioning_name)
|
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
extra_conditioning_info = c.extra_conditioning
|
extra_conditioning_info = c.extra_conditioning
|
||||||
|
|
||||||
negative_cond_data = context.conditioning.get(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
@ -621,10 +621,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if self.denoise_mask is None:
|
if self.denoise_mask is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
mask = context.tensors.get(self.denoise_mask.mask_name)
|
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
if self.denoise_mask.masked_latents_name is not None:
|
if self.denoise_mask.masked_latents_name is not None:
|
||||||
masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name)
|
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
|
||||||
else:
|
else:
|
||||||
masked_latents = None
|
masked_latents = None
|
||||||
|
|
||||||
@ -636,11 +636,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
seed = None
|
seed = None
|
||||||
noise = None
|
noise = None
|
||||||
if self.noise is not None:
|
if self.noise is not None:
|
||||||
noise = context.tensors.get(self.noise.latents_name)
|
noise = context.tensors.load(self.noise.latents_name)
|
||||||
seed = self.noise.seed
|
seed = self.noise.seed
|
||||||
|
|
||||||
if self.latents is not None:
|
if self.latents is not None:
|
||||||
latents = context.tensors.get(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = self.latents.seed
|
seed = self.latents.seed
|
||||||
|
|
||||||
@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
latents = context.tensors.get(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
|
|
||||||
@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.tensors.get(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.tensors.get(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents_a = context.tensors.get(self.latents_a.latents_name)
|
latents_a = context.tensors.load(self.latents_a.latents_name)
|
||||||
latents_b = context.tensors.get(self.latents_b.latents_name)
|
latents_b = context.tensors.load(self.latents_b.latents_name)
|
||||||
|
|
||||||
if latents_a.shape != latents_b.shape:
|
if latents_a.shape != latents_b.shape:
|
||||||
raise Exception("Latents to blend must be the same size.")
|
raise Exception("Latents to blend must be the same size.")
|
||||||
@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.tensors.get(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
x1 = self.x // LATENT_SCALE_FACTOR
|
x1 = self.x // LATENT_SCALE_FACTOR
|
||||||
y1 = self.y // LATENT_SCALE_FACTOR
|
y1 = self.y // LATENT_SCALE_FACTOR
|
||||||
|
@ -344,7 +344,7 @@ class LatentsInvocation(BaseInvocation):
|
|||||||
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.tensors.get(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
return LatentsOutput.build(self.latents.latents_name, latents)
|
return LatentsOutput.build(self.latents.latents_name, latents)
|
||||||
|
|
||||||
|
@ -3,6 +3,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
@ -65,8 +67,8 @@ class InvocationServices:
|
|||||||
names: "NameServiceBase",
|
names: "NameServiceBase",
|
||||||
urls: "UrlServiceBase",
|
urls: "UrlServiceBase",
|
||||||
workflow_records: "WorkflowRecordsStorageBase",
|
workflow_records: "WorkflowRecordsStorageBase",
|
||||||
tensors: "ItemStorageABC[torch.Tensor]",
|
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||||
conditioning: "ItemStorageABC[ConditioningFieldData]",
|
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.board_image_records = board_image_records
|
self.board_image_records = board_image_records
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Generic, TypeVar
|
from typing import Callable, Generic, TypeVar
|
||||||
|
|
||||||
T = TypeVar("T")
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
class ItemStorageABC(ABC, Generic[T]):
|
class ItemStorageABC(ABC, Generic[T]):
|
||||||
@ -26,9 +28,9 @@ class ItemStorageABC(ABC, Generic[T]):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set(self, item: T) -> str:
|
def set(self, item: T) -> None:
|
||||||
"""
|
"""
|
||||||
Sets the item. The id will be extracted based on id_field.
|
Sets the item.
|
||||||
:param item: the item to set
|
:param item: the item to set
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
@ -1,15 +1,5 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Callable, TypeAlias, TypeVar
|
|
||||||
|
|
||||||
|
|
||||||
class ItemNotFoundError(KeyError):
|
class ItemNotFoundError(KeyError):
|
||||||
"""Raised when an item is not found in storage"""
|
"""Raised when an item is not found in storage"""
|
||||||
|
|
||||||
def __init__(self, item_id: str) -> None:
|
def __init__(self, item_id: str) -> None:
|
||||||
super().__init__(f"Item with id {item_id} not found")
|
super().__init__(f"Item with id {item_id} not found")
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
SaveFunc: TypeAlias = Callable[[T, Path], None]
|
|
||||||
LoadFunc: TypeAlias = Callable[[Path], T]
|
|
||||||
|
@ -1,97 +0,0 @@
|
|||||||
import typing
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Type, TypeVar
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
|
||||||
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError, LoadFunc, SaveFunc
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class ItemStorageEphemeralDisk(ItemStorageABC[T]):
|
|
||||||
"""Provides a disk-backed ephemeral storage. The storage is cleared at startup.
|
|
||||||
|
|
||||||
:param output_folder: The folder where the items will be stored
|
|
||||||
:param save: The function to use to save the items to disk [torch.save]
|
|
||||||
:param load: The function to use to load the items from disk [torch.load]
|
|
||||||
:param load_exc: The exception that is raised when an item is not found [FileNotFoundError]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
output_folder: Path,
|
|
||||||
save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType]
|
|
||||||
load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType]
|
|
||||||
load_exc: Type[Exception] = FileNotFoundError,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._output_folder = output_folder
|
|
||||||
self._output_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
self._save = save
|
|
||||||
self._load = load
|
|
||||||
self._load_exc = load_exc
|
|
||||||
self.__item_class_name: Optional[str] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _item_class_name(self) -> str:
|
|
||||||
if not self.__item_class_name:
|
|
||||||
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
|
|
||||||
self.__item_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues]
|
|
||||||
return self.__item_class_name
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self._invoker = invoker
|
|
||||||
self._delete_all_items()
|
|
||||||
|
|
||||||
def get(self, item_id: str) -> T:
|
|
||||||
file_path = self._get_path(item_id)
|
|
||||||
try:
|
|
||||||
return self._load(file_path)
|
|
||||||
except self._load_exc as e:
|
|
||||||
raise ItemNotFoundError(item_id) from e
|
|
||||||
|
|
||||||
def set(self, item: T) -> str:
|
|
||||||
self._output_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
item_id = self._new_item_id()
|
|
||||||
file_path = self._get_path(item_id)
|
|
||||||
self._save(item, file_path)
|
|
||||||
return item_id
|
|
||||||
|
|
||||||
def delete(self, item_id: str) -> None:
|
|
||||||
file_path = self._get_path(item_id)
|
|
||||||
file_path.unlink()
|
|
||||||
|
|
||||||
def _get_path(self, item_id: str) -> Path:
|
|
||||||
return self._output_folder / item_id
|
|
||||||
|
|
||||||
def _new_item_id(self) -> str:
|
|
||||||
return f"{self._item_class_name}_{uuid_string()}"
|
|
||||||
|
|
||||||
def _delete_all_items(self) -> None:
|
|
||||||
"""
|
|
||||||
Deletes all pickled items from disk.
|
|
||||||
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
|
||||||
"""
|
|
||||||
|
|
||||||
# We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have
|
|
||||||
# to manually clear them on startup anyways. This is a bit simpler and more reliable.
|
|
||||||
|
|
||||||
if not self._invoker:
|
|
||||||
raise ValueError("Invoker is not set. Must call `start()` first.")
|
|
||||||
|
|
||||||
deleted_count = 0
|
|
||||||
freed_space = 0
|
|
||||||
for file in Path(self._output_folder).glob("*"):
|
|
||||||
if file.is_file():
|
|
||||||
freed_space += file.stat().st_size
|
|
||||||
deleted_count += 1
|
|
||||||
file.unlink()
|
|
||||||
if deleted_count > 0:
|
|
||||||
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
|
|
||||||
self._invoker.services.logger.info(
|
|
||||||
f"Deleted {deleted_count} {self._item_class_name} files (freed {freed_space_in_mb}MB)"
|
|
||||||
)
|
|
@ -1,61 +0,0 @@
|
|||||||
from queue import Queue
|
|
||||||
from typing import Optional, TypeVar
|
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class ItemStorageForwardCache(ItemStorageABC[T]):
|
|
||||||
"""Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size."""
|
|
||||||
|
|
||||||
def __init__(self, underlying_storage: ItemStorageABC[T], max_cache_size: int = 20):
|
|
||||||
super().__init__()
|
|
||||||
self._underlying_storage = underlying_storage
|
|
||||||
self._cache: dict[str, T] = {}
|
|
||||||
self._cache_ids = Queue[str]()
|
|
||||||
self._max_cache_size = max_cache_size
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self._invoker = invoker
|
|
||||||
start_op = getattr(self._underlying_storage, "start", None)
|
|
||||||
if callable(start_op):
|
|
||||||
start_op(invoker)
|
|
||||||
|
|
||||||
def stop(self, invoker: Invoker) -> None:
|
|
||||||
self._invoker = invoker
|
|
||||||
stop_op = getattr(self._underlying_storage, "stop", None)
|
|
||||||
if callable(stop_op):
|
|
||||||
stop_op(invoker)
|
|
||||||
|
|
||||||
def get(self, item_id: str) -> T:
|
|
||||||
cache_item = self._get_cache(item_id)
|
|
||||||
if cache_item is not None:
|
|
||||||
return cache_item
|
|
||||||
|
|
||||||
latent = self._underlying_storage.get(item_id)
|
|
||||||
self._set_cache(item_id, latent)
|
|
||||||
return latent
|
|
||||||
|
|
||||||
def set(self, item: T) -> str:
|
|
||||||
item_id = self._underlying_storage.set(item)
|
|
||||||
self._set_cache(item_id, item)
|
|
||||||
self._on_changed(item)
|
|
||||||
return item_id
|
|
||||||
|
|
||||||
def delete(self, item_id: str) -> None:
|
|
||||||
self._underlying_storage.delete(item_id)
|
|
||||||
if item_id in self._cache:
|
|
||||||
del self._cache[item_id]
|
|
||||||
self._on_deleted(item_id)
|
|
||||||
|
|
||||||
def _get_cache(self, item_id: str) -> Optional[T]:
|
|
||||||
return None if item_id not in self._cache else self._cache[item_id]
|
|
||||||
|
|
||||||
def _set_cache(self, item_id: str, data: T):
|
|
||||||
if item_id not in self._cache:
|
|
||||||
self._cache[item_id] = data
|
|
||||||
self._cache_ids.put(item_id)
|
|
||||||
if self._cache_ids.qsize() > self._max_cache_size:
|
|
||||||
self._cache.pop(self._cache_ids.get())
|
|
@ -34,7 +34,7 @@ class ItemStorageMemory(ItemStorageABC[T], Generic[T]):
|
|||||||
self._items[item_id] = item
|
self._items[item_id] = item
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def set(self, item: T) -> str:
|
def set(self, item: T) -> None:
|
||||||
item_id = getattr(item, self._id_field)
|
item_id = getattr(item, self._id_field)
|
||||||
if item_id in self._items:
|
if item_id in self._items:
|
||||||
# If item already exists, remove it and add it to the end
|
# If item already exists, remove it and add it to the end
|
||||||
@ -44,7 +44,6 @@ class ItemStorageMemory(ItemStorageABC[T], Generic[T]):
|
|||||||
self._items.popitem(last=False)
|
self._items.popitem(last=False)
|
||||||
self._items[item_id] = item
|
self._items[item_id] = item
|
||||||
self._on_changed(item)
|
self._on_changed(item)
|
||||||
return item_id
|
|
||||||
|
|
||||||
def delete(self, item_id: str) -> None:
|
def delete(self, item_id: str) -> None:
|
||||||
# This is a no-op if the item doesn't exist.
|
# This is a no-op if the item doesn't exist.
|
||||||
|
@ -0,0 +1,53 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable, Generic, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectSerializerBase(ABC, Generic[T]):
|
||||||
|
"""Saves and loads arbitrary python objects."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._on_saved_callbacks: list[Callable[[str, T], None]] = []
|
||||||
|
self._on_deleted_callbacks: list[Callable[[str], None]] = []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, name: str) -> T:
|
||||||
|
"""
|
||||||
|
Loads the object.
|
||||||
|
:param name: The name of the object to load.
|
||||||
|
:raises ObjectNotFoundError: if the object is not found
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self, obj: T) -> str:
|
||||||
|
"""
|
||||||
|
Saves the object, returning its name.
|
||||||
|
:param obj: The object to save.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
"""
|
||||||
|
Deletes the object, if it exists.
|
||||||
|
:param name: The name of the object to delete.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_saved(self, on_saved: Callable[[str, T], None]) -> None:
|
||||||
|
"""Register a callback for when an object is saved"""
|
||||||
|
self._on_saved_callbacks.append(on_saved)
|
||||||
|
|
||||||
|
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||||
|
"""Register a callback for when an object is deleted"""
|
||||||
|
self._on_deleted_callbacks.append(on_deleted)
|
||||||
|
|
||||||
|
def _on_saved(self, name: str, obj: T) -> None:
|
||||||
|
for callback in self._on_saved_callbacks:
|
||||||
|
callback(name, obj)
|
||||||
|
|
||||||
|
def _on_deleted(self, name: str) -> None:
|
||||||
|
for callback in self._on_deleted_callbacks:
|
||||||
|
callback(name)
|
@ -0,0 +1,5 @@
|
|||||||
|
class ObjectNotFoundError(KeyError):
|
||||||
|
"""Raised when an object is not found while loading"""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
super().__init__(f"Object with name {name} not found")
|
@ -0,0 +1,84 @@
|
|||||||
|
import typing
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||||
|
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]):
|
||||||
|
"""Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup.
|
||||||
|
|
||||||
|
:param output_folder: The folder where the objects will be stored
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, output_dir: Path):
|
||||||
|
super().__init__()
|
||||||
|
self._output_dir = output_dir
|
||||||
|
self._output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.__obj_class_name: Optional[str] = None
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self._invoker = invoker
|
||||||
|
self._delete_all()
|
||||||
|
|
||||||
|
def load(self, name: str) -> T:
|
||||||
|
file_path = self._get_path(name)
|
||||||
|
try:
|
||||||
|
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
raise ObjectNotFoundError(name) from e
|
||||||
|
|
||||||
|
def save(self, obj: T) -> str:
|
||||||
|
name = self._new_name()
|
||||||
|
file_path = self._get_path(name)
|
||||||
|
torch.save(obj, file_path) # pyright: ignore [reportUnknownMemberType]
|
||||||
|
return name
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
file_path = self._get_path(name)
|
||||||
|
file_path.unlink()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _obj_class_name(self) -> str:
|
||||||
|
if not self.__obj_class_name:
|
||||||
|
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
|
||||||
|
self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues]
|
||||||
|
return self.__obj_class_name
|
||||||
|
|
||||||
|
def _get_path(self, name: str) -> Path:
|
||||||
|
return self._output_dir / name
|
||||||
|
|
||||||
|
def _new_name(self) -> str:
|
||||||
|
return f"{self._obj_class_name}_{uuid_string()}"
|
||||||
|
|
||||||
|
def _delete_all(self) -> None:
|
||||||
|
"""
|
||||||
|
Deletes all objects from disk.
|
||||||
|
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have
|
||||||
|
# to manually clear them on startup anyways. This is a bit simpler and more reliable.
|
||||||
|
|
||||||
|
if not self._invoker:
|
||||||
|
raise ValueError("Invoker is not set. Must call `start()` first.")
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
freed_space = 0
|
||||||
|
for file in Path(self._output_dir).glob("*"):
|
||||||
|
if file.is_file():
|
||||||
|
freed_space += file.stat().st_size
|
||||||
|
deleted_count += 1
|
||||||
|
file.unlink()
|
||||||
|
if deleted_count > 0:
|
||||||
|
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
|
||||||
|
self._invoker.services.logger.info(
|
||||||
|
f"Deleted {deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)"
|
||||||
|
)
|
@ -0,0 +1,61 @@
|
|||||||
|
from queue import Queue
|
||||||
|
from typing import Optional, TypeVar
|
||||||
|
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
|
||||||
|
"""Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size."""
|
||||||
|
|
||||||
|
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
|
||||||
|
super().__init__()
|
||||||
|
self._underlying_storage = underlying_storage
|
||||||
|
self._cache: dict[str, T] = {}
|
||||||
|
self._cache_ids = Queue[str]()
|
||||||
|
self._max_cache_size = max_cache_size
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self._invoker = invoker
|
||||||
|
start_op = getattr(self._underlying_storage, "start", None)
|
||||||
|
if callable(start_op):
|
||||||
|
start_op(invoker)
|
||||||
|
|
||||||
|
def stop(self, invoker: Invoker) -> None:
|
||||||
|
self._invoker = invoker
|
||||||
|
stop_op = getattr(self._underlying_storage, "stop", None)
|
||||||
|
if callable(stop_op):
|
||||||
|
stop_op(invoker)
|
||||||
|
|
||||||
|
def load(self, name: str) -> T:
|
||||||
|
cache_item = self._get_cache(name)
|
||||||
|
if cache_item is not None:
|
||||||
|
return cache_item
|
||||||
|
|
||||||
|
latent = self._underlying_storage.load(name)
|
||||||
|
self._set_cache(name, latent)
|
||||||
|
return latent
|
||||||
|
|
||||||
|
def save(self, obj: T) -> str:
|
||||||
|
name = self._underlying_storage.save(obj)
|
||||||
|
self._set_cache(name, obj)
|
||||||
|
self._on_saved(name, obj)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
self._underlying_storage.delete(name)
|
||||||
|
if name in self._cache:
|
||||||
|
del self._cache[name]
|
||||||
|
self._on_deleted(name)
|
||||||
|
|
||||||
|
def _get_cache(self, name: str) -> Optional[T]:
|
||||||
|
return None if name not in self._cache else self._cache[name]
|
||||||
|
|
||||||
|
def _set_cache(self, name: str, data: T):
|
||||||
|
if name not in self._cache:
|
||||||
|
self._cache[name] = data
|
||||||
|
self._cache_ids.put(name)
|
||||||
|
if self._cache_ids.qsize() > self._max_cache_size:
|
||||||
|
self._cache.pop(self._cache_ids.get())
|
@ -223,16 +223,16 @@ class TensorsInterface(InvocationContextInterface):
|
|||||||
:param tensor: The tensor to save.
|
:param tensor: The tensor to save.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = self._services.tensors.set(item=tensor)
|
tensor_id = self._services.tensors.save(obj=tensor)
|
||||||
return name
|
return tensor_id
|
||||||
|
|
||||||
def get(self, tensor_name: str) -> Tensor:
|
def load(self, name: str) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Gets a tensor by name.
|
Loads a tensor by name.
|
||||||
|
|
||||||
:param tensor_name: The name of the tensor to get.
|
:param name: The name of the tensor to load.
|
||||||
"""
|
"""
|
||||||
return self._services.tensors.get(tensor_name)
|
return self._services.tensors.load(name)
|
||||||
|
|
||||||
|
|
||||||
class ConditioningInterface(InvocationContextInterface):
|
class ConditioningInterface(InvocationContextInterface):
|
||||||
@ -243,17 +243,17 @@ class ConditioningInterface(InvocationContextInterface):
|
|||||||
:param conditioning_context_data: The conditioning data to save.
|
:param conditioning_context_data: The conditioning data to save.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = self._services.conditioning.set(item=conditioning_data)
|
conditioning_id = self._services.conditioning.save(obj=conditioning_data)
|
||||||
return name
|
return conditioning_id
|
||||||
|
|
||||||
def get(self, conditioning_name: str) -> ConditioningFieldData:
|
def load(self, name: str) -> ConditioningFieldData:
|
||||||
"""
|
"""
|
||||||
Gets conditioning data by name.
|
Loads conditioning data by name.
|
||||||
|
|
||||||
:param conditioning_name: The name of the conditioning data to get.
|
:param name: The name of the conditioning data to load.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._services.conditioning.get(conditioning_name)
|
return self._services.conditioning.load(name)
|
||||||
|
|
||||||
|
|
||||||
class ModelsInterface(InvocationContextInterface):
|
class ModelsInterface(InvocationContextInterface):
|
||||||
|
Loading…
Reference in New Issue
Block a user