Fix handling of stateful schedulers in MultiDiffusionPipeline.

This commit is contained in:
Ryan Dick
2024-06-18 11:33:29 -04:00
committed by Kent Keirsey
parent c881882f73
commit c22526b9d0
2 changed files with 20 additions and 19 deletions

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import copy
from dataclasses import dataclass
from typing import Any, Callable, Optional
import torch
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
@ -89,6 +91,13 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
region = region_conditioning.region
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
# separate scheduler state for each region batch.
region_batch_schedulers: list[SchedulerMixin] = [
copy.copy(self.scheduler) for _ in multi_diffusion_conditioning
]
callback(
PipelineIntermediateState(
step=-1,
@ -104,7 +113,10 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
merged_latents = torch.zeros_like(latents)
merged_pred_original: torch.Tensor | None = None
for region_conditioning in multi_diffusion_conditioning:
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
# Switch to the scheduler for the region batch.
self.scheduler = region_batch_schedulers[region_idx]
# Run a denoising step on the region.
step_output = self._region_step(
region_conditioning=region_conditioning,
@ -129,10 +141,12 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
)
# Normalize the merged results.
latents = merged_latents / region_weight_mask
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
predicted_original = None
if merged_pred_original is not None:
predicted_original = merged_pred_original / region_weight_mask
predicted_original = torch.where(
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
)
callback(
PipelineIntermediateState(