feat(nodes): replace latents service with tensors and conditioning services

- New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling
- Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk`
- Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices`
- Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices`
- Remove `latents` service and all `LatentsStorage` classes
- Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods
This commit is contained in:
psychedelicious
2024-02-07 17:41:23 +11:00
parent 31db62ba99
commit 0710fb3fb0
13 changed files with 197 additions and 193 deletions

View File

@ -216,48 +216,46 @@ class ImagesInterface(InvocationContextInterface):
return self._services.images.get_dto(image_name)
class LatentsInterface(InvocationContextInterface):
class TensorsInterface(InvocationContextInterface):
def save(self, tensor: Tensor) -> str:
"""
Saves a latents tensor, returning its name.
Saves a tensor, returning its name.
:param tensor: The latents tensor to save.
:param tensor: The tensor to save.
"""
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
# "mask", "noise", "masked_latents", etc.
#
# Retaining that capability in this wrapper would require either many different methods
# to save latents, or extra args for this method. Instead of complicating the API, we
# will use the same naming scheme for all latents.
# to save tensors, or extra args for this method. Instead of complicating the API, we
# will use the same naming scheme for all tensors.
#
# This has a very minor impact as we don't use them after a session completes.
# Previously, invocations chose the name for their latents. This is a bit risky, so we
# Previously, invocations chose the name for their tensors. This is a bit risky, so we
# will generate a name for them instead. We use a uuid to ensure the name is unique.
#
# Because the name of the latents file will includes the session and invocation IDs,
# Because the name of the tensors file will includes the session and invocation IDs,
# we don't need to worry about collisions. A truncated UUIDv4 is fine.
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
self._services.latents.save(
self._services.tensors.save(
name=name,
data=tensor,
)
return name
def get(self, latents_name: str) -> Tensor:
def get(self, tensor_name: str) -> Tensor:
"""
Gets a latents tensor by name.
Gets a tensor by name.
:param latents_name: The name of the latents tensor to get.
:param tensor_name: The name of the tensor to get.
"""
return self._services.latents.get(latents_name)
return self._services.tensors.get(tensor_name)
class ConditioningInterface(InvocationContextInterface):
# TODO(psyche): We are (ab)using the latents storage service as a general pickle storage
# service, but it is typed to work with Tensors only. We have to fudge the types here.
def save(self, conditioning_data: ConditioningFieldData) -> str:
"""
Saves a conditioning data object, returning its name.
@ -265,15 +263,12 @@ class ConditioningInterface(InvocationContextInterface):
:param conditioning_context_data: The conditioning data to save.
"""
# Conditioning data is *not* a Tensor, so we will suffix it to indicate this.
#
# See comment for `LatentsInterface.save` for more info about this method (it's very
# similar).
# See comment in TensorsInterface.save for why we generate the name here.
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning"
self._services.latents.save(
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
self._services.conditioning.save(
name=name,
data=conditioning_data, # type: ignore [arg-type]
data=conditioning_data,
)
return name
@ -284,7 +279,7 @@ class ConditioningInterface(InvocationContextInterface):
:param conditioning_name: The name of the conditioning data to get.
"""
return self._services.latents.get(conditioning_name) # type: ignore [return-value]
return self._services.conditioning.get(conditioning_name)
class ModelsInterface(InvocationContextInterface):
@ -400,7 +395,7 @@ class InvocationContext:
def __init__(
self,
images: ImagesInterface,
latents: LatentsInterface,
tensors: TensorsInterface,
conditioning: ConditioningInterface,
models: ModelsInterface,
logger: LoggerInterface,
@ -412,8 +407,8 @@ class InvocationContext:
) -> None:
self.images = images
"""Provides methods to save, get and update images and their metadata."""
self.latents = latents
"""Provides methods to save and get latents tensors, including image, noise, masks, and masked images."""
self.tensors = tensors
"""Provides methods to save and get tensors, including image, noise, masks, and masked images."""
self.conditioning = conditioning
"""Provides methods to save and get conditioning data."""
self.models = models
@ -532,7 +527,7 @@ def build_invocation_context(
logger = LoggerInterface(services=services, context_data=context_data)
images = ImagesInterface(services=services, context_data=context_data)
latents = LatentsInterface(services=services, context_data=context_data)
tensors = TensorsInterface(services=services, context_data=context_data)
models = ModelsInterface(services=services, context_data=context_data)
config = ConfigInterface(services=services, context_data=context_data)
util = UtilInterface(services=services, context_data=context_data)
@ -543,7 +538,7 @@ def build_invocation_context(
images=images,
logger=logger,
config=config,
latents=latents,
tensors=tensors,
models=models,
context_data=context_data,
util=util,