From c22526b9d0af63596dcbd37683dcc9b05a679f0e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 18 Jun 2024 11:33:29 -0400 Subject: [PATCH] Fix handling of stateful schedulers in MultiDiffusionPipeline. --- .../tiled_multi_diffusion_denoise_latents.py | 19 +++--------------- .../multi_diffusion_pipeline.py | 20 ++++++++++++++++--- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 9d998f345f..213fc4567e 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -20,7 +20,6 @@ from invokeai.app.invocations.fields import ( UIType, ) from invokeai.app.invocations.model import UNetField -from invokeai.app.invocations.noise import get_noise from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.lora import LoRAModelRaw @@ -166,21 +165,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents) _, _, latent_height, latent_width = latents.shape - # If noise is None, populate it here. - # TODO(ryand): Currently there is logic to generate noise deeper in the stack if it is None. We should just move - # that logic up the stack in all places that it's relied upon (i.e. do it in prepare_noise_and_latents). In this - # particular case, we want to make sure that the noise is generated globally rather than per-tile so that - # overlapping tile regions use the same noise. - if noise is None: - noise = get_noise( - width=latent_width * LATENT_SCALE_FACTOR, - height=latent_height * LATENT_SCALE_FACTOR, - device=TorchDevice.choose_torch_device(), - seed=seed, - downsampling_factor=LATENT_SCALE_FACTOR, - use_cpu=True, - ) - # Calculate the tile locations to cover the latent-space image. # TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint? tiles = calc_tiles_min_overlap( @@ -204,6 +188,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()): assert isinstance(unet, UNet2DConditionModel) + latents = latents.to(device=unet.device, dtype=unet.dtype) + if noise is not None: + noise = noise.to(device=unet.device, dtype=unet.dtype) scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py index 2049a19733..68d4097004 100644 --- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py +++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py @@ -1,9 +1,11 @@ from __future__ import annotations +import copy from dataclasses import dataclass from typing import Any, Callable, Optional import torch +from diffusers.schedulers.scheduling_utils import SchedulerMixin from invokeai.backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, @@ -89,6 +91,13 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): region = region_conditioning.region region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0 + # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since + # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a + # separate scheduler state for each region batch. + region_batch_schedulers: list[SchedulerMixin] = [ + copy.copy(self.scheduler) for _ in multi_diffusion_conditioning + ] + callback( PipelineIntermediateState( step=-1, @@ -104,7 +113,10 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): merged_latents = torch.zeros_like(latents) merged_pred_original: torch.Tensor | None = None - for region_conditioning in multi_diffusion_conditioning: + for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning): + # Switch to the scheduler for the region batch. + self.scheduler = region_batch_schedulers[region_idx] + # Run a denoising step on the region. step_output = self._region_step( region_conditioning=region_conditioning, @@ -129,10 +141,12 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): ) # Normalize the merged results. - latents = merged_latents / region_weight_mask + latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents) predicted_original = None if merged_pred_original is not None: - predicted_original = merged_pred_original / region_weight_mask + predicted_original = torch.where( + region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original + ) callback( PipelineIntermediateState(