mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add utility to_standard_float_mask(...) to convert various mask formats to a standardized format.
This commit is contained in:
parent
338bf808d6
commit
182810337c
@ -61,6 +61,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
@ -386,25 +387,25 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
return text_embeddings, text_embeddings_masks
|
||||
|
||||
def _preprocess_regional_prompt_mask(
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
|
||||
"""
|
||||
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
|
||||
|
||||
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
|
||||
if len(mask.shape) != 3 or mask.shape[0] != 1:
|
||||
raise ValueError(f"Invalid regional prompt mask shape: {mask.shape}. Expected shape (1, h, w).")
|
||||
|
||||
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
resized_mask = tf(mask)
|
||||
@ -416,6 +417,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
masks: Optional[list[Optional[torch.Tensor]]],
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
||||
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||
if masks is None:
|
||||
@ -465,7 +467,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||
)
|
||||
)
|
||||
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
||||
processed_masks.append(
|
||||
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||
)
|
||||
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
|
||||
@ -524,12 +528,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
masks=cond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
|
||||
conditioning_data = TextConditioningData(
|
||||
|
@ -90,8 +90,6 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation):
|
||||
for pair in self.prompt_color_pairs:
|
||||
# TODO(ryand): Make this work for both RGB and RGBA images.
|
||||
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
|
||||
# Add explicit channel dimension.
|
||||
mask = mask.unsqueeze(0)
|
||||
mask_tensor_name = context.tensors.save(mask)
|
||||
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name)))
|
||||
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)
|
||||
|
@ -46,7 +46,6 @@ class RegionalPromptData:
|
||||
for batch_sample_regions in regions:
|
||||
batch_sample_masks_by_seq_len.append({})
|
||||
|
||||
# Convert the bool masks to float masks so that max pooling can be applied.
|
||||
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
|
||||
|
||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||
|
@ -360,7 +360,7 @@ class InvokeAIDiffuserComponent:
|
||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||
_, _, h, w = x.shape
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
|
||||
masks=torch.ones((1, 1, h, w), dtype=x.dtype),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
)
|
||||
regions.append(r)
|
||||
|
53
invokeai/backend/util/mask.py
Normal file
53
invokeai/backend/util/mask.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
|
||||
|
||||
def to_standard_mask_dim(mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Standardize the dimensions of a mask tensor.
|
||||
|
||||
Args:
|
||||
mask (torch.Tensor): A mask tensor. The shape can be (1, h, w) or (h, w).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output mask tensor. The shape is (1, h, w).
|
||||
"""
|
||||
# Get the mask height and width.
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
elif mask.ndim == 3 and mask.shape[0] == 1:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unsupported mask shape: {mask.shape}. Expected (1, h, w) or (h, w).")
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def to_standard_float_mask(mask: torch.Tensor, out_dtype: torch.dtype) -> torch.Tensor:
|
||||
"""Standardize the format of a mask tensor.
|
||||
|
||||
Args:
|
||||
mask (torch.Tensor): A mask tensor. The dtype can be any bool, float, or int type. The shape must be (1, h, w)
|
||||
or (h, w).
|
||||
|
||||
out_dtype (torch.dtype): The dtype of the output mask tensor. Must be a float type.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output mask tensor. The dtype is out_dtype. The shape is (1, h, w). All values are either 0.0
|
||||
or 1.0.
|
||||
"""
|
||||
|
||||
if not out_dtype.is_floating_point:
|
||||
raise ValueError(f"out_dtype must be a float type, but got {out_dtype}")
|
||||
|
||||
mask = to_standard_mask_dim(mask)
|
||||
mask = mask.to(out_dtype)
|
||||
|
||||
# Set masked regions to 1.0.
|
||||
if mask.dtype == torch.bool:
|
||||
mask = mask.to(out_dtype)
|
||||
else:
|
||||
mask = mask.to(out_dtype)
|
||||
mask_region = mask > 0.5
|
||||
mask[mask_region] = 1.0
|
||||
mask[~mask_region] = 0.0
|
||||
|
||||
return mask
|
88
tests/backend/util/test_mask.py
Normal file
88
tests/backend/util/test_mask.py
Normal file
@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
|
||||
|
||||
def test_to_standard_float_mask_wrong_ndim():
|
||||
with pytest.raises(ValueError):
|
||||
to_standard_float_mask(mask=torch.zeros((1, 1, 5, 10)), out_dtype=torch.float32)
|
||||
|
||||
|
||||
def test_to_standard_float_mask_wrong_shape():
|
||||
with pytest.raises(ValueError):
|
||||
to_standard_float_mask(mask=torch.zeros((2, 5, 10)), out_dtype=torch.float32)
|
||||
|
||||
|
||||
def check_mask_result(mask: torch.Tensor, expected_mask: torch.Tensor):
|
||||
"""Helper function to check the result of `to_standard_float_mask()`."""
|
||||
assert mask.shape == expected_mask.shape
|
||||
assert mask.dtype == expected_mask.dtype
|
||||
assert torch.allclose(mask, expected_mask)
|
||||
|
||||
|
||||
def test_to_standard_float_mask_ndim_2():
|
||||
"""Test the case where the input mask has shape (h, w)."""
|
||||
mask = torch.zeros((3, 2), dtype=torch.float32)
|
||||
mask[0, 0] = 1.0
|
||||
mask[1, 1] = 1.0
|
||||
|
||||
expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32)
|
||||
expected_mask[0, 0, 0] = 1.0
|
||||
expected_mask[0, 1, 1] = 1.0
|
||||
|
||||
new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32)
|
||||
|
||||
check_mask_result(mask=new_mask, expected_mask=expected_mask)
|
||||
|
||||
|
||||
def test_to_standard_float_mask_ndim_3():
|
||||
"""Test the case where the input mask has shape (1, h, w)."""
|
||||
mask = torch.zeros((1, 3, 2), dtype=torch.float32)
|
||||
mask[0, 0, 0] = 1.0
|
||||
mask[0, 1, 1] = 1.0
|
||||
|
||||
expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32)
|
||||
expected_mask[0, 0, 0] = 1.0
|
||||
expected_mask[0, 1, 1] = 1.0
|
||||
|
||||
new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32)
|
||||
|
||||
check_mask_result(mask=new_mask, expected_mask=expected_mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"out_dtype",
|
||||
[torch.float32, torch.float16],
|
||||
)
|
||||
def test_to_standard_float_mask_bool_to_float(out_dtype: torch.dtype):
|
||||
"""Test the case where the input mask has dtype bool."""
|
||||
mask = torch.zeros((3, 2), dtype=torch.bool)
|
||||
mask[0, 0] = True
|
||||
mask[1, 1] = True
|
||||
|
||||
expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype)
|
||||
expected_mask[0, 0, 0] = 1.0
|
||||
expected_mask[0, 1, 1] = 1.0
|
||||
|
||||
new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype)
|
||||
|
||||
check_mask_result(mask=new_mask, expected_mask=expected_mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"out_dtype",
|
||||
[torch.float32, torch.float16],
|
||||
)
|
||||
def test_to_standard_float_mask_float_to_float(out_dtype: torch.dtype):
|
||||
"""Test the case where the input mask has type float (but not all values are 0.0 or 1.0)."""
|
||||
mask = torch.zeros((3, 2), dtype=torch.float32)
|
||||
mask[0, 0] = 0.1 # Should be converted to 0.0
|
||||
mask[0, 1] = 0.9 # Should be converted to 1.0
|
||||
|
||||
expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype)
|
||||
expected_mask[0, 0, 1] = 1.0
|
||||
|
||||
new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype)
|
||||
|
||||
check_mask_result(mask=new_mask, expected_mask=expected_mask)
|
Loading…
Reference in New Issue
Block a user