From 60d1e686d83b3ecd1cd719174b75cfb69844add1 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 14 Jun 2024 14:35:48 -0400 Subject: [PATCH] Copy StableDiffusionGeneratorPipeline as a starting point for a new MultiDiffusionPipeline. --- .../multi_diffusion_pipeline.py | 282 ++++++++++++++++++ 1 file changed, 282 insertions(+) create mode 100644 invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py new file mode 100644 index 0000000000..ef3216cd45 --- /dev/null +++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import math +from contextlib import nullcontext +from typing import Any, Callable, List, Optional + +import torch + +from invokeai.backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData +from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData + + +class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): + """A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising.""" + + def latents_from_embeddings( + self, + latents: torch.Tensor, + scheduler_step_kwargs: dict[str, Any], + conditioning_data: TextConditioningData, + noise: Optional[torch.Tensor], + seed: int, + timesteps: torch.Tensor, + init_timestep: torch.Tensor, + callback: Callable[[PipelineIntermediateState], None], + control_data: list[ControlNetData] | None = None, + ip_adapter_data: Optional[list[IPAdapterData]] = None, + t2i_adapter_data: Optional[list[T2IAdapterData]] = None, + mask: Optional[torch.Tensor] = None, + masked_latents: Optional[torch.Tensor] = None, + is_gradient_mask: bool = False, + ) -> torch.Tensor: + # TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle + # cases where densoisings_start and denoising_end are set such that there are no timesteps. + if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0: + return latents + + orig_latents = latents.clone() + + batch_size = latents.shape[0] + batched_init_timestep = init_timestep.expand(batch_size) + + # noise can be None if the latents have already been noised (e.g. when running the SDXL refiner). + if noise is not None: + # TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with + # full noise. Investigate the history of why this got commented out. + # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers + latents = self.scheduler.add_noise(latents, noise, batched_init_timestep) + + self._adjust_memory_efficient_attention(latents) + + # Handle mask guidance (a.k.a. inpainting). + mask_guidance: AddsMaskGuidance | None = None + if mask is not None and not is_inpainting_model(self.unet): + # We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will + # apply mask guidance to the latents. + + # 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner). + # We still need noise for inpainting, so we generate it from the seed here. + if noise is None: + noise = torch.randn( + orig_latents.shape, + dtype=torch.float32, + device="cpu", + generator=torch.Generator(device="cpu").manual_seed(seed), + ).to(device=orig_latents.device, dtype=orig_latents.dtype) + + mask_guidance = AddsMaskGuidance( + mask=mask, + mask_latents=orig_latents, + scheduler=self.scheduler, + noise=noise, + is_gradient_mask=is_gradient_mask, + ) + + use_ip_adapter = ip_adapter_data is not None + use_regional_prompting = ( + conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None + ) + unet_attention_patcher = None + attn_ctx = nullcontext() + + if use_ip_adapter or use_regional_prompting: + ip_adapters: Optional[List[UNetIPAdapterData]] = ( + [{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data] + if use_ip_adapter + else None + ) + unet_attention_patcher = UNetAttentionPatcher(ip_adapters) + attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) + + with attn_ctx: + callback( + PipelineIntermediateState( + step=-1, + order=self.scheduler.order, + total_steps=len(timesteps), + timestep=self.scheduler.config.num_train_timesteps, + latents=latents, + ) + ) + + for i, t in enumerate(self.progress_bar(timesteps)): + batched_t = t.expand(batch_size) + step_output = self.step( + t=batched_t, + latents=latents, + conditioning_data=conditioning_data, + step_index=i, + total_step_count=len(timesteps), + scheduler_step_kwargs=scheduler_step_kwargs, + mask_guidance=mask_guidance, + mask=mask, + masked_latents=masked_latents, + control_data=control_data, + ip_adapter_data=ip_adapter_data, + t2i_adapter_data=t2i_adapter_data, + ) + latents = step_output.prev_sample + predicted_original = getattr(step_output, "pred_original_sample", None) + + callback( + PipelineIntermediateState( + step=i, + order=self.scheduler.order, + total_steps=len(timesteps), + timestep=int(t), + latents=latents, + predicted_original=predicted_original, + ) + ) + + # restore unmasked part after the last step is completed + # in-process masking happens before each step + if mask is not None: + if is_gradient_mask: + latents = torch.where(mask > 0, latents, orig_latents) + else: + latents = torch.lerp( + orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype) + ) + + return latents + + @torch.inference_mode() + def step( + self, + t: torch.Tensor, + latents: torch.Tensor, + conditioning_data: TextConditioningData, + step_index: int, + total_step_count: int, + scheduler_step_kwargs: dict[str, Any], + mask_guidance: AddsMaskGuidance | None, + mask: torch.Tensor | None, + masked_latents: torch.Tensor | None, + control_data: list[ControlNetData] | None = None, + ip_adapter_data: Optional[list[IPAdapterData]] = None, + t2i_adapter_data: Optional[list[T2IAdapterData]] = None, + ): + # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value + timestep = t[0] + + # Handle masked image-to-image (a.k.a inpainting). + if mask_guidance is not None: + # NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...). + latents = mask_guidance(latents, timestep) + + # TODO: should this scaling happen here or inside self._unet_forward? + # i.e. before or after passing it to InvokeAIDiffuserComponent + latent_model_input = self.scheduler.scale_model_input(latents, timestep) + + # Handle ControlNet(s) + down_block_additional_residuals = None + mid_block_additional_residual = None + if control_data is not None: + down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step( + control_data=control_data, + sample=latent_model_input, + timestep=timestep, + step_index=step_index, + total_step_count=total_step_count, + conditioning_data=conditioning_data, + ) + + # Handle T2I-Adapter(s) + down_intrablock_additional_residuals = None + if t2i_adapter_data is not None: + accum_adapter_state = None + for single_t2i_adapter_data in t2i_adapter_data: + # Determine the T2I-Adapter weights for the current denoising step. + first_t2i_adapter_step = math.floor(single_t2i_adapter_data.begin_step_percent * total_step_count) + last_t2i_adapter_step = math.ceil(single_t2i_adapter_data.end_step_percent * total_step_count) + t2i_adapter_weight = ( + single_t2i_adapter_data.weight[step_index] + if isinstance(single_t2i_adapter_data.weight, list) + else single_t2i_adapter_data.weight + ) + if step_index < first_t2i_adapter_step or step_index > last_t2i_adapter_step: + # If the current step is outside of the T2I-Adapter's begin/end step range, then set its weight to 0 + # so it has no effect. + t2i_adapter_weight = 0.0 + + # Apply the t2i_adapter_weight, and accumulate. + if accum_adapter_state is None: + # Handle the first T2I-Adapter. + accum_adapter_state = [val * t2i_adapter_weight for val in single_t2i_adapter_data.adapter_state] + else: + # Add to the previous adapter states. + for idx, value in enumerate(single_t2i_adapter_data.adapter_state): + accum_adapter_state[idx] += value * t2i_adapter_weight + + down_intrablock_additional_residuals = accum_adapter_state + + # Handle inpainting models. + if is_inpainting_model(self.unet): + # NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after* + # self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image + # latents. + if mask is not None: + if masked_latents is None: + raise ValueError("Source image required for inpaint mask when inpaint model used!") + latent_model_input = self.add_inpainting_channels_to_latents( + latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask + ) + else: + # We are using an inpainting model, but no mask was provided, so we are not really "inpainting". + # We generate a global mask and empty original image so that we can still generate in this + # configuration. + # TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting + # to do this. + # TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake' + # mask and original image once rather than on every denoising step. + latent_model_input = self.add_inpainting_channels_to_latents( + latents=latent_model_input, + masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]), + inpainting_mask=torch.ones_like(latent_model_input[:1, :1]), + ) + + uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step( + sample=latent_model_input, + timestep=t, # TODO: debug how handled batched and non batched timesteps + step_index=step_index, + total_step_count=total_step_count, + conditioning_data=conditioning_data, + ip_adapter_data=ip_adapter_data, + down_block_additional_residuals=down_block_additional_residuals, # for ControlNet + mid_block_additional_residual=mid_block_additional_residual, # for ControlNet + down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter + ) + + guidance_scale = conditioning_data.guidance_scale + if isinstance(guidance_scale, list): + guidance_scale = guidance_scale[step_index] + + noise_pred = self.invokeai_diffuser._combine(uc_noise_pred, c_noise_pred, guidance_scale) + guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier + if guidance_rescale_multiplier > 0: + noise_pred = self._rescale_cfg( + noise_pred, + c_noise_pred, + guidance_rescale_multiplier, + ) + + # compute the previous noisy sample x_t -> x_t-1 + step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs) + + # TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting + # again. + if mask_guidance is not None: + # Apply the mask to any "denoised" or "pred_original_sample" fields. + if hasattr(step_output, "denoised"): + step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1]) + elif hasattr(step_output, "pred_original_sample"): + step_output.pred_original_sample = mask_guidance( + step_output.pred_original_sample, self.scheduler.timesteps[-1] + ) + else: + step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1]) + + return step_output