diff --git a/invokeai/app/invocations/conditioning.py b/invokeai/app/invocations/conditioning.py index 9579d80009..2e46149271 100644 --- a/invokeai/app/invocations/conditioning.py +++ b/invokeai/app/invocations/conditioning.py @@ -1,9 +1,16 @@ import numpy as np import torch -from PIL.Image import Image +from PIL import Image -from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, InvocationContext, invocation -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + InputField, + InvocationContext, + WithMetadata, + invocation, +) +from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin @invocation( @@ -23,7 +30,7 @@ class AddConditioningMaskInvocation(BaseInvocation): ) @staticmethod - def convert_image_to_mask(image: Image) -> torch.Tensor: + def convert_image_to_mask(image: Image.Image) -> torch.Tensor: """Convert a PIL image to a uint8 mask tensor.""" np_image = np.array(image) torch_image = torch.from_numpy(np_image[0, :, :]) @@ -39,3 +46,43 @@ class AddConditioningMaskInvocation(BaseInvocation): self.conditioning.mask_name = mask_name return ConditioningOutput(conditioning=self.conditioning) + + +@invocation( + "rectangle_mask", + title="Create Rectangle Mask", + tags=["conditioning"], + category="conditioning", + version="1.0.0", +) +class RectangleMaskInvocation(BaseInvocation, WithMetadata): + """Create a mask image containing a rectangular mask region.""" + + height: int = InputField(description="The height of the image.") + width: int = InputField(description="The width of the image.") + y_top: int = InputField(description="The top y-coordinate of the rectangle (inclusive).") + y_bottom: int = InputField(description="The bottom y-coordinate of the rectangle (exclusive).") + x_left: int = InputField(description="The left x-coordinate of the rectangle (inclusive).") + x_right: int = InputField(description="The right x-coordinate of the rectangle (exclusive).") + + def invoke(self, context: InvocationContext) -> ImageOutput: + mask = np.zeros((self.height, self.width, 3), dtype=np.uint8) + mask[self.y_top : self.y_bottom, self.x_left : self.x_right, :] = 255 + mask_image = Image.fromarray(mask) + + image_dto = context.services.images.create( + image=mask_image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + metadata=self.metadata, + workflow=context.workflow, + ) + + return ImageOutput( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + )