From 1bef13db37d098094cdaf1ae4c023b9c04ad5d35 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Apr 2024 20:25:12 +1000 Subject: [PATCH] feat(nodes): restore unet check on `CreateGradientMaskInvocation` Special handling for inpainting models --- invokeai/app/invocations/latent.py | 37 ++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 0f579cd9e7..66d6a70e40 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -51,6 +51,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType, LoadedModel +from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( @@ -204,6 +205,13 @@ class CreateGradientMaskInvocation(BaseInvocation): title="[OPTIONAL] Image", ui_order=6, ) + unet: Optional[UNetField] = InputField( + description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE", + default=None, + input=Input.Connection, + title="[OPTIONAL] UNet", + ui_order=5, + ) vae: Optional[VAEField] = InputField( default=None, description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE", @@ -253,18 +261,23 @@ class CreateGradientMaskInvocation(BaseInvocation): expanded_image_dto = context.images.save(expanded_mask_image) masked_latents_name = None - if self.vae is not None and self.image is not None: - # both fields must be present at the same time - mask = blur_tensor - vae_info: LoadedModel = context.models.load(self.vae.vae) - image = context.images.get_pil(self.image.image_name) - image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image_tensor.dim() == 3: - image_tensor = image_tensor.unsqueeze(0) - img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) - masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) - masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) - masked_latents_name = context.tensors.save(tensor=masked_latents) + if self.unet is not None and self.vae is not None and self.image is not None: + # all three fields must be present at the same time + main_model_config = context.models.get_config(self.unet.unet.key) + assert isinstance(main_model_config, MainConfigBase) + if main_model_config.variant is ModelVariantType.Inpaint: + mask = blur_tensor + vae_info: LoadedModel = context.models.load(self.vae.vae) + image = context.images.get_pil(self.image.image_name) + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = image_tensor.unsqueeze(0) + img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) + masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) + masked_latents = ImageToLatentsInvocation.vae_encode( + vae_info, self.fp32, self.tiled, masked_image.clone() + ) + masked_latents_name = context.tensors.save(tensor=masked_latents) return GradientMaskOutput( denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),