From 5b5a4204a1da0de1c4b2a124d844d1abeed0bd5b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 8 Apr 2024 12:27:57 -0400 Subject: [PATCH] Fix dimensions of mask produced by ExtractMasksAndPromptsInvocation. Also, added a clearer error message in case the same error is introduced in the future. --- invokeai/app/invocations/latent.py | 5 +++++ invokeai/app/invocations/mask.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 764e744a2e..db7cd20172 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -401,6 +401,11 @@ class DenoiseLatentsInvocation(BaseInvocation): tf = torchvision.transforms.Resize( (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST ) + + if len(mask.shape) != 3 or mask.shape[0] != 1: + raise ValueError(f"Invalid regional prompt mask shape: {mask.shape}. Expected shape (1, h, w).") + + # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) resized_mask = tf(mask) return resized_mask diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index 572fd7c15d..31eb70e056 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -88,10 +88,12 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation): image_as_tensor = torch.from_numpy(np.array(image, dtype=np.uint8)) for pair in self.prompt_color_pairs: + # TODO(ryand): Make this work for both RGB and RGBA images. mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1) + # Add explicit channel dimension. + mask = mask.unsqueeze(0) mask_name = context.tensors.save(mask) prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name))) - return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)