Compare commits

...

51 Commits

Author SHA1 Message Date
6bcf48aa37 WIP - Started working towards MultiDiffusion batching. 2024-06-18 15:44:39 -04:00
b1bb1511fe Delete rough notes. 2024-06-18 15:36:36 -04:00
99046a8145 Fix advanced scheduler behaviour in MultiDiffusionPipeline. 2024-06-18 15:36:36 -04:00
72be7e71e3 Fix handling of stateful schedulers in MultiDiffusionPipeline. 2024-06-18 15:36:36 -04:00
35adaf1c17 Connect TiledMultiDiffusionDenoiseLatents to the MultiDiffusionPipeline backend. 2024-06-18 15:36:34 -04:00
865c2335de Remove regional conditioning logic from MultiDiffusionPipeline - it is not yet supported. 2024-06-18 15:35:52 -04:00
49ca42f84a Initial (untested) implementation of MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
493fcd8660 Remove inpainting support from MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
20322d781e Remove IP-Adapter and T2I-Adapter support from MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
889d13e02a Document plan for the rest of the MultiDiffusion implementation. 2024-06-18 15:35:52 -04:00
6ccd2a867b Add detailed docstring to latents_from_embeddings(). 2024-06-18 15:35:52 -04:00
5861fa1719 Copy StableDiffusionGeneratorPipeline as a starting point for a new MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
dfd4beb62b Simplify handling of inpainting models. Improve the in-code documentation around inpainting. 2024-06-18 15:35:52 -04:00
83df0c0df5 Minor tidying of latents_from_embeddings(...). 2024-06-18 15:35:52 -04:00
c58c4069a7 Consolidate latents_from_embeddings(...) and generate_latents_from_embeddings(...) into a single function. 2024-06-18 15:35:52 -04:00
3937fffa94 Fix invocation name of tiled_multi_diffusion_denoise_latents. 2024-06-18 15:35:52 -04:00
bbf5f67691 Improve clarity of comments regarded when 'noise' and 'latents' are expected to be set. 2024-06-18 15:35:52 -04:00
2f5c147b84 Fix static check errors on imports in diffusers_pipeline.py. 2024-06-18 15:35:52 -04:00
bd2839b748 Remove a condition for handling inpainting models that never resolves to True. The same logic is already applied earlier by AddsMaskLatents. 2024-06-18 15:35:52 -04:00
4f70dd7ce1 Add clarifying comment to explain why noise might be None in latents_from_embedding(). 2024-06-18 15:35:52 -04:00
066672fbfd Remove unused are_like_tensors() function. 2024-06-18 15:35:52 -04:00
abefaee4d1 Remove unused StableDiffusionGeneratorPipeline.use_ip_adapter member. 2024-06-18 15:35:52 -04:00
3254ba5904 Remove unused StableDiffusionGeneratorPipeline.control_model. 2024-06-18 15:35:52 -04:00
73a8c55852 Stricter typing for the is_gradient_mask: bool. 2024-06-18 15:35:52 -04:00
f82af7c22d Fix typing of control_data to reflect that it can be None. 2024-06-18 15:35:52 -04:00
3aef717ef4 Fix typing of timesteps and init_timestep. 2024-06-18 15:35:52 -04:00
c2cf1137e9 Fix typing to reflect that the callback arg to latents_from_embeddings is never None. 2024-06-18 15:35:52 -04:00
803a24bc0a Move seed above optional params. 2024-06-18 15:35:52 -04:00
7d24ad8ccd Simplify handling of AddsMaskGuidance, and fix some related type errors. 2024-06-18 15:35:52 -04:00
cb389063b2 Remove unused num_inference_steps. 2024-06-18 15:35:52 -04:00
81b8a69e1a WIP TiledMultiDiffusionDenoiseLatents. Updated parameter list and first half of the logic. 2024-06-18 15:35:50 -04:00
7ee5db87ad Tidy DenoiseLatentsInvocation.prep_control_data(...) and fix some type errors. 2024-06-18 15:34:30 -04:00
66cf2c59bd Make DenoiseLatentsInvocation.prep_control_data(...) a staticmethod so that it can be called externally. 2024-06-18 15:34:30 -04:00
3bad1367e9 Copy TiledStableDiffusionRefineInvocation as a starting point for TiledMultiDiffusionDenoiseLatents.py 2024-06-18 15:34:22 -04:00
867a7642a6 Change tiling strategy to make TiledStableDiffusionRefineInvocation work with more tile shapes and overlaps. 2024-06-18 15:31:58 -04:00
d9d1c8f9cb Expose a few more params from TiledStableDiffusionRefineInvocation. 2024-06-18 15:31:58 -04:00
e03eb7fb45 Add support for LoRA models in TiledStableDiffusionRefineInvocation. 2024-06-18 15:31:58 -04:00
85db33bc7e Add naive ControlNet support to TiledStableDiffusionRefineInvocation 2024-06-18 15:31:58 -04:00
93e3a2b504 Fix ControlNetModel type hint import source. 2024-06-18 15:31:58 -04:00
6a7a26f1bf Rough prototype of TiledStableDiffusionRefineInvocation is working. 2024-06-18 15:31:58 -04:00
08ca03ef9f WIP - TiledStableDiffusionRefine 2024-06-18 15:31:54 -04:00
ccf90b6bd6 Minor improvements to LatentsToImageInvocation type hints. 2024-06-18 15:31:21 -04:00
753239b48d Expose vae_decode(...) as a staticmethod on LatentsToImageInvocation. 2024-06-18 15:31:21 -04:00
65fa4664c9 Fix return type of prepare_noise_and_latents(...). 2024-06-18 15:31:21 -04:00
297570ded3 Make init_scheduler() a staticmethod on DenoiseLatentsInvocation so that it can be called externally. 2024-06-18 15:31:21 -04:00
680fdcf293 Only allow a single positive/negative prompt conditioning input for tiled refine. 2024-06-18 15:31:21 -04:00
5ff91f2c44 WIP on TiledStableDiffusionRefine 2024-06-18 15:31:14 -04:00
69aa7057e7 Convert several methods in DenoiseLatentsInvocation to staticmethods so that they can be called externally. 2024-06-18 15:25:08 -04:00
d3932f40de Simplify the logic in prepare_noise_and_latents(...). 2024-06-18 15:25:08 -04:00
ee74cd7fab Split out the prepare_noise_and_latents(...) logic in DenoiseLatentsInvocation so that it can be called from other invocations. 2024-06-18 15:25:08 -04:00
bda25b40c9 (minor) Add a TODO note to get_scheduler(...). 2024-06-18 15:25:08 -04:00
7 changed files with 1202 additions and 256 deletions

View File

@ -55,6 +55,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
)
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings
@ -65,6 +66,9 @@ def get_scheduler(
scheduler_name: str,
seed: int,
) -> Scheduler:
"""Load a scheduler and apply some scheduler-specific overrides."""
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
@ -182,8 +186,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def _get_text_embeddings_and_masks(
self,
cond_list: list[ConditioningField],
context: InvocationContext,
device: torch.device,
@ -203,8 +207,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
return text_embeddings, text_embeddings_masks
@staticmethod
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
@ -228,8 +233,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
resized_mask = tf(mask)
return resized_mask
@staticmethod
def _concat_regional_text_embeddings(
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
@ -279,7 +284,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
)
processed_masks.append(
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
mask, latent_height, latent_width, dtype=dtype
)
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
@ -301,36 +308,41 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
return BasicConditioningInfo(embeds=text_embedding), regions
@staticmethod
def get_conditioning_data(
self,
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
) -> TextConditioningData:
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
cond_list = self.positive_conditioning
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
cond_list = positive_conditioning_field
if not isinstance(cond_list, list):
cond_list = [cond_list]
uncond_list = self.negative_conditioning
uncond_list = negative_conditioning_field
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
)
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
@ -338,23 +350,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype=unet.dtype,
)
if isinstance(self.cfg_scale, list):
assert (
len(self.cfg_scale) == self.steps
), "cfg_scale (list) must have the same length as the number of steps"
if isinstance(cfg_scale, list):
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
guidance_scale=cfg_scale,
guidance_rescale_multiplier=cfg_rescale_multiplier,
)
return conditioning_data
@staticmethod
def create_pipeline(
self,
unet: UNet2DConditionModel,
scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline:
@ -377,38 +387,38 @@ class DenoiseLatentsInvocation(BaseInvocation):
requires_safety_checker=False,
)
@staticmethod
def prep_control_data(
self,
context: InvocationContext,
control_input: Optional[Union[ControlField, List[ControlField]]],
control_input: ControlField | list[ControlField] | None,
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> Optional[List[ControlNetData]]:
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
if control_input is None:
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
control_list = None
elif isinstance(control_input, ControlField):
) -> list[ControlNetData] | None:
# Normalize control_input to a list.
control_list: list[ControlField]
if isinstance(control_input, ControlField):
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
elif isinstance(control_input, list):
control_list = control_input
elif control_input is None:
control_list = []
else:
control_list = None
if control_list is None:
return None
# After above handling, any control that is not None should now be of type list[ControlField].
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
controlnet_data = []
if len(control_list) == 0:
return None
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
_, _, latent_height, latent_width = latents_shape
control_height_resize = latent_height * LATENT_SCALE_FACTOR
control_width_resize = latent_width * LATENT_SCALE_FACTOR
controlnet_data: list[ControlNetData] = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
assert isinstance(control_model, ControlNetModel)
# control_models.append(control_model)
control_image_field = control_info.image
input_image = context.images.get_pil(control_image_field.image_name)
# self.image.image_type, self.image.image_name
@ -429,7 +439,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
@ -583,15 +593,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
@staticmethod
def init_scheduler(
self,
scheduler: Union[Scheduler, ConfigMixin],
device: torch.device,
steps: int,
denoising_start: float,
denoising_end: float,
seed: int,
) -> Tuple[int, List[int], int, Dict[str, Any]]:
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu")
@ -617,7 +627,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order
scheduler_step_kwargs: Dict[str, Any] = {}
scheduler_step_signature = inspect.signature(scheduler.step)
@ -639,7 +648,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if isinstance(scheduler, TCDScheduler):
scheduler_step_kwargs.update({"eta": 1.0})
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
return timesteps, init_timestep, scheduler_step_kwargs
def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor
@ -656,31 +665,52 @@ class DenoiseLatentsInvocation(BaseInvocation):
return 1 - mask, masked_latents, self.denoise_mask.gradient
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed = None
@staticmethod
def prepare_noise_and_latents(
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
) -> Tuple[int, torch.Tensor | None, torch.Tensor]:
"""Depending on the workflow, we expect different combinations of noise and latents to be provided. This
function handles preparing these values accordingly.
Expected workflows:
- Text-to-Image Denoising: `noise` is provided, `latents` is not. `latents` is initialized to zeros.
- Image-to-Image Denoising: `noise` and `latents` are both provided.
- Text-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
- Image-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
NOTE(ryand): I wrote this docstring, but I am not the original author of this code. There may be other workflows
I haven't considered.
"""
noise = None
if self.noise is not None:
noise = context.tensors.load(self.noise.latents_name)
seed = self.noise.seed
if self.latents is not None:
latents = context.tensors.load(self.latents.latents_name)
if seed is None:
seed = self.latents.seed
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
if noise_field is not None:
noise = context.tensors.load(noise_field.latents_name)
if latents_field is not None:
latents = context.tensors.load(latents_field.latents_name)
elif noise is not None:
latents = torch.zeros_like(noise)
else:
raise Exception("'latents' or 'noise' must be provided!")
raise ValueError("'latents' or 'noise' must be provided!")
if seed is None:
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise ValueError(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
seed = 0
if noise_field is not None and noise_field.seed is not None:
seed = noise_field.seed
elif latents_field is not None and latents_field.seed is not None:
seed = latents_field.seed
else:
seed = 0
return seed, noise, latents
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
@ -754,7 +784,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = self.prep_control_data(
@ -776,7 +814,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype=unet.dtype,
)
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
@ -793,8 +831,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed=seed,
mask=mask,
masked_latents=masked_latents,
gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps,
is_gradient_mask=gradient_mask,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=controlnet_data,

View File

@ -8,7 +8,7 @@ from diffusers.models.attention_processor import (
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
@ -23,6 +23,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.util.devices import TorchDevice
@ -48,16 +49,20 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
@staticmethod
def vae_decode(
context: InvocationContext,
vae_info: LoadedModel,
seamless_axes: list[str],
latents: torch.Tensor,
use_fp32: bool,
use_tiling: bool,
) -> Image.Image:
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:
if use_fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
@ -82,7 +87,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.to(dtype=torch.float16)
latents = latents.half()
if self.tiled or context.config.get().force_tiled_decode:
if use_tiling or context.config.get().force_tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
@ -102,6 +107,21 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
TorchDevice.empty_cache()
return image
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
image = self.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=latents,
use_fp32=self.fp32,
use_tiling=self.tiled,
)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@ -0,0 +1,268 @@
import copy
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
LatentsField,
UIType,
)
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
MultiDiffusionRegionConditioning,
)
from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap,
)
from invokeai.backend.tiles.utils import TBLR
from invokeai.backend.util.devices import TorchDevice
def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
"""Crop a ControlNetData object to a region."""
# Create a shallow copy of the control_data object.
control_data_copy = copy.copy(control_data)
# The ControlNet reference image is the only attribute that needs to be cropped.
control_data_copy.image_tensor = control_data.image_tensor[
:,
:,
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
]
return control_data_copy
@invocation(
"tiled_multi_diffusion_denoise_latents",
title="Tiled Multi-Diffusion Denoise Latents",
tags=["upscale", "denoise"],
category="latents",
# TODO(ryand): Reset to 1.0.0 right before release.
version="1.0.0",
)
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
"""Tiled Multi-Diffusion denoising.
This node handles automatically tiling the input image. Future iterations of
this node should allow the user to specify custom regions with different parameters for each region to harness the
full power of Multi-Diffusion.
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
T2I-Adapter, masking, etc.).
"""
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
noise: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
)
latents: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
# TODO(ryand): Add multiple-of validation.
# TODO(ryand): Smaller defaults might make more sense.
tile_height: int = InputField(default=112, gt=0, description="Height of the tiles in latent space.")
tile_width: int = InputField(default=112, gt=0, description="Width of the tiles in latent space.")
tile_min_overlap: int = InputField(
default=16,
gt=0,
description="The minimum overlap between adjacent tiles in latent space. The actual overlap may be larger than "
"this to evenly cover the entire image.",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
# TODO(ryand): The default here should probably be 0.0.
denoising_start: float = InputField(
default=0.65,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
control: ControlField | list[ControlField] | None = InputField(
default=None,
input=Input.Connection,
)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def create_pipeline(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
) -> MultiDiffusionPipeline:
# TODO(ryand): Get rid of this FakeVae hack.
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
return MultiDiffusionPipeline(
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# Calculate the tile locations to cover the latent-space image.
# TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
min_overlap=self.tile_min_overlap,
)
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=self.tile_height,
latent_width=self.tile_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = DenoiseLatentsInvocation.prep_control_data(
context=context,
control_input=self.control,
latents_shape=list(latents.shape),
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# Split the controlnet_data into tiles.
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
controlnet_data_tiles: list[list[ControlNetData]] = []
for tile in tiles:
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
controlnet_data_tiles.append(tile_controlnet_data)
# Prepare the MultiDiffusionRegionConditioning list.
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
multi_diffusion_conditioning.append(
MultiDiffusionRegionConditioning(
region=tile.coords,
text_conditioning_data=conditioning_data,
control_data=tile_controlnet_data,
)
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
# Run Multi-Diffusion denoising.
result_latents = pipeline.multi_diffusion_denoise(
multi_diffusion_conditioning=multi_diffusion_conditioning,
latents=latents,
scheduler_step_kwargs=scheduler_step_kwargs,
noise=noise,
timesteps=timesteps,
init_timestep=init_timestep,
# TODO(ryand): Add proper callback.
callback=lambda x: None,
)
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
result_latents = result_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

View File

@ -0,0 +1,380 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
import numpy as np
import numpy.typing as npt
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
ImageField,
Input,
InputField,
UIType,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField
from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.utils import Tile
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@invocation(
"tiled_stable_diffusion_refine",
title="Tiled Stable Diffusion Refine",
tags=["upscale", "denoise"],
category="latents",
version="1.0.0",
)
class TiledStableDiffusionRefineInvocation(BaseInvocation):
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
"""
image: ImageField = InputField(description="Image to be refined.")
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
# TODO(ryand): Add multiple-of validation.
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.")
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.")
tile_overlap: int = InputField(
default=16,
gt=0,
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField(
default=0.65,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
vae_fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
)
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly
# don't want to use the image field. Figure out how best to handle this.
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
# etc.
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float = InputField(default=0.6)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
"""
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
if coord % LATENT_SCALE_FACTOR != 0:
raise ValueError(
f"The tile coordinates must all be divisible by the latent scale factor"
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
)
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
top = image_tile.coords.top // LATENT_SCALE_FACTOR
left = image_tile.coords.left // LATENT_SCALE_FACTOR
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
right = image_tile.coords.right // LATENT_SCALE_FACTOR
return latents[..., top:bottom, left:right]
def run_controlnet(
self,
image: Image.Image,
controlnet_model: ControlNetModel,
weight: float,
do_classifier_free_guidance: bool,
width: int,
height: int,
device: torch.device,
dtype: torch.dtype,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
) -> ControlNetData:
control_image = prepare_control_image(
image=image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
device=device,
dtype=dtype,
control_mode=control_mode,
resize_mode=resize_mode,
)
return ControlNetData(
model=controlnet_model,
image_tensor=control_image,
weight=weight,
begin_step_percent=0.0,
end_step_percent=1.0,
control_mode=control_mode,
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
# ControlNetData in case needed in the future.
resize_mode=resize_mode,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# TODO(ryand): Expose the seed parameter.
seed = 0
# Load the input image.
input_image = context.images.get_pil(self.image.image_name)
# Calculate the tile locations to cover the image.
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This
# facilitates conversions between image space and latent space.
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
tiles = calc_tiles_with_overlap(
image_height=input_image.height,
image_width=input_image.width,
tile_height=self.tile_height,
tile_width=self.tile_width,
overlap=self.tile_overlap,
)
# Convert the input image to a torch.Tensor.
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
# Validate our assumptions about the shape of input_image_torch.
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
assert input_image_torch.shape[:2] == (1, 3)
# Split the input image into tiles in torch.Tensor format.
image_tiles_torch: list[torch.Tensor] = []
for tile in tiles:
image_tile = input_image_torch[
:,
:,
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
]
image_tiles_torch.append(image_tile)
# Split the input image into tiles in numpy format.
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
# with torch.Tensor tiles.
input_image_np = np.array(input_image)
image_tiles_np: list[npt.NDArray[np.uint8]] = []
for tile in tiles:
image_tile_np = input_image_np[
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
:,
]
image_tiles_np.append(image_tile_np)
# VAE-encode each image tile independently.
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
# about for decoding?
vae_info = context.models.load(self.vae.vae)
latent_tiles: list[torch.Tensor] = []
for image_tile_torch in image_tiles_torch:
latent_tiles.append(
ImageToLatentsInvocation.vae_encode(
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch
)
)
# Generate noise with dimensions corresponding to the full image in latent space.
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
# noise.
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
global_noise = get_noise(
width=input_image_torch.shape[3],
height=input_image_torch.shape[2],
device=TorchDevice.choose_torch_device(),
seed=seed,
downsampling_factor=LATENT_SCALE_FACTOR,
use_cpu=True,
)
# Crop the global noise into tiles.
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles]
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
refined_latent_tiles: list[torch.Tensor] = []
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
# Assume that all tiles have the same shape.
_, _, latent_height, latent_width = latent_tiles[0].shape
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
# Load the ControlNet model.
# TODO(ryand): Support multiple ControlNet models.
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model))
assert isinstance(controlnet_model, ControlNetModel)
# Denoise (i.e. "refine") each tile independently.
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
assert latent_tile.shape == noise_tile.shape
# Prepare a PIL Image for ControlNet processing.
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
image_tile_pil = Image.fromarray(image_tile_np)
# Run the ControlNet on the image tile.
height, width, _ = image_tile_np.shape
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
# validate this assumption here.
assert height % LATENT_SCALE_FACTOR == 0
assert width % LATENT_SCALE_FACTOR == 0
controlnet_data = self.run_controlnet(
image=image_tile_pil,
controlnet_model=controlnet_model,
weight=self.control_weight,
do_classifier_free_guidance=True,
width=width,
height=height,
device=controlnet_model.device,
dtype=controlnet_model.dtype,
control_mode="balanced",
resize_mode="just_resize_simple",
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
refined_latent_tile = pipeline.latents_from_embeddings(
latents=latent_tile,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise_tile,
seed=seed,
mask=None,
masked_latents=None,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=[controlnet_data],
ip_adapter_data=None,
t2i_adapter_data=None,
callback=lambda x: None,
)
refined_latent_tiles.append(refined_latent_tile)
# VAE-decode each refined latent tile independently.
refined_image_tiles: list[Image.Image] = []
for refined_latent_tile in refined_latent_tiles:
refined_image_tile = LatentsToImageInvocation.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=refined_latent_tile,
use_fp32=self.vae_fp32,
use_tiling=False,
)
refined_image_tiles.append(refined_image_tile)
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
# Merge the refined image tiles back into a single image.
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
# TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
merge_tiles_with_linear_blending(
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap
)
# Save the refined image and return its reference.
merged_image_pil = Image.fromarray(merged_image_np)
image_dto = context.images.save(image=merged_image_pil)
return ImageOutput.build(image_dto)

View File

@ -289,7 +289,7 @@ def prepare_control_image(
width: int,
height: int,
num_channels: int = 3,
device: str = "cuda",
device: str | torch.device = "cuda",
dtype: torch.dtype = torch.float16,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
@ -304,7 +304,7 @@ def prepare_control_image(
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
device (str, optional): The target device for the output image. Defaults to "cuda".
device (str | torch.Device, optional): The target device for the output image. Defaults to "cuda".
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
Defaults to True.

View File

@ -10,12 +10,11 @@ import PIL.Image
import psutil
import torch
import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from diffusers.utils.import_utils import is_xformers_available
from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@ -26,6 +25,7 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@dataclass
@ -38,56 +38,18 @@ class PipelineIntermediateState:
predicted_original: Optional[torch.Tensor] = None
@dataclass
class AddsMaskLatents:
"""Add the channels required for inpainting model input.
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
and the latent encoding of the base image.
This class assumes the same mask and base image should apply to all items in the batch.
"""
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
mask: torch.Tensor
initial_image_latents: torch.Tensor
def __call__(
self,
latents: torch.Tensor,
t: torch.Tensor,
text_embeddings: torch.Tensor,
**kwargs,
) -> torch.Tensor:
model_input = self.add_mask_channels(latents)
return self.forward(model_input, t, text_embeddings, **kwargs)
def add_mask_channels(self, latents):
batch_size = latents.size(0)
# duplicate mask and latents for each batch
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
# add mask and image as additional channels
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
return model_input
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
return isinstance(b, torch.Tensor) and (a.size() == b.size())
@dataclass
class AddsMaskGuidance:
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
mask: torch.Tensor
mask_latents: torch.Tensor
scheduler: SchedulerMixin
noise: torch.Tensor
gradient_mask: bool
is_gradient_mask: bool
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return self.apply_mask(latents, t)
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
def apply_mask(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
if t.dim() == 0:
@ -100,7 +62,7 @@ class AddsMaskGuidance:
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
if self.gradient_mask:
if self.is_gradient_mask:
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
masked_input = torch.where(mask_bool, latents, mask_latents)
@ -200,7 +162,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
control_model: ControlNetModel = None,
):
super().__init__(
vae=vae,
@ -214,8 +175,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.control_model = control_model
self.use_ip_adapter = False
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
@ -280,116 +239,131 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
raise Exception("Should not be called")
def add_inpainting_channels_to_latents(
self, latents: torch.Tensor, masked_ref_image_latents: torch.Tensor, inpainting_mask: torch.Tensor
):
"""Given a `latents` tensor, adds the mask and image latents channels required for inpainting.
Standard (non-inpainting) SD UNet models expect an input with shape (N, 4, H, W). Inpainting models expect an
input of shape (N, 9, H, W). The 9 channels are defined as follows:
- Channel 0-3: The latents being denoised.
- Channel 4: The mask indicating which parts of the image are being inpainted.
- Channel 5-8: The latent representation of the masked reference image being inpainted.
This function assumes that the same mask and base image should apply to all items in the batch.
"""
# Validate assumptions about input tensor shapes.
batch_size, latent_channels, latent_height, latent_width = latents.shape
assert latent_channels == 4
assert masked_ref_image_latents.shape == [1, 4, latent_height, latent_width]
assert inpainting_mask == [1, 1, latent_height, latent_width]
# Repeat original_image_latents and inpainting_mask to match the latents batch size.
original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1)
inpainting_mask = inpainting_mask.expand(batch_size, -1, -1, -1)
# Concatenate along the channel dimension.
return torch.cat([latents, inpainting_mask, original_image_latents], dim=1)
def latents_from_embeddings(
self,
latents: torch.Tensor,
num_inference_steps: int,
scheduler_step_kwargs: dict[str, Any],
conditioning_data: TextConditioningData,
*,
noise: Optional[torch.Tensor],
seed: int,
timesteps: torch.Tensor,
init_timestep: torch.Tensor,
additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
callback: Callable[[PipelineIntermediateState], None],
control_data: list[ControlNetData] | None = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
gradient_mask: Optional[bool] = False,
seed: int,
is_gradient_mask: bool = False,
) -> torch.Tensor:
if init_timestep.shape[0] == 0:
return latents
"""Denoise the latents.
if additional_guidance is None:
additional_guidance = []
Args:
latents: The latent-space image to denoise.
- If we are inpainting, this is the initial latent image before noise has been added.
- If we are generating a new image, this should be initialized to zeros.
- In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method.
conditioning_data: Text conditionging data.
noise: Noise used for two purposes:
1. Used by the scheduler to noise the initial `latents` before denoising.
2. Used to noise the `masked_latents` when inpainting.
`noise` should be None if the `latents` tensor has already been noised.
seed: The seed used to generate the noise for the denoising process.
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
same noise used earlier in the pipeline. This should really be handled in a clearer way.
timesteps: The timestep schedule for the denoising process.
init_timestep: The first timestep in the schedule.
TODO(ryand): I'm pretty sure this should always be the same as timesteps[0:1]. Confirm that that is the
case, and remove this duplicate param.
callback: A callback function that is called to report progress during the denoising process.
control_data: ControlNet data.
ip_adapter_data: IP-Adapter data.
t2i_adapter_data: T2I-Adapter data.
mask: A mask indicating which parts of the image are being inpainted. The presence of mask is used to
determine whether we are inpainting or not. `mask` should have the same spatial dimensions as the
`latents` tensor.
TODO(ryand): Check and document the expected dtype, range, and values used to represent
foreground/background.
masked_latents: A latent-space representation of a masked inpainting reference image. This tensor is only
used if an *inpainting* model is being used i.e. this tensor is not used when inpainting with a standard
SD UNet model.
is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not.
"""
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
return latents
orig_latents = latents.clone()
batch_size = latents.shape[0]
batched_t = init_timestep.expand(batch_size)
batched_init_timestep = init_timestep.expand(batch_size)
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
if noise is not None:
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
# full noise. Investigate the history of why this got commented out.
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
if mask is not None:
if is_inpainting_model(self.unet):
if masked_latents is None:
raise Exception("Source image required for inpaint mask when inpaint model used!")
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
self._unet_forward, mask, masked_latents
)
else:
# if no noise provided, noisify unmasked area based on seed
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
try:
latents = self.generate_latents_from_embeddings(
latents,
timesteps,
conditioning_data,
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=callback,
)
finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward
# restore unmasked part after the last step is completed
# in-process masking happens before each step
if mask is not None:
if gradient_mask:
latents = torch.where(mask > 0, latents, orig_latents)
else:
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
return latents
def generate_latents_from_embeddings(
self,
latents: torch.Tensor,
timesteps,
conditioning_data: TextConditioningData,
scheduler_step_kwargs: dict[str, Any],
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
) -> torch.Tensor:
self._adjust_memory_efficient_attention(latents)
if additional_guidance is None:
additional_guidance = []
batch_size = latents.shape[0]
# Handle mask guidance (a.k.a. inpainting).
mask_guidance: AddsMaskGuidance | None = None
if mask is not None and not is_inpainting_model(self.unet):
# We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will
# apply mask guidance to the latents.
if timesteps.shape[0] == 0:
return latents
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
# We still need noise for inpainting, so we generate it from the seed here.
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
mask_guidance = AddsMaskGuidance(
mask=mask,
mask_latents=orig_latents,
scheduler=self.scheduler,
noise=noise,
is_gradient_mask=is_gradient_mask,
)
use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting:
@ -402,28 +376,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
with attn_ctx:
if callback is not None:
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
)
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
step_output = self.step(
batched_t,
latents,
conditioning_data,
t=batched_t,
latents=latents,
conditioning_data=conditioning_data,
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
mask_guidance=mask_guidance,
mask=mask,
masked_latents=masked_latents,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
@ -431,19 +405,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = step_output.prev_sample
predicted_original = getattr(step_output, "pred_original_sample", None)
if callback is not None:
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
)
return latents
# restore unmasked part after the last step is completed
# in-process masking happens before each step
if mask is not None:
if is_gradient_mask:
latents = torch.where(mask > 0, latents, orig_latents)
else:
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
return latents
@torch.inference_mode()
def step(
@ -454,19 +437,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
mask_guidance: AddsMaskGuidance | None,
mask: torch.Tensor | None,
masked_latents: torch.Tensor | None,
control_data: list[ControlNetData] | None = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
if additional_guidance is None:
additional_guidance = []
# one day we will expand this extension point, but for now it just does denoise masking
for guidance in additional_guidance:
latents = guidance(latents, timestep)
# Handle masked image-to-image (a.k.a inpainting).
if mask_guidance is not None:
# NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...).
latents = mask_guidance(latents, timestep)
# TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent
@ -514,6 +498,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
down_intrablock_additional_residuals = accum_adapter_state
# Handle inpainting models.
if is_inpainting_model(self.unet):
# NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after*
# self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image
# latents.
if mask is not None:
if masked_latents is None:
raise ValueError("Source image required for inpaint mask when inpaint model used!")
latent_model_input = self.add_inpainting_channels_to_latents(
latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask
)
else:
# We are using an inpainting model, but no mask was provided, so we are not really "inpainting".
# We generate a global mask and empty original image so that we can still generate in this
# configuration.
# TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting
# to do this.
# TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake'
# mask and original image once rather than on every denoising step.
latent_model_input = self.add_inpainting_channels_to_latents(
latents=latent_model_input,
masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]),
inpainting_mask=torch.ones_like(latent_model_input[:1, :1]),
)
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input,
timestep=t, # TODO: debug how handled batched and non batched timesteps
@ -542,17 +551,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
for guidance in additional_guidance:
# apply the mask to any "denoised" or "pred_original_sample" fields
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting
# again.
if mask_guidance is not None:
# Apply the mask to any "denoised" or "pred_original_sample" fields.
if hasattr(step_output, "denoised"):
step_output.pred_original_sample = guidance(step_output.denoised, self.scheduler.timesteps[-1])
step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1])
elif hasattr(step_output, "pred_original_sample"):
step_output.pred_original_sample = guidance(
step_output.pred_original_sample = mask_guidance(
step_output.pred_original_sample, self.scheduler.timesteps[-1]
)
else:
step_output.pred_original_sample = guidance(latents, self.scheduler.timesteps[-1])
step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1])
return step_output
@ -575,17 +585,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
**kwargs,
):
"""predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4:
# Pad out normal non-inpainting inputs for an inpainting model.
# FIXME: There are too many layers of functions and we have too many different ways of
# overriding things! This should get handled in a way more consistent with the other
# use of AddsMaskLatents.
latents = AddsMaskLatents(
self._unet_forward,
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype),
).add_mask_channels(latents)
# First three args should be positional, not keywords, so torch hooks can see them.
return self.unet(
latents,

View File

@ -0,0 +1,242 @@
from __future__ import annotations
import copy
from dataclasses import dataclass
from typing import Any, Callable, Optional
import torch
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
PipelineIntermediateState,
StableDiffusionGeneratorPipeline,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
from invokeai.backend.tiles.utils import TBLR
# The maximum number of regions with compatible sizes that will be batched together.
# Larger batch sizes improve speed, but require more device memory.
MAX_REGION_BATCH_SIZE = 4
@dataclass
class MultiDiffusionRegionConditioning:
# Region coords in latent space.
region: TBLR
text_conditioning_data: TextConditioningData
control_data: list[ControlNetData]
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
def _split_into_region_batches(
self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]
) -> list[list[MultiDiffusionRegionConditioning]]:
# Group the regions by shape. Only regions with the same shape can be batched together.
conditioning_by_shape: dict[tuple[int, int], list[MultiDiffusionRegionConditioning]] = {}
for region_conditioning in multi_diffusion_conditioning:
shape_hw = (
region_conditioning.region.bottom - region_conditioning.region.top,
region_conditioning.region.right - region_conditioning.region.left,
)
# In python, a tuple of hashable objects is hashable, so can be used as a key in a dict.
if shape_hw not in conditioning_by_shape:
conditioning_by_shape[shape_hw] = []
conditioning_by_shape[shape_hw].append(region_conditioning)
# Split the regions into batches, respecting the MAX_REGION_BATCH_SIZE constraint.
region_conditioning_batches = []
for region_conditioning_batch in conditioning_by_shape.values():
for i in range(0, len(region_conditioning_batch), MAX_REGION_BATCH_SIZE):
region_conditioning_batches.append(region_conditioning_batch[i : i + MAX_REGION_BATCH_SIZE])
return region_conditioning_batches
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
"""Check the input conditioning and confirm that regional prompting is not used."""
for region_conditioning in multi_diffusion_conditioning:
if (
region_conditioning.text_conditioning_data.cond_regions is not None
or region_conditioning.text_conditioning_data.uncond_regions is not None
):
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
def multi_diffusion_denoise(
self,
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
latents: torch.Tensor,
scheduler_step_kwargs: dict[str, Any],
noise: Optional[torch.Tensor],
timesteps: torch.Tensor,
init_timestep: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None],
) -> torch.Tensor:
self._check_regional_prompting(multi_diffusion_conditioning)
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
return latents
batch_size, _, latent_height, latent_width = latents.shape
batched_init_timestep = init_timestep.expand(batch_size)
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
if noise is not None:
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
# full noise. Investigate the history of why this got commented out.
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
# cropping into regions.
self._adjust_memory_efficient_attention(latents)
# Populate a weighted mask that will be used to combine the results from each region after every step.
# For now, we assume that each region has the same weight (1.0).
region_weight_mask = torch.zeros(
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
)
for region_conditioning in multi_diffusion_conditioning:
region = region_conditioning.region
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
# Group the region conditioning into batches for faster processing.
# region_conditioning_batches[b][r] is the r'th region in the b'th batch.
region_conditioning_batches = self._split_into_region_batches(multi_diffusion_conditioning)
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
# separate scheduler state for each region batch.
region_batch_schedulers: list[SchedulerMixin] = [
copy.deepcopy(self.scheduler) for _ in region_conditioning_batches
]
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
merged_latents = torch.zeros_like(latents)
merged_pred_original: torch.Tensor | None = None
for region_batch_idx, region_conditioning_batch in enumerate(region_conditioning_batches):
# Switch to the scheduler for the region batch.
self.scheduler = region_batch_schedulers[region_batch_idx]
# TODO(ryand): This logic has not yet been tested with input latents with a batch_size > 1.
# Prepare the latents for the region batch.
batch_latents = torch.cat(
[
latents[
:,
:,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
for region_conditioning in region_conditioning_batch
],
)
# TODO(ryand): Do we have to repeat the text_conditioning_data to match the batch size? Or does step()
# handle broadcasting properly?
# TODO(ryand): Resume here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# Run the denoising step on the region.
step_output = self.step(
t=batched_t,
latents=batch_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=i,
total_step_count=total_step_count,
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)
# Run a denoising step on the region.
# step_output = self._region_step(
# region_conditioning=region_conditioning,
# t=batched_t,
# latents=latents,
# step_index=i,
# total_step_count=len(timesteps),
# scheduler_step_kwargs=scheduler_step_kwargs,
# )
# Store the results from the region.
region = region_conditioning.region
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
if pred_orig_sample is not None:
# If one region has pred_original_sample, then we can assume that all regions will have it, because
# they all use the same scheduler.
if merged_pred_original is None:
merged_pred_original = torch.zeros_like(latents)
merged_pred_original[:, :, region.top : region.bottom, region.left : region.right] += (
pred_orig_sample
)
# Normalize the merged results.
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
predicted_original = None
if merged_pred_original is not None:
predicted_original = torch.where(
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
)
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
)
return latents
@torch.inference_mode()
def _region_batch_step(
self,
region_conditioning: MultiDiffusionRegionConditioning,
t: torch.Tensor,
latents: torch.Tensor,
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
):
# Crop the inputs to the region.
region_latents = latents[
:,
:,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
# Run the denoising step on the region.
return self.step(
t=t,
latents=region_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=step_index,
total_step_count=total_step_count,
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)