From 6bcf48aa37c2010d602b0e7266cc171ed370ac6f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 18 Jun 2024 14:35:41 -0400 Subject: [PATCH] WIP - Started working towards MultiDiffusion batching. --- .../multi_diffusion_pipeline.py | 107 ++++++++++++++---- 1 file changed, 87 insertions(+), 20 deletions(-) diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py index 2f945cfeca..f74c118091 100644 --- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py +++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py @@ -15,6 +15,10 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData from invokeai.backend.tiles.utils import TBLR +# The maximum number of regions with compatible sizes that will be batched together. +# Larger batch sizes improve speed, but require more device memory. +MAX_REGION_BATCH_SIZE = 4 + @dataclass class MultiDiffusionRegionConditioning: @@ -27,6 +31,38 @@ class MultiDiffusionRegionConditioning: class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): """A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising.""" + def _split_into_region_batches( + self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] + ) -> list[list[MultiDiffusionRegionConditioning]]: + # Group the regions by shape. Only regions with the same shape can be batched together. + conditioning_by_shape: dict[tuple[int, int], list[MultiDiffusionRegionConditioning]] = {} + for region_conditioning in multi_diffusion_conditioning: + shape_hw = ( + region_conditioning.region.bottom - region_conditioning.region.top, + region_conditioning.region.right - region_conditioning.region.left, + ) + # In python, a tuple of hashable objects is hashable, so can be used as a key in a dict. + if shape_hw not in conditioning_by_shape: + conditioning_by_shape[shape_hw] = [] + conditioning_by_shape[shape_hw].append(region_conditioning) + + # Split the regions into batches, respecting the MAX_REGION_BATCH_SIZE constraint. + region_conditioning_batches = [] + for region_conditioning_batch in conditioning_by_shape.values(): + for i in range(0, len(region_conditioning_batch), MAX_REGION_BATCH_SIZE): + region_conditioning_batches.append(region_conditioning_batch[i : i + MAX_REGION_BATCH_SIZE]) + + return region_conditioning_batches + + def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]): + """Check the input conditioning and confirm that regional prompting is not used.""" + for region_conditioning in multi_diffusion_conditioning: + if ( + region_conditioning.text_conditioning_data.cond_regions is not None + or region_conditioning.text_conditioning_data.uncond_regions is not None + ): + raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.") + def multi_diffusion_denoise( self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning], @@ -37,6 +73,8 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): init_timestep: torch.Tensor, callback: Callable[[PipelineIntermediateState], None], ) -> torch.Tensor: + self._check_regional_prompting(multi_diffusion_conditioning) + # TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle # cases where densoisings_start and denoising_end are set such that there are no timesteps. if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0: @@ -57,7 +95,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): self._adjust_memory_efficient_attention(latents) # Populate a weighted mask that will be used to combine the results from each region after every step. - # For now, we assume that each regions has the same weight (1.0). + # For now, we assume that each region has the same weight (1.0). region_weight_mask = torch.zeros( (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype ) @@ -65,11 +103,15 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): region = region_conditioning.region region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0 + # Group the region conditioning into batches for faster processing. + # region_conditioning_batches[b][r] is the r'th region in the b'th batch. + region_conditioning_batches = self._split_into_region_batches(multi_diffusion_conditioning) + # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a # separate scheduler state for each region batch. region_batch_schedulers: list[SchedulerMixin] = [ - copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning + copy.deepcopy(self.scheduler) for _ in region_conditioning_batches ] callback( @@ -87,20 +129,52 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): merged_latents = torch.zeros_like(latents) merged_pred_original: torch.Tensor | None = None - for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning): + for region_batch_idx, region_conditioning_batch in enumerate(region_conditioning_batches): # Switch to the scheduler for the region batch. - self.scheduler = region_batch_schedulers[region_idx] + self.scheduler = region_batch_schedulers[region_batch_idx] - # Run a denoising step on the region. - step_output = self._region_step( - region_conditioning=region_conditioning, - t=batched_t, - latents=latents, - step_index=i, - total_step_count=len(timesteps), - scheduler_step_kwargs=scheduler_step_kwargs, + # TODO(ryand): This logic has not yet been tested with input latents with a batch_size > 1. + + # Prepare the latents for the region batch. + batch_latents = torch.cat( + [ + latents[ + :, + :, + region_conditioning.region.top : region_conditioning.region.bottom, + region_conditioning.region.left : region_conditioning.region.right, + ] + for region_conditioning in region_conditioning_batch + ], ) + # TODO(ryand): Do we have to repeat the text_conditioning_data to match the batch size? Or does step() + # handle broadcasting properly? + + # TODO(ryand): Resume here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # Run the denoising step on the region. + step_output = self.step( + t=batched_t, + latents=batch_latents, + conditioning_data=region_conditioning.text_conditioning_data, + step_index=i, + total_step_count=total_step_count, + scheduler_step_kwargs=scheduler_step_kwargs, + mask_guidance=None, + mask=None, + masked_latents=None, + control_data=region_conditioning.control_data, + ) + # Run a denoising step on the region. + # step_output = self._region_step( + # region_conditioning=region_conditioning, + # t=batched_t, + # latents=latents, + # step_index=i, + # total_step_count=len(timesteps), + # scheduler_step_kwargs=scheduler_step_kwargs, + # ) + # Store the results from the region. region = region_conditioning.region merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample @@ -136,7 +210,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): return latents @torch.inference_mode() - def _region_step( + def _region_batch_step( self, region_conditioning: MultiDiffusionRegionConditioning, t: torch.Tensor, @@ -145,13 +219,6 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): total_step_count: int, scheduler_step_kwargs: dict[str, Any], ): - use_regional_prompting = ( - region_conditioning.text_conditioning_data.cond_regions is not None - or region_conditioning.text_conditioning_data.uncond_regions is not None - ) - if use_regional_prompting: - raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.") - # Crop the inputs to the region. region_latents = latents[ :,