Add utility to_standard_float_mask(...) to convert various mask formats to a standardized format.

This commit is contained in:
Ryan Dick 2024-04-08 15:07:49 -04:00 committed by Kent Keirsey
parent 338bf808d6
commit 182810337c
6 changed files with 155 additions and 11 deletions

View File

@ -61,6 +61,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningData, TextConditioningData,
TextConditioningRegions, TextConditioningRegions,
) )
from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
@ -386,25 +387,25 @@ class DenoiseLatentsInvocation(BaseInvocation):
return text_embeddings, text_embeddings_masks return text_embeddings, text_embeddings_masks
def _preprocess_regional_prompt_mask( def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor: ) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width. """Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width. If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
Returns: Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width). torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
""" """
if mask is None: if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool) return torch.ones((1, 1, target_height, target_width), dtype=dtype)
mask = to_standard_float_mask(mask, out_dtype=dtype)
tf = torchvision.transforms.Resize( tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
) )
if len(mask.shape) != 3 or mask.shape[0] != 1:
raise ValueError(f"Invalid regional prompt mask shape: {mask.shape}. Expected shape (1, h, w).")
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask) resized_mask = tf(mask)
@ -416,6 +417,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
masks: Optional[list[Optional[torch.Tensor]]], masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int, latent_height: int,
latent_width: int, latent_width: int,
dtype: torch.dtype,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly.""" """Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
if masks is None: if masks is None:
@ -465,7 +467,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
) )
) )
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width)) processed_masks.append(
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1] cur_text_embedding_len += text_embedding_info.embeds.shape[1]
@ -524,12 +528,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
masks=cond_text_embedding_masks, masks=cond_text_embedding_masks,
latent_height=latent_height, latent_height=latent_height,
latent_width=latent_width, latent_width=latent_width,
dtype=unet.dtype,
) )
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings( uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings, text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks, masks=uncond_text_embedding_masks,
latent_height=latent_height, latent_height=latent_height,
latent_width=latent_width, latent_width=latent_width,
dtype=unet.dtype,
) )
conditioning_data = TextConditioningData( conditioning_data = TextConditioningData(

View File

@ -90,8 +90,6 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation):
for pair in self.prompt_color_pairs: for pair in self.prompt_color_pairs:
# TODO(ryand): Make this work for both RGB and RGBA images. # TODO(ryand): Make this work for both RGB and RGBA images.
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1) mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
# Add explicit channel dimension.
mask = mask.unsqueeze(0)
mask_tensor_name = context.tensors.save(mask) mask_tensor_name = context.tensors.save(mask)
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name))) prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name)))
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs) return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)

View File

@ -46,7 +46,6 @@ class RegionalPromptData:
for batch_sample_regions in regions: for batch_sample_regions in regions:
batch_sample_masks_by_seq_len.append({}) batch_sample_masks_by_seq_len.append({})
# Convert the bool masks to float masks so that max pooling can be applied.
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype) batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.

View File

@ -360,7 +360,7 @@ class InvokeAIDiffuserComponent:
# Create a dummy mask and range for text conditioning that doesn't have region masks. # Create a dummy mask and range for text conditioning that doesn't have region masks.
_, _, h, w = x.shape _, _, h, w = x.shape
r = TextConditioningRegions( r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=torch.bool), masks=torch.ones((1, 1, h, w), dtype=x.dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])], ranges=[Range(start=0, end=c.embeds.shape[1])],
) )
regions.append(r) regions.append(r)

View File

@ -0,0 +1,53 @@
import torch
def to_standard_mask_dim(mask: torch.Tensor) -> torch.Tensor:
"""Standardize the dimensions of a mask tensor.
Args:
mask (torch.Tensor): A mask tensor. The shape can be (1, h, w) or (h, w).
Returns:
torch.Tensor: The output mask tensor. The shape is (1, h, w).
"""
# Get the mask height and width.
if mask.ndim == 2:
mask = mask.unsqueeze(0)
elif mask.ndim == 3 and mask.shape[0] == 1:
pass
else:
raise ValueError(f"Unsupported mask shape: {mask.shape}. Expected (1, h, w) or (h, w).")
return mask
def to_standard_float_mask(mask: torch.Tensor, out_dtype: torch.dtype) -> torch.Tensor:
"""Standardize the format of a mask tensor.
Args:
mask (torch.Tensor): A mask tensor. The dtype can be any bool, float, or int type. The shape must be (1, h, w)
or (h, w).
out_dtype (torch.dtype): The dtype of the output mask tensor. Must be a float type.
Returns:
torch.Tensor: The output mask tensor. The dtype is out_dtype. The shape is (1, h, w). All values are either 0.0
or 1.0.
"""
if not out_dtype.is_floating_point:
raise ValueError(f"out_dtype must be a float type, but got {out_dtype}")
mask = to_standard_mask_dim(mask)
mask = mask.to(out_dtype)
# Set masked regions to 1.0.
if mask.dtype == torch.bool:
mask = mask.to(out_dtype)
else:
mask = mask.to(out_dtype)
mask_region = mask > 0.5
mask[mask_region] = 1.0
mask[~mask_region] = 0.0
return mask

View File

@ -0,0 +1,88 @@
import pytest
import torch
from invokeai.backend.util.mask import to_standard_float_mask
def test_to_standard_float_mask_wrong_ndim():
with pytest.raises(ValueError):
to_standard_float_mask(mask=torch.zeros((1, 1, 5, 10)), out_dtype=torch.float32)
def test_to_standard_float_mask_wrong_shape():
with pytest.raises(ValueError):
to_standard_float_mask(mask=torch.zeros((2, 5, 10)), out_dtype=torch.float32)
def check_mask_result(mask: torch.Tensor, expected_mask: torch.Tensor):
"""Helper function to check the result of `to_standard_float_mask()`."""
assert mask.shape == expected_mask.shape
assert mask.dtype == expected_mask.dtype
assert torch.allclose(mask, expected_mask)
def test_to_standard_float_mask_ndim_2():
"""Test the case where the input mask has shape (h, w)."""
mask = torch.zeros((3, 2), dtype=torch.float32)
mask[0, 0] = 1.0
mask[1, 1] = 1.0
expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32)
expected_mask[0, 0, 0] = 1.0
expected_mask[0, 1, 1] = 1.0
new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32)
check_mask_result(mask=new_mask, expected_mask=expected_mask)
def test_to_standard_float_mask_ndim_3():
"""Test the case where the input mask has shape (1, h, w)."""
mask = torch.zeros((1, 3, 2), dtype=torch.float32)
mask[0, 0, 0] = 1.0
mask[0, 1, 1] = 1.0
expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32)
expected_mask[0, 0, 0] = 1.0
expected_mask[0, 1, 1] = 1.0
new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32)
check_mask_result(mask=new_mask, expected_mask=expected_mask)
@pytest.mark.parametrize(
"out_dtype",
[torch.float32, torch.float16],
)
def test_to_standard_float_mask_bool_to_float(out_dtype: torch.dtype):
"""Test the case where the input mask has dtype bool."""
mask = torch.zeros((3, 2), dtype=torch.bool)
mask[0, 0] = True
mask[1, 1] = True
expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype)
expected_mask[0, 0, 0] = 1.0
expected_mask[0, 1, 1] = 1.0
new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype)
check_mask_result(mask=new_mask, expected_mask=expected_mask)
@pytest.mark.parametrize(
"out_dtype",
[torch.float32, torch.float16],
)
def test_to_standard_float_mask_float_to_float(out_dtype: torch.dtype):
"""Test the case where the input mask has type float (but not all values are 0.0 or 1.0)."""
mask = torch.zeros((3, 2), dtype=torch.float32)
mask[0, 0] = 0.1 # Should be converted to 0.0
mask[0, 1] = 0.9 # Should be converted to 1.0
expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype)
expected_mask[0, 0, 1] = 1.0
new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype)
check_mask_result(mask=new_mask, expected_mask=expected_mask)