From 182810337c28fc5314776b4a8accf892c87dfabf Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 8 Apr 2024 15:07:49 -0400 Subject: [PATCH] Add utility to_standard_float_mask(...) to convert various mask formats to a standardized format. --- invokeai/app/invocations/latent.py | 20 +++-- invokeai/app/invocations/mask.py | 2 - .../diffusion/regional_prompt_data.py | 1 - .../diffusion/shared_invokeai_diffusion.py | 2 +- invokeai/backend/util/mask.py | 53 +++++++++++ tests/backend/util/test_mask.py | 88 +++++++++++++++++++ 6 files changed, 155 insertions(+), 11 deletions(-) create mode 100644 invokeai/backend/util/mask.py create mode 100644 tests/backend/util/test_mask.py diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 3070cd1e70..d5babe42cc 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -61,6 +61,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( TextConditioningData, TextConditioningRegions, ) +from invokeai.backend.util.mask import to_standard_float_mask from invokeai.backend.util.silence_warnings import SilenceWarnings from ...backend.stable_diffusion.diffusers_pipeline import ( @@ -386,25 +387,25 @@ class DenoiseLatentsInvocation(BaseInvocation): return text_embeddings, text_embeddings_masks def _preprocess_regional_prompt_mask( - self, mask: Optional[torch.Tensor], target_height: int, target_width: int + self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype ) -> torch.Tensor: """Preprocess a regional prompt mask to match the target height and width. If mask is None, returns a mask of all ones with the target height and width. If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. Returns: - torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width). + torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). """ + if mask is None: - return torch.ones((1, 1, target_height, target_width), dtype=torch.bool) + return torch.ones((1, 1, target_height, target_width), dtype=dtype) + + mask = to_standard_float_mask(mask, out_dtype=dtype) tf = torchvision.transforms.Resize( (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST ) - if len(mask.shape) != 3 or mask.shape[0] != 1: - raise ValueError(f"Invalid regional prompt mask shape: {mask.shape}. Expected shape (1, h, w).") - # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) resized_mask = tf(mask) @@ -416,6 +417,7 @@ class DenoiseLatentsInvocation(BaseInvocation): masks: Optional[list[Optional[torch.Tensor]]], latent_height: int, latent_width: int, + dtype: torch.dtype, ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: """Concatenate regional text embeddings into a single embedding and track the region masks accordingly.""" if masks is None: @@ -465,7 +467,9 @@ class DenoiseLatentsInvocation(BaseInvocation): start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] ) ) - processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width)) + processed_masks.append( + self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) + ) cur_text_embedding_len += text_embedding_info.embeds.shape[1] @@ -524,12 +528,14 @@ class DenoiseLatentsInvocation(BaseInvocation): masks=cond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, + dtype=unet.dtype, ) uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings( text_conditionings=uncond_text_embeddings, masks=uncond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, + dtype=unet.dtype, ) conditioning_data = TextConditioningData( diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index de4887e20d..b3588f204d 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -90,8 +90,6 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation): for pair in self.prompt_color_pairs: # TODO(ryand): Make this work for both RGB and RGBA images. mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1) - # Add explicit channel dimension. - mask = mask.unsqueeze(0) mask_tensor_name = context.tensors.save(mask) prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name))) return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs) diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index 95f81b1f93..85331013d5 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -46,7 +46,6 @@ class RegionalPromptData: for batch_sample_regions in regions: batch_sample_masks_by_seq_len.append({}) - # Convert the bool masks to float masks so that max pooling can be applied. batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype) # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 8ba988a0eb..4d95cb8f0d 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -360,7 +360,7 @@ class InvokeAIDiffuserComponent: # Create a dummy mask and range for text conditioning that doesn't have region masks. _, _, h, w = x.shape r = TextConditioningRegions( - masks=torch.ones((1, 1, h, w), dtype=torch.bool), + masks=torch.ones((1, 1, h, w), dtype=x.dtype), ranges=[Range(start=0, end=c.embeds.shape[1])], ) regions.append(r) diff --git a/invokeai/backend/util/mask.py b/invokeai/backend/util/mask.py new file mode 100644 index 0000000000..45aa32061c --- /dev/null +++ b/invokeai/backend/util/mask.py @@ -0,0 +1,53 @@ +import torch + + +def to_standard_mask_dim(mask: torch.Tensor) -> torch.Tensor: + """Standardize the dimensions of a mask tensor. + + Args: + mask (torch.Tensor): A mask tensor. The shape can be (1, h, w) or (h, w). + + Returns: + torch.Tensor: The output mask tensor. The shape is (1, h, w). + """ + # Get the mask height and width. + if mask.ndim == 2: + mask = mask.unsqueeze(0) + elif mask.ndim == 3 and mask.shape[0] == 1: + pass + else: + raise ValueError(f"Unsupported mask shape: {mask.shape}. Expected (1, h, w) or (h, w).") + + return mask + + +def to_standard_float_mask(mask: torch.Tensor, out_dtype: torch.dtype) -> torch.Tensor: + """Standardize the format of a mask tensor. + + Args: + mask (torch.Tensor): A mask tensor. The dtype can be any bool, float, or int type. The shape must be (1, h, w) + or (h, w). + + out_dtype (torch.dtype): The dtype of the output mask tensor. Must be a float type. + + Returns: + torch.Tensor: The output mask tensor. The dtype is out_dtype. The shape is (1, h, w). All values are either 0.0 + or 1.0. + """ + + if not out_dtype.is_floating_point: + raise ValueError(f"out_dtype must be a float type, but got {out_dtype}") + + mask = to_standard_mask_dim(mask) + mask = mask.to(out_dtype) + + # Set masked regions to 1.0. + if mask.dtype == torch.bool: + mask = mask.to(out_dtype) + else: + mask = mask.to(out_dtype) + mask_region = mask > 0.5 + mask[mask_region] = 1.0 + mask[~mask_region] = 0.0 + + return mask diff --git a/tests/backend/util/test_mask.py b/tests/backend/util/test_mask.py new file mode 100644 index 0000000000..96d3aab07f --- /dev/null +++ b/tests/backend/util/test_mask.py @@ -0,0 +1,88 @@ +import pytest +import torch + +from invokeai.backend.util.mask import to_standard_float_mask + + +def test_to_standard_float_mask_wrong_ndim(): + with pytest.raises(ValueError): + to_standard_float_mask(mask=torch.zeros((1, 1, 5, 10)), out_dtype=torch.float32) + + +def test_to_standard_float_mask_wrong_shape(): + with pytest.raises(ValueError): + to_standard_float_mask(mask=torch.zeros((2, 5, 10)), out_dtype=torch.float32) + + +def check_mask_result(mask: torch.Tensor, expected_mask: torch.Tensor): + """Helper function to check the result of `to_standard_float_mask()`.""" + assert mask.shape == expected_mask.shape + assert mask.dtype == expected_mask.dtype + assert torch.allclose(mask, expected_mask) + + +def test_to_standard_float_mask_ndim_2(): + """Test the case where the input mask has shape (h, w).""" + mask = torch.zeros((3, 2), dtype=torch.float32) + mask[0, 0] = 1.0 + mask[1, 1] = 1.0 + + expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32) + expected_mask[0, 0, 0] = 1.0 + expected_mask[0, 1, 1] = 1.0 + + new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32) + + check_mask_result(mask=new_mask, expected_mask=expected_mask) + + +def test_to_standard_float_mask_ndim_3(): + """Test the case where the input mask has shape (1, h, w).""" + mask = torch.zeros((1, 3, 2), dtype=torch.float32) + mask[0, 0, 0] = 1.0 + mask[0, 1, 1] = 1.0 + + expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32) + expected_mask[0, 0, 0] = 1.0 + expected_mask[0, 1, 1] = 1.0 + + new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32) + + check_mask_result(mask=new_mask, expected_mask=expected_mask) + + +@pytest.mark.parametrize( + "out_dtype", + [torch.float32, torch.float16], +) +def test_to_standard_float_mask_bool_to_float(out_dtype: torch.dtype): + """Test the case where the input mask has dtype bool.""" + mask = torch.zeros((3, 2), dtype=torch.bool) + mask[0, 0] = True + mask[1, 1] = True + + expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype) + expected_mask[0, 0, 0] = 1.0 + expected_mask[0, 1, 1] = 1.0 + + new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype) + + check_mask_result(mask=new_mask, expected_mask=expected_mask) + + +@pytest.mark.parametrize( + "out_dtype", + [torch.float32, torch.float16], +) +def test_to_standard_float_mask_float_to_float(out_dtype: torch.dtype): + """Test the case where the input mask has type float (but not all values are 0.0 or 1.0).""" + mask = torch.zeros((3, 2), dtype=torch.float32) + mask[0, 0] = 0.1 # Should be converted to 0.0 + mask[0, 1] = 0.9 # Should be converted to 1.0 + + expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype) + expected_mask[0, 0, 1] = 1.0 + + new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype) + + check_mask_result(mask=new_mask, expected_mask=expected_mask)