import torch def to_standard_mask_dim(mask: torch.Tensor) -> torch.Tensor: """Standardize the dimensions of a mask tensor. Args: mask (torch.Tensor): A mask tensor. The shape can be (1, h, w) or (h, w). Returns: torch.Tensor: The output mask tensor. The shape is (1, h, w). """ # Get the mask height and width. if mask.ndim == 2: mask = mask.unsqueeze(0) elif mask.ndim == 3 and mask.shape[0] == 1: pass else: raise ValueError(f"Unsupported mask shape: {mask.shape}. Expected (1, h, w) or (h, w).") return mask def to_standard_float_mask(mask: torch.Tensor, out_dtype: torch.dtype) -> torch.Tensor: """Standardize the format of a mask tensor. Args: mask (torch.Tensor): A mask tensor. The dtype can be any bool, float, or int type. The shape must be (1, h, w) or (h, w). out_dtype (torch.dtype): The dtype of the output mask tensor. Must be a float type. Returns: torch.Tensor: The output mask tensor. The dtype is out_dtype. The shape is (1, h, w). All values are either 0.0 or 1.0. """ if not out_dtype.is_floating_point: raise ValueError(f"out_dtype must be a float type, but got {out_dtype}") mask = to_standard_mask_dim(mask) mask = mask.to(out_dtype) # Set masked regions to 1.0. if mask.dtype == torch.bool: mask = mask.to(out_dtype) else: mask = mask.to(out_dtype) mask_region = mask > 0.5 mask[mask_region] = 1.0 mask[~mask_region] = 0.0 return mask