mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix handling of stateful schedulers in MultiDiffusionPipeline.
This commit is contained in:
parent
c881882f73
commit
c22526b9d0
@ -20,7 +20,6 @@ from invokeai.app.invocations.fields import (
|
|||||||
UIType,
|
UIType,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.model import UNetField
|
from invokeai.app.invocations.model import UNetField
|
||||||
from invokeai.app.invocations.noise import get_noise
|
|
||||||
from invokeai.app.invocations.primitives import LatentsOutput
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
@ -166,21 +165,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
_, _, latent_height, latent_width = latents.shape
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
|
||||||
# If noise is None, populate it here.
|
|
||||||
# TODO(ryand): Currently there is logic to generate noise deeper in the stack if it is None. We should just move
|
|
||||||
# that logic up the stack in all places that it's relied upon (i.e. do it in prepare_noise_and_latents). In this
|
|
||||||
# particular case, we want to make sure that the noise is generated globally rather than per-tile so that
|
|
||||||
# overlapping tile regions use the same noise.
|
|
||||||
if noise is None:
|
|
||||||
noise = get_noise(
|
|
||||||
width=latent_width * LATENT_SCALE_FACTOR,
|
|
||||||
height=latent_height * LATENT_SCALE_FACTOR,
|
|
||||||
device=TorchDevice.choose_torch_device(),
|
|
||||||
seed=seed,
|
|
||||||
downsampling_factor=LATENT_SCALE_FACTOR,
|
|
||||||
use_cpu=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate the tile locations to cover the latent-space image.
|
# Calculate the tile locations to cover the latent-space image.
|
||||||
# TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
|
# TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
|
||||||
tiles = calc_tiles_min_overlap(
|
tiles = calc_tiles_min_overlap(
|
||||||
@ -204,6 +188,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
|
|
||||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
if noise is not None:
|
||||||
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
scheduler_info=self.unet.scheduler,
|
scheduler_info=self.unet.scheduler,
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ControlNetData,
|
ControlNetData,
|
||||||
@ -89,6 +91,13 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
region = region_conditioning.region
|
region = region_conditioning.region
|
||||||
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
|
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
|
||||||
|
|
||||||
|
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
|
||||||
|
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
|
||||||
|
# separate scheduler state for each region batch.
|
||||||
|
region_batch_schedulers: list[SchedulerMixin] = [
|
||||||
|
copy.copy(self.scheduler) for _ in multi_diffusion_conditioning
|
||||||
|
]
|
||||||
|
|
||||||
callback(
|
callback(
|
||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
step=-1,
|
step=-1,
|
||||||
@ -104,7 +113,10 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
|
|
||||||
merged_latents = torch.zeros_like(latents)
|
merged_latents = torch.zeros_like(latents)
|
||||||
merged_pred_original: torch.Tensor | None = None
|
merged_pred_original: torch.Tensor | None = None
|
||||||
for region_conditioning in multi_diffusion_conditioning:
|
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
|
||||||
|
# Switch to the scheduler for the region batch.
|
||||||
|
self.scheduler = region_batch_schedulers[region_idx]
|
||||||
|
|
||||||
# Run a denoising step on the region.
|
# Run a denoising step on the region.
|
||||||
step_output = self._region_step(
|
step_output = self._region_step(
|
||||||
region_conditioning=region_conditioning,
|
region_conditioning=region_conditioning,
|
||||||
@ -129,10 +141,12 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Normalize the merged results.
|
# Normalize the merged results.
|
||||||
latents = merged_latents / region_weight_mask
|
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
|
||||||
predicted_original = None
|
predicted_original = None
|
||||||
if merged_pred_original is not None:
|
if merged_pred_original is not None:
|
||||||
predicted_original = merged_pred_original / region_weight_mask
|
predicted_original = torch.where(
|
||||||
|
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
|
||||||
|
)
|
||||||
|
|
||||||
callback(
|
callback(
|
||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
|
Loading…
Reference in New Issue
Block a user