|
|
|
@ -173,6 +173,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@invocation_output("gradient_mask_output")
|
|
|
|
|
class GradientMaskOutput(BaseInvocationOutput):
|
|
|
|
|
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
|
|
|
|
|
|
|
|
|
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
|
|
|
|
expanded_mask_area: ImageField = OutputField(
|
|
|
|
|
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@invocation(
|
|
|
|
|
"create_gradient_mask",
|
|
|
|
|
title="Create Gradient Mask",
|
|
|
|
@ -193,38 +203,42 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
|
|
|
|
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
|
|
|
|
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
|
|
|
|
if self.coherence_mode == "Box Blur":
|
|
|
|
|
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
|
|
|
|
else: # Gaussian Blur OR Staged
|
|
|
|
|
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
|
|
|
|
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
|
|
|
|
if self.edge_radius > 0:
|
|
|
|
|
if self.coherence_mode == "Box Blur":
|
|
|
|
|
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
|
|
|
|
else: # Gaussian Blur OR Staged
|
|
|
|
|
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
|
|
|
|
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
|
|
|
|
|
|
|
|
|
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
|
|
|
|
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
|
|
|
|
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
|
|
|
|
|
|
|
|
|
# redistribute blur so that the edges are 0 and blur out to 1
|
|
|
|
|
blur_tensor = (blur_tensor - 0.5) * 2
|
|
|
|
|
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
|
|
|
|
blur_tensor = (blur_tensor - 0.5) * 2
|
|
|
|
|
|
|
|
|
|
threshold = 1 - self.minimum_denoise
|
|
|
|
|
threshold = 1 - self.minimum_denoise
|
|
|
|
|
|
|
|
|
|
if self.coherence_mode == "Staged":
|
|
|
|
|
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
|
|
|
|
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
|
|
|
|
else:
|
|
|
|
|
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
|
|
|
|
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
|
|
|
|
|
|
|
|
|
if self.coherence_mode == "Staged":
|
|
|
|
|
# wherever the blur_tensor is masked to any degree, convert it to threshold
|
|
|
|
|
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
|
|
|
|
|
else:
|
|
|
|
|
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
|
|
|
|
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
|
|
|
|
|
|
|
|
|
# multiply original mask to force actually masked regions to 0
|
|
|
|
|
blur_tensor = mask_tensor * blur_tensor
|
|
|
|
|
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
|
|
|
|
|
|
|
|
|
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
|
|
|
|
|
|
|
|
|
return DenoiseMaskOutput.build(
|
|
|
|
|
mask_name=mask_name,
|
|
|
|
|
masked_latents_name=None,
|
|
|
|
|
gradient=True,
|
|
|
|
|
# compute a [0, 1] mask from the blur_tensor
|
|
|
|
|
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
|
|
|
|
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
|
|
|
|
expanded_image_dto = context.images.save(expanded_mask_image)
|
|
|
|
|
|
|
|
|
|
return GradientMaskOutput(
|
|
|
|
|
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
|
|
|
|
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|