feat: Add Mask Blur node

This commit is contained in:
blessedcoolant 2023-08-12 16:20:58 +12:00
parent ce3675fc14
commit 6034fa12de

View File

@ -706,7 +706,7 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
"ui": {"title": "Mask Combine", "tags": ["mask", "combine"]}, "ui": {"title": "Mask Combine", "tags": ["mask", "combine"]},
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> MaskOutput:
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L") mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
mask2 = context.services.images.get_pil_image(self.mask2.image_name).convert("L") mask2 = context.services.images.get_pil_image(self.mask2.image_name).convert("L")
@ -721,7 +721,48 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
) )
return ImageOutput( return MaskOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class MaskBlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs a mask"""
# fmt: off
type: Literal["mask_blur"] = "mask_blur"
# Inputs
mask: Optional[ImageField] = Field(default=None, description="The mask image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Mask Blur", "tags": ["mask", "blur"]},
}
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = context.services.images.get_pil_image(self.mask.image_name)
blur = (
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
)
blur_mask = mask.filter(blur)
image_dto = context.services.images.create(
image=blur_mask,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return MaskOutput(
image=ImageField(image_name=image_dto.image_name), image=ImageField(image_name=image_dto.image_name),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,