Add utility to_standard_float_mask(...) to convert various mask formats to a standardized format.

This commit is contained in:
Ryan Dick 2024-04-08 15:07:49 -04:00 committed by Kent Keirsey
parent 338bf808d6
commit 182810337c
6 changed files with 155 additions and 11 deletions

View File

@ -61,6 +61,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings
from ...backend.stable_diffusion.diffusers_pipeline import (
@ -386,25 +387,25 @@ class DenoiseLatentsInvocation(BaseInvocation):
return text_embeddings, text_embeddings_masks
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
mask = to_standard_float_mask(mask, out_dtype=dtype)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
if len(mask.shape) != 3 or mask.shape[0] != 1:
raise ValueError(f"Invalid regional prompt mask shape: {mask.shape}. Expected shape (1, h, w).")
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
@ -416,6 +417,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
if masks is None:
@ -465,7 +467,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
)
)
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
processed_masks.append(
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
@ -524,12 +528,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
conditioning_data = TextConditioningData(

View File

@ -90,8 +90,6 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation):
for pair in self.prompt_color_pairs:
# TODO(ryand): Make this work for both RGB and RGBA images.
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
# Add explicit channel dimension.
mask = mask.unsqueeze(0)
mask_tensor_name = context.tensors.save(mask)
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name)))
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)

View File

@ -46,7 +46,6 @@ class RegionalPromptData:
for batch_sample_regions in regions:
batch_sample_masks_by_seq_len.append({})
# Convert the bool masks to float masks so that max pooling can be applied.
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.

View File

@ -360,7 +360,7 @@ class InvokeAIDiffuserComponent:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
_, _, h, w = x.shape
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
masks=torch.ones((1, 1, h, w), dtype=x.dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
regions.append(r)

View File

@ -0,0 +1,53 @@
import torch
def to_standard_mask_dim(mask: torch.Tensor) -> torch.Tensor:
"""Standardize the dimensions of a mask tensor.
Args:
mask (torch.Tensor): A mask tensor. The shape can be (1, h, w) or (h, w).
Returns:
torch.Tensor: The output mask tensor. The shape is (1, h, w).
"""
# Get the mask height and width.
if mask.ndim == 2:
mask = mask.unsqueeze(0)
elif mask.ndim == 3 and mask.shape[0] == 1:
pass
else:
raise ValueError(f"Unsupported mask shape: {mask.shape}. Expected (1, h, w) or (h, w).")
return mask
def to_standard_float_mask(mask: torch.Tensor, out_dtype: torch.dtype) -> torch.Tensor:
"""Standardize the format of a mask tensor.
Args:
mask (torch.Tensor): A mask tensor. The dtype can be any bool, float, or int type. The shape must be (1, h, w)
or (h, w).
out_dtype (torch.dtype): The dtype of the output mask tensor. Must be a float type.
Returns:
torch.Tensor: The output mask tensor. The dtype is out_dtype. The shape is (1, h, w). All values are either 0.0
or 1.0.
"""
if not out_dtype.is_floating_point:
raise ValueError(f"out_dtype must be a float type, but got {out_dtype}")
mask = to_standard_mask_dim(mask)
mask = mask.to(out_dtype)
# Set masked regions to 1.0.
if mask.dtype == torch.bool:
mask = mask.to(out_dtype)
else:
mask = mask.to(out_dtype)
mask_region = mask > 0.5
mask[mask_region] = 1.0
mask[~mask_region] = 0.0
return mask

View File

@ -0,0 +1,88 @@
import pytest
import torch
from invokeai.backend.util.mask import to_standard_float_mask
def test_to_standard_float_mask_wrong_ndim():
with pytest.raises(ValueError):
to_standard_float_mask(mask=torch.zeros((1, 1, 5, 10)), out_dtype=torch.float32)
def test_to_standard_float_mask_wrong_shape():
with pytest.raises(ValueError):
to_standard_float_mask(mask=torch.zeros((2, 5, 10)), out_dtype=torch.float32)
def check_mask_result(mask: torch.Tensor, expected_mask: torch.Tensor):
"""Helper function to check the result of `to_standard_float_mask()`."""
assert mask.shape == expected_mask.shape
assert mask.dtype == expected_mask.dtype
assert torch.allclose(mask, expected_mask)
def test_to_standard_float_mask_ndim_2():
"""Test the case where the input mask has shape (h, w)."""
mask = torch.zeros((3, 2), dtype=torch.float32)
mask[0, 0] = 1.0
mask[1, 1] = 1.0
expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32)
expected_mask[0, 0, 0] = 1.0
expected_mask[0, 1, 1] = 1.0
new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32)
check_mask_result(mask=new_mask, expected_mask=expected_mask)
def test_to_standard_float_mask_ndim_3():
"""Test the case where the input mask has shape (1, h, w)."""
mask = torch.zeros((1, 3, 2), dtype=torch.float32)
mask[0, 0, 0] = 1.0
mask[0, 1, 1] = 1.0
expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32)
expected_mask[0, 0, 0] = 1.0
expected_mask[0, 1, 1] = 1.0
new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32)
check_mask_result(mask=new_mask, expected_mask=expected_mask)
@pytest.mark.parametrize(
"out_dtype",
[torch.float32, torch.float16],
)
def test_to_standard_float_mask_bool_to_float(out_dtype: torch.dtype):
"""Test the case where the input mask has dtype bool."""
mask = torch.zeros((3, 2), dtype=torch.bool)
mask[0, 0] = True
mask[1, 1] = True
expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype)
expected_mask[0, 0, 0] = 1.0
expected_mask[0, 1, 1] = 1.0
new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype)
check_mask_result(mask=new_mask, expected_mask=expected_mask)
@pytest.mark.parametrize(
"out_dtype",
[torch.float32, torch.float16],
)
def test_to_standard_float_mask_float_to_float(out_dtype: torch.dtype):
"""Test the case where the input mask has type float (but not all values are 0.0 or 1.0)."""
mask = torch.zeros((3, 2), dtype=torch.float32)
mask[0, 0] = 0.1 # Should be converted to 0.0
mask[0, 1] = 0.9 # Should be converted to 1.0
expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype)
expected_mask[0, 0, 1] = 1.0
new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype)
check_mask_result(mask=new_mask, expected_mask=expected_mask)