diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py index 2a7b82b694..bcc1f4f8c3 100644 --- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py +++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy from contextlib import nullcontext from typing import Any, Callable, Optional @@ -61,7 +62,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0: return latents - batch_size = latents.shape[0] + batch_size, _, latent_height, latent_width = latents.shape batched_init_timestep = init_timestep.expand(batch_size) # noise can be None if the latents have already been noised (e.g. when running the SDXL refiner). @@ -85,6 +86,16 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): unet_attention_patcher = UNetAttentionPatcher(ip_adapter_data=None) attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) + # Populate a weighted mask that will be used to combine the results from each region after every step. + # For now, we assume that each regions has the same weight (1.0). + region_weight_mask = torch.zeros( + (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype + ) + for region in regions: + region_weight_mask[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += 1.0 + with attn_ctx: callback( PipelineIntermediateState( @@ -98,20 +109,41 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): for i, t in enumerate(self.progress_bar(timesteps)): batched_t = t.expand(batch_size) - step_output = self.step( - t=batched_t, - latents=latents, - conditioning_data=conditioning_data, - step_index=i, - total_step_count=len(timesteps), - scheduler_step_kwargs=scheduler_step_kwargs, - mask_guidance=None, - mask=None, - masked_latents=None, - control_data=control_data, - ) - latents = step_output.prev_sample - predicted_original = getattr(step_output, "pred_original_sample", None) + + prev_samples_by_region: list[torch.Tensor] = [] + pred_original_by_region: list[torch.Tensor | None] = [] + for region in regions: + # Run a denoising step on the region. + step_output = self._region_step( + region=region, + t=batched_t, + latents=latents, + conditioning_data=conditioning_data, + step_index=i, + total_step_count=len(timesteps), + scheduler_step_kwargs=scheduler_step_kwargs, + control_data=control_data, + ) + prev_samples_by_region.append(step_output.prev_sample) + pred_original_by_region.append(getattr(step_output, "pred_original_sample", None)) + + # Merge the prev_sample results from each region. + merged_latents = torch.zeros_like(latents) + for region_idx, region in enumerate(regions): + merged_latents[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += prev_samples_by_region[region_idx] + latents = merged_latents / region_weight_mask + + # Merge the predicted_original results from each region. + predicted_original = None + if all(pred_original_by_region): + merged_pred_original = torch.zeros_like(latents) + for region_idx, region in enumerate(regions): + merged_pred_original[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += pred_original_by_region[region_idx] + predicted_original = merged_pred_original / region_weight_mask callback( PipelineIntermediateState( @@ -125,3 +157,48 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): ) return latents + + @torch.inference_mode() + def _region_step( + self, + region: Tile, + t: torch.Tensor, + latents: torch.Tensor, + conditioning_data: TextConditioningData, + step_index: int, + total_step_count: int, + scheduler_step_kwargs: dict[str, Any], + control_data: list[ControlNetData] | None = None, + ): + # Crop the inputs to the region. + region_latents = latents[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] + + region_control_data: list[ControlNetData] | None = None + if control_data is not None: + region_control_data = [self._crop_controlnet_data(c, region) for c in control_data] + + # Run the denoising step on the region. + return self.step( + t=t, + latents=region_latents, + conditioning_data=conditioning_data, + step_index=step_index, + total_step_count=total_step_count, + scheduler_step_kwargs=scheduler_step_kwargs, + mask_guidance=None, + mask=None, + masked_latents=None, + control_data=region_control_data, + ) + + def _crop_controlnet_data(self, control_data: ControlNetData, region: Tile) -> ControlNetData: + """Crop a ControlNetData object to a region.""" + # Create a shallow copy of the control_data object. + control_data_copy = copy.copy(control_data) + # The ControlNet reference image is the only attribute that needs to be cropped. + control_data_copy.image_tensor = control_data.image_tensor[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] + return control_data_copy