mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add Mask Blur node
This commit is contained in:
parent
ce3675fc14
commit
6034fa12de
@ -706,7 +706,7 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"ui": {"title": "Mask Combine", "tags": ["mask", "combine"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
|
||||
mask2 = context.services.images.get_pil_image(self.mask2.image_name).convert("L")
|
||||
|
||||
@ -721,7 +721,48 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
return MaskOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class MaskBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Blurs a mask"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask_blur"] = "mask_blur"
|
||||
|
||||
# Inputs
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Mask Blur", "tags": ["mask", "blur"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
mask = context.services.images.get_pil_image(self.mask.image_name)
|
||||
|
||||
blur = (
|
||||
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
|
||||
)
|
||||
blur_mask = mask.filter(blur)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=blur_mask,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return MaskOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
|
Loading…
Reference in New Issue
Block a user