mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add unet check in gradient mask node
This commit is contained in:
parent
0063014f2b
commit
c094bad233
@ -200,6 +200,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
|
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
|
||||||
|
default=None,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
title="[OPTIONAL] UNet",
|
title="[OPTIONAL] UNet",
|
||||||
ui_order=5,
|
ui_order=5,
|
||||||
@ -263,12 +264,9 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
# Check for Inpaint model and generate masked_latents
|
# Check for Inpaint model and generate masked_latents
|
||||||
if self.unet is not None and self.vae is not None and self.image is not None:
|
if self.unet is not None and self.vae is not None and self.image is not None:
|
||||||
#all three fields must be present at the same time
|
#all three fields must be present at the same time
|
||||||
unet_info: UNet2DConditionModel = context.models.load(self.unet.unet)
|
unet_info = context.models.load(self.unet.unet)
|
||||||
quick_info = context.models.get_config(self.unet.unet)
|
|
||||||
quick_inpaint = quick_info.variant == "inpaint"
|
|
||||||
print(f"quick_inpaint: {quick_info.variant}")
|
|
||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
is_inpaint = unet_info.conv_in.in_channels == 9
|
is_inpaint = unet_info.model.conv_in.in_channels == 9
|
||||||
if is_inpaint:
|
if is_inpaint:
|
||||||
mask = blur_tensor
|
mask = blur_tensor
|
||||||
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user