mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
gradient mask node test for inpaint
This commit is contained in:
parent
d7b5ad02e8
commit
0063014f2b
@ -185,7 +185,7 @@ class GradientMaskOutput(BaseInvocationOutput):
|
|||||||
title="Create Gradient Mask",
|
title="Create Gradient Mask",
|
||||||
tags=["mask", "denoise"],
|
tags=["mask", "denoise"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.1.0",
|
||||||
)
|
)
|
||||||
class CreateGradientMaskInvocation(BaseInvocation):
|
class CreateGradientMaskInvocation(BaseInvocation):
|
||||||
"""Creates mask for denoising model run."""
|
"""Creates mask for denoising model run."""
|
||||||
@ -198,6 +198,32 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
minimum_denoise: float = InputField(
|
minimum_denoise: float = InputField(
|
||||||
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
||||||
)
|
)
|
||||||
|
unet: Optional[UNetField] = InputField(
|
||||||
|
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
|
||||||
|
input=Input.Connection,
|
||||||
|
title="[OPTIONAL] UNet",
|
||||||
|
ui_order=5,
|
||||||
|
)
|
||||||
|
image: Optional[ImageField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
|
||||||
|
title="[OPTIONAL] Image",
|
||||||
|
ui_order=6
|
||||||
|
)
|
||||||
|
vae: Optional[VAEField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
|
||||||
|
title="[OPTIONAL] VAE",
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=7
|
||||||
|
)
|
||||||
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
|
||||||
|
fp32: bool = InputField(
|
||||||
|
default=DEFAULT_PRECISION == "float32",
|
||||||
|
description=FieldDescriptions.fp32,
|
||||||
|
ui_order=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||||
@ -233,8 +259,31 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||||
|
|
||||||
|
masked_latents_name = None
|
||||||
|
# Check for Inpaint model and generate masked_latents
|
||||||
|
if self.unet is not None and self.vae is not None and self.image is not None:
|
||||||
|
#all three fields must be present at the same time
|
||||||
|
unet_info: UNet2DConditionModel = context.models.load(self.unet.unet)
|
||||||
|
quick_info = context.models.get_config(self.unet.unet)
|
||||||
|
quick_inpaint = quick_info.variant == "inpaint"
|
||||||
|
print(f"quick_inpaint: {quick_info.variant}")
|
||||||
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
|
is_inpaint = unet_info.conv_in.in_channels == 9
|
||||||
|
if is_inpaint:
|
||||||
|
mask = blur_tensor
|
||||||
|
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
||||||
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
if image_tensor.dim() == 3:
|
||||||
|
image_tensor = image_tensor.unsqueeze(0)
|
||||||
|
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
|
masked_image = image_tensor * torch.where(img_mask < 1, 0.0, 1.0) # <1 to include gradient area
|
||||||
|
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||||
|
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||||
|
|
||||||
|
|
||||||
return GradientMaskOutput(
|
return GradientMaskOutput(
|
||||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
|
||||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user