Add better documentation/errors around the possibility that inpainting models may be incorrectly labelled as non-inpainting models.

This commit is contained in:
Ryan Dick 2024-07-26 10:18:10 -04:00
parent 416d29fb83
commit 9db0e9d696
2 changed files with 9 additions and 1 deletions

View File

@ -775,6 +775,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
### inpaint ### inpaint
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents) mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
# prevalent, we will have to revisit how we initialize the inpainting extensions.
if unet_config.variant == ModelVariantType.Inpaint: if unet_config.variant == ModelVariantType.Inpaint:
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask)) ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
elif mask is not None: elif mask is not None:

View File

@ -75,7 +75,11 @@ class InpaintExt(ExtensionBase):
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP) @callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def init_tensors(self, ctx: DenoiseContext): def init_tensors(self, ctx: DenoiseContext):
if not self._is_normal_model(ctx.unet): if not self._is_normal_model(ctx.unet):
raise ValueError("InpaintExt should be used only on normal models!") raise ValueError(
"InpaintExt should be used only on normal (non-inpainting) models. This could be caused by an "
"inpainting model that was incorrectly marked as a non-inpainting model. In some cases, this can be "
"fixed by removing and re-adding the model (so that it gets re-probed)."
)
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)