feat(nodes): replace latents service with tensors and conditioning services

- New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling
- Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk`
- Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices`
- Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices`
- Remove `latents` service and all `LatentsStorage` classes
- Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods
This commit is contained in:
psychedelicious
2024-02-07 17:41:23 +11:00
parent 31db62ba99
commit 0710fb3fb0
13 changed files with 197 additions and 193 deletions

View File

@ -163,11 +163,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
# TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = context.latents.save(tensor=masked_latents)
masked_latents_name = context.tensors.save(tensor=masked_latents)
else:
masked_latents_name = None
mask_name = context.latents.save(tensor=mask)
mask_name = context.tensors.save(tensor=mask)
return DenoiseMaskOutput.build(
mask_name=mask_name,
@ -621,10 +621,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.denoise_mask is None:
return None, None
mask = context.latents.get(self.denoise_mask.mask_name)
mask = context.tensors.get(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.latents.get(self.denoise_mask.masked_latents_name)
masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name)
else:
masked_latents = None
@ -636,11 +636,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed = None
noise = None
if self.noise is not None:
noise = context.latents.get(self.noise.latents_name)
noise = context.tensors.get(self.noise.latents_name)
seed = self.noise.seed
if self.latents is not None:
latents = context.latents.get(self.latents.latents_name)
latents = context.tensors.get(self.latents.latents_name)
if seed is None:
seed = self.latents.seed
@ -752,7 +752,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
name = context.latents.save(tensor=result_latents)
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.latents.get(self.latents.latents_name)
latents = context.tensors.get(self.latents.latents_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation):
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name)
latents = context.tensors.get(self.latents.latents_name)
# TODO:
device = choose_torch_device()
@ -888,7 +888,7 @@ class ResizeLatentsInvocation(BaseInvocation):
if device == torch.device("mps"):
mps.empty_cache()
name = context.latents.save(tensor=resized_latents)
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation):
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name)
latents = context.tensors.get(self.latents.latents_name)
# TODO:
device = choose_torch_device()
@ -930,7 +930,7 @@ class ScaleLatentsInvocation(BaseInvocation):
if device == torch.device("mps"):
mps.empty_cache()
name = context.latents.save(tensor=resized_latents)
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@ -1011,7 +1011,7 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
latents = latents.to("cpu")
name = context.latents.save(tensor=latents)
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation):
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.latents.get(self.latents_a.latents_name)
latents_b = context.latents.get(self.latents_b.latents_name)
latents_a = context.tensors.get(self.latents_a.latents_name)
latents_b = context.tensors.get(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
@ -1103,7 +1103,7 @@ class BlendLatentsInvocation(BaseInvocation):
if device == torch.device("mps"):
mps.empty_cache()
name = context.latents.save(tensor=blended_latents)
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents)
@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name)
latents = context.tensors.get(self.latents.latents_name)
x1 = self.x // LATENT_SCALE_FACTOR
y1 = self.y // LATENT_SCALE_FACTOR
@ -1158,7 +1158,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
cropped_latents = latents[..., y1:y2, x1:x2]
name = context.latents.save(tensor=cropped_latents)
name = context.tensors.save(tensor=cropped_latents)
return LatentsOutput.build(latents_name=name, latents=cropped_latents)