Fix handling of stateful schedulers in MultiDiffusionPipeline.

This commit is contained in:
Ryan Dick 2024-06-18 11:33:29 -04:00 committed by Kent Keirsey
parent c881882f73
commit c22526b9d0
2 changed files with 20 additions and 19 deletions

View File

@ -20,7 +20,6 @@ from invokeai.app.invocations.fields import (
UIType, UIType,
) )
from invokeai.app.invocations.model import UNetField from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
@ -166,21 +165,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents) seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape _, _, latent_height, latent_width = latents.shape
# If noise is None, populate it here.
# TODO(ryand): Currently there is logic to generate noise deeper in the stack if it is None. We should just move
# that logic up the stack in all places that it's relied upon (i.e. do it in prepare_noise_and_latents). In this
# particular case, we want to make sure that the noise is generated globally rather than per-tile so that
# overlapping tile regions use the same noise.
if noise is None:
noise = get_noise(
width=latent_width * LATENT_SCALE_FACTOR,
height=latent_height * LATENT_SCALE_FACTOR,
device=TorchDevice.choose_torch_device(),
seed=seed,
downsampling_factor=LATENT_SCALE_FACTOR,
use_cpu=True,
)
# Calculate the tile locations to cover the latent-space image. # Calculate the tile locations to cover the latent-space image.
# TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint? # TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
tiles = calc_tiles_min_overlap( tiles = calc_tiles_min_overlap(
@ -204,6 +188,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()): with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel) assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
scheduler_info=self.unet.scheduler, scheduler_info=self.unet.scheduler,

View File

@ -1,9 +1,11 @@
from __future__ import annotations from __future__ import annotations
import copy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import torch import torch
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from invokeai.backend.stable_diffusion.diffusers_pipeline import ( from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData, ControlNetData,
@ -89,6 +91,13 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
region = region_conditioning.region region = region_conditioning.region
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0 region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
# separate scheduler state for each region batch.
region_batch_schedulers: list[SchedulerMixin] = [
copy.copy(self.scheduler) for _ in multi_diffusion_conditioning
]
callback( callback(
PipelineIntermediateState( PipelineIntermediateState(
step=-1, step=-1,
@ -104,7 +113,10 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
merged_latents = torch.zeros_like(latents) merged_latents = torch.zeros_like(latents)
merged_pred_original: torch.Tensor | None = None merged_pred_original: torch.Tensor | None = None
for region_conditioning in multi_diffusion_conditioning: for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
# Switch to the scheduler for the region batch.
self.scheduler = region_batch_schedulers[region_idx]
# Run a denoising step on the region. # Run a denoising step on the region.
step_output = self._region_step( step_output = self._region_step(
region_conditioning=region_conditioning, region_conditioning=region_conditioning,
@ -129,10 +141,12 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
) )
# Normalize the merged results. # Normalize the merged results.
latents = merged_latents / region_weight_mask latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
predicted_original = None predicted_original = None
if merged_pred_original is not None: if merged_pred_original is not None:
predicted_original = merged_pred_original / region_weight_mask predicted_original = torch.where(
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
)
callback( callback(
PipelineIntermediateState( PipelineIntermediateState(