add unet check in gradient mask node

This commit is contained in:
dunkeroni 2024-04-11 18:15:00 -04:00 committed by Kent Keirsey
parent 0063014f2b
commit c094bad233

View File

@ -200,6 +200,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
)
unet: Optional[UNetField] = InputField(
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
default=None,
input=Input.Connection,
title="[OPTIONAL] UNet",
ui_order=5,
@ -263,12 +264,9 @@ class CreateGradientMaskInvocation(BaseInvocation):
# Check for Inpaint model and generate masked_latents
if self.unet is not None and self.vae is not None and self.image is not None:
#all three fields must be present at the same time
unet_info: UNet2DConditionModel = context.models.load(self.unet.unet)
quick_info = context.models.get_config(self.unet.unet)
quick_inpaint = quick_info.variant == "inpaint"
print(f"quick_inpaint: {quick_info.variant}")
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
is_inpaint = unet_info.conv_in.in_channels == 9
is_inpaint = unet_info.model.conv_in.in_channels == 9
if is_inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)