mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): use ItemStorageABC
for tensors and conditioning
Turns out `ItemStorageABC` was almost identical to `PickleStorageBase`. Instead of maintaining separate classes, we can use `ItemStorageABC` for both. There's only one change needed - the `ItemStorageABC.set` method must return the newly stored item's ID. This allows us to let the service handle the responsibility of naming the item, but still create the requisite output objects during node execution. The naming implementation is improved here. It extracts the name of the generic and appends a UUID to that string when saving items.
This commit is contained in:
parent
ca09bd63a3
commit
a50c7c1cd7
@ -4,9 +4,9 @@ from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk
|
||||
from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache
|
||||
from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
@ -90,9 +90,9 @@ class ApiDependencies:
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor"))
|
||||
conditioning = PickleStorageForwardCache(
|
||||
PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning")
|
||||
tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors"))
|
||||
conditioning = ItemStorageForwardCache(
|
||||
ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning")
|
||||
)
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
|
@ -29,7 +29,6 @@ if TYPE_CHECKING:
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .model_records import ModelRecordServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .pickle_storage.pickle_storage_base import PickleStorageBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
from .shared.graph import GraphExecutionState
|
||||
@ -66,8 +65,8 @@ class InvocationServices:
|
||||
names: "NameServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "PickleStorageBase[torch.Tensor]",
|
||||
conditioning: "PickleStorageBase[ConditioningFieldData]",
|
||||
tensors: "ItemStorageABC[torch.Tensor]",
|
||||
conditioning: "ItemStorageABC[ConditioningFieldData]",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
|
@ -26,7 +26,7 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, item: T) -> None:
|
||||
def set(self, item: T) -> str:
|
||||
"""
|
||||
Sets the item. The id will be extracted based on id_field.
|
||||
:param item: the item to set
|
||||
|
@ -0,0 +1,72 @@
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ItemStorageEphemeralDisk(ItemStorageABC[T]):
|
||||
"""Provides arbitrary item storage with a disk-backed ephemeral storage. The storage is cleared at startup."""
|
||||
|
||||
def __init__(self, output_folder: Path):
|
||||
super().__init__()
|
||||
self._output_folder = output_folder
|
||||
self._output_folder.mkdir(parents=True, exist_ok=True)
|
||||
self.__item_class_name: Optional[str] = None
|
||||
|
||||
@property
|
||||
def _item_class_name(self) -> str:
|
||||
if not self.__item_class_name:
|
||||
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
|
||||
self.__item_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues]
|
||||
return self.__item_class_name
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._delete_all_items()
|
||||
|
||||
def get(self, item_id: str) -> T:
|
||||
file_path = self._get_path(item_id)
|
||||
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
|
||||
|
||||
def set(self, item: T) -> str:
|
||||
self._output_folder.mkdir(parents=True, exist_ok=True)
|
||||
item_id = f"{self._item_class_name}_{uuid_string()}"
|
||||
file_path = self._get_path(item_id)
|
||||
torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType]
|
||||
return item_id
|
||||
|
||||
def delete(self, item_id: str) -> None:
|
||||
file_path = self._get_path(item_id)
|
||||
file_path.unlink()
|
||||
|
||||
def _get_path(self, item_id: str) -> Path:
|
||||
return self._output_folder / item_id
|
||||
|
||||
def _delete_all_items(self) -> None:
|
||||
"""
|
||||
Deletes all pickled items from disk.
|
||||
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
||||
"""
|
||||
|
||||
if not self._invoker:
|
||||
raise ValueError("Invoker is not set. Must call `start()` first.")
|
||||
|
||||
deleted_count = 0
|
||||
freed_space = 0
|
||||
for file in Path(self._output_folder).glob("*"):
|
||||
if file.is_file():
|
||||
freed_space += file.stat().st_size
|
||||
deleted_count += 1
|
||||
file.unlink()
|
||||
if deleted_count > 0:
|
||||
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
|
||||
self._invoker.services.logger.info(
|
||||
f"Deleted {deleted_count} {self._item_class_name} files (freed {freed_space_in_mb}MB)"
|
||||
)
|
@ -0,0 +1,61 @@
|
||||
from queue import Queue
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ItemStorageForwardCache(ItemStorageABC[T]):
|
||||
"""Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size."""
|
||||
|
||||
def __init__(self, underlying_storage: ItemStorageABC[T], max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self._underlying_storage = underlying_storage
|
||||
self._cache: dict[str, T] = {}
|
||||
self._cache_ids = Queue[str]()
|
||||
self._max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
start_op = getattr(self._underlying_storage, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
stop_op = getattr(self._underlying_storage, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(invoker)
|
||||
|
||||
def get(self, item_id: str) -> T:
|
||||
cache_item = self._get_cache(item_id)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self._underlying_storage.get(item_id)
|
||||
self._set_cache(item_id, latent)
|
||||
return latent
|
||||
|
||||
def set(self, item: T) -> str:
|
||||
item_id = self._underlying_storage.set(item)
|
||||
self._set_cache(item_id, item)
|
||||
self._on_changed(item)
|
||||
return item_id
|
||||
|
||||
def delete(self, item_id: str) -> None:
|
||||
self._underlying_storage.delete(item_id)
|
||||
if item_id in self._cache:
|
||||
del self._cache[item_id]
|
||||
self._on_deleted(item_id)
|
||||
|
||||
def _get_cache(self, item_id: str) -> Optional[T]:
|
||||
return None if item_id not in self._cache else self._cache[item_id]
|
||||
|
||||
def _set_cache(self, item_id: str, data: T):
|
||||
if item_id not in self._cache:
|
||||
self._cache[item_id] = data
|
||||
self._cache_ids.put(item_id)
|
||||
if self._cache_ids.qsize() > self._max_cache_size:
|
||||
self._cache.pop(self._cache_ids.get())
|
@ -34,7 +34,7 @@ class ItemStorageMemory(ItemStorageABC[T], Generic[T]):
|
||||
self._items[item_id] = item
|
||||
return item
|
||||
|
||||
def set(self, item: T) -> None:
|
||||
def set(self, item: T) -> str:
|
||||
item_id = getattr(item, self._id_field)
|
||||
if item_id in self._items:
|
||||
# If item already exists, remove it and add it to the end
|
||||
@ -44,6 +44,7 @@ class ItemStorageMemory(ItemStorageABC[T], Generic[T]):
|
||||
self._items.popitem(last=False)
|
||||
self._items[item_id] = item
|
||||
self._on_changed(item)
|
||||
return item_id
|
||||
|
||||
def delete(self, item_id: str) -> None:
|
||||
# This is a no-op if the item doesn't exist.
|
||||
|
@ -1,45 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PickleStorageBase(ABC, Generic[T]):
|
||||
"""Responsible for storing and retrieving non-serializable data using a pickler."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[T], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = []
|
||||
self._on_deleted_callbacks = []
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> T:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, name: str, data: T) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: T) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
@ -1,58 +0,0 @@
|
||||
from queue import Queue
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PickleStorageForwardCache(PickleStorageBase[T]):
|
||||
def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self._underlying_storage = underlying_storage
|
||||
self._cache: dict[str, T] = {}
|
||||
self._cache_ids = Queue[str]()
|
||||
self._max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
start_op = getattr(self._underlying_storage, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
stop_op = getattr(self._underlying_storage, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(invoker)
|
||||
|
||||
def get(self, name: str) -> T:
|
||||
cache_item = self._get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self._underlying_storage.get(name)
|
||||
self._set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def save(self, name: str, data: T) -> None:
|
||||
self._underlying_storage.save(name, data)
|
||||
self._set_cache(name, data)
|
||||
self._on_changed(data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self._underlying_storage.delete(name)
|
||||
if name in self._cache:
|
||||
del self._cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def _get_cache(self, name: str) -> Optional[T]:
|
||||
return None if name not in self._cache else self._cache[name]
|
||||
|
||||
def _set_cache(self, name: str, data: T):
|
||||
if name not in self._cache:
|
||||
self._cache[name] = data
|
||||
self._cache_ids.put(name)
|
||||
if self._cache_ids.qsize() > self._max_cache_size:
|
||||
self._cache.pop(self._cache_ids.get())
|
@ -1,63 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PickleStorageTorch(PickleStorageBase[T]):
|
||||
"""Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`."""
|
||||
|
||||
def __init__(self, output_folder: Path, item_type_name: "str"):
|
||||
super().__init__()
|
||||
self._output_folder = output_folder
|
||||
self._output_folder.mkdir(parents=True, exist_ok=True)
|
||||
self._item_type_name = item_type_name
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._delete_all_items()
|
||||
|
||||
def get(self, name: str) -> T:
|
||||
file_path = self._get_path(name)
|
||||
return torch.load(file_path)
|
||||
|
||||
def save(self, name: str, data: T) -> None:
|
||||
self._output_folder.mkdir(parents=True, exist_ok=True)
|
||||
file_path = self._get_path(name)
|
||||
torch.save(data, file_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
file_path = self._get_path(name)
|
||||
file_path.unlink()
|
||||
|
||||
def _get_path(self, name: str) -> Path:
|
||||
return self._output_folder / name
|
||||
|
||||
def _delete_all_items(self) -> None:
|
||||
"""
|
||||
Deletes all pickled items from disk.
|
||||
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
||||
"""
|
||||
|
||||
if not self._invoker:
|
||||
raise ValueError("Invoker is not set. Must call `start()` first.")
|
||||
|
||||
deleted_count = 0
|
||||
freed_space = 0
|
||||
for file in Path(self._output_folder).glob("*"):
|
||||
if file.is_file():
|
||||
freed_space += file.stat().st_size
|
||||
deleted_count += 1
|
||||
file.unlink()
|
||||
if deleted_count > 0:
|
||||
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
|
||||
self._invoker.services.logger.info(
|
||||
f"Deleted {deleted_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)"
|
||||
)
|
@ -12,7 +12,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
@ -224,26 +223,7 @@ class TensorsInterface(InvocationContextInterface):
|
||||
:param tensor: The tensor to save.
|
||||
"""
|
||||
|
||||
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
|
||||
# "mask", "noise", "masked_latents", etc.
|
||||
#
|
||||
# Retaining that capability in this wrapper would require either many different methods
|
||||
# to save tensors, or extra args for this method. Instead of complicating the API, we
|
||||
# will use the same naming scheme for all tensors.
|
||||
#
|
||||
# This has a very minor impact as we don't use them after a session completes.
|
||||
|
||||
# Previously, invocations chose the name for their tensors. This is a bit risky, so we
|
||||
# will generate a name for them instead. We use a uuid to ensure the name is unique.
|
||||
#
|
||||
# Because the name of the tensors file will includes the session and invocation IDs,
|
||||
# we don't need to worry about collisions. A truncated UUIDv4 is fine.
|
||||
|
||||
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
|
||||
self._services.tensors.save(
|
||||
name=name,
|
||||
data=tensor,
|
||||
)
|
||||
name = self._services.tensors.set(item=tensor)
|
||||
return name
|
||||
|
||||
def get(self, tensor_name: str) -> Tensor:
|
||||
@ -263,13 +243,7 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
:param conditioning_context_data: The conditioning data to save.
|
||||
"""
|
||||
|
||||
# See comment in TensorsInterface.save for why we generate the name here.
|
||||
|
||||
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
|
||||
self._services.conditioning.save(
|
||||
name=name,
|
||||
data=conditioning_data,
|
||||
)
|
||||
name = self._services.conditioning.set(item=conditioning_data)
|
||||
return name
|
||||
|
||||
def get(self, conditioning_name: str) -> ConditioningFieldData:
|
||||
|
Loading…
Reference in New Issue
Block a user