From fc26f3e430ce6c400ce70cab6ddc93b6738f136e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 9 Apr 2024 20:27:03 +1000 Subject: [PATCH] feat(nodes): add alpha mask to tensor invocation --- invokeai/app/invocations/mask.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index a7f3207764..72c5886336 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -1,7 +1,8 @@ +import numpy as np import torch from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext, invocation -from invokeai.app.invocations.fields import InputField, TensorField, WithMetadata +from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata from invokeai.app.invocations.primitives import MaskOutput @@ -34,3 +35,25 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata): width=self.width, height=self.height, ) + + +@invocation( + "alpha_mask_to_tensor", + title="Alpha Mask to Tensor", + tags=["conditioning"], + category="conditioning", + version="1.0.0", +) +class AlphaMaskToTensorInvocation(BaseInvocation): + """Convert a mask image to a tensor. Opaque regions are 1 and transparent regions are 0.""" + + image: ImageField = InputField(description="The mask image to convert.") + + def invoke(self, context: InvocationContext) -> MaskOutput: + image = context.images.get_pil(self.image.image_name) + mask = torch.zeros((1, image.height, image.width), dtype=torch.bool) + mask[0] = torch.tensor(np.array(image)[:, :, 3] > 0, dtype=torch.bool) + + return MaskOutput( + mask=TensorField(tensor_name=context.tensors.save(mask)), height=image.height, width=image.width + )