mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Modular backend - inpaint (#6643)
## Summary Code for inpainting and inpaint models handling from https://github.com/invoke-ai/InvokeAI/pull/6577. Separated in 2 extensions as discussed briefly before, so wait for discussion about such implementation. ## Related Issues / Discussions #6606 https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d ## QA Instructions Run with and without set `USE_MODULAR_DENOISE` environment. Try and compare outputs between backends in cases: - Normal generation on inpaint model - Inpainting on inpaint model - Inpainting on normal model ## Merge Plan Nope. If you think that there should be some kind of tests - feel free to add. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
This commit is contained in:
commit
2ad13ac7eb
@ -37,7 +37,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||||
@ -60,6 +60,8 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB
|
|||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||||
@ -736,7 +738,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
else:
|
else:
|
||||||
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
||||||
|
|
||||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
return mask, masked_latents, self.denoise_mask.gradient
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_noise_and_latents(
|
def prepare_noise_and_latents(
|
||||||
@ -794,10 +796,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
dtype = TorchDevice.choose_torch_dtype()
|
dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
latents = latents.to(device=device, dtype=dtype)
|
|
||||||
if noise is not None:
|
|
||||||
noise = noise.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
_, _, latent_height, latent_width = latents.shape
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
|
||||||
conditioning_data = self.get_conditioning_data(
|
conditioning_data = self.get_conditioning_data(
|
||||||
@ -830,21 +828,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
denoise_ctx = DenoiseContext(
|
|
||||||
inputs=DenoiseInputs(
|
|
||||||
orig_latents=latents,
|
|
||||||
timesteps=timesteps,
|
|
||||||
init_timestep=init_timestep,
|
|
||||||
noise=noise,
|
|
||||||
seed=seed,
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
attention_processor_cls=CustomAttnProcessor2_0,
|
|
||||||
),
|
|
||||||
unet=None,
|
|
||||||
scheduler=scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||||
unet_config = context.models.get_config(self.unet.unet.key)
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
|
||||||
@ -866,6 +849,36 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if self.unet.seamless_axes:
|
if self.unet.seamless_axes:
|
||||||
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
||||||
|
|
||||||
|
### inpaint
|
||||||
|
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
|
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
|
||||||
|
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
|
||||||
|
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
|
||||||
|
# prevalent, we will have to revisit how we initialize the inpainting extensions.
|
||||||
|
if unet_config.variant == ModelVariantType.Inpaint:
|
||||||
|
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
|
||||||
|
elif mask is not None:
|
||||||
|
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
|
||||||
|
|
||||||
|
# Initialize context for modular denoise
|
||||||
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
if noise is not None:
|
||||||
|
noise = noise.to(device=device, dtype=dtype)
|
||||||
|
denoise_ctx = DenoiseContext(
|
||||||
|
inputs=DenoiseInputs(
|
||||||
|
orig_latents=latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
init_timestep=init_timestep,
|
||||||
|
noise=noise,
|
||||||
|
seed=seed,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
attention_processor_cls=CustomAttnProcessor2_0,
|
||||||
|
),
|
||||||
|
unet=None,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
# context for loading additional models
|
# context for loading additional models
|
||||||
with ExitStack() as exit_stack:
|
with ExitStack() as exit_stack:
|
||||||
# later should be smth like:
|
# later should be smth like:
|
||||||
@ -905,6 +918,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
|
||||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
|
# At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
|
||||||
|
# We invert the mask here for compatibility with the old backend implementation.
|
||||||
|
if mask is not None:
|
||||||
|
mask = 1 - mask
|
||||||
|
|
||||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
# below. Investigate whether this is appropriate.
|
# below. Investigate whether this is appropriate.
|
||||||
|
120
invokeai/backend/stable_diffusion/extensions/inpaint.py
Normal file
120
invokeai/backend/stable_diffusion/extensions/inpaint.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintExt(ExtensionBase):
|
||||||
|
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
is_gradient_mask: bool,
|
||||||
|
):
|
||||||
|
"""Initialize InpaintExt.
|
||||||
|
Args:
|
||||||
|
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||||
|
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
|
||||||
|
inpainted.
|
||||||
|
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||||
|
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||||
|
1.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._mask = mask
|
||||||
|
self._is_gradient_mask = is_gradient_mask
|
||||||
|
|
||||||
|
# Noise, which used to noisify unmasked part of image
|
||||||
|
# if noise provided to context, then it will be used
|
||||||
|
# if no noise provided, then noise will be generated based on seed
|
||||||
|
self._noise: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_normal_model(unet: UNet2DConditionModel):
|
||||||
|
"""Checks if the provided UNet belongs to a regular model.
|
||||||
|
The `in_channels` of a UNet vary depending on model type:
|
||||||
|
- normal - 4
|
||||||
|
- depth - 5
|
||||||
|
- inpaint - 9
|
||||||
|
"""
|
||||||
|
return unet.conv_in.in_channels == 4
|
||||||
|
|
||||||
|
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size = latents.size(0)
|
||||||
|
mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
|
if t.dim() == 0:
|
||||||
|
# some schedulers expect t to be one-dimensional.
|
||||||
|
# TODO: file diffusers bug about inconsistency?
|
||||||
|
t = einops.repeat(t, "-> batch", batch=batch_size)
|
||||||
|
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||||
|
# get very confused about what is happening from step to step when we do that.
|
||||||
|
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)
|
||||||
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||||
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||||
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
|
if self._is_gradient_mask:
|
||||||
|
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
||||||
|
mask_bool = mask < 1 - threshold
|
||||||
|
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||||
|
else:
|
||||||
|
masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))
|
||||||
|
return masked_input
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def init_tensors(self, ctx: DenoiseContext):
|
||||||
|
if not self._is_normal_model(ctx.unet):
|
||||||
|
raise ValueError(
|
||||||
|
"InpaintExt should be used only on normal (non-inpainting) models. This could be caused by an "
|
||||||
|
"inpainting model that was incorrectly marked as a non-inpainting model. In some cases, this can be "
|
||||||
|
"fixed by removing and re-adding the model (so that it gets re-probed)."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
self._noise = ctx.inputs.noise
|
||||||
|
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||||
|
# We still need noise for inpainting, so we generate it from the seed here.
|
||||||
|
if self._noise is None:
|
||||||
|
self._noise = torch.randn(
|
||||||
|
ctx.latents.shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
|
||||||
|
).to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
# Use negative order to make extensions with default order work with patched latents
|
||||||
|
@callback(ExtensionCallbackType.PRE_STEP, order=-100)
|
||||||
|
def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
|
||||||
|
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)
|
||||||
|
|
||||||
|
# TODO: redo this with preview events rewrite
|
||||||
|
# Use negative order to make extensions with default order work with patched latents
|
||||||
|
@callback(ExtensionCallbackType.POST_STEP, order=-100)
|
||||||
|
def apply_mask_to_step_output(self, ctx: DenoiseContext):
|
||||||
|
timestep = ctx.scheduler.timesteps[-1]
|
||||||
|
if hasattr(ctx.step_output, "denoised"):
|
||||||
|
ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)
|
||||||
|
elif hasattr(ctx.step_output, "pred_original_sample"):
|
||||||
|
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)
|
||||||
|
else:
|
||||||
|
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)
|
||||||
|
|
||||||
|
# Restore unmasked part after the last step is completed
|
||||||
|
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
||||||
|
def restore_unmasked(self, ctx: DenoiseContext):
|
||||||
|
if self._is_gradient_mask:
|
||||||
|
ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)
|
||||||
|
else:
|
||||||
|
ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)
|
@ -0,0 +1,88 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintModelExt(ExtensionBase):
|
||||||
|
"""An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
masked_latents: Optional[torch.Tensor],
|
||||||
|
is_gradient_mask: bool,
|
||||||
|
):
|
||||||
|
"""Initialize InpaintModelExt.
|
||||||
|
Args:
|
||||||
|
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||||
|
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
|
||||||
|
inpainted.
|
||||||
|
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
|
||||||
|
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
|
||||||
|
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||||
|
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||||
|
1.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if mask is not None and masked_latents is None:
|
||||||
|
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||||
|
|
||||||
|
# Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint
|
||||||
|
self._mask = None
|
||||||
|
if mask is not None:
|
||||||
|
self._mask = 1 - mask
|
||||||
|
self._masked_latents = masked_latents
|
||||||
|
self._is_gradient_mask = is_gradient_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_inpaint_model(unet: UNet2DConditionModel):
|
||||||
|
"""Checks if the provided UNet belongs to a regular model.
|
||||||
|
The `in_channels` of a UNet vary depending on model type:
|
||||||
|
- normal - 4
|
||||||
|
- depth - 5
|
||||||
|
- inpaint - 9
|
||||||
|
"""
|
||||||
|
return unet.conv_in.in_channels == 9
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def init_tensors(self, ctx: DenoiseContext):
|
||||||
|
if not self._is_inpaint_model(ctx.unet):
|
||||||
|
raise ValueError("InpaintModelExt should be used only on inpaint models!")
|
||||||
|
|
||||||
|
if self._mask is None:
|
||||||
|
self._mask = torch.ones_like(ctx.latents[:1, :1])
|
||||||
|
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
if self._masked_latents is None:
|
||||||
|
self._masked_latents = torch.zeros_like(ctx.latents[:1])
|
||||||
|
self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
# Do last so that other extensions works with normal latents
|
||||||
|
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
|
||||||
|
def append_inpaint_layers(self, ctx: DenoiseContext):
|
||||||
|
batch_size = ctx.unet_kwargs.sample.shape[0]
|
||||||
|
b_mask = torch.cat([self._mask] * batch_size)
|
||||||
|
b_masked_latents = torch.cat([self._masked_latents] * batch_size)
|
||||||
|
ctx.unet_kwargs.sample = torch.cat(
|
||||||
|
[ctx.unet_kwargs.sample, b_mask, b_masked_latents],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Restore unmasked part as inpaint model can change unmasked part slightly
|
||||||
|
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
||||||
|
def restore_unmasked(self, ctx: DenoiseContext):
|
||||||
|
if self._is_gradient_mask:
|
||||||
|
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
|
||||||
|
else:
|
||||||
|
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)
|
Loading…
Reference in New Issue
Block a user