From 589cd3d6544cbdfcf9e9c955fff88501456bf684 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sat, 2 Mar 2024 19:31:13 -0500 Subject: [PATCH] fix(canvas): use corrected mask for pasteback --- invokeai/app/invocations/latent.py | 60 ++++++++++++------- .../util/graph/buildCanvasInpaintGraph.ts | 4 +- .../util/graph/buildCanvasOutpaintGraph.ts | 4 +- .../util/graph/buildCanvasSDXLInpaintGraph.ts | 4 +- .../graph/buildCanvasSDXLOutpaintGraph.ts | 4 +- 5 files changed, 45 insertions(+), 31 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 4ab64abc23..4c766e955c 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -181,6 +181,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation): ) +@invocation_output("gradient_mask_output") +class GradientMaskOutput(BaseInvocationOutput): + """Outputs a denoise mask and an image representing the total gradient of the mask.""" + + denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") + expanded_mask_area: ImageField = OutputField( + description="Image representing the total gradient area of the mask. For paste-back purposes." + ) + + @invocation( "create_gradient_mask", title="Create Gradient Mask", @@ -201,38 +211,42 @@ class CreateGradientMaskInvocation(BaseInvocation): ) @torch.no_grad() - def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: + def invoke(self, context: InvocationContext) -> GradientMaskOutput: mask_image = context.images.get_pil(self.mask.image_name, mode="L") - if self.coherence_mode == "Box Blur": - blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius)) - else: # Gaussian Blur OR Staged - # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation - blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2)) + if self.edge_radius > 0: + if self.coherence_mode == "Box Blur": + blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius)) + else: # Gaussian Blur OR Staged + # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation + blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2)) - mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) - blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False) + blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False) - # redistribute blur so that the edges are 0 and blur out to 1 - blur_tensor = (blur_tensor - 0.5) * 2 + # redistribute blur so that the original edges are 0 and blur outwards to 1 + blur_tensor = (blur_tensor - 0.5) * 2 - threshold = 1 - self.minimum_denoise + threshold = 1 - self.minimum_denoise + + if self.coherence_mode == "Staged": + # wherever the blur_tensor is less than fully masked, convert it to threshold + blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor) + else: + # wherever the blur_tensor is above threshold but less than 1, drop it to threshold + blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor) - if self.coherence_mode == "Staged": - # wherever the blur_tensor is masked to any degree, convert it to threshold - blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor) else: - # wherever the blur_tensor is above threshold but less than 1, drop it to threshold - blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor) - - # multiply original mask to force actually masked regions to 0 - blur_tensor = mask_tensor * blur_tensor + blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1)) - return DenoiseMaskOutput.build( - mask_name=mask_name, - masked_latents_name=None, - gradient=True, + # compute a [0, 1] mask from the blur_tensor + expanded_mask = torch.where((blur_tensor < 1), 0, 1) + expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L") + expanded_image_dto = context.images.save(expanded_mask_image) + + return GradientMaskOutput( + denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True), + expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name), ) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts index 00bad63c3b..2672cf5be3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts @@ -344,8 +344,8 @@ export const buildCanvasInpaintGraph = ( }, { source: { - node_id: MASK_RESIZE_UP, - field: 'image', + node_id: INPAINT_CREATE_MASK, + field: 'expanded_mask_area', }, destination: { node_id: MASK_RESIZE_DOWN, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts index 75f9a15f48..a9707e50f8 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts @@ -439,8 +439,8 @@ export const buildCanvasOutpaintGraph = ( }, { source: { - node_id: MASK_RESIZE_UP, - field: 'image', + node_id: INPAINT_CREATE_MASK, + field: 'expanded_mask_area', }, destination: { node_id: MASK_RESIZE_DOWN, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts index fc60805e85..9f4e75de48 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts @@ -355,8 +355,8 @@ export const buildCanvasSDXLInpaintGraph = ( }, { source: { - node_id: MASK_RESIZE_UP, - field: 'image', + node_id: INPAINT_CREATE_MASK, + field: 'expanded_mask_area', }, destination: { node_id: MASK_RESIZE_DOWN, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts index 44950ff40a..6c5a31926a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts @@ -448,8 +448,8 @@ export const buildCanvasSDXLOutpaintGraph = ( }, { source: { - node_id: MASK_RESIZE_UP, - field: 'image', + node_id: INPAINT_CREATE_MASK, + field: 'expanded_mask_area', }, destination: { node_id: MASK_RESIZE_DOWN,