mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): replace latents service with tensors and conditioning services
- New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling - Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk` - Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices` - Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices` - Remove `latents` service and all `LatentsStorage` classes - Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods
This commit is contained in:
@ -216,48 +216,46 @@ class ImagesInterface(InvocationContextInterface):
|
||||
return self._services.images.get_dto(image_name)
|
||||
|
||||
|
||||
class LatentsInterface(InvocationContextInterface):
|
||||
class TensorsInterface(InvocationContextInterface):
|
||||
def save(self, tensor: Tensor) -> str:
|
||||
"""
|
||||
Saves a latents tensor, returning its name.
|
||||
Saves a tensor, returning its name.
|
||||
|
||||
:param tensor: The latents tensor to save.
|
||||
:param tensor: The tensor to save.
|
||||
"""
|
||||
|
||||
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
|
||||
# "mask", "noise", "masked_latents", etc.
|
||||
#
|
||||
# Retaining that capability in this wrapper would require either many different methods
|
||||
# to save latents, or extra args for this method. Instead of complicating the API, we
|
||||
# will use the same naming scheme for all latents.
|
||||
# to save tensors, or extra args for this method. Instead of complicating the API, we
|
||||
# will use the same naming scheme for all tensors.
|
||||
#
|
||||
# This has a very minor impact as we don't use them after a session completes.
|
||||
|
||||
# Previously, invocations chose the name for their latents. This is a bit risky, so we
|
||||
# Previously, invocations chose the name for their tensors. This is a bit risky, so we
|
||||
# will generate a name for them instead. We use a uuid to ensure the name is unique.
|
||||
#
|
||||
# Because the name of the latents file will includes the session and invocation IDs,
|
||||
# Because the name of the tensors file will includes the session and invocation IDs,
|
||||
# we don't need to worry about collisions. A truncated UUIDv4 is fine.
|
||||
|
||||
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
|
||||
self._services.latents.save(
|
||||
self._services.tensors.save(
|
||||
name=name,
|
||||
data=tensor,
|
||||
)
|
||||
return name
|
||||
|
||||
def get(self, latents_name: str) -> Tensor:
|
||||
def get(self, tensor_name: str) -> Tensor:
|
||||
"""
|
||||
Gets a latents tensor by name.
|
||||
Gets a tensor by name.
|
||||
|
||||
:param latents_name: The name of the latents tensor to get.
|
||||
:param tensor_name: The name of the tensor to get.
|
||||
"""
|
||||
return self._services.latents.get(latents_name)
|
||||
return self._services.tensors.get(tensor_name)
|
||||
|
||||
|
||||
class ConditioningInterface(InvocationContextInterface):
|
||||
# TODO(psyche): We are (ab)using the latents storage service as a general pickle storage
|
||||
# service, but it is typed to work with Tensors only. We have to fudge the types here.
|
||||
def save(self, conditioning_data: ConditioningFieldData) -> str:
|
||||
"""
|
||||
Saves a conditioning data object, returning its name.
|
||||
@ -265,15 +263,12 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
:param conditioning_context_data: The conditioning data to save.
|
||||
"""
|
||||
|
||||
# Conditioning data is *not* a Tensor, so we will suffix it to indicate this.
|
||||
#
|
||||
# See comment for `LatentsInterface.save` for more info about this method (it's very
|
||||
# similar).
|
||||
# See comment in TensorsInterface.save for why we generate the name here.
|
||||
|
||||
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning"
|
||||
self._services.latents.save(
|
||||
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
|
||||
self._services.conditioning.save(
|
||||
name=name,
|
||||
data=conditioning_data, # type: ignore [arg-type]
|
||||
data=conditioning_data,
|
||||
)
|
||||
return name
|
||||
|
||||
@ -284,7 +279,7 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
:param conditioning_name: The name of the conditioning data to get.
|
||||
"""
|
||||
|
||||
return self._services.latents.get(conditioning_name) # type: ignore [return-value]
|
||||
return self._services.conditioning.get(conditioning_name)
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
@ -400,7 +395,7 @@ class InvocationContext:
|
||||
def __init__(
|
||||
self,
|
||||
images: ImagesInterface,
|
||||
latents: LatentsInterface,
|
||||
tensors: TensorsInterface,
|
||||
conditioning: ConditioningInterface,
|
||||
models: ModelsInterface,
|
||||
logger: LoggerInterface,
|
||||
@ -412,8 +407,8 @@ class InvocationContext:
|
||||
) -> None:
|
||||
self.images = images
|
||||
"""Provides methods to save, get and update images and their metadata."""
|
||||
self.latents = latents
|
||||
"""Provides methods to save and get latents tensors, including image, noise, masks, and masked images."""
|
||||
self.tensors = tensors
|
||||
"""Provides methods to save and get tensors, including image, noise, masks, and masked images."""
|
||||
self.conditioning = conditioning
|
||||
"""Provides methods to save and get conditioning data."""
|
||||
self.models = models
|
||||
@ -532,7 +527,7 @@ def build_invocation_context(
|
||||
|
||||
logger = LoggerInterface(services=services, context_data=context_data)
|
||||
images = ImagesInterface(services=services, context_data=context_data)
|
||||
latents = LatentsInterface(services=services, context_data=context_data)
|
||||
tensors = TensorsInterface(services=services, context_data=context_data)
|
||||
models = ModelsInterface(services=services, context_data=context_data)
|
||||
config = ConfigInterface(services=services, context_data=context_data)
|
||||
util = UtilInterface(services=services, context_data=context_data)
|
||||
@ -543,7 +538,7 @@ def build_invocation_context(
|
||||
images=images,
|
||||
logger=logger,
|
||||
config=config,
|
||||
latents=latents,
|
||||
tensors=tensors,
|
||||
models=models,
|
||||
context_data=context_data,
|
||||
util=util,
|
||||
|
Reference in New Issue
Block a user