mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update MaskTensorToImageInvocation to support input mask tensors with or without a channel dimension.
This commit is contained in:
parent
bfa9de6826
commit
e0f12c762e
@ -126,7 +126,7 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
|||||||
title="Tensor Mask to Image",
|
title="Tensor Mask to Image",
|
||||||
tags=["mask"],
|
tags=["mask"],
|
||||||
category="mask",
|
category="mask",
|
||||||
version="1.0.0",
|
version="1.1.0",
|
||||||
)
|
)
|
||||||
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Convert a mask tensor to an image."""
|
"""Convert a mask tensor to an image."""
|
||||||
@ -135,6 +135,11 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
mask = context.tensors.load(self.mask.tensor_name)
|
mask = context.tensors.load(self.mask.tensor_name)
|
||||||
|
|
||||||
|
# Squeeze the channel dimension if it exists.
|
||||||
|
if mask.dim() == 3:
|
||||||
|
mask = mask.squeeze(0)
|
||||||
|
|
||||||
# Ensure that the mask is binary.
|
# Ensure that the mask is binary.
|
||||||
if mask.dtype != torch.bool:
|
if mask.dtype != torch.bool:
|
||||||
mask = mask > 0.5
|
mask = mask > 0.5
|
||||||
|
Loading…
Reference in New Issue
Block a user