2024-07-21 19:17:29 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-07-22 20:47:39 +00:00
|
|
|
from typing import TYPE_CHECKING, Optional
|
2024-07-21 19:17:29 +00:00
|
|
|
|
|
|
|
import einops
|
|
|
|
import torch
|
|
|
|
from diffusers import UNet2DConditionModel
|
|
|
|
|
|
|
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
|
|
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
|
|
|
|
|
|
|
|
|
|
|
class InpaintExt(ExtensionBase):
|
2024-07-23 20:34:28 +00:00
|
|
|
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
|
|
|
|
models.
|
|
|
|
"""
|
2024-07-21 19:17:29 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
mask: torch.Tensor,
|
|
|
|
is_gradient_mask: bool,
|
|
|
|
):
|
2024-07-23 20:34:28 +00:00
|
|
|
"""Initialize InpaintExt.
|
|
|
|
Args:
|
|
|
|
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
|
|
|
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
|
|
|
|
inpainted.
|
|
|
|
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
|
|
|
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
|
|
|
1.
|
|
|
|
"""
|
2024-07-21 19:17:29 +00:00
|
|
|
super().__init__()
|
2024-07-22 20:47:39 +00:00
|
|
|
self._mask = mask
|
|
|
|
self._is_gradient_mask = is_gradient_mask
|
2024-07-23 20:34:28 +00:00
|
|
|
|
|
|
|
# Noise, which used to noisify unmasked part of image
|
|
|
|
# if noise provided to context, then it will be used
|
|
|
|
# if no noise provided, then noise will be generated based on seed
|
2024-07-22 20:47:39 +00:00
|
|
|
self._noise: Optional[torch.Tensor] = None
|
2024-07-21 19:17:29 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _is_normal_model(unet: UNet2DConditionModel):
|
2024-07-23 20:34:28 +00:00
|
|
|
""" Checks if the provided UNet belongs to a regular model.
|
|
|
|
The `in_channels` of a UNet vary depending on model type:
|
|
|
|
- normal - 4
|
|
|
|
- depth - 5
|
|
|
|
- inpaint - 9
|
|
|
|
"""
|
2024-07-21 19:17:29 +00:00
|
|
|
return unet.conv_in.in_channels == 4
|
|
|
|
|
|
|
|
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
|
|
|
batch_size = latents.size(0)
|
2024-07-22 20:47:39 +00:00
|
|
|
mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
2024-07-21 19:17:29 +00:00
|
|
|
if t.dim() == 0:
|
|
|
|
# some schedulers expect t to be one-dimensional.
|
|
|
|
# TODO: file diffusers bug about inconsistency?
|
|
|
|
t = einops.repeat(t, "-> batch", batch=batch_size)
|
|
|
|
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
|
|
|
# get very confused about what is happening from step to step when we do that.
|
2024-07-22 20:47:39 +00:00
|
|
|
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)
|
2024-07-21 19:17:29 +00:00
|
|
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
|
|
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
|
|
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
2024-07-22 20:47:39 +00:00
|
|
|
if self._is_gradient_mask:
|
2024-07-23 20:34:28 +00:00
|
|
|
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
|
|
|
mask_bool = mask > threshold
|
2024-07-21 19:17:29 +00:00
|
|
|
masked_input = torch.where(mask_bool, latents, mask_latents)
|
|
|
|
else:
|
|
|
|
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
|
|
|
return masked_input
|
|
|
|
|
|
|
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
|
|
|
def init_tensors(self, ctx: DenoiseContext):
|
|
|
|
if not self._is_normal_model(ctx.unet):
|
2024-07-23 20:34:28 +00:00
|
|
|
raise ValueError("InpaintExt should be used only on normal models!")
|
2024-07-21 19:17:29 +00:00
|
|
|
|
2024-07-22 20:47:39 +00:00
|
|
|
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
2024-07-21 19:17:29 +00:00
|
|
|
|
2024-07-22 20:47:39 +00:00
|
|
|
self._noise = ctx.inputs.noise
|
2024-07-23 20:34:28 +00:00
|
|
|
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
|
|
|
# We still need noise for inpainting, so we generate it from the seed here.
|
2024-07-22 20:47:39 +00:00
|
|
|
if self._noise is None:
|
|
|
|
self._noise = torch.randn(
|
2024-07-21 19:17:29 +00:00
|
|
|
ctx.latents.shape,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device="cpu",
|
|
|
|
generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
|
|
|
|
).to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
|
|
|
|
|
|
|
# TODO: order value
|
|
|
|
@callback(ExtensionCallbackType.PRE_STEP, order=-100)
|
|
|
|
def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
|
|
|
|
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)
|
|
|
|
|
|
|
|
# TODO: order value
|
|
|
|
# TODO: redo this with preview events rewrite
|
|
|
|
@callback(ExtensionCallbackType.POST_STEP, order=-100)
|
|
|
|
def apply_mask_to_step_output(self, ctx: DenoiseContext):
|
|
|
|
timestep = ctx.scheduler.timesteps[-1]
|
|
|
|
if hasattr(ctx.step_output, "denoised"):
|
|
|
|
ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)
|
|
|
|
elif hasattr(ctx.step_output, "pred_original_sample"):
|
|
|
|
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)
|
|
|
|
else:
|
|
|
|
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)
|
|
|
|
|
|
|
|
# TODO: should here be used order?
|
|
|
|
# restore unmasked part after the last step is completed
|
|
|
|
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
|
|
|
def restore_unmasked(self, ctx: DenoiseContext):
|
2024-07-22 20:47:39 +00:00
|
|
|
if self._is_gradient_mask:
|
|
|
|
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
|
2024-07-21 19:17:29 +00:00
|
|
|
else:
|
2024-07-22 20:47:39 +00:00
|
|
|
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)
|