mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix dimensions of mask produced by ExtractMasksAndPromptsInvocation. Also, added a clearer error message in case the same error is introduced in the future.
This commit is contained in:
parent
75ef473748
commit
5b5a4204a1
@ -401,6 +401,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
tf = torchvision.transforms.Resize(
|
tf = torchvision.transforms.Resize(
|
||||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(mask.shape) != 3 or mask.shape[0] != 1:
|
||||||
|
raise ValueError(f"Invalid regional prompt mask shape: {mask.shape}. Expected shape (1, h, w).")
|
||||||
|
|
||||||
|
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||||
resized_mask = tf(mask)
|
resized_mask = tf(mask)
|
||||||
return resized_mask
|
return resized_mask
|
||||||
|
@ -88,10 +88,12 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation):
|
|||||||
image_as_tensor = torch.from_numpy(np.array(image, dtype=np.uint8))
|
image_as_tensor = torch.from_numpy(np.array(image, dtype=np.uint8))
|
||||||
|
|
||||||
for pair in self.prompt_color_pairs:
|
for pair in self.prompt_color_pairs:
|
||||||
|
# TODO(ryand): Make this work for both RGB and RGBA images.
|
||||||
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
|
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
|
||||||
|
# Add explicit channel dimension.
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
mask_name = context.tensors.save(mask)
|
mask_name = context.tensors.save(mask)
|
||||||
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name)))
|
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name)))
|
||||||
|
|
||||||
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)
|
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user