Handle inpainting on normal models

This commit is contained in:
Sergey Borisov
2024-07-21 22:17:29 +03:00
parent 9e7b470189
commit 58f3072b91
3 changed files with 97 additions and 4 deletions

View File

@ -37,7 +37,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
@ -792,10 +793,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
ext_manager.add_extension(PreviewExt(step_callback))
### inpaint
# TODO: add inpainting on normal model
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
if unet_config.variant == "inpaint": # ModelVariantType.Inpaint:
if unet_config.variant == ModelVariantType.Inpaint:
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
elif mask is not None:
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)