From de0b72528cbb3f7ffcdafcad5c3232e81e748fd2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:41:23 +1100 Subject: [PATCH] feat(nodes): replace latents service with tensors and conditioning services - New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling - Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk` - Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices` - Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices` - Remove `latents` service and all `LatentsStorage` classes - Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods --- invokeai/app/api/dependencies.py | 18 +++-- invokeai/app/invocations/latent.py | 36 +++++----- invokeai/app/invocations/noise.py | 2 +- invokeai/app/invocations/primitives.py | 6 +- .../invocation_cache_memory.py | 3 +- invokeai/app/services/invocation_services.py | 12 +++- .../app/services/latents_storage/__init__.py | 0 .../latents_storage/latents_storage_disk.py | 58 ---------------- .../latents_storage_forward_cache.py | 68 ------------------- .../pickle_storage_base.py} | 18 ++--- .../pickle_storage_forward_cache.py | 58 ++++++++++++++++ .../pickle_storage/pickle_storage_torch.py | 62 +++++++++++++++++ .../app/services/shared/invocation_context.py | 49 ++++++------- 13 files changed, 197 insertions(+), 193 deletions(-) delete mode 100644 invokeai/app/services/latents_storage/__init__.py delete mode 100644 invokeai/app/services/latents_storage/latents_storage_disk.py delete mode 100644 invokeai/app/services/latents_storage/latents_storage_forward_cache.py rename invokeai/app/services/{latents_storage/latents_storage_base.py => pickle_storage/pickle_storage_base.py} (68%) create mode 100644 invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py create mode 100644 invokeai/app/services/pickle_storage/pickle_storage_torch.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index c8309e1729..6bb0915cb6 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -2,9 +2,14 @@ from logging import Logger +import torch + from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory +from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache +from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -23,8 +28,6 @@ from ..services.invocation_queue.invocation_queue_memory import MemoryInvocation from ..services.invocation_services import InvocationServices from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker -from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage -from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage from ..services.model_install import ModelInstallService from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_records import ModelRecordServiceSQL @@ -68,6 +71,9 @@ class ApiDependencies: logger.debug(f"Internet connectivity is {config.internet_available}") output_folder = config.output_path + if output_folder is None: + raise ValueError("Output folder is not set") + image_files = DiskImageFileStorage(f"{output_folder}/images") db = init_db(config=config, logger=logger, image_files=image_files) @@ -84,7 +90,10 @@ class ApiDependencies: image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) + tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor")) + conditioning = PickleStorageForwardCache( + PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning") + ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) download_queue_service = DownloadQueueService(event_bus=events) @@ -117,7 +126,6 @@ class ApiDependencies: image_records=image_records, images=images, invocation_cache=invocation_cache, - latents=latents, logger=logger, model_manager=model_manager, model_records=model_record_service, @@ -131,6 +139,8 @@ class ApiDependencies: session_queue=session_queue, urls=urls, workflow_records=workflow_records, + tensors=tensors, + conditioning=conditioning, ) ApiDependencies.invoker = Invoker(services) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 5449ec9af7..94440d3e2a 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -163,11 +163,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation): # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) - masked_latents_name = context.latents.save(tensor=masked_latents) + masked_latents_name = context.tensors.save(tensor=masked_latents) else: masked_latents_name = None - mask_name = context.latents.save(tensor=mask) + mask_name = context.tensors.save(tensor=mask) return DenoiseMaskOutput.build( mask_name=mask_name, @@ -621,10 +621,10 @@ class DenoiseLatentsInvocation(BaseInvocation): if self.denoise_mask is None: return None, None - mask = context.latents.get(self.denoise_mask.mask_name) + mask = context.tensors.get(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.latents.get(self.denoise_mask.masked_latents_name) + masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name) else: masked_latents = None @@ -636,11 +636,11 @@ class DenoiseLatentsInvocation(BaseInvocation): seed = None noise = None if self.noise is not None: - noise = context.latents.get(self.noise.latents_name) + noise = context.tensors.get(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -752,7 +752,7 @@ class DenoiseLatentsInvocation(BaseInvocation): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=result_latents) + name = context.tensors.save(tensor=result_latents) return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed) @@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -888,7 +888,7 @@ class ResizeLatentsInvocation(BaseInvocation): if device == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=resized_latents) + name = context.tensors.save(tensor=resized_latents) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -930,7 +930,7 @@ class ScaleLatentsInvocation(BaseInvocation): if device == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=resized_latents) + name = context.tensors.save(tensor=resized_latents) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @@ -1011,7 +1011,7 @@ class ImageToLatentsInvocation(BaseInvocation): latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) latents = latents.to("cpu") - name = context.latents.save(tensor=latents) + name = context.tensors.save(tensor=latents) return LatentsOutput.build(latents_name=name, latents=latents, seed=None) @singledispatchmethod @@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation): alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.latents.get(self.latents_a.latents_name) - latents_b = context.latents.get(self.latents_b.latents_name) + latents_a = context.tensors.get(self.latents_a.latents_name) + latents_b = context.tensors.get(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") @@ -1103,7 +1103,7 @@ class BlendLatentsInvocation(BaseInvocation): if device == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=blended_latents) + name = context.tensors.save(tensor=blended_latents) return LatentsOutput.build(latents_name=name, latents=blended_latents) @@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR @@ -1158,7 +1158,7 @@ class CropLatentsCoreInvocation(BaseInvocation): cropped_latents = latents[..., y1:y2, x1:x2] - name = context.latents.save(tensor=cropped_latents) + name = context.tensors.save(tensor=cropped_latents) return LatentsOutput.build(latents_name=name, latents=cropped_latents) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 78f13cc52d..74b3d6e4cb 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -121,5 +121,5 @@ class NoiseInvocation(BaseInvocation): seed=self.seed, use_cpu=self.use_cpu, ) - name = context.latents.save(tensor=noise) + name = context.tensors.save(tensor=noise) return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index a77939943a..082d5432cc 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -313,9 +313,7 @@ class DenoiseMaskOutput(BaseInvocationOutput): class LatentsOutput(BaseInvocationOutput): """Base class for nodes that output a single latents tensor""" - latents: LatentsField = OutputField( - description=FieldDescriptions.latents, - ) + latents: LatentsField = OutputField(description=FieldDescriptions.latents) width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) @@ -346,7 +344,7 @@ class LatentsInvocation(BaseInvocation): latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) return LatentsOutput.build(self.latents.latents_name, latents) diff --git a/invokeai/app/services/invocation_cache/invocation_cache_memory.py b/invokeai/app/services/invocation_cache/invocation_cache_memory.py index 4a503b3c6b..c700f81186 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_memory.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_memory.py @@ -37,7 +37,8 @@ class MemoryInvocationCache(InvocationCacheBase): if self._max_cache_size == 0: return self._invoker.services.images.on_deleted(self._delete_by_match) - self._invoker.services.latents.on_deleted(self._delete_by_match) + self._invoker.services.tensors.on_deleted(self._delete_by_match) + self._invoker.services.conditioning.on_deleted(self._delete_by_match) def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: with self._lock: diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 51bfd5d77a..81885781ac 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -6,6 +6,10 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from logging import Logger + import torch + + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData + from .board_image_records.board_image_records_base import BoardImageRecordStorageBase from .board_images.board_images_base import BoardImagesServiceABC from .board_records.board_records_base import BoardRecordStorageBase @@ -21,11 +25,11 @@ if TYPE_CHECKING: from .invocation_queue.invocation_queue_base import InvocationQueueABC from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .item_storage.item_storage_base import ItemStorageABC - from .latents_storage.latents_storage_base import LatentsStorageBase from .model_install import ModelInstallServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase + from .pickle_storage.pickle_storage_base import PickleStorageBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase from .shared.graph import GraphExecutionState @@ -48,7 +52,6 @@ class InvocationServices: images: "ImageServiceABC", image_files: "ImageFileStorageBase", image_records: "ImageRecordStorageBase", - latents: "LatentsStorageBase", logger: "Logger", model_manager: "ModelManagerServiceBase", model_records: "ModelRecordServiceBase", @@ -63,6 +66,8 @@ class InvocationServices: names: "NameServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", + tensors: "PickleStorageBase[torch.Tensor]", + conditioning: "PickleStorageBase[ConditioningFieldData]", ): self.board_images = board_images self.board_image_records = board_image_records @@ -74,7 +79,6 @@ class InvocationServices: self.images = images self.image_files = image_files self.image_records = image_records - self.latents = latents self.logger = logger self.model_manager = model_manager self.model_records = model_records @@ -89,3 +93,5 @@ class InvocationServices: self.names = names self.urls = urls self.workflow_records = workflow_records + self.tensors = tensors + self.conditioning = conditioning diff --git a/invokeai/app/services/latents_storage/__init__.py b/invokeai/app/services/latents_storage/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/invokeai/app/services/latents_storage/latents_storage_disk.py b/invokeai/app/services/latents_storage/latents_storage_disk.py deleted file mode 100644 index 9192b9147f..0000000000 --- a/invokeai/app/services/latents_storage/latents_storage_disk.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from pathlib import Path -from typing import Union - -import torch - -from invokeai.app.services.invoker import Invoker - -from .latents_storage_base import LatentsStorageBase - - -class DiskLatentsStorage(LatentsStorageBase): - """Stores latents in a folder on disk without caching""" - - __output_folder: Path - - def __init__(self, output_folder: Union[str, Path]): - self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) - self.__output_folder.mkdir(parents=True, exist_ok=True) - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - self._delete_all_latents() - - def get(self, name: str) -> torch.Tensor: - latent_path = self.get_path(name) - return torch.load(latent_path) - - def save(self, name: str, data: torch.Tensor) -> None: - self.__output_folder.mkdir(parents=True, exist_ok=True) - latent_path = self.get_path(name) - torch.save(data, latent_path) - - def delete(self, name: str) -> None: - latent_path = self.get_path(name) - latent_path.unlink() - - def get_path(self, name: str) -> Path: - return self.__output_folder / name - - def _delete_all_latents(self) -> None: - """ - Deletes all latents from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). - """ - deleted_latents_count = 0 - freed_space = 0 - for latents_file in Path(self.__output_folder).glob("*"): - if latents_file.is_file(): - freed_space += latents_file.stat().st_size - deleted_latents_count += 1 - latents_file.unlink() - if deleted_latents_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)" - ) diff --git a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py deleted file mode 100644 index 6232b76a27..0000000000 --- a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from queue import Queue -from typing import Dict, Optional - -import torch - -from invokeai.app.services.invoker import Invoker - -from .latents_storage_base import LatentsStorageBase - - -class ForwardCacheLatentsStorage(LatentsStorageBase): - """Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage""" - - __cache: Dict[str, torch.Tensor] - __cache_ids: Queue - __max_cache_size: int - __underlying_storage: LatentsStorageBase - - def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20): - super().__init__() - self.__underlying_storage = underlying_storage - self.__cache = {} - self.__cache_ids = Queue() - self.__max_cache_size = max_cache_size - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - start_op = getattr(self.__underlying_storage, "start", None) - if callable(start_op): - start_op(invoker) - - def stop(self, invoker: Invoker) -> None: - self._invoker = invoker - stop_op = getattr(self.__underlying_storage, "stop", None) - if callable(stop_op): - stop_op(invoker) - - def get(self, name: str) -> torch.Tensor: - cache_item = self.__get_cache(name) - if cache_item is not None: - return cache_item - - latent = self.__underlying_storage.get(name) - self.__set_cache(name, latent) - return latent - - def save(self, name: str, data: torch.Tensor) -> None: - self.__underlying_storage.save(name, data) - self.__set_cache(name, data) - self._on_changed(data) - - def delete(self, name: str) -> None: - self.__underlying_storage.delete(name) - if name in self.__cache: - del self.__cache[name] - self._on_deleted(name) - - def __get_cache(self, name: str) -> Optional[torch.Tensor]: - return None if name not in self.__cache else self.__cache[name] - - def __set_cache(self, name: str, data: torch.Tensor): - if name not in self.__cache: - self.__cache[name] = data - self.__cache_ids.put(name) - if self.__cache_ids.qsize() > self.__max_cache_size: - self.__cache.pop(self.__cache_ids.get()) diff --git a/invokeai/app/services/latents_storage/latents_storage_base.py b/invokeai/app/services/pickle_storage/pickle_storage_base.py similarity index 68% rename from invokeai/app/services/latents_storage/latents_storage_base.py rename to invokeai/app/services/pickle_storage/pickle_storage_base.py index 9fa42b0ae6..558b97c0f1 100644 --- a/invokeai/app/services/latents_storage/latents_storage_base.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_base.py @@ -1,15 +1,15 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Generic, TypeVar -import torch +T = TypeVar("T") -class LatentsStorageBase(ABC): - """Responsible for storing and retrieving latents.""" +class PickleStorageBase(ABC, Generic[T]): + """Responsible for storing and retrieving non-serializable data using a pickler.""" - _on_changed_callbacks: list[Callable[[torch.Tensor], None]] + _on_changed_callbacks: list[Callable[[T], None]] _on_deleted_callbacks: list[Callable[[str], None]] def __init__(self) -> None: @@ -17,18 +17,18 @@ class LatentsStorageBase(ABC): self._on_deleted_callbacks = [] @abstractmethod - def get(self, name: str) -> torch.Tensor: + def get(self, name: str) -> T: pass @abstractmethod - def save(self, name: str, data: torch.Tensor) -> None: + def save(self, name: str, data: T) -> None: pass @abstractmethod def delete(self, name: str) -> None: pass - def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None: + def on_changed(self, on_changed: Callable[[T], None]) -> None: """Register a callback for when an item is changed""" self._on_changed_callbacks.append(on_changed) @@ -36,7 +36,7 @@ class LatentsStorageBase(ABC): """Register a callback for when an item is deleted""" self._on_deleted_callbacks.append(on_deleted) - def _on_changed(self, item: torch.Tensor) -> None: + def _on_changed(self, item: T) -> None: for callback in self._on_changed_callbacks: callback(item) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py b/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py new file mode 100644 index 0000000000..3002d9e045 --- /dev/null +++ b/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py @@ -0,0 +1,58 @@ +from queue import Queue +from typing import Optional, TypeVar + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase + +T = TypeVar("T") + + +class PickleStorageForwardCache(PickleStorageBase[T]): + def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20): + super().__init__() + self._underlying_storage = underlying_storage + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() + self._max_cache_size = max_cache_size + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + start_op = getattr(self._underlying_storage, "start", None) + if callable(start_op): + start_op(invoker) + + def stop(self, invoker: Invoker) -> None: + self._invoker = invoker + stop_op = getattr(self._underlying_storage, "stop", None) + if callable(stop_op): + stop_op(invoker) + + def get(self, name: str) -> T: + cache_item = self._get_cache(name) + if cache_item is not None: + return cache_item + + latent = self._underlying_storage.get(name) + self._set_cache(name, latent) + return latent + + def save(self, name: str, data: T) -> None: + self._underlying_storage.save(name, data) + self._set_cache(name, data) + self._on_changed(data) + + def delete(self, name: str) -> None: + self._underlying_storage.delete(name) + if name in self._cache: + del self._cache[name] + self._on_deleted(name) + + def _get_cache(self, name: str) -> Optional[T]: + return None if name not in self._cache else self._cache[name] + + def _set_cache(self, name: str, data: T): + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py new file mode 100644 index 0000000000..0b3c9af7a3 --- /dev/null +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from pathlib import Path +from typing import TypeVar + +import torch + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase + +T = TypeVar("T") + + +class PickleStorageTorch(PickleStorageBase[T]): + """Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`.""" + + def __init__(self, output_folder: Path, item_type_name: "str"): + self._output_folder = output_folder + self._output_folder.mkdir(parents=True, exist_ok=True) + self._item_type_name = item_type_name + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + self._delete_all_items() + + def get(self, name: str) -> T: + latent_path = self._get_path(name) + return torch.load(latent_path) + + def save(self, name: str, data: T) -> None: + self._output_folder.mkdir(parents=True, exist_ok=True) + latent_path = self._get_path(name) + torch.save(data, latent_path) + + def delete(self, name: str) -> None: + latent_path = self._get_path(name) + latent_path.unlink() + + def _get_path(self, name: str) -> Path: + return self._output_folder / name + + def _delete_all_items(self) -> None: + """ + Deletes all pickled items from disk. + Must be called after we have access to `self._invoker` (e.g. in `start()`). + """ + + if not self._invoker: + raise ValueError("Invoker is not set. Must call `start()` first.") + + deleted_latents_count = 0 + freed_space = 0 + for latents_file in Path(self._output_folder).glob("*"): + if latents_file.is_file(): + freed_space += latents_file.stat().st_size + deleted_latents_count += 1 + latents_file.unlink() + if deleted_latents_count > 0: + freed_space_in_mb = round(freed_space / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {deleted_latents_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" + ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 97a62246fb..6756b1f5c6 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -216,48 +216,46 @@ class ImagesInterface(InvocationContextInterface): return self._services.images.get_dto(image_name) -class LatentsInterface(InvocationContextInterface): +class TensorsInterface(InvocationContextInterface): def save(self, tensor: Tensor) -> str: """ - Saves a latents tensor, returning its name. + Saves a tensor, returning its name. - :param tensor: The latents tensor to save. + :param tensor: The tensor to save. """ # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. # "mask", "noise", "masked_latents", etc. # # Retaining that capability in this wrapper would require either many different methods - # to save latents, or extra args for this method. Instead of complicating the API, we - # will use the same naming scheme for all latents. + # to save tensors, or extra args for this method. Instead of complicating the API, we + # will use the same naming scheme for all tensors. # # This has a very minor impact as we don't use them after a session completes. - # Previously, invocations chose the name for their latents. This is a bit risky, so we + # Previously, invocations chose the name for their tensors. This is a bit risky, so we # will generate a name for them instead. We use a uuid to ensure the name is unique. # - # Because the name of the latents file will includes the session and invocation IDs, + # Because the name of the tensors file will includes the session and invocation IDs, # we don't need to worry about collisions. A truncated UUIDv4 is fine. name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" - self._services.latents.save( + self._services.tensors.save( name=name, data=tensor, ) return name - def get(self, latents_name: str) -> Tensor: + def get(self, tensor_name: str) -> Tensor: """ - Gets a latents tensor by name. + Gets a tensor by name. - :param latents_name: The name of the latents tensor to get. + :param tensor_name: The name of the tensor to get. """ - return self._services.latents.get(latents_name) + return self._services.tensors.get(tensor_name) class ConditioningInterface(InvocationContextInterface): - # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage - # service, but it is typed to work with Tensors only. We have to fudge the types here. def save(self, conditioning_data: ConditioningFieldData) -> str: """ Saves a conditioning data object, returning its name. @@ -265,15 +263,12 @@ class ConditioningInterface(InvocationContextInterface): :param conditioning_context_data: The conditioning data to save. """ - # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. - # - # See comment for `LatentsInterface.save` for more info about this method (it's very - # similar). + # See comment in TensorsInterface.save for why we generate the name here. - name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning" - self._services.latents.save( + name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" + self._services.conditioning.save( name=name, - data=conditioning_data, # type: ignore [arg-type] + data=conditioning_data, ) return name @@ -284,7 +279,7 @@ class ConditioningInterface(InvocationContextInterface): :param conditioning_name: The name of the conditioning data to get. """ - return self._services.latents.get(conditioning_name) # type: ignore [return-value] + return self._services.conditioning.get(conditioning_name) class ModelsInterface(InvocationContextInterface): @@ -400,7 +395,7 @@ class InvocationContext: def __init__( self, images: ImagesInterface, - latents: LatentsInterface, + tensors: TensorsInterface, conditioning: ConditioningInterface, models: ModelsInterface, logger: LoggerInterface, @@ -412,8 +407,8 @@ class InvocationContext: ) -> None: self.images = images """Provides methods to save, get and update images and their metadata.""" - self.latents = latents - """Provides methods to save and get latents tensors, including image, noise, masks, and masked images.""" + self.tensors = tensors + """Provides methods to save and get tensors, including image, noise, masks, and masked images.""" self.conditioning = conditioning """Provides methods to save and get conditioning data.""" self.models = models @@ -532,7 +527,7 @@ def build_invocation_context( logger = LoggerInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data) - latents = LatentsInterface(services=services, context_data=context_data) + tensors = TensorsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data) config = ConfigInterface(services=services, context_data=context_data) util = UtilInterface(services=services, context_data=context_data) @@ -543,7 +538,7 @@ def build_invocation_context( images=images, logger=logger, config=config, - latents=latents, + tensors=tensors, models=models, context_data=context_data, util=util,