mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Simplify handling of inpainting models. Improve the in-code documentation around inpainting.
This commit is contained in:
parent
875673c9ba
commit
22704dd542
@ -38,40 +38,6 @@ class PipelineIntermediateState:
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskLatents:
|
||||
"""Add the channels required for inpainting model input.
|
||||
|
||||
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
|
||||
and the latent encoding of the base image.
|
||||
|
||||
This class assumes the same mask and base image should apply to all items in the batch.
|
||||
"""
|
||||
|
||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
mask: torch.Tensor
|
||||
initial_image_latents: torch.Tensor
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
model_input = self.add_mask_channels(latents)
|
||||
return self.forward(model_input, t, text_embeddings, **kwargs)
|
||||
|
||||
def add_mask_channels(self, latents):
|
||||
batch_size = latents.size(0)
|
||||
# duplicate mask and latents for each batch
|
||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
# add mask and image as additional channels
|
||||
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
|
||||
return model_input
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
mask: torch.Tensor
|
||||
@ -273,6 +239,32 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
raise Exception("Should not be called")
|
||||
|
||||
def add_inpainting_channels_to_latents(
|
||||
self, latents: torch.Tensor, masked_ref_image_latents: torch.Tensor, inpainting_mask: torch.Tensor
|
||||
):
|
||||
"""Given a `latents` tensor, adds the mask and image latents channels required for inpainting.
|
||||
|
||||
Standard (non-inpainting) SD UNet models expect an input with shape (N, 4, H, W). Inpainting models expect an
|
||||
input of shape (N, 9, H, W). The 9 channels are defined as follows:
|
||||
- Channel 0-3: The latents being denoised.
|
||||
- Channel 4: The mask indicating which parts of the image are being inpainted.
|
||||
- Channel 5-8: The latent representation of the masked reference image being inpainted.
|
||||
|
||||
This function assumes that the same mask and base image should apply to all items in the batch.
|
||||
"""
|
||||
# Validate assumptions about input tensor shapes.
|
||||
batch_size, latent_channels, latent_height, latent_width = latents.shape
|
||||
assert latent_channels == 4
|
||||
assert masked_ref_image_latents.shape == [1, 4, latent_height, latent_width]
|
||||
assert inpainting_mask == [1, 1, latent_height, latent_width]
|
||||
|
||||
# Repeat original_image_latents and inpainting_mask to match the latents batch size.
|
||||
original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1)
|
||||
inpainting_mask = inpainting_mask.expand(batch_size, -1, -1, -1)
|
||||
|
||||
# Concatenate along the channel dimension.
|
||||
return torch.cat([latents, inpainting_mask, original_image_latents], dim=1)
|
||||
|
||||
def latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
@ -302,22 +294,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
if noise is not None:
|
||||
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
|
||||
# full noise. Investigate the history of why this got commented out.
|
||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
|
||||
# Handle mask guidance (a.k.a. inpainting).
|
||||
mask_guidance: AddsMaskGuidance | None = None
|
||||
if mask is not None:
|
||||
if is_inpainting_model(self.unet):
|
||||
if masked_latents is None:
|
||||
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
||||
if mask is not None and not is_inpainting_model(self.unet):
|
||||
# We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will
|
||||
# apply mask guidance to the latents.
|
||||
|
||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
||||
self._unet_forward, mask, masked_latents
|
||||
)
|
||||
else:
|
||||
# if no noise provided, noisify unmasked area based on seed
|
||||
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
# We still need noise for inpainting, so we generate it from the seed here.
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
@ -326,9 +317,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, is_gradient_mask)
|
||||
mask_guidance = AddsMaskGuidance(
|
||||
mask=mask,
|
||||
mask_latents=orig_latents,
|
||||
scheduler=self.scheduler,
|
||||
noise=noise,
|
||||
is_gradient_mask=is_gradient_mask,
|
||||
)
|
||||
|
||||
try:
|
||||
use_ip_adapter = ip_adapter_data is not None
|
||||
use_regional_prompting = (
|
||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||
@ -338,10 +334,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
if use_ip_adapter or use_regional_prompting:
|
||||
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
||||
[
|
||||
{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks}
|
||||
for ipa in ip_adapter_data
|
||||
]
|
||||
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
|
||||
if use_ip_adapter
|
||||
else None
|
||||
)
|
||||
@ -362,13 +355,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t = t.expand(batch_size)
|
||||
step_output = self.step(
|
||||
batched_t,
|
||||
latents,
|
||||
conditioning_data,
|
||||
t=batched_t,
|
||||
latents=latents,
|
||||
conditioning_data=conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
mask_guidance=mask_guidance,
|
||||
mask=mask,
|
||||
masked_latents=masked_latents,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
@ -387,9 +382,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
)
|
||||
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
# restore unmasked part after the last step is completed
|
||||
# in-process masking happens before each step
|
||||
if mask is not None:
|
||||
@ -411,7 +403,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
mask_guidance: AddsMaskGuidance | None = None,
|
||||
mask_guidance: AddsMaskGuidance | None,
|
||||
mask: torch.Tensor | None,
|
||||
masked_latents: torch.Tensor | None,
|
||||
control_data: list[ControlNetData] | None = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
@ -419,7 +413,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
|
||||
# Handle masked image-to-image (a.k.a inpainting).
|
||||
if mask_guidance is not None:
|
||||
# NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...).
|
||||
latents = mask_guidance(latents, timestep)
|
||||
|
||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||
@ -468,6 +464,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
down_intrablock_additional_residuals = accum_adapter_state
|
||||
|
||||
# Handle inpainting models.
|
||||
if is_inpainting_model(self.unet):
|
||||
# NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after*
|
||||
# self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image
|
||||
# latents.
|
||||
if mask is not None:
|
||||
if masked_latents is None:
|
||||
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||
latent_model_input = self.add_inpainting_channels_to_latents(
|
||||
latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask
|
||||
)
|
||||
else:
|
||||
# We are using an inpainting model, but no mask was provided, so we are not really "inpainting".
|
||||
# We generate a global mask and empty original image so that we can still generate in this
|
||||
# configuration.
|
||||
# TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting
|
||||
# to do this.
|
||||
# TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake'
|
||||
# mask and original image once rather than on every denoising step.
|
||||
latent_model_input = self.add_inpainting_channels_to_latents(
|
||||
latents=latent_model_input,
|
||||
masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]),
|
||||
inpainting_mask=torch.ones_like(latent_model_input[:1, :1]),
|
||||
)
|
||||
|
||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||
sample=latent_model_input,
|
||||
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||
|
Loading…
Reference in New Issue
Block a user