diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index 64c6b0702c..6f54660847 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -88,3 +88,33 @@ class InvertTensorMaskInvocation(BaseInvocation): height=inverted.shape[1], width=inverted.shape[2], ) + + +@invocation( + "image_mask_to_tensor", + title="Image Mask to Tensor", + tags=["conditioning"], + category="conditioning", + version="1.0.0", +) +class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata): + """Convert a mask image to a tensor. Converts the image to grayscale and uses thresholding at the specified value.""" + + image: ImageField = InputField(description="The mask image to convert.") + cutoff: int = InputField(ge=0, le=255, description="Cutoff (<)", default=128) + invert: bool = InputField(default=False, description="Whether to invert the mask.") + + def invoke(self, context: InvocationContext) -> MaskOutput: + image = context.images.get_pil(self.image.image_name, mode="L") + + mask = torch.zeros((1, image.height, image.width), dtype=torch.bool) + if self.invert: + mask[0] = torch.tensor(np.array(image)[:, :] >= self.cutoff, dtype=torch.bool) + else: + mask[0] = torch.tensor(np.array(image)[:, :] < self.cutoff, dtype=torch.bool) + + return MaskOutput( + mask=TensorField(tensor_name=context.tensors.save(mask)), + height=mask.shape[1], + width=mask.shape[2], + )