fix(canvas): use corrected mask for pasteback

This commit is contained in:
dunkeroni 2024-03-02 19:31:13 -05:00 committed by Ryan Dick
parent f1d8865bb5
commit 589cd3d654
5 changed files with 45 additions and 31 deletions

View File

@ -181,6 +181,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation(
"create_gradient_mask",
title="Create Gradient Mask",
@ -201,38 +211,42 @@ class CreateGradientMaskInvocation(BaseInvocation):
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
if self.edge_radius > 0:
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the edges are 0 and blur out to 1
blur_tensor = (blur_tensor - 0.5) * 2
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
if self.coherence_mode == "Staged":
# wherever the blur_tensor is masked to any degree, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
# multiply original mask to force actually masked regions to 0
blur_tensor = mask_tensor * blur_tensor
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
return DenoiseMaskOutput.build(
mask_name=mask_name,
masked_latents_name=None,
gradient=True,
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)

View File

@ -344,8 +344,8 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_UP,
field: 'image',
node_id: INPAINT_CREATE_MASK,
field: 'expanded_mask_area',
},
destination: {
node_id: MASK_RESIZE_DOWN,

View File

@ -439,8 +439,8 @@ export const buildCanvasOutpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_UP,
field: 'image',
node_id: INPAINT_CREATE_MASK,
field: 'expanded_mask_area',
},
destination: {
node_id: MASK_RESIZE_DOWN,

View File

@ -355,8 +355,8 @@ export const buildCanvasSDXLInpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_UP,
field: 'image',
node_id: INPAINT_CREATE_MASK,
field: 'expanded_mask_area',
},
destination: {
node_id: MASK_RESIZE_DOWN,

View File

@ -448,8 +448,8 @@ export const buildCanvasSDXLOutpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_UP,
field: 'image',
node_id: INPAINT_CREATE_MASK,
field: 'expanded_mask_area',
},
destination: {
node_id: MASK_RESIZE_DOWN,