mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
54 lines
1.6 KiB
Python
54 lines
1.6 KiB
Python
import torch
|
|
|
|
|
|
def to_standard_mask_dim(mask: torch.Tensor) -> torch.Tensor:
|
|
"""Standardize the dimensions of a mask tensor.
|
|
|
|
Args:
|
|
mask (torch.Tensor): A mask tensor. The shape can be (1, h, w) or (h, w).
|
|
|
|
Returns:
|
|
torch.Tensor: The output mask tensor. The shape is (1, h, w).
|
|
"""
|
|
# Get the mask height and width.
|
|
if mask.ndim == 2:
|
|
mask = mask.unsqueeze(0)
|
|
elif mask.ndim == 3 and mask.shape[0] == 1:
|
|
pass
|
|
else:
|
|
raise ValueError(f"Unsupported mask shape: {mask.shape}. Expected (1, h, w) or (h, w).")
|
|
|
|
return mask
|
|
|
|
|
|
def to_standard_float_mask(mask: torch.Tensor, out_dtype: torch.dtype) -> torch.Tensor:
|
|
"""Standardize the format of a mask tensor.
|
|
|
|
Args:
|
|
mask (torch.Tensor): A mask tensor. The dtype can be any bool, float, or int type. The shape must be (1, h, w)
|
|
or (h, w).
|
|
|
|
out_dtype (torch.dtype): The dtype of the output mask tensor. Must be a float type.
|
|
|
|
Returns:
|
|
torch.Tensor: The output mask tensor. The dtype is out_dtype. The shape is (1, h, w). All values are either 0.0
|
|
or 1.0.
|
|
"""
|
|
|
|
if not out_dtype.is_floating_point:
|
|
raise ValueError(f"out_dtype must be a float type, but got {out_dtype}")
|
|
|
|
mask = to_standard_mask_dim(mask)
|
|
mask = mask.to(out_dtype)
|
|
|
|
# Set masked regions to 1.0.
|
|
if mask.dtype == torch.bool:
|
|
mask = mask.to(out_dtype)
|
|
else:
|
|
mask = mask.to(out_dtype)
|
|
mask_region = mask > 0.5
|
|
mask[mask_region] = 1.0
|
|
mask[~mask_region] = 0.0
|
|
|
|
return mask
|