mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix handling of stateful schedulers in MultiDiffusionPipeline.
This commit is contained in:
parent
c881882f73
commit
c22526b9d0
@ -20,7 +20,6 @@ from invokeai.app.invocations.fields import (
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.noise import get_noise
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
@ -166,21 +165,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
# If noise is None, populate it here.
|
||||
# TODO(ryand): Currently there is logic to generate noise deeper in the stack if it is None. We should just move
|
||||
# that logic up the stack in all places that it's relied upon (i.e. do it in prepare_noise_and_latents). In this
|
||||
# particular case, we want to make sure that the noise is generated globally rather than per-tile so that
|
||||
# overlapping tile regions use the same noise.
|
||||
if noise is None:
|
||||
noise = get_noise(
|
||||
width=latent_width * LATENT_SCALE_FACTOR,
|
||||
height=latent_height * LATENT_SCALE_FACTOR,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
seed=seed,
|
||||
downsampling_factor=LATENT_SCALE_FACTOR,
|
||||
use_cpu=True,
|
||||
)
|
||||
|
||||
# Calculate the tile locations to cover the latent-space image.
|
||||
# TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
|
||||
tiles = calc_tiles_min_overlap(
|
||||
@ -204,6 +188,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
|
||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
|
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
@ -89,6 +91,13 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
region = region_conditioning.region
|
||||
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
|
||||
|
||||
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
|
||||
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
|
||||
# separate scheduler state for each region batch.
|
||||
region_batch_schedulers: list[SchedulerMixin] = [
|
||||
copy.copy(self.scheduler) for _ in multi_diffusion_conditioning
|
||||
]
|
||||
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
@ -104,7 +113,10 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
|
||||
merged_latents = torch.zeros_like(latents)
|
||||
merged_pred_original: torch.Tensor | None = None
|
||||
for region_conditioning in multi_diffusion_conditioning:
|
||||
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
|
||||
# Switch to the scheduler for the region batch.
|
||||
self.scheduler = region_batch_schedulers[region_idx]
|
||||
|
||||
# Run a denoising step on the region.
|
||||
step_output = self._region_step(
|
||||
region_conditioning=region_conditioning,
|
||||
@ -129,10 +141,12 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
)
|
||||
|
||||
# Normalize the merged results.
|
||||
latents = merged_latents / region_weight_mask
|
||||
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
|
||||
predicted_original = None
|
||||
if merged_pred_original is not None:
|
||||
predicted_original = merged_pred_original / region_weight_mask
|
||||
predicted_original = torch.where(
|
||||
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
|
||||
)
|
||||
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
|
Loading…
Reference in New Issue
Block a user