Update MaskTensorToImageInvocation to support input mask tensors with or without a channel dimension.

This commit is contained in:
Ryan Dick 2024-08-29 17:41:06 +00:00
parent bfa9de6826
commit e0f12c762e

View File

@ -126,7 +126,7 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
title="Tensor Mask to Image", title="Tensor Mask to Image",
tags=["mask"], tags=["mask"],
category="mask", category="mask",
version="1.0.0", version="1.1.0",
) )
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard): class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Convert a mask tensor to an image.""" """Convert a mask tensor to an image."""
@ -135,6 +135,11 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.tensors.load(self.mask.tensor_name) mask = context.tensors.load(self.mask.tensor_name)
# Squeeze the channel dimension if it exists.
if mask.dim() == 3:
mask = mask.squeeze(0)
# Ensure that the mask is binary. # Ensure that the mask is binary.
if mask.dtype != torch.bool: if mask.dtype != torch.bool:
mask = mask > 0.5 mask = mask > 0.5