Remove regional conditioning logic from MultiDiffusionPipeline - it is not yet supported.

This commit is contained in:
Ryan Dick 2024-06-17 14:40:43 -04:00
parent 49ca42f84a
commit 865c2335de

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import copy
from contextlib import nullcontext
from typing import Any, Callable, Optional
import torch
@ -12,7 +11,6 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import (
StableDiffusionGeneratorPipeline,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.tiles.utils import Tile
@ -79,12 +77,8 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
unet_attention_patcher = None
attn_ctx = nullcontext()
if use_regional_prompting:
unet_attention_patcher = UNetAttentionPatcher(ip_adapter_data=None)
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
# Populate a weighted mask that will be used to combine the results from each region after every step.
# For now, we assume that each regions has the same weight (1.0).
@ -96,66 +90,65 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
] += 1.0
with attn_ctx:
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
prev_samples_by_region: list[torch.Tensor] = []
pred_original_by_region: list[torch.Tensor | None] = []
for region in regions:
# Run a denoising step on the region.
step_output = self._region_step(
region=region,
t=batched_t,
latents=latents,
conditioning_data=conditioning_data,
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
control_data=control_data,
)
prev_samples_by_region.append(step_output.prev_sample)
pred_original_by_region.append(getattr(step_output, "pred_original_sample", None))
# Merge the prev_sample results from each region.
merged_latents = torch.zeros_like(latents)
for region_idx, region in enumerate(regions):
merged_latents[
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
] += prev_samples_by_region[region_idx]
latents = merged_latents / region_weight_mask
# Merge the predicted_original results from each region.
predicted_original = None
if all(pred_original_by_region):
merged_pred_original = torch.zeros_like(latents)
for region_idx, region in enumerate(regions):
merged_pred_original[
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
] += pred_original_by_region[region_idx]
predicted_original = merged_pred_original / region_weight_mask
callback(
PipelineIntermediateState(
step=-1,
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
prev_samples_by_region: list[torch.Tensor] = []
pred_original_by_region: list[torch.Tensor | None] = []
for region in regions:
# Run a denoising step on the region.
step_output = self._region_step(
region=region,
t=batched_t,
latents=latents,
conditioning_data=conditioning_data,
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
control_data=control_data,
)
prev_samples_by_region.append(step_output.prev_sample)
pred_original_by_region.append(getattr(step_output, "pred_original_sample", None))
# Merge the prev_sample results from each region.
merged_latents = torch.zeros_like(latents)
for region_idx, region in enumerate(regions):
merged_latents[
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
] += prev_samples_by_region[region_idx]
latents = merged_latents / region_weight_mask
# Merge the predicted_original results from each region.
predicted_original = None
if all(pred_original_by_region):
merged_pred_original = torch.zeros_like(latents)
for region_idx, region in enumerate(regions):
merged_pred_original[
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
] += pred_original_by_region[region_idx]
predicted_original = merged_pred_original / region_weight_mask
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
)
return latents
@torch.inference_mode()