diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 9957853738..05ffc0d67b 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -967,3 +967,56 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard): image_dto = context.images.save(image=source_image) return ImageOutput.build(image_dto) + + +@invocation( + "mask_from_id", + title="Mask from ID", + tags=["image", "mask", "id"], + category="image", + version="1.0.0", +) +class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard): + """Generate a mask for a particular color in an ID Map""" + + image: ImageField = InputField(description="The image to create the mask from") + color: ColorField = InputField(description="ID color to mask") + threshold: int = InputField(default=100, description="Threshold for color detection") + invert: bool = InputField(default=False, description="Whether or not to invert the mask") + + def rgba_to_hex(self, rgba_color: tuple[int, int, int, int]): + r, g, b, a = rgba_color + hex_code = "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, int(a * 255)) + return hex_code + + def id_to_mask(self, id_mask: Image.Image, color: tuple[int, int, int, int], threshold: int = 100): + if id_mask.mode != "RGB": + id_mask = id_mask.convert("RGB") + + # Can directly just use the tuple but I'll leave this rgba_to_hex here + # incase anyone prefers using hex codes directly instead of the color picker + hex_color_str = self.rgba_to_hex(color) + rgb_color = numpy.array([int(hex_color_str[i : i + 2], 16) for i in (1, 3, 5)]) + + # Maybe there's a faster way to calculate this distance but I can't think of any right now. + color_distance = numpy.linalg.norm(id_mask - rgb_color, axis=-1) + + # Create a mask based on the threshold and the distance calculated above + binary_mask = (color_distance < threshold).astype(numpy.uint8) * 255 + + # Convert the mask back to PIL + binary_mask_pil = Image.fromarray(binary_mask) + + return binary_mask_pil + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + + mask = self.id_to_mask(image, self.color.tuple(), self.threshold) + + if self.invert: + mask = ImageOps.invert(mask) + + image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK) + + return ImageOutput.build(image_dto)