mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): replace latents service with tensors and conditioning services
- New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling - Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk` - Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices` - Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices` - Remove `latents` service and all `LatentsStorage` classes - Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods
This commit is contained in:
@ -163,11 +163,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
# TODO:
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||
|
||||
masked_latents_name = context.latents.save(tensor=masked_latents)
|
||||
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||
else:
|
||||
masked_latents_name = None
|
||||
|
||||
mask_name = context.latents.save(tensor=mask)
|
||||
mask_name = context.tensors.save(tensor=mask)
|
||||
|
||||
return DenoiseMaskOutput.build(
|
||||
mask_name=mask_name,
|
||||
@ -621,10 +621,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if self.denoise_mask is None:
|
||||
return None, None
|
||||
|
||||
mask = context.latents.get(self.denoise_mask.mask_name)
|
||||
mask = context.tensors.get(self.denoise_mask.mask_name)
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
if self.denoise_mask.masked_latents_name is not None:
|
||||
masked_latents = context.latents.get(self.denoise_mask.masked_latents_name)
|
||||
masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name)
|
||||
else:
|
||||
masked_latents = None
|
||||
|
||||
@ -636,11 +636,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed = None
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.latents.get(self.noise.latents_name)
|
||||
noise = context.tensors.get(self.noise.latents_name)
|
||||
seed = self.noise.seed
|
||||
|
||||
if self.latents is not None:
|
||||
latents = context.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
|
||||
@ -752,7 +752,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = context.latents.save(tensor=result_latents)
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
|
||||
|
||||
|
||||
@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
@ -888,7 +888,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = context.latents.save(tensor=resized_latents)
|
||||
name = context.tensors.save(tensor=resized_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
@ -930,7 +930,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = context.latents.save(tensor=resized_latents)
|
||||
name = context.tensors.save(tensor=resized_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@ -1011,7 +1011,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.latents.save(tensor=latents)
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
@singledispatchmethod
|
||||
@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents_a = context.latents.get(self.latents_a.latents_name)
|
||||
latents_b = context.latents.get(self.latents_b.latents_name)
|
||||
latents_a = context.tensors.get(self.latents_a.latents_name)
|
||||
latents_b = context.tensors.get(self.latents_b.latents_name)
|
||||
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise Exception("Latents to blend must be the same size.")
|
||||
@ -1103,7 +1103,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = context.latents.save(tensor=blended_latents)
|
||||
name = context.tensors.save(tensor=blended_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||
|
||||
|
||||
@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
|
||||
x1 = self.x // LATENT_SCALE_FACTOR
|
||||
y1 = self.y // LATENT_SCALE_FACTOR
|
||||
@ -1158,7 +1158,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
||||
|
||||
cropped_latents = latents[..., y1:y2, x1:x2]
|
||||
|
||||
name = context.latents.save(tensor=cropped_latents)
|
||||
name = context.tensors.save(tensor=cropped_latents)
|
||||
|
||||
return LatentsOutput.build(latents_name=name, latents=cropped_latents)
|
||||
|
||||
|
Reference in New Issue
Block a user