From 22704dd5425f03ff0286c9010dc0a96d4279c3b8 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 13 Jun 2024 17:58:15 -0400 Subject: [PATCH] Simplify handling of inpainting models. Improve the in-code documentation around inpainting. --- .../stable_diffusion/diffusers_pipeline.py | 227 ++++++++++-------- 1 file changed, 124 insertions(+), 103 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index ab2f12be1d..7752c632aa 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -38,40 +38,6 @@ class PipelineIntermediateState: predicted_original: Optional[torch.Tensor] = None -@dataclass -class AddsMaskLatents: - """Add the channels required for inpainting model input. - - The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask - and the latent encoding of the base image. - - This class assumes the same mask and base image should apply to all items in the batch. - """ - - forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] - mask: torch.Tensor - initial_image_latents: torch.Tensor - - def __call__( - self, - latents: torch.Tensor, - t: torch.Tensor, - text_embeddings: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - model_input = self.add_mask_channels(latents) - return self.forward(model_input, t, text_embeddings, **kwargs) - - def add_mask_channels(self, latents): - batch_size = latents.size(0) - # duplicate mask and latents for each batch - mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size) - image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) - # add mask and image as additional channels - model_input, _ = einops.pack([latents, mask, image_latents], "b * h w") - return model_input - - @dataclass class AddsMaskGuidance: mask: torch.Tensor @@ -273,6 +239,32 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): raise Exception("Should not be called") + def add_inpainting_channels_to_latents( + self, latents: torch.Tensor, masked_ref_image_latents: torch.Tensor, inpainting_mask: torch.Tensor + ): + """Given a `latents` tensor, adds the mask and image latents channels required for inpainting. + + Standard (non-inpainting) SD UNet models expect an input with shape (N, 4, H, W). Inpainting models expect an + input of shape (N, 9, H, W). The 9 channels are defined as follows: + - Channel 0-3: The latents being denoised. + - Channel 4: The mask indicating which parts of the image are being inpainted. + - Channel 5-8: The latent representation of the masked reference image being inpainted. + + This function assumes that the same mask and base image should apply to all items in the batch. + """ + # Validate assumptions about input tensor shapes. + batch_size, latent_channels, latent_height, latent_width = latents.shape + assert latent_channels == 4 + assert masked_ref_image_latents.shape == [1, 4, latent_height, latent_width] + assert inpainting_mask == [1, 1, latent_height, latent_width] + + # Repeat original_image_latents and inpainting_mask to match the latents batch size. + original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1) + inpainting_mask = inpainting_mask.expand(batch_size, -1, -1, -1) + + # Concatenate along the channel dimension. + return torch.cat([latents, inpainting_mask, original_image_latents], dim=1) + def latents_from_embeddings( self, latents: torch.Tensor, @@ -302,94 +294,94 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # noise can be None if the latents have already been noised (e.g. when running the SDXL refiner). if noise is not None: + # TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with + # full noise. Investigate the history of why this got commented out. # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers latents = self.scheduler.add_noise(latents, noise, batched_init_timestep) self._adjust_memory_efficient_attention(latents) + # Handle mask guidance (a.k.a. inpainting). mask_guidance: AddsMaskGuidance | None = None - if mask is not None: - if is_inpainting_model(self.unet): - if masked_latents is None: - raise Exception("Source image required for inpaint mask when inpaint model used!") + if mask is not None and not is_inpainting_model(self.unet): + # We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will + # apply mask guidance to the latents. - self.invokeai_diffuser.model_forward_callback = AddsMaskLatents( - self._unet_forward, mask, masked_latents - ) - else: - # if no noise provided, noisify unmasked area based on seed - if noise is None: - noise = torch.randn( - orig_latents.shape, - dtype=torch.float32, - device="cpu", - generator=torch.Generator(device="cpu").manual_seed(seed), - ).to(device=orig_latents.device, dtype=orig_latents.dtype) + # 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner). + # We still need noise for inpainting, so we generate it from the seed here. + if noise is None: + noise = torch.randn( + orig_latents.shape, + dtype=torch.float32, + device="cpu", + generator=torch.Generator(device="cpu").manual_seed(seed), + ).to(device=orig_latents.device, dtype=orig_latents.dtype) - mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, is_gradient_mask) - - try: - use_ip_adapter = ip_adapter_data is not None - use_regional_prompting = ( - conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None + mask_guidance = AddsMaskGuidance( + mask=mask, + mask_latents=orig_latents, + scheduler=self.scheduler, + noise=noise, + is_gradient_mask=is_gradient_mask, ) - unet_attention_patcher = None - attn_ctx = nullcontext() - if use_ip_adapter or use_regional_prompting: - ip_adapters: Optional[List[UNetIPAdapterData]] = ( - [ - {"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} - for ipa in ip_adapter_data - ] - if use_ip_adapter - else None + use_ip_adapter = ip_adapter_data is not None + use_regional_prompting = ( + conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None + ) + unet_attention_patcher = None + attn_ctx = nullcontext() + + if use_ip_adapter or use_regional_prompting: + ip_adapters: Optional[List[UNetIPAdapterData]] = ( + [{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data] + if use_ip_adapter + else None + ) + unet_attention_patcher = UNetAttentionPatcher(ip_adapters) + attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) + + with attn_ctx: + callback( + PipelineIntermediateState( + step=-1, + order=self.scheduler.order, + total_steps=len(timesteps), + timestep=self.scheduler.config.num_train_timesteps, + latents=latents, ) - unet_attention_patcher = UNetAttentionPatcher(ip_adapters) - attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) + ) + + for i, t in enumerate(self.progress_bar(timesteps)): + batched_t = t.expand(batch_size) + step_output = self.step( + t=batched_t, + latents=latents, + conditioning_data=conditioning_data, + step_index=i, + total_step_count=len(timesteps), + scheduler_step_kwargs=scheduler_step_kwargs, + mask_guidance=mask_guidance, + mask=mask, + masked_latents=masked_latents, + control_data=control_data, + ip_adapter_data=ip_adapter_data, + t2i_adapter_data=t2i_adapter_data, + ) + latents = step_output.prev_sample + predicted_original = getattr(step_output, "pred_original_sample", None) - with attn_ctx: callback( PipelineIntermediateState( - step=-1, + step=i, order=self.scheduler.order, total_steps=len(timesteps), - timestep=self.scheduler.config.num_train_timesteps, + timestep=int(t), latents=latents, + predicted_original=predicted_original, ) ) - for i, t in enumerate(self.progress_bar(timesteps)): - batched_t = t.expand(batch_size) - step_output = self.step( - batched_t, - latents, - conditioning_data, - step_index=i, - total_step_count=len(timesteps), - scheduler_step_kwargs=scheduler_step_kwargs, - mask_guidance=mask_guidance, - control_data=control_data, - ip_adapter_data=ip_adapter_data, - t2i_adapter_data=t2i_adapter_data, - ) - latents = step_output.prev_sample - predicted_original = getattr(step_output, "pred_original_sample", None) - - callback( - PipelineIntermediateState( - step=i, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=int(t), - latents=latents, - predicted_original=predicted_original, - ) - ) - - finally: - self.invokeai_diffuser.model_forward_callback = self._unet_forward - # restore unmasked part after the last step is completed # in-process masking happens before each step if mask is not None: @@ -411,7 +403,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index: int, total_step_count: int, scheduler_step_kwargs: dict[str, Any], - mask_guidance: AddsMaskGuidance | None = None, + mask_guidance: AddsMaskGuidance | None, + mask: torch.Tensor | None, + masked_latents: torch.Tensor | None, control_data: list[ControlNetData] | None = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, @@ -419,7 +413,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] + # Handle masked image-to-image (a.k.a inpainting). if mask_guidance is not None: + # NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...). latents = mask_guidance(latents, timestep) # TODO: should this scaling happen here or inside self._unet_forward? @@ -468,6 +464,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): down_intrablock_additional_residuals = accum_adapter_state + # Handle inpainting models. + if is_inpainting_model(self.unet): + # NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after* + # self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image + # latents. + if mask is not None: + if masked_latents is None: + raise ValueError("Source image required for inpaint mask when inpaint model used!") + latent_model_input = self.add_inpainting_channels_to_latents( + latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask + ) + else: + # We are using an inpainting model, but no mask was provided, so we are not really "inpainting". + # We generate a global mask and empty original image so that we can still generate in this + # configuration. + # TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting + # to do this. + # TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake' + # mask and original image once rather than on every denoising step. + latent_model_input = self.add_inpainting_channels_to_latents( + latents=latent_model_input, + masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]), + inpainting_mask=torch.ones_like(latent_model_input[:1, :1]), + ) + uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step( sample=latent_model_input, timestep=t, # TODO: debug how handled batched and non batched timesteps