Fix dimensions of mask produced by ExtractMasksAndPromptsInvocation. Also, added a clearer error message in case the same error is introduced in the future.

This commit is contained in:
Ryan Dick 2024-04-08 12:27:57 -04:00 committed by Kent Keirsey
parent 75ef473748
commit 5b5a4204a1
2 changed files with 8 additions and 1 deletions

View File

@ -401,6 +401,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
tf = torchvision.transforms.Resize( tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
) )
if len(mask.shape) != 3 or mask.shape[0] != 1:
raise ValueError(f"Invalid regional prompt mask shape: {mask.shape}. Expected shape (1, h, w).")
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask) resized_mask = tf(mask)
return resized_mask return resized_mask

View File

@ -88,10 +88,12 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation):
image_as_tensor = torch.from_numpy(np.array(image, dtype=np.uint8)) image_as_tensor = torch.from_numpy(np.array(image, dtype=np.uint8))
for pair in self.prompt_color_pairs: for pair in self.prompt_color_pairs:
# TODO(ryand): Make this work for both RGB and RGBA images.
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1) mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
# Add explicit channel dimension.
mask = mask.unsqueeze(0)
mask_name = context.tensors.save(mask) mask_name = context.tensors.save(mask)
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name))) prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name)))
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs) return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)