diff --git a/invokeai/app/invocations/conditioning.py b/invokeai/app/invocations/conditioning.py deleted file mode 100644 index 9579d80009..0000000000 --- a/invokeai/app/invocations/conditioning.py +++ /dev/null @@ -1,41 +0,0 @@ -import numpy as np -import torch -from PIL.Image import Image - -from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, InvocationContext, invocation -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField - - -@invocation( - "add_conditioning_mask", - title="Add Conditioning Mask", - tags=["conditioning"], - category="conditioning", - version="1.0.0", -) -class AddConditioningMaskInvocation(BaseInvocation): - """Add a mask to an existing conditioning tensor.""" - - conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.") - image: ImageField = InputField( - description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. " - "Pixels <128 are excluded from the mask, pixels >=128 are included in the mask." - ) - - @staticmethod - def convert_image_to_mask(image: Image) -> torch.Tensor: - """Convert a PIL image to a uint8 mask tensor.""" - np_image = np.array(image) - torch_image = torch.from_numpy(np_image[0, :, :]) - mask = torch_image >= 128 - return mask.to(dtype=torch.uint8) - - def invoke(self, context: InvocationContext) -> ConditioningOutput: - image = context.services.images.get_pil_image(self.image.image_name) - mask = self.convert_image_to_mask(image) - - mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask" - context.services.latents.save(mask_name, mask) - - self.conditioning.mask_name = mask_name - return ConditioningOutput(conditioning=self.conditioning) diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py new file mode 100644 index 0000000000..e892a766c1 --- /dev/null +++ b/invokeai/app/invocations/mask.py @@ -0,0 +1,40 @@ +import torch + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + InvocationContext, + invocation, +) +from invokeai.app.invocations.fields import InputField, MaskField, WithMetadata +from invokeai.app.invocations.primitives import MaskOutput + + +@invocation( + "rectangle_mask", + title="Create Rectangle Mask", + tags=["conditioning"], + category="conditioning", + version="1.0.0", +) +class RectangleMaskInvocation(BaseInvocation, WithMetadata): + """Create a rectangular mask.""" + + height: int = InputField(description="The height of the entire mask.") + width: int = InputField(description="The width of the entire mask.") + y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).") + x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).") + rectangle_height: int = InputField(description="The height of the rectangular masked region.") + rectangle_width: int = InputField(description="The width of the rectangular masked region.") + + def invoke(self, context: InvocationContext) -> MaskOutput: + mask = torch.zeros((1, self.height, self.width), dtype=torch.bool) + mask[ + :, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width + ] = True + + mask_name = context.tensors.save(mask) + return MaskOutput( + mask=MaskField(mask_name=mask_name), + width=self.width, + height=self.height, + ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 6a8e4e4531..25930f7d00 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -14,6 +14,7 @@ from invokeai.app.invocations.fields import ( Input, InputField, LatentsField, + MaskField, OutputField, UIComponent, ) @@ -405,9 +406,19 @@ class ColorInvocation(BaseInvocation): # endregion + # region Conditioning +@invocation_output("mask_output") +class MaskOutput(BaseInvocationOutput): + """A torch mask tensor.""" + + mask: MaskField = OutputField(description="The mask.") + width: int = OutputField(description="The width of the mask in pixels.") + height: int = OutputField(description="The height of the mask in pixels.") + + @invocation_output("conditioning_output") class ConditioningOutput(BaseInvocationOutput): """Base class for nodes that output a single conditioning tensor"""