fix(canvas): use corrected mask for pasteback

This commit is contained in:
dunkeroni 2024-03-02 19:31:13 -05:00 committed by Ryan Dick
parent f1d8865bb5
commit 589cd3d654
5 changed files with 45 additions and 31 deletions

View File

@ -181,6 +181,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
) )
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation( @invocation(
"create_gradient_mask", "create_gradient_mask",
title="Create Gradient Mask", title="Create Gradient Mask",
@ -201,38 +211,42 @@ class CreateGradientMaskInvocation(BaseInvocation):
) )
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L") mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.coherence_mode == "Box Blur": if self.edge_radius > 0:
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius)) if self.coherence_mode == "Box Blur":
else: # Gaussian Blur OR Staged blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation else: # Gaussian Blur OR Staged
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2)) # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the edges are 0 and blur out to 1 # redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2 blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
if self.coherence_mode == "Staged":
# wherever the blur_tensor is masked to any degree, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
else: else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
# multiply original mask to force actually masked regions to 0
blur_tensor = mask_tensor * blur_tensor
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1)) mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
return DenoiseMaskOutput.build( # compute a [0, 1] mask from the blur_tensor
mask_name=mask_name, expanded_mask = torch.where((blur_tensor < 1), 0, 1)
masked_latents_name=None, expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
gradient=True, expanded_image_dto = context.images.save(expanded_mask_image)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
) )

View File

@ -344,8 +344,8 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_RESIZE_UP, node_id: INPAINT_CREATE_MASK,
field: 'image', field: 'expanded_mask_area',
}, },
destination: { destination: {
node_id: MASK_RESIZE_DOWN, node_id: MASK_RESIZE_DOWN,

View File

@ -439,8 +439,8 @@ export const buildCanvasOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_RESIZE_UP, node_id: INPAINT_CREATE_MASK,
field: 'image', field: 'expanded_mask_area',
}, },
destination: { destination: {
node_id: MASK_RESIZE_DOWN, node_id: MASK_RESIZE_DOWN,

View File

@ -355,8 +355,8 @@ export const buildCanvasSDXLInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_RESIZE_UP, node_id: INPAINT_CREATE_MASK,
field: 'image', field: 'expanded_mask_area',
}, },
destination: { destination: {
node_id: MASK_RESIZE_DOWN, node_id: MASK_RESIZE_DOWN,

View File

@ -448,8 +448,8 @@ export const buildCanvasSDXLOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_RESIZE_UP, node_id: INPAINT_CREATE_MASK,
field: 'image', field: 'expanded_mask_area',
}, },
destination: { destination: {
node_id: MASK_RESIZE_DOWN, node_id: MASK_RESIZE_DOWN,