Fix handling of stateful schedulers in MultiDiffusionPipeline.

This commit is contained in:
Ryan Dick 2024-06-18 11:33:29 -04:00
parent 35adaf1c17
commit 72be7e71e3
2 changed files with 20 additions and 19 deletions

View File

@ -20,7 +20,6 @@ from invokeai.app.invocations.fields import (
UIType,
)
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
@ -166,21 +165,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# If noise is None, populate it here.
# TODO(ryand): Currently there is logic to generate noise deeper in the stack if it is None. We should just move
# that logic up the stack in all places that it's relied upon (i.e. do it in prepare_noise_and_latents). In this
# particular case, we want to make sure that the noise is generated globally rather than per-tile so that
# overlapping tile regions use the same noise.
if noise is None:
noise = get_noise(
width=latent_width * LATENT_SCALE_FACTOR,
height=latent_height * LATENT_SCALE_FACTOR,
device=TorchDevice.choose_torch_device(),
seed=seed,
downsampling_factor=LATENT_SCALE_FACTOR,
use_cpu=True,
)
# Calculate the tile locations to cover the latent-space image.
# TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
tiles = calc_tiles_min_overlap(
@ -204,6 +188,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import copy
from dataclasses import dataclass
from typing import Any, Callable, Optional
import torch
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
@ -89,6 +91,13 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
region = region_conditioning.region
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
# separate scheduler state for each region batch.
region_batch_schedulers: list[SchedulerMixin] = [
copy.copy(self.scheduler) for _ in multi_diffusion_conditioning
]
callback(
PipelineIntermediateState(
step=-1,
@ -104,7 +113,10 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
merged_latents = torch.zeros_like(latents)
merged_pred_original: torch.Tensor | None = None
for region_conditioning in multi_diffusion_conditioning:
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
# Switch to the scheduler for the region batch.
self.scheduler = region_batch_schedulers[region_idx]
# Run a denoising step on the region.
step_output = self._region_step(
region_conditioning=region_conditioning,
@ -129,10 +141,12 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
)
# Normalize the merged results.
latents = merged_latents / region_weight_mask
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
predicted_original = None
if merged_pred_original is not None:
predicted_original = merged_pred_original / region_weight_mask
predicted_original = torch.where(
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
)
callback(
PipelineIntermediateState(