feat(nodes): use ItemStorageABC for tensors and conditioning

Turns out `ItemStorageABC` was almost identical to `PickleStorageBase`. Instead of maintaining separate classes, we can use `ItemStorageABC` for both.

There's only one change needed - the `ItemStorageABC.set` method must return the newly stored item's ID. This allows us to let the service handle the responsibility of naming the item, but still create the requisite output objects during node execution.

The naming implementation is improved here. It extracts the name of the generic and appends a UUID to that string when saving items.
This commit is contained in:
psychedelicious 2024-02-07 19:39:03 +11:00 committed by Brandon Rising
parent f593959bea
commit 7cb8e29726
10 changed files with 145 additions and 204 deletions

View File

@ -4,9 +4,9 @@ from logging import Logger
import torch
from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk
from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache
from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@ -90,9 +90,9 @@ class ApiDependencies:
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor"))
conditioning = PickleStorageForwardCache(
PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning")
tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors"))
conditioning = ItemStorageForwardCache(
ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning")
)
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)

View File

@ -29,7 +29,6 @@ if TYPE_CHECKING:
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
from .pickle_storage.pickle_storage_base import PickleStorageBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
from .shared.graph import GraphExecutionState
@ -66,8 +65,8 @@ class InvocationServices:
names: "NameServiceBase",
urls: "UrlServiceBase",
workflow_records: "WorkflowRecordsStorageBase",
tensors: "PickleStorageBase[torch.Tensor]",
conditioning: "PickleStorageBase[ConditioningFieldData]",
tensors: "ItemStorageABC[torch.Tensor]",
conditioning: "ItemStorageABC[ConditioningFieldData]",
):
self.board_images = board_images
self.board_image_records = board_image_records

View File

@ -26,7 +26,7 @@ class ItemStorageABC(ABC, Generic[T]):
pass
@abstractmethod
def set(self, item: T) -> None:
def set(self, item: T) -> str:
"""
Sets the item. The id will be extracted based on id_field.
:param item: the item to set

View File

@ -0,0 +1,72 @@
import typing
from pathlib import Path
from typing import Optional, TypeVar
import torch
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
from invokeai.app.util.misc import uuid_string
T = TypeVar("T")
class ItemStorageEphemeralDisk(ItemStorageABC[T]):
"""Provides arbitrary item storage with a disk-backed ephemeral storage. The storage is cleared at startup."""
def __init__(self, output_folder: Path):
super().__init__()
self._output_folder = output_folder
self._output_folder.mkdir(parents=True, exist_ok=True)
self.__item_class_name: Optional[str] = None
@property
def _item_class_name(self) -> str:
if not self.__item_class_name:
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
self.__item_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues]
return self.__item_class_name
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._delete_all_items()
def get(self, item_id: str) -> T:
file_path = self._get_path(item_id)
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
def set(self, item: T) -> str:
self._output_folder.mkdir(parents=True, exist_ok=True)
item_id = f"{self._item_class_name}_{uuid_string()}"
file_path = self._get_path(item_id)
torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType]
return item_id
def delete(self, item_id: str) -> None:
file_path = self._get_path(item_id)
file_path.unlink()
def _get_path(self, item_id: str) -> Path:
return self._output_folder / item_id
def _delete_all_items(self) -> None:
"""
Deletes all pickled items from disk.
Must be called after we have access to `self._invoker` (e.g. in `start()`).
"""
if not self._invoker:
raise ValueError("Invoker is not set. Must call `start()` first.")
deleted_count = 0
freed_space = 0
for file in Path(self._output_folder).glob("*"):
if file.is_file():
freed_space += file.stat().st_size
deleted_count += 1
file.unlink()
if deleted_count > 0:
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
self._invoker.services.logger.info(
f"Deleted {deleted_count} {self._item_class_name} files (freed {freed_space_in_mb}MB)"
)

View File

@ -0,0 +1,61 @@
from queue import Queue
from typing import Optional, TypeVar
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
T = TypeVar("T")
class ItemStorageForwardCache(ItemStorageABC[T]):
"""Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size."""
def __init__(self, underlying_storage: ItemStorageABC[T], max_cache_size: int = 20):
super().__init__()
self._underlying_storage = underlying_storage
self._cache: dict[str, T] = {}
self._cache_ids = Queue[str]()
self._max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
start_op = getattr(self._underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
self._invoker = invoker
stop_op = getattr(self._underlying_storage, "stop", None)
if callable(stop_op):
stop_op(invoker)
def get(self, item_id: str) -> T:
cache_item = self._get_cache(item_id)
if cache_item is not None:
return cache_item
latent = self._underlying_storage.get(item_id)
self._set_cache(item_id, latent)
return latent
def set(self, item: T) -> str:
item_id = self._underlying_storage.set(item)
self._set_cache(item_id, item)
self._on_changed(item)
return item_id
def delete(self, item_id: str) -> None:
self._underlying_storage.delete(item_id)
if item_id in self._cache:
del self._cache[item_id]
self._on_deleted(item_id)
def _get_cache(self, item_id: str) -> Optional[T]:
return None if item_id not in self._cache else self._cache[item_id]
def _set_cache(self, item_id: str, data: T):
if item_id not in self._cache:
self._cache[item_id] = data
self._cache_ids.put(item_id)
if self._cache_ids.qsize() > self._max_cache_size:
self._cache.pop(self._cache_ids.get())

View File

@ -34,7 +34,7 @@ class ItemStorageMemory(ItemStorageABC[T], Generic[T]):
self._items[item_id] = item
return item
def set(self, item: T) -> None:
def set(self, item: T) -> str:
item_id = getattr(item, self._id_field)
if item_id in self._items:
# If item already exists, remove it and add it to the end
@ -44,6 +44,7 @@ class ItemStorageMemory(ItemStorageABC[T], Generic[T]):
self._items.popitem(last=False)
self._items[item_id] = item
self._on_changed(item)
return item_id
def delete(self, item_id: str) -> None:
# This is a no-op if the item doesn't exist.

View File

@ -1,45 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
from typing import Callable, Generic, TypeVar
T = TypeVar("T")
class PickleStorageBase(ABC, Generic[T]):
"""Responsible for storing and retrieving non-serializable data using a pickler."""
_on_changed_callbacks: list[Callable[[T], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = []
self._on_deleted_callbacks = []
@abstractmethod
def get(self, name: str) -> T:
pass
@abstractmethod
def save(self, name: str, data: T) -> None:
pass
@abstractmethod
def delete(self, name: str) -> None:
pass
def on_changed(self, on_changed: Callable[[T], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: T) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)

View File

@ -1,58 +0,0 @@
from queue import Queue
from typing import Optional, TypeVar
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase
T = TypeVar("T")
class PickleStorageForwardCache(PickleStorageBase[T]):
def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20):
super().__init__()
self._underlying_storage = underlying_storage
self._cache: dict[str, T] = {}
self._cache_ids = Queue[str]()
self._max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
start_op = getattr(self._underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
self._invoker = invoker
stop_op = getattr(self._underlying_storage, "stop", None)
if callable(stop_op):
stop_op(invoker)
def get(self, name: str) -> T:
cache_item = self._get_cache(name)
if cache_item is not None:
return cache_item
latent = self._underlying_storage.get(name)
self._set_cache(name, latent)
return latent
def save(self, name: str, data: T) -> None:
self._underlying_storage.save(name, data)
self._set_cache(name, data)
self._on_changed(data)
def delete(self, name: str) -> None:
self._underlying_storage.delete(name)
if name in self._cache:
del self._cache[name]
self._on_deleted(name)
def _get_cache(self, name: str) -> Optional[T]:
return None if name not in self._cache else self._cache[name]
def _set_cache(self, name: str, data: T):
if name not in self._cache:
self._cache[name] = data
self._cache_ids.put(name)
if self._cache_ids.qsize() > self._max_cache_size:
self._cache.pop(self._cache_ids.get())

View File

@ -1,63 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from pathlib import Path
from typing import TypeVar
import torch
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase
T = TypeVar("T")
class PickleStorageTorch(PickleStorageBase[T]):
"""Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`."""
def __init__(self, output_folder: Path, item_type_name: "str"):
super().__init__()
self._output_folder = output_folder
self._output_folder.mkdir(parents=True, exist_ok=True)
self._item_type_name = item_type_name
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._delete_all_items()
def get(self, name: str) -> T:
file_path = self._get_path(name)
return torch.load(file_path)
def save(self, name: str, data: T) -> None:
self._output_folder.mkdir(parents=True, exist_ok=True)
file_path = self._get_path(name)
torch.save(data, file_path)
def delete(self, name: str) -> None:
file_path = self._get_path(name)
file_path.unlink()
def _get_path(self, name: str) -> Path:
return self._output_folder / name
def _delete_all_items(self) -> None:
"""
Deletes all pickled items from disk.
Must be called after we have access to `self._invoker` (e.g. in `start()`).
"""
if not self._invoker:
raise ValueError("Invoker is not set. Must call `start()` first.")
deleted_count = 0
freed_space = 0
for file in Path(self._output_folder).glob("*"):
if file.is_file():
freed_space += file.stat().st_size
deleted_count += 1
file.unlink()
if deleted_count > 0:
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
self._invoker.services.logger.info(
f"Deleted {deleted_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)"
)

View File

@ -12,7 +12,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.misc import uuid_string
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.model_manager import ModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
@ -224,26 +223,7 @@ class TensorsInterface(InvocationContextInterface):
:param tensor: The tensor to save.
"""
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
# "mask", "noise", "masked_latents", etc.
#
# Retaining that capability in this wrapper would require either many different methods
# to save tensors, or extra args for this method. Instead of complicating the API, we
# will use the same naming scheme for all tensors.
#
# This has a very minor impact as we don't use them after a session completes.
# Previously, invocations chose the name for their tensors. This is a bit risky, so we
# will generate a name for them instead. We use a uuid to ensure the name is unique.
#
# Because the name of the tensors file will includes the session and invocation IDs,
# we don't need to worry about collisions. A truncated UUIDv4 is fine.
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
self._services.tensors.save(
name=name,
data=tensor,
)
name = self._services.tensors.set(item=tensor)
return name
def get(self, tensor_name: str) -> Tensor:
@ -263,13 +243,7 @@ class ConditioningInterface(InvocationContextInterface):
:param conditioning_context_data: The conditioning data to save.
"""
# See comment in TensorsInterface.save for why we generate the name here.
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
self._services.conditioning.save(
name=name,
data=conditioning_data,
)
name = self._services.conditioning.set(item=conditioning_data)
return name
def get(self, conditioning_name: str) -> ConditioningFieldData: