from __future__ import annotations import copy from dataclasses import dataclass from typing import Any, Callable, Optional import torch from diffusers.schedulers.scheduling_utils import SchedulerMixin from invokeai.backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, PipelineIntermediateState, StableDiffusionGeneratorPipeline, ) from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData from invokeai.backend.tiles.utils import Tile @dataclass class MultiDiffusionRegionConditioning: # Region coords in latent space. region: Tile text_conditioning_data: TextConditioningData control_data: list[ControlNetData] class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): """A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising.""" def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]): """Validate that regional conditioning is not used.""" for region_conditioning in multi_diffusion_conditioning: if ( region_conditioning.text_conditioning_data.cond_regions is not None or region_conditioning.text_conditioning_data.uncond_regions is not None ): raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.") def multi_diffusion_denoise( self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning], target_overlap: int, latents: torch.Tensor, scheduler_step_kwargs: dict[str, Any], noise: Optional[torch.Tensor], timesteps: torch.Tensor, init_timestep: torch.Tensor, callback: Callable[[PipelineIntermediateState], None], ) -> torch.Tensor: self._check_regional_prompting(multi_diffusion_conditioning) if init_timestep.shape[0] == 0: return latents batch_size, _, latent_height, latent_width = latents.shape batched_init_timestep = init_timestep.expand(batch_size) # noise can be None if the latents have already been noised (e.g. when running the SDXL refiner). if noise is not None: # TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with # full noise. Investigate the history of why this got commented out. # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers latents = self.scheduler.add_noise(latents, noise, batched_init_timestep) assert isinstance(latents, torch.Tensor) # For static type checking. # TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after # cropping into regions. self._adjust_memory_efficient_attention(latents) # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a # separate scheduler state for each region batch. # TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler # statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect # as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when # multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each # scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion. region_batch_schedulers: list[SchedulerMixin] = [ copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning ] callback( PipelineIntermediateState( step=-1, order=self.scheduler.order, total_steps=len(timesteps), timestep=self.scheduler.config.num_train_timesteps, latents=latents, ) ) for i, t in enumerate(self.progress_bar(timesteps)): batched_t = t.expand(batch_size) merged_latents = torch.zeros_like(latents) merged_latents_weights = torch.zeros( (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype ) merged_pred_original: torch.Tensor | None = None for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning): # Switch to the scheduler for the region batch. self.scheduler = region_batch_schedulers[region_idx] # Crop the inputs to the region. region_latents = latents[ :, :, region_conditioning.region.coords.top : region_conditioning.region.coords.bottom, region_conditioning.region.coords.left : region_conditioning.region.coords.right, ] # Run the denoising step on the region. step_output = self.step( t=batched_t, latents=region_latents, conditioning_data=region_conditioning.text_conditioning_data, step_index=i, total_step_count=len(timesteps), scheduler_step_kwargs=scheduler_step_kwargs, mask_guidance=None, mask=None, masked_latents=None, control_data=region_conditioning.control_data, ) # Build a region_weight matrix that applies gradient blending to the edges of the region. region = region_conditioning.region _, _, region_height, region_width = step_output.prev_sample.shape region_weight = torch.ones( (1, 1, region_height, region_width), dtype=latents.dtype, device=latents.device, ) if region.overlap.left > 0: left_grad = torch.linspace( 0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype ).view((1, 1, 1, -1)) region_weight[:, :, :, : region.overlap.left] *= left_grad if region.overlap.top > 0: top_grad = torch.linspace( 0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype ).view((1, 1, -1, 1)) region_weight[:, :, : region.overlap.top, :] *= top_grad if region.overlap.right > 0: right_grad = torch.linspace( 1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype ).view((1, 1, 1, -1)) region_weight[:, :, :, -region.overlap.right :] *= right_grad if region.overlap.bottom > 0: bottom_grad = torch.linspace( 1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype ).view((1, 1, -1, 1)) region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad # Update the merged results with the region results. merged_latents[ :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right ] += step_output.prev_sample * region_weight merged_latents_weights[ :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right ] += region_weight pred_orig_sample = getattr(step_output, "pred_original_sample", None) if pred_orig_sample is not None: # If one region has pred_original_sample, then we can assume that all regions will have it, because # they all use the same scheduler. if merged_pred_original is None: merged_pred_original = torch.zeros_like(latents) merged_pred_original[ :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right ] += pred_orig_sample # Normalize the merged results. latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents) # For debugging, uncomment this line to visualize the region seams: # latents = torch.where(merged_latents_weights > 1, 0.0, latents) predicted_original = None if merged_pred_original is not None: predicted_original = torch.where( merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original ) callback( PipelineIntermediateState( step=i, order=self.scheduler.order, total_steps=len(timesteps), timestep=int(t), latents=latents, predicted_original=predicted_original, ) ) return latents