Document plan for the rest of the MultiDiffusion implementation.

This commit is contained in:
Ryan Dick 2024-06-14 18:08:11 -04:00 committed by Kent Keirsey
parent 605f460c7d
commit fc187c9253

View File

@ -1,12 +1,18 @@
from __future__ import annotations from __future__ import annotations
import math
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional
import torch import torch
from invokeai.backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline from invokeai.backend.stable_diffusion.diffusers_pipeline import (
AddsMaskGuidance,
ControlNetData,
PipelineIntermediateState,
StableDiffusionGeneratorPipeline,
T2IAdapterData,
is_inpainting_model,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
@ -14,6 +20,32 @@ from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import U
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising.""" """A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
# Plan:
# - latents_from_embeddings(...) will accept all of the same global params, but the "local" params will be bundled
# together with tile locations.
# - What is "local"?:
# - conditioning_data could be local, but for upscaling will be global
# - control_data makes more sense as global, then we split it up as we split up the latents
# - ip_adapter_data sort of has 3 modes to consider:
# - global style: applied in the same way to all tiles
# - local style: apply different IP-Adapters to each tile
# - global structure: we want to crop the input image and run the IP-Adapter on each separately
# - t2i_adapter_data won't be supported at first - it's not popular enough
# - All the inpainting params are global and need to be cropped accordingly
# - Local:
# - latents
# - conditioning_data
# - noise
# - control_data
# - ip_adapter_data (skip for now)
# - t2i_adapter_data (skip for now)
# - mask
# - masked_latents
# - is_gradient_mask ???
# - Can we support inpainting models in this node?
# - TBD, need to think about this more
# - step(...) remains mostly unmodified, is not overriden in this sub-class.
# - May need a cleaner AddsMaskGuidance implementation to handle this plan... we'll see.
def latents_from_embeddings( def latents_from_embeddings(
self, self,
latents: torch.Tensor, latents: torch.Tensor,
@ -142,141 +174,3 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
) )
return latents return latents
@torch.inference_mode()
def step(
self,
t: torch.Tensor,
latents: torch.Tensor,
conditioning_data: TextConditioningData,
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
mask_guidance: AddsMaskGuidance | None,
mask: torch.Tensor | None,
masked_latents: torch.Tensor | None,
control_data: list[ControlNetData] | None = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
# Handle masked image-to-image (a.k.a inpainting).
if mask_guidance is not None:
# NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...).
latents = mask_guidance(latents, timestep)
# TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# Handle ControlNet(s)
down_block_additional_residuals = None
mid_block_additional_residual = None
if control_data is not None:
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
control_data=control_data,
sample=latent_model_input,
timestep=timestep,
step_index=step_index,
total_step_count=total_step_count,
conditioning_data=conditioning_data,
)
# Handle T2I-Adapter(s)
down_intrablock_additional_residuals = None
if t2i_adapter_data is not None:
accum_adapter_state = None
for single_t2i_adapter_data in t2i_adapter_data:
# Determine the T2I-Adapter weights for the current denoising step.
first_t2i_adapter_step = math.floor(single_t2i_adapter_data.begin_step_percent * total_step_count)
last_t2i_adapter_step = math.ceil(single_t2i_adapter_data.end_step_percent * total_step_count)
t2i_adapter_weight = (
single_t2i_adapter_data.weight[step_index]
if isinstance(single_t2i_adapter_data.weight, list)
else single_t2i_adapter_data.weight
)
if step_index < first_t2i_adapter_step or step_index > last_t2i_adapter_step:
# If the current step is outside of the T2I-Adapter's begin/end step range, then set its weight to 0
# so it has no effect.
t2i_adapter_weight = 0.0
# Apply the t2i_adapter_weight, and accumulate.
if accum_adapter_state is None:
# Handle the first T2I-Adapter.
accum_adapter_state = [val * t2i_adapter_weight for val in single_t2i_adapter_data.adapter_state]
else:
# Add to the previous adapter states.
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
accum_adapter_state[idx] += value * t2i_adapter_weight
down_intrablock_additional_residuals = accum_adapter_state
# Handle inpainting models.
if is_inpainting_model(self.unet):
# NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after*
# self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image
# latents.
if mask is not None:
if masked_latents is None:
raise ValueError("Source image required for inpaint mask when inpaint model used!")
latent_model_input = self.add_inpainting_channels_to_latents(
latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask
)
else:
# We are using an inpainting model, but no mask was provided, so we are not really "inpainting".
# We generate a global mask and empty original image so that we can still generate in this
# configuration.
# TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting
# to do this.
# TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake'
# mask and original image once rather than on every denoising step.
latent_model_input = self.add_inpainting_channels_to_latents(
latents=latent_model_input,
masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]),
inpainting_mask=torch.ones_like(latent_model_input[:1, :1]),
)
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input,
timestep=t, # TODO: debug how handled batched and non batched timesteps
step_index=step_index,
total_step_count=total_step_count,
conditioning_data=conditioning_data,
ip_adapter_data=ip_adapter_data,
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
)
guidance_scale = conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[step_index]
noise_pred = self.invokeai_diffuser._combine(uc_noise_pred, c_noise_pred, guidance_scale)
guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier
if guidance_rescale_multiplier > 0:
noise_pred = self._rescale_cfg(
noise_pred,
c_noise_pred,
guidance_rescale_multiplier,
)
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting
# again.
if mask_guidance is not None:
# Apply the mask to any "denoised" or "pred_original_sample" fields.
if hasattr(step_output, "denoised"):
step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1])
elif hasattr(step_output, "pred_original_sample"):
step_output.pred_original_sample = mask_guidance(
step_output.pred_original_sample, self.scheduler.timesteps[-1]
)
else:
step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1])
return step_output