From 6034fa12debd5dd95a2088a4d75e999e0431654d Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sat, 12 Aug 2023 16:20:58 +1200 Subject: [PATCH] feat: Add Mask Blur node --- invokeai/app/invocations/image.py | 45 +++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index cc05b529b5..846812435d 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -706,7 +706,7 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig): "ui": {"title": "Mask Combine", "tags": ["mask", "combine"]}, } - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context: InvocationContext) -> MaskOutput: mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L") mask2 = context.services.images.get_pil_image(self.mask2.image_name).convert("L") @@ -721,7 +721,48 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig): is_intermediate=self.is_intermediate, ) - return ImageOutput( + return MaskOutput( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + ) + + +class MaskBlurInvocation(BaseInvocation, PILInvocationConfig): + """Blurs a mask""" + + # fmt: off + type: Literal["mask_blur"] = "mask_blur" + + # Inputs + mask: Optional[ImageField] = Field(default=None, description="The mask image to blur") + radius: float = Field(default=8.0, ge=0, description="The blur radius") + blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur") + # fmt: on + + class Config(InvocationConfig): + schema_extra = { + "ui": {"title": "Mask Blur", "tags": ["mask", "blur"]}, + } + + def invoke(self, context: InvocationContext) -> MaskOutput: + mask = context.services.images.get_pil_image(self.mask.image_name) + + blur = ( + ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius) + ) + blur_mask = mask.filter(blur) + + image_dto = context.services.images.create( + image=blur_mask, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + ) + + return MaskOutput( image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height,