revert(nodes): revert making tensors/conditioning use item storage

Turns out they are just different enough in purpose that the implementations would be rather unintuitive. I've made a separate ObjectSerializer service to handle tensors and conditioning.

Refined the class a bit too.
This commit is contained in:
psychedelicious 2024-02-07 23:30:46 +11:00
parent 73d871116c
commit 9f382419dc
14 changed files with 243 additions and 205 deletions

View File

@ -4,9 +4,9 @@ from logging import Logger
import torch
from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk
from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@ -90,9 +90,9 @@ class ApiDependencies:
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors"))
conditioning = ItemStorageForwardCache(
ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning")
tensors = ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[torch.Tensor](output_folder / "tensors"))
conditioning = ObjectSerializerForwardCache(
ObjectSerializerEphemeralDisk[ConditioningFieldData](output_folder / "conditioning")
)
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)

View File

@ -304,11 +304,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet,
seed,
) -> ConditioningData:
positive_cond_data = context.conditioning.get(self.positive_conditioning.conditioning_name)
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = c.extra_conditioning
negative_cond_data = context.conditioning.get(self.negative_conditioning.conditioning_name)
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData(
@ -621,10 +621,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.denoise_mask is None:
return None, None
mask = context.tensors.get(self.denoise_mask.mask_name)
mask = context.tensors.load(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name)
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
else:
masked_latents = None
@ -636,11 +636,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed = None
noise = None
if self.noise is not None:
noise = context.tensors.get(self.noise.latents_name)
noise = context.tensors.load(self.noise.latents_name)
seed = self.noise.seed
if self.latents is not None:
latents = context.tensors.get(self.latents.latents_name)
latents = context.tensors.load(self.latents.latents_name)
if seed is None:
seed = self.latents.seed
@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.get(self.latents.latents_name)
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation):
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.get(self.latents.latents_name)
latents = context.tensors.load(self.latents.latents_name)
# TODO:
device = choose_torch_device()
@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation):
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.get(self.latents.latents_name)
latents = context.tensors.load(self.latents.latents_name)
# TODO:
device = choose_torch_device()
@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation):
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.tensors.get(self.latents_a.latents_name)
latents_b = context.tensors.get(self.latents_b.latents_name)
latents_a = context.tensors.load(self.latents_a.latents_name)
latents_b = context.tensors.load(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.get(self.latents.latents_name)
latents = context.tensors.load(self.latents.latents_name)
x1 = self.x // LATENT_SCALE_FACTOR
y1 = self.y // LATENT_SCALE_FACTOR

View File

@ -344,7 +344,7 @@ class LatentsInvocation(BaseInvocation):
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.get(self.latents.latents_name)
latents = context.tensors.load(self.latents.latents_name)
return LatentsOutput.build(self.latents.latents_name, latents)

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
if TYPE_CHECKING:
from logging import Logger
@ -65,8 +67,8 @@ class InvocationServices:
names: "NameServiceBase",
urls: "UrlServiceBase",
workflow_records: "WorkflowRecordsStorageBase",
tensors: "ItemStorageABC[torch.Tensor]",
conditioning: "ItemStorageABC[ConditioningFieldData]",
tensors: "ObjectSerializerBase[torch.Tensor]",
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
):
self.board_images = board_images
self.board_image_records = board_image_records

View File

@ -1,7 +1,9 @@
from abc import ABC, abstractmethod
from typing import Callable, Generic, TypeVar
T = TypeVar("T")
from pydantic import BaseModel
T = TypeVar("T", bound=BaseModel)
class ItemStorageABC(ABC, Generic[T]):
@ -26,9 +28,9 @@ class ItemStorageABC(ABC, Generic[T]):
pass
@abstractmethod
def set(self, item: T) -> str:
def set(self, item: T) -> None:
"""
Sets the item. The id will be extracted based on id_field.
Sets the item.
:param item: the item to set
"""
pass

View File

@ -1,15 +1,5 @@
from pathlib import Path
from typing import Callable, TypeAlias, TypeVar
class ItemNotFoundError(KeyError):
"""Raised when an item is not found in storage"""
def __init__(self, item_id: str) -> None:
super().__init__(f"Item with id {item_id} not found")
T = TypeVar("T")
SaveFunc: TypeAlias = Callable[[T, Path], None]
LoadFunc: TypeAlias = Callable[[Path], T]

View File

@ -1,97 +0,0 @@
import typing
from pathlib import Path
from typing import Optional, Type, TypeVar
import torch
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError, LoadFunc, SaveFunc
from invokeai.app.util.misc import uuid_string
T = TypeVar("T")
class ItemStorageEphemeralDisk(ItemStorageABC[T]):
"""Provides a disk-backed ephemeral storage. The storage is cleared at startup.
:param output_folder: The folder where the items will be stored
:param save: The function to use to save the items to disk [torch.save]
:param load: The function to use to load the items from disk [torch.load]
:param load_exc: The exception that is raised when an item is not found [FileNotFoundError]
"""
def __init__(
self,
output_folder: Path,
save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType]
load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType]
load_exc: Type[Exception] = FileNotFoundError,
):
super().__init__()
self._output_folder = output_folder
self._output_folder.mkdir(parents=True, exist_ok=True)
self._save = save
self._load = load
self._load_exc = load_exc
self.__item_class_name: Optional[str] = None
@property
def _item_class_name(self) -> str:
if not self.__item_class_name:
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
self.__item_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues]
return self.__item_class_name
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._delete_all_items()
def get(self, item_id: str) -> T:
file_path = self._get_path(item_id)
try:
return self._load(file_path)
except self._load_exc as e:
raise ItemNotFoundError(item_id) from e
def set(self, item: T) -> str:
self._output_folder.mkdir(parents=True, exist_ok=True)
item_id = self._new_item_id()
file_path = self._get_path(item_id)
self._save(item, file_path)
return item_id
def delete(self, item_id: str) -> None:
file_path = self._get_path(item_id)
file_path.unlink()
def _get_path(self, item_id: str) -> Path:
return self._output_folder / item_id
def _new_item_id(self) -> str:
return f"{self._item_class_name}_{uuid_string()}"
def _delete_all_items(self) -> None:
"""
Deletes all pickled items from disk.
Must be called after we have access to `self._invoker` (e.g. in `start()`).
"""
# We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have
# to manually clear them on startup anyways. This is a bit simpler and more reliable.
if not self._invoker:
raise ValueError("Invoker is not set. Must call `start()` first.")
deleted_count = 0
freed_space = 0
for file in Path(self._output_folder).glob("*"):
if file.is_file():
freed_space += file.stat().st_size
deleted_count += 1
file.unlink()
if deleted_count > 0:
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
self._invoker.services.logger.info(
f"Deleted {deleted_count} {self._item_class_name} files (freed {freed_space_in_mb}MB)"
)

View File

@ -1,61 +0,0 @@
from queue import Queue
from typing import Optional, TypeVar
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
T = TypeVar("T")
class ItemStorageForwardCache(ItemStorageABC[T]):
"""Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size."""
def __init__(self, underlying_storage: ItemStorageABC[T], max_cache_size: int = 20):
super().__init__()
self._underlying_storage = underlying_storage
self._cache: dict[str, T] = {}
self._cache_ids = Queue[str]()
self._max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
start_op = getattr(self._underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
self._invoker = invoker
stop_op = getattr(self._underlying_storage, "stop", None)
if callable(stop_op):
stop_op(invoker)
def get(self, item_id: str) -> T:
cache_item = self._get_cache(item_id)
if cache_item is not None:
return cache_item
latent = self._underlying_storage.get(item_id)
self._set_cache(item_id, latent)
return latent
def set(self, item: T) -> str:
item_id = self._underlying_storage.set(item)
self._set_cache(item_id, item)
self._on_changed(item)
return item_id
def delete(self, item_id: str) -> None:
self._underlying_storage.delete(item_id)
if item_id in self._cache:
del self._cache[item_id]
self._on_deleted(item_id)
def _get_cache(self, item_id: str) -> Optional[T]:
return None if item_id not in self._cache else self._cache[item_id]
def _set_cache(self, item_id: str, data: T):
if item_id not in self._cache:
self._cache[item_id] = data
self._cache_ids.put(item_id)
if self._cache_ids.qsize() > self._max_cache_size:
self._cache.pop(self._cache_ids.get())

View File

@ -34,7 +34,7 @@ class ItemStorageMemory(ItemStorageABC[T], Generic[T]):
self._items[item_id] = item
return item
def set(self, item: T) -> str:
def set(self, item: T) -> None:
item_id = getattr(item, self._id_field)
if item_id in self._items:
# If item already exists, remove it and add it to the end
@ -44,7 +44,6 @@ class ItemStorageMemory(ItemStorageABC[T], Generic[T]):
self._items.popitem(last=False)
self._items[item_id] = item
self._on_changed(item)
return item_id
def delete(self, item_id: str) -> None:
# This is a no-op if the item doesn't exist.

View File

@ -0,0 +1,53 @@
from abc import ABC, abstractmethod
from typing import Callable, Generic, TypeVar
T = TypeVar("T")
class ObjectSerializerBase(ABC, Generic[T]):
"""Saves and loads arbitrary python objects."""
def __init__(self) -> None:
self._on_saved_callbacks: list[Callable[[str, T], None]] = []
self._on_deleted_callbacks: list[Callable[[str], None]] = []
@abstractmethod
def load(self, name: str) -> T:
"""
Loads the object.
:param name: The name of the object to load.
:raises ObjectNotFoundError: if the object is not found
"""
pass
@abstractmethod
def save(self, obj: T) -> str:
"""
Saves the object, returning its name.
:param obj: The object to save.
"""
pass
@abstractmethod
def delete(self, name: str) -> None:
"""
Deletes the object, if it exists.
:param name: The name of the object to delete.
"""
pass
def on_saved(self, on_saved: Callable[[str, T], None]) -> None:
"""Register a callback for when an object is saved"""
self._on_saved_callbacks.append(on_saved)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an object is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_saved(self, name: str, obj: T) -> None:
for callback in self._on_saved_callbacks:
callback(name, obj)
def _on_deleted(self, name: str) -> None:
for callback in self._on_deleted_callbacks:
callback(name)

View File

@ -0,0 +1,5 @@
class ObjectNotFoundError(KeyError):
"""Raised when an object is not found while loading"""
def __init__(self, name: str) -> None:
super().__init__(f"Object with name {name} not found")

View File

@ -0,0 +1,84 @@
import typing
from pathlib import Path
from typing import Optional, TypeVar
import torch
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.util.misc import uuid_string
T = TypeVar("T")
class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]):
"""Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup.
:param output_folder: The folder where the objects will be stored
"""
def __init__(self, output_dir: Path):
super().__init__()
self._output_dir = output_dir
self._output_dir.mkdir(parents=True, exist_ok=True)
self.__obj_class_name: Optional[str] = None
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._delete_all()
def load(self, name: str) -> T:
file_path = self._get_path(name)
try:
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
except FileNotFoundError as e:
raise ObjectNotFoundError(name) from e
def save(self, obj: T) -> str:
name = self._new_name()
file_path = self._get_path(name)
torch.save(obj, file_path) # pyright: ignore [reportUnknownMemberType]
return name
def delete(self, name: str) -> None:
file_path = self._get_path(name)
file_path.unlink()
@property
def _obj_class_name(self) -> str:
if not self.__obj_class_name:
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues]
return self.__obj_class_name
def _get_path(self, name: str) -> Path:
return self._output_dir / name
def _new_name(self) -> str:
return f"{self._obj_class_name}_{uuid_string()}"
def _delete_all(self) -> None:
"""
Deletes all objects from disk.
Must be called after we have access to `self._invoker` (e.g. in `start()`).
"""
# We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have
# to manually clear them on startup anyways. This is a bit simpler and more reliable.
if not self._invoker:
raise ValueError("Invoker is not set. Must call `start()` first.")
deleted_count = 0
freed_space = 0
for file in Path(self._output_dir).glob("*"):
if file.is_file():
freed_space += file.stat().st_size
deleted_count += 1
file.unlink()
if deleted_count > 0:
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
self._invoker.services.logger.info(
f"Deleted {deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)"
)

View File

@ -0,0 +1,61 @@
from queue import Queue
from typing import Optional, TypeVar
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
T = TypeVar("T")
class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
"""Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size."""
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
super().__init__()
self._underlying_storage = underlying_storage
self._cache: dict[str, T] = {}
self._cache_ids = Queue[str]()
self._max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
start_op = getattr(self._underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
self._invoker = invoker
stop_op = getattr(self._underlying_storage, "stop", None)
if callable(stop_op):
stop_op(invoker)
def load(self, name: str) -> T:
cache_item = self._get_cache(name)
if cache_item is not None:
return cache_item
latent = self._underlying_storage.load(name)
self._set_cache(name, latent)
return latent
def save(self, obj: T) -> str:
name = self._underlying_storage.save(obj)
self._set_cache(name, obj)
self._on_saved(name, obj)
return name
def delete(self, name: str) -> None:
self._underlying_storage.delete(name)
if name in self._cache:
del self._cache[name]
self._on_deleted(name)
def _get_cache(self, name: str) -> Optional[T]:
return None if name not in self._cache else self._cache[name]
def _set_cache(self, name: str, data: T):
if name not in self._cache:
self._cache[name] = data
self._cache_ids.put(name)
if self._cache_ids.qsize() > self._max_cache_size:
self._cache.pop(self._cache_ids.get())

View File

@ -223,16 +223,16 @@ class TensorsInterface(InvocationContextInterface):
:param tensor: The tensor to save.
"""
name = self._services.tensors.set(item=tensor)
return name
tensor_id = self._services.tensors.save(obj=tensor)
return tensor_id
def get(self, tensor_name: str) -> Tensor:
def load(self, name: str) -> Tensor:
"""
Gets a tensor by name.
Loads a tensor by name.
:param tensor_name: The name of the tensor to get.
:param name: The name of the tensor to load.
"""
return self._services.tensors.get(tensor_name)
return self._services.tensors.load(name)
class ConditioningInterface(InvocationContextInterface):
@ -243,17 +243,17 @@ class ConditioningInterface(InvocationContextInterface):
:param conditioning_context_data: The conditioning data to save.
"""
name = self._services.conditioning.set(item=conditioning_data)
return name
conditioning_id = self._services.conditioning.save(obj=conditioning_data)
return conditioning_id
def get(self, conditioning_name: str) -> ConditioningFieldData:
def load(self, name: str) -> ConditioningFieldData:
"""
Gets conditioning data by name.
Loads conditioning data by name.
:param conditioning_name: The name of the conditioning data to get.
:param name: The name of the conditioning data to load.
"""
return self._services.conditioning.get(conditioning_name)
return self._services.conditioning.load(name)
class ModelsInterface(InvocationContextInterface):