mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Handle inpainting on normal models
This commit is contained in:
parent
9e7b470189
commit
58f3072b91
@ -37,7 +37,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||||
@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
@ -792,10 +793,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
ext_manager.add_extension(PreviewExt(step_callback))
|
ext_manager.add_extension(PreviewExt(step_callback))
|
||||||
|
|
||||||
### inpaint
|
### inpaint
|
||||||
# TODO: add inpainting on normal model
|
|
||||||
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
|
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
if unet_config.variant == "inpaint": # ModelVariantType.Inpaint:
|
if unet_config.variant == ModelVariantType.Inpaint:
|
||||||
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
|
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
|
||||||
|
elif mask is not None:
|
||||||
|
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
|
||||||
|
|
||||||
# ext: t2i/ip adapter
|
# ext: t2i/ip adapter
|
||||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||||
|
91
invokeai/backend/stable_diffusion/extensions/inpaint.py
Normal file
91
invokeai/backend/stable_diffusion/extensions/inpaint.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintExt(ExtensionBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
is_gradient_mask: bool,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.mask = mask
|
||||||
|
self.is_gradient_mask = is_gradient_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_normal_model(unet: UNet2DConditionModel):
|
||||||
|
return unet.conv_in.in_channels == 4
|
||||||
|
|
||||||
|
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size = latents.size(0)
|
||||||
|
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
|
if t.dim() == 0:
|
||||||
|
# some schedulers expect t to be one-dimensional.
|
||||||
|
# TODO: file diffusers bug about inconsistency?
|
||||||
|
t = einops.repeat(t, "-> batch", batch=batch_size)
|
||||||
|
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||||
|
# get very confused about what is happening from step to step when we do that.
|
||||||
|
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self.noise, t)
|
||||||
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||||
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||||
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
|
if self.is_gradient_mask:
|
||||||
|
threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
||||||
|
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
|
||||||
|
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||||
|
else:
|
||||||
|
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||||
|
return masked_input
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def init_tensors(self, ctx: DenoiseContext):
|
||||||
|
if not self._is_normal_model(ctx.unet):
|
||||||
|
raise Exception("InpaintExt should be used only on normal models!")
|
||||||
|
|
||||||
|
self.mask = self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
self.noise = ctx.inputs.noise
|
||||||
|
if self.noise is None:
|
||||||
|
self.noise = torch.randn(
|
||||||
|
ctx.latents.shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
|
||||||
|
).to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
# TODO: order value
|
||||||
|
@callback(ExtensionCallbackType.PRE_STEP, order=-100)
|
||||||
|
def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
|
||||||
|
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)
|
||||||
|
|
||||||
|
# TODO: order value
|
||||||
|
# TODO: redo this with preview events rewrite
|
||||||
|
@callback(ExtensionCallbackType.POST_STEP, order=-100)
|
||||||
|
def apply_mask_to_step_output(self, ctx: DenoiseContext):
|
||||||
|
timestep = ctx.scheduler.timesteps[-1]
|
||||||
|
if hasattr(ctx.step_output, "denoised"):
|
||||||
|
ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)
|
||||||
|
elif hasattr(ctx.step_output, "pred_original_sample"):
|
||||||
|
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)
|
||||||
|
else:
|
||||||
|
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)
|
||||||
|
|
||||||
|
# TODO: should here be used order?
|
||||||
|
# restore unmasked part after the last step is completed
|
||||||
|
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
||||||
|
def restore_unmasked(self, ctx: DenoiseContext):
|
||||||
|
if self.is_gradient_mask:
|
||||||
|
ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.inputs.orig_latents)
|
||||||
|
else:
|
||||||
|
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self.mask)
|
@ -31,7 +31,7 @@ class InpaintModelExt(ExtensionBase):
|
|||||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
def init_tensors(self, ctx: DenoiseContext):
|
def init_tensors(self, ctx: DenoiseContext):
|
||||||
if not self._is_inpaint_model(ctx.unet):
|
if not self._is_inpaint_model(ctx.unet):
|
||||||
raise Exception("InpaintModelExt should be used only on inpaint model!")
|
raise Exception("InpaintModelExt should be used only on inpaint models!")
|
||||||
|
|
||||||
if self.mask is None:
|
if self.mask is None:
|
||||||
self.mask = torch.ones_like(ctx.latents[:1, :1])
|
self.mask = torch.ones_like(ctx.latents[:1, :1])
|
||||||
|
Loading…
Reference in New Issue
Block a user