mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(canvas): use corrected mask for pasteback
This commit is contained in:
parent
f1d8865bb5
commit
589cd3d654
@ -181,6 +181,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("gradient_mask_output")
|
||||||
|
class GradientMaskOutput(BaseInvocationOutput):
|
||||||
|
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
||||||
|
|
||||||
|
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||||
|
expanded_mask_area: ImageField = OutputField(
|
||||||
|
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"create_gradient_mask",
|
"create_gradient_mask",
|
||||||
title="Create Gradient Mask",
|
title="Create Gradient Mask",
|
||||||
@ -201,38 +211,42 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||||
if self.coherence_mode == "Box Blur":
|
if self.edge_radius > 0:
|
||||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
if self.coherence_mode == "Box Blur":
|
||||||
else: # Gaussian Blur OR Staged
|
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
else: # Gaussian Blur OR Staged
|
||||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||||
|
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||||
|
|
||||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
|
||||||
|
|
||||||
# redistribute blur so that the edges are 0 and blur out to 1
|
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||||
blur_tensor = (blur_tensor - 0.5) * 2
|
blur_tensor = (blur_tensor - 0.5) * 2
|
||||||
|
|
||||||
threshold = 1 - self.minimum_denoise
|
threshold = 1 - self.minimum_denoise
|
||||||
|
|
||||||
|
if self.coherence_mode == "Staged":
|
||||||
|
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
||||||
|
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
||||||
|
else:
|
||||||
|
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||||
|
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||||
|
|
||||||
if self.coherence_mode == "Staged":
|
|
||||||
# wherever the blur_tensor is masked to any degree, convert it to threshold
|
|
||||||
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
|
|
||||||
else:
|
else:
|
||||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
|
||||||
|
|
||||||
# multiply original mask to force actually masked regions to 0
|
|
||||||
blur_tensor = mask_tensor * blur_tensor
|
|
||||||
|
|
||||||
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||||
|
|
||||||
return DenoiseMaskOutput.build(
|
# compute a [0, 1] mask from the blur_tensor
|
||||||
mask_name=mask_name,
|
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
||||||
masked_latents_name=None,
|
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||||
gradient=True,
|
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||||
|
|
||||||
|
return GradientMaskOutput(
|
||||||
|
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
||||||
|
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -344,8 +344,8 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_RESIZE_UP,
|
node_id: INPAINT_CREATE_MASK,
|
||||||
field: 'image',
|
field: 'expanded_mask_area',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: MASK_RESIZE_DOWN,
|
node_id: MASK_RESIZE_DOWN,
|
||||||
|
@ -439,8 +439,8 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_RESIZE_UP,
|
node_id: INPAINT_CREATE_MASK,
|
||||||
field: 'image',
|
field: 'expanded_mask_area',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: MASK_RESIZE_DOWN,
|
node_id: MASK_RESIZE_DOWN,
|
||||||
|
@ -355,8 +355,8 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_RESIZE_UP,
|
node_id: INPAINT_CREATE_MASK,
|
||||||
field: 'image',
|
field: 'expanded_mask_area',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: MASK_RESIZE_DOWN,
|
node_id: MASK_RESIZE_DOWN,
|
||||||
|
@ -448,8 +448,8 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_RESIZE_UP,
|
node_id: INPAINT_CREATE_MASK,
|
||||||
field: 'image',
|
field: 'expanded_mask_area',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: MASK_RESIZE_DOWN,
|
node_id: MASK_RESIZE_DOWN,
|
||||||
|
Loading…
Reference in New Issue
Block a user