From 7cb8e297268b41379b75c1197f9725bac3992227 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:39:03 +1100 Subject: [PATCH] feat(nodes): use `ItemStorageABC` for tensors and conditioning Turns out `ItemStorageABC` was almost identical to `PickleStorageBase`. Instead of maintaining separate classes, we can use `ItemStorageABC` for both. There's only one change needed - the `ItemStorageABC.set` method must return the newly stored item's ID. This allows us to let the service handle the responsibility of naming the item, but still create the requisite output objects during node execution. The naming implementation is improved here. It extracts the name of the generic and appends a UUID to that string when saving items. --- invokeai/app/api/dependencies.py | 10 +-- invokeai/app/services/invocation_services.py | 5 +- .../item_storage/item_storage_base.py | 2 +- .../item_storage_ephemeral_disk.py | 72 +++++++++++++++++++ .../item_storage_forward_cache.py | 61 ++++++++++++++++ .../item_storage/item_storage_memory.py | 3 +- .../pickle_storage/pickle_storage_base.py | 45 ------------ .../pickle_storage_forward_cache.py | 58 --------------- .../pickle_storage/pickle_storage_torch.py | 63 ---------------- .../app/services/shared/invocation_context.py | 30 +------- 10 files changed, 145 insertions(+), 204 deletions(-) create mode 100644 invokeai/app/services/item_storage/item_storage_ephemeral_disk.py create mode 100644 invokeai/app/services/item_storage/item_storage_forward_cache.py delete mode 100644 invokeai/app/services/pickle_storage/pickle_storage_base.py delete mode 100644 invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py delete mode 100644 invokeai/app/services/pickle_storage/pickle_storage_torch.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 6bb0915cb6..d6fd970a22 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,9 +4,9 @@ from logging import Logger import torch +from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk +from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory -from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache -from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -90,9 +90,9 @@ class ApiDependencies: image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor")) - conditioning = PickleStorageForwardCache( - PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning") + tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors")) + conditioning = ItemStorageForwardCache( + ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 81885781ac..69599d83a4 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -29,7 +29,6 @@ if TYPE_CHECKING: from .model_manager.model_manager_base import ModelManagerServiceBase from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase - from .pickle_storage.pickle_storage_base import PickleStorageBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase from .shared.graph import GraphExecutionState @@ -66,8 +65,8 @@ class InvocationServices: names: "NameServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", - tensors: "PickleStorageBase[torch.Tensor]", - conditioning: "PickleStorageBase[ConditioningFieldData]", + tensors: "ItemStorageABC[torch.Tensor]", + conditioning: "ItemStorageABC[ConditioningFieldData]", ): self.board_images = board_images self.board_image_records = board_image_records diff --git a/invokeai/app/services/item_storage/item_storage_base.py b/invokeai/app/services/item_storage/item_storage_base.py index d736679159..f2d62ea45f 100644 --- a/invokeai/app/services/item_storage/item_storage_base.py +++ b/invokeai/app/services/item_storage/item_storage_base.py @@ -26,7 +26,7 @@ class ItemStorageABC(ABC, Generic[T]): pass @abstractmethod - def set(self, item: T) -> None: + def set(self, item: T) -> str: """ Sets the item. The id will be extracted based on id_field. :param item: the item to set diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py new file mode 100644 index 0000000000..9843d1e54b --- /dev/null +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -0,0 +1,72 @@ +import typing +from pathlib import Path +from typing import Optional, TypeVar + +import torch + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC +from invokeai.app.util.misc import uuid_string + +T = TypeVar("T") + + +class ItemStorageEphemeralDisk(ItemStorageABC[T]): + """Provides arbitrary item storage with a disk-backed ephemeral storage. The storage is cleared at startup.""" + + def __init__(self, output_folder: Path): + super().__init__() + self._output_folder = output_folder + self._output_folder.mkdir(parents=True, exist_ok=True) + self.__item_class_name: Optional[str] = None + + @property + def _item_class_name(self) -> str: + if not self.__item_class_name: + # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason + self.__item_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] + return self.__item_class_name + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + self._delete_all_items() + + def get(self, item_id: str) -> T: + file_path = self._get_path(item_id) + return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] + + def set(self, item: T) -> str: + self._output_folder.mkdir(parents=True, exist_ok=True) + item_id = f"{self._item_class_name}_{uuid_string()}" + file_path = self._get_path(item_id) + torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType] + return item_id + + def delete(self, item_id: str) -> None: + file_path = self._get_path(item_id) + file_path.unlink() + + def _get_path(self, item_id: str) -> Path: + return self._output_folder / item_id + + def _delete_all_items(self) -> None: + """ + Deletes all pickled items from disk. + Must be called after we have access to `self._invoker` (e.g. in `start()`). + """ + + if not self._invoker: + raise ValueError("Invoker is not set. Must call `start()` first.") + + deleted_count = 0 + freed_space = 0 + for file in Path(self._output_folder).glob("*"): + if file.is_file(): + freed_space += file.stat().st_size + deleted_count += 1 + file.unlink() + if deleted_count > 0: + freed_space_in_mb = round(freed_space / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {deleted_count} {self._item_class_name} files (freed {freed_space_in_mb}MB)" + ) diff --git a/invokeai/app/services/item_storage/item_storage_forward_cache.py b/invokeai/app/services/item_storage/item_storage_forward_cache.py new file mode 100644 index 0000000000..d1fe8e13fa --- /dev/null +++ b/invokeai/app/services/item_storage/item_storage_forward_cache.py @@ -0,0 +1,61 @@ +from queue import Queue +from typing import Optional, TypeVar + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC + +T = TypeVar("T") + + +class ItemStorageForwardCache(ItemStorageABC[T]): + """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" + + def __init__(self, underlying_storage: ItemStorageABC[T], max_cache_size: int = 20): + super().__init__() + self._underlying_storage = underlying_storage + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() + self._max_cache_size = max_cache_size + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + start_op = getattr(self._underlying_storage, "start", None) + if callable(start_op): + start_op(invoker) + + def stop(self, invoker: Invoker) -> None: + self._invoker = invoker + stop_op = getattr(self._underlying_storage, "stop", None) + if callable(stop_op): + stop_op(invoker) + + def get(self, item_id: str) -> T: + cache_item = self._get_cache(item_id) + if cache_item is not None: + return cache_item + + latent = self._underlying_storage.get(item_id) + self._set_cache(item_id, latent) + return latent + + def set(self, item: T) -> str: + item_id = self._underlying_storage.set(item) + self._set_cache(item_id, item) + self._on_changed(item) + return item_id + + def delete(self, item_id: str) -> None: + self._underlying_storage.delete(item_id) + if item_id in self._cache: + del self._cache[item_id] + self._on_deleted(item_id) + + def _get_cache(self, item_id: str) -> Optional[T]: + return None if item_id not in self._cache else self._cache[item_id] + + def _set_cache(self, item_id: str, data: T): + if item_id not in self._cache: + self._cache[item_id] = data + self._cache_ids.put(item_id) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/item_storage/item_storage_memory.py b/invokeai/app/services/item_storage/item_storage_memory.py index d8dd0e0664..6d02874516 100644 --- a/invokeai/app/services/item_storage/item_storage_memory.py +++ b/invokeai/app/services/item_storage/item_storage_memory.py @@ -34,7 +34,7 @@ class ItemStorageMemory(ItemStorageABC[T], Generic[T]): self._items[item_id] = item return item - def set(self, item: T) -> None: + def set(self, item: T) -> str: item_id = getattr(item, self._id_field) if item_id in self._items: # If item already exists, remove it and add it to the end @@ -44,6 +44,7 @@ class ItemStorageMemory(ItemStorageABC[T], Generic[T]): self._items.popitem(last=False) self._items[item_id] = item self._on_changed(item) + return item_id def delete(self, item_id: str) -> None: # This is a no-op if the item doesn't exist. diff --git a/invokeai/app/services/pickle_storage/pickle_storage_base.py b/invokeai/app/services/pickle_storage/pickle_storage_base.py deleted file mode 100644 index 558b97c0f1..0000000000 --- a/invokeai/app/services/pickle_storage/pickle_storage_base.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from abc import ABC, abstractmethod -from typing import Callable, Generic, TypeVar - -T = TypeVar("T") - - -class PickleStorageBase(ABC, Generic[T]): - """Responsible for storing and retrieving non-serializable data using a pickler.""" - - _on_changed_callbacks: list[Callable[[T], None]] - _on_deleted_callbacks: list[Callable[[str], None]] - - def __init__(self) -> None: - self._on_changed_callbacks = [] - self._on_deleted_callbacks = [] - - @abstractmethod - def get(self, name: str) -> T: - pass - - @abstractmethod - def save(self, name: str, data: T) -> None: - pass - - @abstractmethod - def delete(self, name: str) -> None: - pass - - def on_changed(self, on_changed: Callable[[T], None]) -> None: - """Register a callback for when an item is changed""" - self._on_changed_callbacks.append(on_changed) - - def on_deleted(self, on_deleted: Callable[[str], None]) -> None: - """Register a callback for when an item is deleted""" - self._on_deleted_callbacks.append(on_deleted) - - def _on_changed(self, item: T) -> None: - for callback in self._on_changed_callbacks: - callback(item) - - def _on_deleted(self, item_id: str) -> None: - for callback in self._on_deleted_callbacks: - callback(item_id) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py b/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py deleted file mode 100644 index 3002d9e045..0000000000 --- a/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py +++ /dev/null @@ -1,58 +0,0 @@ -from queue import Queue -from typing import Optional, TypeVar - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase - -T = TypeVar("T") - - -class PickleStorageForwardCache(PickleStorageBase[T]): - def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20): - super().__init__() - self._underlying_storage = underlying_storage - self._cache: dict[str, T] = {} - self._cache_ids = Queue[str]() - self._max_cache_size = max_cache_size - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - start_op = getattr(self._underlying_storage, "start", None) - if callable(start_op): - start_op(invoker) - - def stop(self, invoker: Invoker) -> None: - self._invoker = invoker - stop_op = getattr(self._underlying_storage, "stop", None) - if callable(stop_op): - stop_op(invoker) - - def get(self, name: str) -> T: - cache_item = self._get_cache(name) - if cache_item is not None: - return cache_item - - latent = self._underlying_storage.get(name) - self._set_cache(name, latent) - return latent - - def save(self, name: str, data: T) -> None: - self._underlying_storage.save(name, data) - self._set_cache(name, data) - self._on_changed(data) - - def delete(self, name: str) -> None: - self._underlying_storage.delete(name) - if name in self._cache: - del self._cache[name] - self._on_deleted(name) - - def _get_cache(self, name: str) -> Optional[T]: - return None if name not in self._cache else self._cache[name] - - def _set_cache(self, name: str, data: T): - if name not in self._cache: - self._cache[name] = data - self._cache_ids.put(name) - if self._cache_ids.qsize() > self._max_cache_size: - self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py deleted file mode 100644 index 16f0d7bb7a..0000000000 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from pathlib import Path -from typing import TypeVar - -import torch - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase - -T = TypeVar("T") - - -class PickleStorageTorch(PickleStorageBase[T]): - """Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`.""" - - def __init__(self, output_folder: Path, item_type_name: "str"): - super().__init__() - self._output_folder = output_folder - self._output_folder.mkdir(parents=True, exist_ok=True) - self._item_type_name = item_type_name - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - self._delete_all_items() - - def get(self, name: str) -> T: - file_path = self._get_path(name) - return torch.load(file_path) - - def save(self, name: str, data: T) -> None: - self._output_folder.mkdir(parents=True, exist_ok=True) - file_path = self._get_path(name) - torch.save(data, file_path) - - def delete(self, name: str) -> None: - file_path = self._get_path(name) - file_path.unlink() - - def _get_path(self, name: str) -> Path: - return self._output_folder / name - - def _delete_all_items(self) -> None: - """ - Deletes all pickled items from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). - """ - - if not self._invoker: - raise ValueError("Invoker is not set. Must call `start()` first.") - - deleted_count = 0 - freed_space = 0 - for file in Path(self._output_folder).glob("*"): - if file.is_file(): - freed_space += file.stat().st_size - deleted_count += 1 - file.unlink() - if deleted_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" - ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 6756b1f5c6..baff47a3df 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -12,7 +12,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID -from invokeai.app.util.misc import uuid_string from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.model_manager import ModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType @@ -224,26 +223,7 @@ class TensorsInterface(InvocationContextInterface): :param tensor: The tensor to save. """ - # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. - # "mask", "noise", "masked_latents", etc. - # - # Retaining that capability in this wrapper would require either many different methods - # to save tensors, or extra args for this method. Instead of complicating the API, we - # will use the same naming scheme for all tensors. - # - # This has a very minor impact as we don't use them after a session completes. - - # Previously, invocations chose the name for their tensors. This is a bit risky, so we - # will generate a name for them instead. We use a uuid to ensure the name is unique. - # - # Because the name of the tensors file will includes the session and invocation IDs, - # we don't need to worry about collisions. A truncated UUIDv4 is fine. - - name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" - self._services.tensors.save( - name=name, - data=tensor, - ) + name = self._services.tensors.set(item=tensor) return name def get(self, tensor_name: str) -> Tensor: @@ -263,13 +243,7 @@ class ConditioningInterface(InvocationContextInterface): :param conditioning_context_data: The conditioning data to save. """ - # See comment in TensorsInterface.save for why we generate the name here. - - name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" - self._services.conditioning.save( - name=name, - data=conditioning_data, - ) + name = self._services.conditioning.set(item=conditioning_data) return name def get(self, conditioning_name: str) -> ConditioningFieldData: