From 36473fc52a858eed206b0d327efc18c9fdac8cc2 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 17 Jun 2024 14:40:43 -0400 Subject: [PATCH] Remove regional conditioning logic from MultiDiffusionPipeline - it is not yet supported. --- .../multi_diffusion_pipeline.py | 111 ++++++++---------- 1 file changed, 52 insertions(+), 59 deletions(-) diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py index bcc1f4f8c3..435140523f 100644 --- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py +++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py @@ -1,7 +1,6 @@ from __future__ import annotations import copy -from contextlib import nullcontext from typing import Any, Callable, Optional import torch @@ -12,7 +11,6 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( StableDiffusionGeneratorPipeline, ) from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData -from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.tiles.utils import Tile @@ -79,12 +77,8 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): use_regional_prompting = ( conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None ) - unet_attention_patcher = None - attn_ctx = nullcontext() - if use_regional_prompting: - unet_attention_patcher = UNetAttentionPatcher(ip_adapter_data=None) - attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) + raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.") # Populate a weighted mask that will be used to combine the results from each region after every step. # For now, we assume that each regions has the same weight (1.0). @@ -96,66 +90,65 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right ] += 1.0 - with attn_ctx: + callback( + PipelineIntermediateState( + step=-1, + order=self.scheduler.order, + total_steps=len(timesteps), + timestep=self.scheduler.config.num_train_timesteps, + latents=latents, + ) + ) + + for i, t in enumerate(self.progress_bar(timesteps)): + batched_t = t.expand(batch_size) + + prev_samples_by_region: list[torch.Tensor] = [] + pred_original_by_region: list[torch.Tensor | None] = [] + for region in regions: + # Run a denoising step on the region. + step_output = self._region_step( + region=region, + t=batched_t, + latents=latents, + conditioning_data=conditioning_data, + step_index=i, + total_step_count=len(timesteps), + scheduler_step_kwargs=scheduler_step_kwargs, + control_data=control_data, + ) + prev_samples_by_region.append(step_output.prev_sample) + pred_original_by_region.append(getattr(step_output, "pred_original_sample", None)) + + # Merge the prev_sample results from each region. + merged_latents = torch.zeros_like(latents) + for region_idx, region in enumerate(regions): + merged_latents[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += prev_samples_by_region[region_idx] + latents = merged_latents / region_weight_mask + + # Merge the predicted_original results from each region. + predicted_original = None + if all(pred_original_by_region): + merged_pred_original = torch.zeros_like(latents) + for region_idx, region in enumerate(regions): + merged_pred_original[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += pred_original_by_region[region_idx] + predicted_original = merged_pred_original / region_weight_mask + callback( PipelineIntermediateState( - step=-1, + step=i, order=self.scheduler.order, total_steps=len(timesteps), - timestep=self.scheduler.config.num_train_timesteps, + timestep=int(t), latents=latents, + predicted_original=predicted_original, ) ) - for i, t in enumerate(self.progress_bar(timesteps)): - batched_t = t.expand(batch_size) - - prev_samples_by_region: list[torch.Tensor] = [] - pred_original_by_region: list[torch.Tensor | None] = [] - for region in regions: - # Run a denoising step on the region. - step_output = self._region_step( - region=region, - t=batched_t, - latents=latents, - conditioning_data=conditioning_data, - step_index=i, - total_step_count=len(timesteps), - scheduler_step_kwargs=scheduler_step_kwargs, - control_data=control_data, - ) - prev_samples_by_region.append(step_output.prev_sample) - pred_original_by_region.append(getattr(step_output, "pred_original_sample", None)) - - # Merge the prev_sample results from each region. - merged_latents = torch.zeros_like(latents) - for region_idx, region in enumerate(regions): - merged_latents[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right - ] += prev_samples_by_region[region_idx] - latents = merged_latents / region_weight_mask - - # Merge the predicted_original results from each region. - predicted_original = None - if all(pred_original_by_region): - merged_pred_original = torch.zeros_like(latents) - for region_idx, region in enumerate(regions): - merged_pred_original[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right - ] += pred_original_by_region[region_idx] - predicted_original = merged_pred_original / region_weight_mask - - callback( - PipelineIntermediateState( - step=i, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=int(t), - latents=latents, - predicted_original=predicted_original, - ) - ) - return latents @torch.inference_mode()