From 9e7b470189e3f8f9cee7baace4ccb37be9f8d61e Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 21 Jul 2024 20:45:55 +0300 Subject: [PATCH] Handle inpaint models --- invokeai/app/invocations/denoise_latents.py | 7 ++ .../extensions/inpaint_model.py | 66 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 invokeai/backend/stable_diffusion/extensions/inpaint_model.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index ccacc3303c..1f28252408 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP @@ -790,6 +791,12 @@ class DenoiseLatentsInvocation(BaseInvocation): ext_manager.add_extension(PreviewExt(step_callback)) + ### inpaint + # TODO: add inpainting on normal model + mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents) + if unet_config.variant == "inpaint": # ModelVariantType.Inpaint: + ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask)) + # ext: t2i/ip adapter ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py new file mode 100644 index 0000000000..190e0fa931 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + + +class InpaintModelExt(ExtensionBase): + def __init__( + self, + mask: Optional[torch.Tensor], + masked_latents: Optional[torch.Tensor], + is_gradient_mask: bool, + ): + super().__init__() + self.mask = mask + self.masked_latents = masked_latents + self.is_gradient_mask = is_gradient_mask + + @staticmethod + def _is_inpaint_model(unet: UNet2DConditionModel): + return unet.conv_in.in_channels == 9 + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def init_tensors(self, ctx: DenoiseContext): + if not self._is_inpaint_model(ctx.unet): + raise Exception("InpaintModelExt should be used only on inpaint model!") + + if self.mask is None: + self.mask = torch.ones_like(ctx.latents[:1, :1]) + self.mask = self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + + if self.masked_latents is None: + self.masked_latents = torch.zeros_like(ctx.latents[:1]) + self.masked_latents = self.masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + + # TODO: any ideas about order value? + # do last so that other extensions works with normal latents + @callback(ExtensionCallbackType.PRE_UNET, order=1000) + def append_inpaint_layers(self, ctx: DenoiseContext): + batch_size = ctx.unet_kwargs.sample.shape[0] + b_mask = torch.cat([self.mask] * batch_size) + b_masked_latents = torch.cat([self.masked_latents] * batch_size) + ctx.unet_kwargs.sample = torch.cat( + [ctx.unet_kwargs.sample, b_mask, b_masked_latents], + dim=1, + ) + + # TODO: should here be used order? + # restore unmasked part as inpaint model can change unmasked part slightly + @callback(ExtensionCallbackType.POST_DENOISE_LOOP) + def restore_unmasked(self, ctx: DenoiseContext): + if self.mask is None: + return + + if self.is_gradient_mask: + ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.inputs.orig_latents) + else: + ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self.mask)