mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix handling of stateful schedulers in MultiDiffusionPipeline.
This commit is contained in:
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
@ -89,6 +91,13 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
region = region_conditioning.region
|
||||
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
|
||||
|
||||
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
|
||||
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
|
||||
# separate scheduler state for each region batch.
|
||||
region_batch_schedulers: list[SchedulerMixin] = [
|
||||
copy.copy(self.scheduler) for _ in multi_diffusion_conditioning
|
||||
]
|
||||
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
@ -104,7 +113,10 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
|
||||
merged_latents = torch.zeros_like(latents)
|
||||
merged_pred_original: torch.Tensor | None = None
|
||||
for region_conditioning in multi_diffusion_conditioning:
|
||||
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
|
||||
# Switch to the scheduler for the region batch.
|
||||
self.scheduler = region_batch_schedulers[region_idx]
|
||||
|
||||
# Run a denoising step on the region.
|
||||
step_output = self._region_step(
|
||||
region_conditioning=region_conditioning,
|
||||
@ -129,10 +141,12 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
)
|
||||
|
||||
# Normalize the merged results.
|
||||
latents = merged_latents / region_weight_mask
|
||||
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
|
||||
predicted_original = None
|
||||
if merged_pred_original is not None:
|
||||
predicted_original = merged_pred_original / region_weight_mask
|
||||
predicted_original = torch.where(
|
||||
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
|
||||
)
|
||||
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
|
Reference in New Issue
Block a user