wip: Remove MaskBlur / Adjust color correction

This commit is contained in:
blessedcoolant
2023-08-12 20:54:30 +12:00
parent 9f6221fe8c
commit f296e5c41e
7 changed files with 67 additions and 130 deletions

View File

@ -697,8 +697,8 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["mask_combine"] = "mask_combine"
# Inputs
mask1: Optional[ImageField] = Field(default=None, description="The first mask to combine")
mask2: Optional[ImageField] = Field(default=None, description="The second image to combine")
mask1: ImageField = Field(default=None, description="The first mask to combine")
mask2: ImageField = Field(default=None, description="The second image to combine")
# fmt: on
class Config(InvocationConfig):
@ -706,7 +706,7 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
"ui": {"title": "Mask Combine", "tags": ["mask", "combine"]},
}
def invoke(self, context: InvocationContext) -> MaskOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
mask2 = context.services.images.get_pil_image(self.mask2.image_name).convert("L")
@ -721,48 +721,7 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
is_intermediate=self.is_intermediate,
)
return MaskOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class MaskBlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs a mask"""
# fmt: off
type: Literal["mask_blur"] = "mask_blur"
# Inputs
mask: Optional[ImageField] = Field(default=None, description="The mask image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Mask Blur", "tags": ["mask", "blur"]},
}
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = context.services.images.get_pil_image(self.mask.image_name)
blur = (
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
)
blur_mask = mask.filter(blur)
image_dto = context.services.images.create(
image=blur_mask,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return MaskOutput(
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,