feat(nodes): replace latents service with tensors and conditioning services

- New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling
- Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk`
- Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices`
- Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices`
- Remove `latents` service and all `LatentsStorage` classes
- Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods
This commit is contained in:
psychedelicious 2024-02-07 17:41:23 +11:00 committed by Brandon Rising
parent 3563d4ecf7
commit 0c149cbd3b
13 changed files with 197 additions and 193 deletions

View File

@ -2,9 +2,14 @@
from logging import Logger
import torch
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache
from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -23,8 +28,6 @@ from ..services.invocation_queue.invocation_queue_memory import MemoryInvocation
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_install import ModelInstallService
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
@ -68,6 +71,9 @@ class ApiDependencies:
logger.debug(f"Internet connectivity is {config.internet_available}")
output_folder = config.output_path
if output_folder is None:
raise ValueError("Output folder is not set")
image_files = DiskImageFileStorage(f"{output_folder}/images")
db = init_db(config=config, logger=logger, image_files=image_files)
@ -84,7 +90,10 @@ class ApiDependencies:
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor"))
conditioning = PickleStorageForwardCache(
PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning")
)
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
download_queue_service = DownloadQueueService(event_bus=events)
@ -117,7 +126,6 @@ class ApiDependencies:
image_records=image_records,
images=images,
invocation_cache=invocation_cache,
latents=latents,
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
@ -131,6 +139,8 @@ class ApiDependencies:
session_queue=session_queue,
urls=urls,
workflow_records=workflow_records,
tensors=tensors,
conditioning=conditioning,
)
ApiDependencies.invoker = Invoker(services)

View File

@ -163,11 +163,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
# TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = context.latents.save(tensor=masked_latents)
masked_latents_name = context.tensors.save(tensor=masked_latents)
else:
masked_latents_name = None
mask_name = context.latents.save(tensor=mask)
mask_name = context.tensors.save(tensor=mask)
return DenoiseMaskOutput.build(
mask_name=mask_name,
@ -621,10 +621,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.denoise_mask is None:
return None, None
mask = context.latents.get(self.denoise_mask.mask_name)
mask = context.tensors.get(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.latents.get(self.denoise_mask.masked_latents_name)
masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name)
else:
masked_latents = None
@ -636,11 +636,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed = None
noise = None
if self.noise is not None:
noise = context.latents.get(self.noise.latents_name)
noise = context.tensors.get(self.noise.latents_name)
seed = self.noise.seed
if self.latents is not None:
latents = context.latents.get(self.latents.latents_name)
latents = context.tensors.get(self.latents.latents_name)
if seed is None:
seed = self.latents.seed
@ -752,7 +752,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
name = context.latents.save(tensor=result_latents)
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.latents.get(self.latents.latents_name)
latents = context.tensors.get(self.latents.latents_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation):
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name)
latents = context.tensors.get(self.latents.latents_name)
# TODO:
device = choose_torch_device()
@ -888,7 +888,7 @@ class ResizeLatentsInvocation(BaseInvocation):
if device == torch.device("mps"):
mps.empty_cache()
name = context.latents.save(tensor=resized_latents)
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation):
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name)
latents = context.tensors.get(self.latents.latents_name)
# TODO:
device = choose_torch_device()
@ -930,7 +930,7 @@ class ScaleLatentsInvocation(BaseInvocation):
if device == torch.device("mps"):
mps.empty_cache()
name = context.latents.save(tensor=resized_latents)
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@ -1011,7 +1011,7 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
latents = latents.to("cpu")
name = context.latents.save(tensor=latents)
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation):
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.latents.get(self.latents_a.latents_name)
latents_b = context.latents.get(self.latents_b.latents_name)
latents_a = context.tensors.get(self.latents_a.latents_name)
latents_b = context.tensors.get(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
@ -1103,7 +1103,7 @@ class BlendLatentsInvocation(BaseInvocation):
if device == torch.device("mps"):
mps.empty_cache()
name = context.latents.save(tensor=blended_latents)
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents)
@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name)
latents = context.tensors.get(self.latents.latents_name)
x1 = self.x // LATENT_SCALE_FACTOR
y1 = self.y // LATENT_SCALE_FACTOR
@ -1158,7 +1158,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
cropped_latents = latents[..., y1:y2, x1:x2]
name = context.latents.save(tensor=cropped_latents)
name = context.tensors.save(tensor=cropped_latents)
return LatentsOutput.build(latents_name=name, latents=cropped_latents)

View File

@ -121,5 +121,5 @@ class NoiseInvocation(BaseInvocation):
seed=self.seed,
use_cpu=self.use_cpu,
)
name = context.latents.save(tensor=noise)
name = context.tensors.save(tensor=noise)
return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)

View File

@ -313,9 +313,7 @@ class DenoiseMaskOutput(BaseInvocationOutput):
class LatentsOutput(BaseInvocationOutput):
"""Base class for nodes that output a single latents tensor"""
latents: LatentsField = OutputField(
description=FieldDescriptions.latents,
)
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
width: int = OutputField(description=FieldDescriptions.width)
height: int = OutputField(description=FieldDescriptions.height)
@ -346,7 +344,7 @@ class LatentsInvocation(BaseInvocation):
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name)
latents = context.tensors.get(self.latents.latents_name)
return LatentsOutput.build(self.latents.latents_name, latents)

View File

@ -37,7 +37,8 @@ class MemoryInvocationCache(InvocationCacheBase):
if self._max_cache_size == 0:
return
self._invoker.services.images.on_deleted(self._delete_by_match)
self._invoker.services.latents.on_deleted(self._delete_by_match)
self._invoker.services.tensors.on_deleted(self._delete_by_match)
self._invoker.services.conditioning.on_deleted(self._delete_by_match)
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
with self._lock:

View File

@ -6,6 +6,10 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from logging import Logger
import torch
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from .board_image_records.board_image_records_base import BoardImageRecordStorageBase
from .board_images.board_images_base import BoardImagesServiceABC
from .board_records.board_records_base import BoardRecordStorageBase
@ -21,11 +25,11 @@ if TYPE_CHECKING:
from .invocation_queue.invocation_queue_base import InvocationQueueABC
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .item_storage.item_storage_base import ItemStorageABC
from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_install import ModelInstallServiceBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
from .pickle_storage.pickle_storage_base import PickleStorageBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
from .shared.graph import GraphExecutionState
@ -48,7 +52,6 @@ class InvocationServices:
images: "ImageServiceABC",
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
latents: "LatentsStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
@ -63,6 +66,8 @@ class InvocationServices:
names: "NameServiceBase",
urls: "UrlServiceBase",
workflow_records: "WorkflowRecordsStorageBase",
tensors: "PickleStorageBase[torch.Tensor]",
conditioning: "PickleStorageBase[ConditioningFieldData]",
):
self.board_images = board_images
self.board_image_records = board_image_records
@ -74,7 +79,6 @@ class InvocationServices:
self.images = images
self.image_files = image_files
self.image_records = image_records
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
@ -89,3 +93,5 @@ class InvocationServices:
self.names = names
self.urls = urls
self.workflow_records = workflow_records
self.tensors = tensors
self.conditioning = conditioning

View File

@ -1,58 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from pathlib import Path
from typing import Union
import torch
from invokeai.app.services.invoker import Invoker
from .latents_storage_base import LatentsStorageBase
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: Path
def __init__(self, output_folder: Union[str, Path]):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True)
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._delete_all_latents()
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def save(self, name: str, data: torch.Tensor) -> None:
self.__output_folder.mkdir(parents=True, exist_ok=True)
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
latent_path.unlink()
def get_path(self, name: str) -> Path:
return self.__output_folder / name
def _delete_all_latents(self) -> None:
"""
Deletes all latents from disk.
Must be called after we have access to `self._invoker` (e.g. in `start()`).
"""
deleted_latents_count = 0
freed_space = 0
for latents_file in Path(self.__output_folder).glob("*"):
if latents_file.is_file():
freed_space += latents_file.stat().st_size
deleted_latents_count += 1
latents_file.unlink()
if deleted_latents_count > 0:
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
self._invoker.services.logger.info(
f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)"
)

View File

@ -1,68 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from queue import Queue
from typing import Dict, Optional
import torch
from invokeai.app.services.invoker import Invoker
from .latents_storage_base import LatentsStorageBase
class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
__cache: Dict[str, torch.Tensor]
__cache_ids: Queue
__max_cache_size: int
__underlying_storage: LatentsStorageBase
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
super().__init__()
self.__underlying_storage = underlying_storage
self.__cache = {}
self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
start_op = getattr(self.__underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
self._invoker = invoker
stop_op = getattr(self.__underlying_storage, "stop", None)
if callable(stop_op):
stop_op(invoker)
def get(self, name: str) -> torch.Tensor:
cache_item = self.__get_cache(name)
if cache_item is not None:
return cache_item
latent = self.__underlying_storage.get(name)
self.__set_cache(name, latent)
return latent
def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.save(name, data)
self.__set_cache(name, data)
self._on_changed(data)
def delete(self, name: str) -> None:
self.__underlying_storage.delete(name)
if name in self.__cache:
del self.__cache[name]
self._on_deleted(name)
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor):
if name not in self.__cache:
self.__cache[name] = data
self.__cache_ids.put(name)
if self.__cache_ids.qsize() > self.__max_cache_size:
self.__cache.pop(self.__cache_ids.get())

View File

@ -1,15 +1,15 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
from typing import Callable
from typing import Callable, Generic, TypeVar
import torch
T = TypeVar("T")
class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents."""
class PickleStorageBase(ABC, Generic[T]):
"""Responsible for storing and retrieving non-serializable data using a pickler."""
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
_on_changed_callbacks: list[Callable[[T], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
@ -17,18 +17,18 @@ class LatentsStorageBase(ABC):
self._on_deleted_callbacks = []
@abstractmethod
def get(self, name: str) -> torch.Tensor:
def get(self, name: str) -> T:
pass
@abstractmethod
def save(self, name: str, data: torch.Tensor) -> None:
def save(self, name: str, data: T) -> None:
pass
@abstractmethod
def delete(self, name: str) -> None:
pass
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
def on_changed(self, on_changed: Callable[[T], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
@ -36,7 +36,7 @@ class LatentsStorageBase(ABC):
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: torch.Tensor) -> None:
def _on_changed(self, item: T) -> None:
for callback in self._on_changed_callbacks:
callback(item)

View File

@ -0,0 +1,58 @@
from queue import Queue
from typing import Optional, TypeVar
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase
T = TypeVar("T")
class PickleStorageForwardCache(PickleStorageBase[T]):
def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20):
super().__init__()
self._underlying_storage = underlying_storage
self._cache: dict[str, T] = {}
self._cache_ids = Queue[str]()
self._max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
start_op = getattr(self._underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
self._invoker = invoker
stop_op = getattr(self._underlying_storage, "stop", None)
if callable(stop_op):
stop_op(invoker)
def get(self, name: str) -> T:
cache_item = self._get_cache(name)
if cache_item is not None:
return cache_item
latent = self._underlying_storage.get(name)
self._set_cache(name, latent)
return latent
def save(self, name: str, data: T) -> None:
self._underlying_storage.save(name, data)
self._set_cache(name, data)
self._on_changed(data)
def delete(self, name: str) -> None:
self._underlying_storage.delete(name)
if name in self._cache:
del self._cache[name]
self._on_deleted(name)
def _get_cache(self, name: str) -> Optional[T]:
return None if name not in self._cache else self._cache[name]
def _set_cache(self, name: str, data: T):
if name not in self._cache:
self._cache[name] = data
self._cache_ids.put(name)
if self._cache_ids.qsize() > self._max_cache_size:
self._cache.pop(self._cache_ids.get())

View File

@ -0,0 +1,62 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from pathlib import Path
from typing import TypeVar
import torch
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase
T = TypeVar("T")
class PickleStorageTorch(PickleStorageBase[T]):
"""Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`."""
def __init__(self, output_folder: Path, item_type_name: "str"):
self._output_folder = output_folder
self._output_folder.mkdir(parents=True, exist_ok=True)
self._item_type_name = item_type_name
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._delete_all_items()
def get(self, name: str) -> T:
latent_path = self._get_path(name)
return torch.load(latent_path)
def save(self, name: str, data: T) -> None:
self._output_folder.mkdir(parents=True, exist_ok=True)
latent_path = self._get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self._get_path(name)
latent_path.unlink()
def _get_path(self, name: str) -> Path:
return self._output_folder / name
def _delete_all_items(self) -> None:
"""
Deletes all pickled items from disk.
Must be called after we have access to `self._invoker` (e.g. in `start()`).
"""
if not self._invoker:
raise ValueError("Invoker is not set. Must call `start()` first.")
deleted_latents_count = 0
freed_space = 0
for latents_file in Path(self._output_folder).glob("*"):
if latents_file.is_file():
freed_space += latents_file.stat().st_size
deleted_latents_count += 1
latents_file.unlink()
if deleted_latents_count > 0:
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
self._invoker.services.logger.info(
f"Deleted {deleted_latents_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)"
)

View File

@ -216,48 +216,46 @@ class ImagesInterface(InvocationContextInterface):
return self._services.images.get_dto(image_name)
class LatentsInterface(InvocationContextInterface):
class TensorsInterface(InvocationContextInterface):
def save(self, tensor: Tensor) -> str:
"""
Saves a latents tensor, returning its name.
Saves a tensor, returning its name.
:param tensor: The latents tensor to save.
:param tensor: The tensor to save.
"""
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
# "mask", "noise", "masked_latents", etc.
#
# Retaining that capability in this wrapper would require either many different methods
# to save latents, or extra args for this method. Instead of complicating the API, we
# will use the same naming scheme for all latents.
# to save tensors, or extra args for this method. Instead of complicating the API, we
# will use the same naming scheme for all tensors.
#
# This has a very minor impact as we don't use them after a session completes.
# Previously, invocations chose the name for their latents. This is a bit risky, so we
# Previously, invocations chose the name for their tensors. This is a bit risky, so we
# will generate a name for them instead. We use a uuid to ensure the name is unique.
#
# Because the name of the latents file will includes the session and invocation IDs,
# Because the name of the tensors file will includes the session and invocation IDs,
# we don't need to worry about collisions. A truncated UUIDv4 is fine.
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
self._services.latents.save(
self._services.tensors.save(
name=name,
data=tensor,
)
return name
def get(self, latents_name: str) -> Tensor:
def get(self, tensor_name: str) -> Tensor:
"""
Gets a latents tensor by name.
Gets a tensor by name.
:param latents_name: The name of the latents tensor to get.
:param tensor_name: The name of the tensor to get.
"""
return self._services.latents.get(latents_name)
return self._services.tensors.get(tensor_name)
class ConditioningInterface(InvocationContextInterface):
# TODO(psyche): We are (ab)using the latents storage service as a general pickle storage
# service, but it is typed to work with Tensors only. We have to fudge the types here.
def save(self, conditioning_data: ConditioningFieldData) -> str:
"""
Saves a conditioning data object, returning its name.
@ -265,15 +263,12 @@ class ConditioningInterface(InvocationContextInterface):
:param conditioning_context_data: The conditioning data to save.
"""
# Conditioning data is *not* a Tensor, so we will suffix it to indicate this.
#
# See comment for `LatentsInterface.save` for more info about this method (it's very
# similar).
# See comment in TensorsInterface.save for why we generate the name here.
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning"
self._services.latents.save(
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
self._services.conditioning.save(
name=name,
data=conditioning_data, # type: ignore [arg-type]
data=conditioning_data,
)
return name
@ -284,7 +279,7 @@ class ConditioningInterface(InvocationContextInterface):
:param conditioning_name: The name of the conditioning data to get.
"""
return self._services.latents.get(conditioning_name) # type: ignore [return-value]
return self._services.conditioning.get(conditioning_name)
class ModelsInterface(InvocationContextInterface):
@ -400,7 +395,7 @@ class InvocationContext:
def __init__(
self,
images: ImagesInterface,
latents: LatentsInterface,
tensors: TensorsInterface,
conditioning: ConditioningInterface,
models: ModelsInterface,
logger: LoggerInterface,
@ -412,8 +407,8 @@ class InvocationContext:
) -> None:
self.images = images
"""Provides methods to save, get and update images and their metadata."""
self.latents = latents
"""Provides methods to save and get latents tensors, including image, noise, masks, and masked images."""
self.tensors = tensors
"""Provides methods to save and get tensors, including image, noise, masks, and masked images."""
self.conditioning = conditioning
"""Provides methods to save and get conditioning data."""
self.models = models
@ -532,7 +527,7 @@ def build_invocation_context(
logger = LoggerInterface(services=services, context_data=context_data)
images = ImagesInterface(services=services, context_data=context_data)
latents = LatentsInterface(services=services, context_data=context_data)
tensors = TensorsInterface(services=services, context_data=context_data)
models = ModelsInterface(services=services, context_data=context_data)
config = ConfigInterface(services=services, context_data=context_data)
util = UtilInterface(services=services, context_data=context_data)
@ -543,7 +538,7 @@ def build_invocation_context(
images=images,
logger=logger,
config=config,
latents=latents,
tensors=tensors,
models=models,
context_data=context_data,
util=util,