feat(nodes): replace latents service with tensors and conditioning services

- New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling
- Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk`
- Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices`
- Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices`
- Remove `latents` service and all `LatentsStorage` classes
- Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods
This commit is contained in:
psychedelicious 2024-02-07 17:41:23 +11:00
parent 2932652787
commit de0b72528c
13 changed files with 197 additions and 193 deletions

View File

@ -2,9 +2,14 @@
from logging import Logger from logging import Logger
import torch
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache
from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch
from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__ from invokeai.version.invokeai_version import __version__
@ -23,8 +28,6 @@ from ..services.invocation_queue.invocation_queue_memory import MemoryInvocation
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker from ..services.invoker import Invoker
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_install import ModelInstallService from ..services.model_install import ModelInstallService
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL from ..services.model_records import ModelRecordServiceSQL
@ -68,6 +71,9 @@ class ApiDependencies:
logger.debug(f"Internet connectivity is {config.internet_available}") logger.debug(f"Internet connectivity is {config.internet_available}")
output_folder = config.output_path output_folder = config.output_path
if output_folder is None:
raise ValueError("Output folder is not set")
image_files = DiskImageFileStorage(f"{output_folder}/images") image_files = DiskImageFileStorage(f"{output_folder}/images")
db = init_db(config=config, logger=logger, image_files=image_files) db = init_db(config=config, logger=logger, image_files=image_files)
@ -84,7 +90,10 @@ class ApiDependencies:
image_records = SqliteImageRecordStorage(db=db) image_records = SqliteImageRecordStorage(db=db)
images = ImageService() images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor"))
conditioning = PickleStorageForwardCache(
PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning")
)
model_manager = ModelManagerService(config, logger) model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db) model_record_service = ModelRecordServiceSQL(db=db)
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(event_bus=events)
@ -117,7 +126,6 @@ class ApiDependencies:
image_records=image_records, image_records=image_records,
images=images, images=images,
invocation_cache=invocation_cache, invocation_cache=invocation_cache,
latents=latents,
logger=logger, logger=logger,
model_manager=model_manager, model_manager=model_manager,
model_records=model_record_service, model_records=model_record_service,
@ -131,6 +139,8 @@ class ApiDependencies:
session_queue=session_queue, session_queue=session_queue,
urls=urls, urls=urls,
workflow_records=workflow_records, workflow_records=workflow_records,
tensors=tensors,
conditioning=conditioning,
) )
ApiDependencies.invoker = Invoker(services) ApiDependencies.invoker = Invoker(services)

View File

@ -163,11 +163,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
# TODO: # TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = context.latents.save(tensor=masked_latents) masked_latents_name = context.tensors.save(tensor=masked_latents)
else: else:
masked_latents_name = None masked_latents_name = None
mask_name = context.latents.save(tensor=mask) mask_name = context.tensors.save(tensor=mask)
return DenoiseMaskOutput.build( return DenoiseMaskOutput.build(
mask_name=mask_name, mask_name=mask_name,
@ -621,10 +621,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.denoise_mask is None: if self.denoise_mask is None:
return None, None return None, None
mask = context.latents.get(self.denoise_mask.mask_name) mask = context.tensors.get(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
if self.denoise_mask.masked_latents_name is not None: if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.latents.get(self.denoise_mask.masked_latents_name) masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name)
else: else:
masked_latents = None masked_latents = None
@ -636,11 +636,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed = None seed = None
noise = None noise = None
if self.noise is not None: if self.noise is not None:
noise = context.latents.get(self.noise.latents_name) noise = context.tensors.get(self.noise.latents_name)
seed = self.noise.seed seed = self.noise.seed
if self.latents is not None: if self.latents is not None:
latents = context.latents.get(self.latents.latents_name) latents = context.tensors.get(self.latents.latents_name)
if seed is None: if seed is None:
seed = self.latents.seed seed = self.latents.seed
@ -752,7 +752,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
name = context.latents.save(tensor=result_latents) name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed) return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.latents.get(self.latents.latents_name) latents = context.tensors.get(self.latents.latents_name)
vae_info = context.models.load(**self.vae.vae.model_dump()) vae_info = context.models.load(**self.vae.vae.model_dump())
@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation):
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name) latents = context.tensors.get(self.latents.latents_name)
# TODO: # TODO:
device = choose_torch_device() device = choose_torch_device()
@ -888,7 +888,7 @@ class ResizeLatentsInvocation(BaseInvocation):
if device == torch.device("mps"): if device == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
name = context.latents.save(tensor=resized_latents) name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation):
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name) latents = context.tensors.get(self.latents.latents_name)
# TODO: # TODO:
device = choose_torch_device() device = choose_torch_device()
@ -930,7 +930,7 @@ class ScaleLatentsInvocation(BaseInvocation):
if device == torch.device("mps"): if device == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
name = context.latents.save(tensor=resized_latents) name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@ -1011,7 +1011,7 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
latents = latents.to("cpu") latents = latents.to("cpu")
name = context.latents.save(tensor=latents) name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None) return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
@singledispatchmethod @singledispatchmethod
@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation):
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.latents.get(self.latents_a.latents_name) latents_a = context.tensors.get(self.latents_a.latents_name)
latents_b = context.latents.get(self.latents_b.latents_name) latents_b = context.tensors.get(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape: if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.") raise Exception("Latents to blend must be the same size.")
@ -1103,7 +1103,7 @@ class BlendLatentsInvocation(BaseInvocation):
if device == torch.device("mps"): if device == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
name = context.latents.save(tensor=blended_latents) name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents) return LatentsOutput.build(latents_name=name, latents=blended_latents)
@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name) latents = context.tensors.get(self.latents.latents_name)
x1 = self.x // LATENT_SCALE_FACTOR x1 = self.x // LATENT_SCALE_FACTOR
y1 = self.y // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR
@ -1158,7 +1158,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
cropped_latents = latents[..., y1:y2, x1:x2] cropped_latents = latents[..., y1:y2, x1:x2]
name = context.latents.save(tensor=cropped_latents) name = context.tensors.save(tensor=cropped_latents)
return LatentsOutput.build(latents_name=name, latents=cropped_latents) return LatentsOutput.build(latents_name=name, latents=cropped_latents)

View File

@ -121,5 +121,5 @@ class NoiseInvocation(BaseInvocation):
seed=self.seed, seed=self.seed,
use_cpu=self.use_cpu, use_cpu=self.use_cpu,
) )
name = context.latents.save(tensor=noise) name = context.tensors.save(tensor=noise)
return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed) return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)

View File

@ -313,9 +313,7 @@ class DenoiseMaskOutput(BaseInvocationOutput):
class LatentsOutput(BaseInvocationOutput): class LatentsOutput(BaseInvocationOutput):
"""Base class for nodes that output a single latents tensor""" """Base class for nodes that output a single latents tensor"""
latents: LatentsField = OutputField( latents: LatentsField = OutputField(description=FieldDescriptions.latents)
description=FieldDescriptions.latents,
)
width: int = OutputField(description=FieldDescriptions.width) width: int = OutputField(description=FieldDescriptions.width)
height: int = OutputField(description=FieldDescriptions.height) height: int = OutputField(description=FieldDescriptions.height)
@ -346,7 +344,7 @@ class LatentsInvocation(BaseInvocation):
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name) latents = context.tensors.get(self.latents.latents_name)
return LatentsOutput.build(self.latents.latents_name, latents) return LatentsOutput.build(self.latents.latents_name, latents)

View File

@ -37,7 +37,8 @@ class MemoryInvocationCache(InvocationCacheBase):
if self._max_cache_size == 0: if self._max_cache_size == 0:
return return
self._invoker.services.images.on_deleted(self._delete_by_match) self._invoker.services.images.on_deleted(self._delete_by_match)
self._invoker.services.latents.on_deleted(self._delete_by_match) self._invoker.services.tensors.on_deleted(self._delete_by_match)
self._invoker.services.conditioning.on_deleted(self._delete_by_match)
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
with self._lock: with self._lock:

View File

@ -6,6 +6,10 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from logging import Logger from logging import Logger
import torch
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from .board_image_records.board_image_records_base import BoardImageRecordStorageBase from .board_image_records.board_image_records_base import BoardImageRecordStorageBase
from .board_images.board_images_base import BoardImagesServiceABC from .board_images.board_images_base import BoardImagesServiceABC
from .board_records.board_records_base import BoardRecordStorageBase from .board_records.board_records_base import BoardRecordStorageBase
@ -21,11 +25,11 @@ if TYPE_CHECKING:
from .invocation_queue.invocation_queue_base import InvocationQueueABC from .invocation_queue.invocation_queue_base import InvocationQueueABC
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .item_storage.item_storage_base import ItemStorageABC from .item_storage.item_storage_base import ItemStorageABC
from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_install import ModelInstallServiceBase from .model_install import ModelInstallServiceBase
from .model_manager.model_manager_base import ModelManagerServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase from .names.names_base import NameServiceBase
from .pickle_storage.pickle_storage_base import PickleStorageBase
from .session_processor.session_processor_base import SessionProcessorBase from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase from .session_queue.session_queue_base import SessionQueueBase
from .shared.graph import GraphExecutionState from .shared.graph import GraphExecutionState
@ -48,7 +52,6 @@ class InvocationServices:
images: "ImageServiceABC", images: "ImageServiceABC",
image_files: "ImageFileStorageBase", image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase", image_records: "ImageRecordStorageBase",
latents: "LatentsStorageBase",
logger: "Logger", logger: "Logger",
model_manager: "ModelManagerServiceBase", model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase", model_records: "ModelRecordServiceBase",
@ -63,6 +66,8 @@ class InvocationServices:
names: "NameServiceBase", names: "NameServiceBase",
urls: "UrlServiceBase", urls: "UrlServiceBase",
workflow_records: "WorkflowRecordsStorageBase", workflow_records: "WorkflowRecordsStorageBase",
tensors: "PickleStorageBase[torch.Tensor]",
conditioning: "PickleStorageBase[ConditioningFieldData]",
): ):
self.board_images = board_images self.board_images = board_images
self.board_image_records = board_image_records self.board_image_records = board_image_records
@ -74,7 +79,6 @@ class InvocationServices:
self.images = images self.images = images
self.image_files = image_files self.image_files = image_files
self.image_records = image_records self.image_records = image_records
self.latents = latents
self.logger = logger self.logger = logger
self.model_manager = model_manager self.model_manager = model_manager
self.model_records = model_records self.model_records = model_records
@ -89,3 +93,5 @@ class InvocationServices:
self.names = names self.names = names
self.urls = urls self.urls = urls
self.workflow_records = workflow_records self.workflow_records = workflow_records
self.tensors = tensors
self.conditioning = conditioning

View File

@ -1,58 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from pathlib import Path
from typing import Union
import torch
from invokeai.app.services.invoker import Invoker
from .latents_storage_base import LatentsStorageBase
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: Path
def __init__(self, output_folder: Union[str, Path]):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True)
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._delete_all_latents()
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def save(self, name: str, data: torch.Tensor) -> None:
self.__output_folder.mkdir(parents=True, exist_ok=True)
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
latent_path.unlink()
def get_path(self, name: str) -> Path:
return self.__output_folder / name
def _delete_all_latents(self) -> None:
"""
Deletes all latents from disk.
Must be called after we have access to `self._invoker` (e.g. in `start()`).
"""
deleted_latents_count = 0
freed_space = 0
for latents_file in Path(self.__output_folder).glob("*"):
if latents_file.is_file():
freed_space += latents_file.stat().st_size
deleted_latents_count += 1
latents_file.unlink()
if deleted_latents_count > 0:
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
self._invoker.services.logger.info(
f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)"
)

View File

@ -1,68 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from queue import Queue
from typing import Dict, Optional
import torch
from invokeai.app.services.invoker import Invoker
from .latents_storage_base import LatentsStorageBase
class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
__cache: Dict[str, torch.Tensor]
__cache_ids: Queue
__max_cache_size: int
__underlying_storage: LatentsStorageBase
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
super().__init__()
self.__underlying_storage = underlying_storage
self.__cache = {}
self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
start_op = getattr(self.__underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
self._invoker = invoker
stop_op = getattr(self.__underlying_storage, "stop", None)
if callable(stop_op):
stop_op(invoker)
def get(self, name: str) -> torch.Tensor:
cache_item = self.__get_cache(name)
if cache_item is not None:
return cache_item
latent = self.__underlying_storage.get(name)
self.__set_cache(name, latent)
return latent
def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.save(name, data)
self.__set_cache(name, data)
self._on_changed(data)
def delete(self, name: str) -> None:
self.__underlying_storage.delete(name)
if name in self.__cache:
del self.__cache[name]
self._on_deleted(name)
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor):
if name not in self.__cache:
self.__cache[name] = data
self.__cache_ids.put(name)
if self.__cache_ids.qsize() > self.__max_cache_size:
self.__cache.pop(self.__cache_ids.get())

View File

@ -1,15 +1,15 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable from typing import Callable, Generic, TypeVar
import torch T = TypeVar("T")
class LatentsStorageBase(ABC): class PickleStorageBase(ABC, Generic[T]):
"""Responsible for storing and retrieving latents.""" """Responsible for storing and retrieving non-serializable data using a pickler."""
_on_changed_callbacks: list[Callable[[torch.Tensor], None]] _on_changed_callbacks: list[Callable[[T], None]]
_on_deleted_callbacks: list[Callable[[str], None]] _on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None: def __init__(self) -> None:
@ -17,18 +17,18 @@ class LatentsStorageBase(ABC):
self._on_deleted_callbacks = [] self._on_deleted_callbacks = []
@abstractmethod @abstractmethod
def get(self, name: str) -> torch.Tensor: def get(self, name: str) -> T:
pass pass
@abstractmethod @abstractmethod
def save(self, name: str, data: torch.Tensor) -> None: def save(self, name: str, data: T) -> None:
pass pass
@abstractmethod @abstractmethod
def delete(self, name: str) -> None: def delete(self, name: str) -> None:
pass pass
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None: def on_changed(self, on_changed: Callable[[T], None]) -> None:
"""Register a callback for when an item is changed""" """Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed) self._on_changed_callbacks.append(on_changed)
@ -36,7 +36,7 @@ class LatentsStorageBase(ABC):
"""Register a callback for when an item is deleted""" """Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted) self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: torch.Tensor) -> None: def _on_changed(self, item: T) -> None:
for callback in self._on_changed_callbacks: for callback in self._on_changed_callbacks:
callback(item) callback(item)

View File

@ -0,0 +1,58 @@
from queue import Queue
from typing import Optional, TypeVar
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase
T = TypeVar("T")
class PickleStorageForwardCache(PickleStorageBase[T]):
def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20):
super().__init__()
self._underlying_storage = underlying_storage
self._cache: dict[str, T] = {}
self._cache_ids = Queue[str]()
self._max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
start_op = getattr(self._underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
self._invoker = invoker
stop_op = getattr(self._underlying_storage, "stop", None)
if callable(stop_op):
stop_op(invoker)
def get(self, name: str) -> T:
cache_item = self._get_cache(name)
if cache_item is not None:
return cache_item
latent = self._underlying_storage.get(name)
self._set_cache(name, latent)
return latent
def save(self, name: str, data: T) -> None:
self._underlying_storage.save(name, data)
self._set_cache(name, data)
self._on_changed(data)
def delete(self, name: str) -> None:
self._underlying_storage.delete(name)
if name in self._cache:
del self._cache[name]
self._on_deleted(name)
def _get_cache(self, name: str) -> Optional[T]:
return None if name not in self._cache else self._cache[name]
def _set_cache(self, name: str, data: T):
if name not in self._cache:
self._cache[name] = data
self._cache_ids.put(name)
if self._cache_ids.qsize() > self._max_cache_size:
self._cache.pop(self._cache_ids.get())

View File

@ -0,0 +1,62 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from pathlib import Path
from typing import TypeVar
import torch
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase
T = TypeVar("T")
class PickleStorageTorch(PickleStorageBase[T]):
"""Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`."""
def __init__(self, output_folder: Path, item_type_name: "str"):
self._output_folder = output_folder
self._output_folder.mkdir(parents=True, exist_ok=True)
self._item_type_name = item_type_name
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._delete_all_items()
def get(self, name: str) -> T:
latent_path = self._get_path(name)
return torch.load(latent_path)
def save(self, name: str, data: T) -> None:
self._output_folder.mkdir(parents=True, exist_ok=True)
latent_path = self._get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self._get_path(name)
latent_path.unlink()
def _get_path(self, name: str) -> Path:
return self._output_folder / name
def _delete_all_items(self) -> None:
"""
Deletes all pickled items from disk.
Must be called after we have access to `self._invoker` (e.g. in `start()`).
"""
if not self._invoker:
raise ValueError("Invoker is not set. Must call `start()` first.")
deleted_latents_count = 0
freed_space = 0
for latents_file in Path(self._output_folder).glob("*"):
if latents_file.is_file():
freed_space += latents_file.stat().st_size
deleted_latents_count += 1
latents_file.unlink()
if deleted_latents_count > 0:
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
self._invoker.services.logger.info(
f"Deleted {deleted_latents_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)"
)

View File

@ -216,48 +216,46 @@ class ImagesInterface(InvocationContextInterface):
return self._services.images.get_dto(image_name) return self._services.images.get_dto(image_name)
class LatentsInterface(InvocationContextInterface): class TensorsInterface(InvocationContextInterface):
def save(self, tensor: Tensor) -> str: def save(self, tensor: Tensor) -> str:
""" """
Saves a latents tensor, returning its name. Saves a tensor, returning its name.
:param tensor: The latents tensor to save. :param tensor: The tensor to save.
""" """
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g. # Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
# "mask", "noise", "masked_latents", etc. # "mask", "noise", "masked_latents", etc.
# #
# Retaining that capability in this wrapper would require either many different methods # Retaining that capability in this wrapper would require either many different methods
# to save latents, or extra args for this method. Instead of complicating the API, we # to save tensors, or extra args for this method. Instead of complicating the API, we
# will use the same naming scheme for all latents. # will use the same naming scheme for all tensors.
# #
# This has a very minor impact as we don't use them after a session completes. # This has a very minor impact as we don't use them after a session completes.
# Previously, invocations chose the name for their latents. This is a bit risky, so we # Previously, invocations chose the name for their tensors. This is a bit risky, so we
# will generate a name for them instead. We use a uuid to ensure the name is unique. # will generate a name for them instead. We use a uuid to ensure the name is unique.
# #
# Because the name of the latents file will includes the session and invocation IDs, # Because the name of the tensors file will includes the session and invocation IDs,
# we don't need to worry about collisions. A truncated UUIDv4 is fine. # we don't need to worry about collisions. A truncated UUIDv4 is fine.
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
self._services.latents.save( self._services.tensors.save(
name=name, name=name,
data=tensor, data=tensor,
) )
return name return name
def get(self, latents_name: str) -> Tensor: def get(self, tensor_name: str) -> Tensor:
""" """
Gets a latents tensor by name. Gets a tensor by name.
:param latents_name: The name of the latents tensor to get. :param tensor_name: The name of the tensor to get.
""" """
return self._services.latents.get(latents_name) return self._services.tensors.get(tensor_name)
class ConditioningInterface(InvocationContextInterface): class ConditioningInterface(InvocationContextInterface):
# TODO(psyche): We are (ab)using the latents storage service as a general pickle storage
# service, but it is typed to work with Tensors only. We have to fudge the types here.
def save(self, conditioning_data: ConditioningFieldData) -> str: def save(self, conditioning_data: ConditioningFieldData) -> str:
""" """
Saves a conditioning data object, returning its name. Saves a conditioning data object, returning its name.
@ -265,15 +263,12 @@ class ConditioningInterface(InvocationContextInterface):
:param conditioning_context_data: The conditioning data to save. :param conditioning_context_data: The conditioning data to save.
""" """
# Conditioning data is *not* a Tensor, so we will suffix it to indicate this. # See comment in TensorsInterface.save for why we generate the name here.
#
# See comment for `LatentsInterface.save` for more info about this method (it's very
# similar).
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning" name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
self._services.latents.save( self._services.conditioning.save(
name=name, name=name,
data=conditioning_data, # type: ignore [arg-type] data=conditioning_data,
) )
return name return name
@ -284,7 +279,7 @@ class ConditioningInterface(InvocationContextInterface):
:param conditioning_name: The name of the conditioning data to get. :param conditioning_name: The name of the conditioning data to get.
""" """
return self._services.latents.get(conditioning_name) # type: ignore [return-value] return self._services.conditioning.get(conditioning_name)
class ModelsInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface):
@ -400,7 +395,7 @@ class InvocationContext:
def __init__( def __init__(
self, self,
images: ImagesInterface, images: ImagesInterface,
latents: LatentsInterface, tensors: TensorsInterface,
conditioning: ConditioningInterface, conditioning: ConditioningInterface,
models: ModelsInterface, models: ModelsInterface,
logger: LoggerInterface, logger: LoggerInterface,
@ -412,8 +407,8 @@ class InvocationContext:
) -> None: ) -> None:
self.images = images self.images = images
"""Provides methods to save, get and update images and their metadata.""" """Provides methods to save, get and update images and their metadata."""
self.latents = latents self.tensors = tensors
"""Provides methods to save and get latents tensors, including image, noise, masks, and masked images.""" """Provides methods to save and get tensors, including image, noise, masks, and masked images."""
self.conditioning = conditioning self.conditioning = conditioning
"""Provides methods to save and get conditioning data.""" """Provides methods to save and get conditioning data."""
self.models = models self.models = models
@ -532,7 +527,7 @@ def build_invocation_context(
logger = LoggerInterface(services=services, context_data=context_data) logger = LoggerInterface(services=services, context_data=context_data)
images = ImagesInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data)
latents = LatentsInterface(services=services, context_data=context_data) tensors = TensorsInterface(services=services, context_data=context_data)
models = ModelsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data)
config = ConfigInterface(services=services, context_data=context_data) config = ConfigInterface(services=services, context_data=context_data)
util = UtilInterface(services=services, context_data=context_data) util = UtilInterface(services=services, context_data=context_data)
@ -543,7 +538,7 @@ def build_invocation_context(
images=images, images=images,
logger=logger, logger=logger,
config=config, config=config,
latents=latents, tensors=tensors,
models=models, models=models,
context_data=context_data, context_data=context_data,
util=util, util=util,