2024-07-12 17:31:26 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field
|
2024-07-16 17:03:29 +00:00
|
|
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
from diffusers import UNet2DConditionModel
|
|
|
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
2024-07-17 01:20:31 +00:00
|
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class UNetKwargs:
|
|
|
|
sample: torch.Tensor
|
|
|
|
timestep: Union[torch.Tensor, float, int]
|
|
|
|
encoder_hidden_states: torch.Tensor
|
|
|
|
|
|
|
|
class_labels: Optional[torch.Tensor] = None
|
|
|
|
timestep_cond: Optional[torch.Tensor] = None
|
|
|
|
attention_mask: Optional[torch.Tensor] = None
|
|
|
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None
|
|
|
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
|
|
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
|
|
|
mid_block_additional_residual: Optional[torch.Tensor] = None
|
|
|
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
|
|
|
encoder_attention_mask: Optional[torch.Tensor] = None
|
|
|
|
# return_dict: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2024-07-16 16:30:29 +00:00
|
|
|
class DenoiseInputs:
|
2024-07-17 01:20:31 +00:00
|
|
|
"""Initial variables passed to denoise. Supposed to be unchanged."""
|
2024-07-16 19:52:44 +00:00
|
|
|
|
2024-07-17 01:20:31 +00:00
|
|
|
# The latent-space image to denoise.
|
|
|
|
# Shape: [batch, channels, latent_height, latent_width]
|
|
|
|
# - If we are inpainting, this is the initial latent image before noise has been added.
|
|
|
|
# - If we are generating a new image, this should be initialized to zeros.
|
|
|
|
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
|
2024-07-16 16:30:29 +00:00
|
|
|
orig_latents: torch.Tensor
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# kwargs forwarded to the scheduler.step() method.
|
2024-07-12 17:31:26 +00:00
|
|
|
scheduler_step_kwargs: dict[str, Any]
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# Text conditionging data.
|
2024-07-12 17:31:26 +00:00
|
|
|
conditioning_data: TextConditioningData
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# Noise used for two purposes:
|
|
|
|
# 1. Used by the scheduler to noise the initial `latents` before denoising.
|
|
|
|
# 2. Used to noise the `masked_latents` when inpainting.
|
|
|
|
# `noise` should be None if the `latents` tensor has already been noised.
|
|
|
|
# Shape: [1 or batch, channels, latent_height, latent_width]
|
2024-07-12 17:31:26 +00:00
|
|
|
noise: Optional[torch.Tensor]
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# The seed used to generate the noise for the denoising process.
|
|
|
|
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
|
|
|
|
# same noise used earlier in the pipeline. This should really be handled in a clearer way.
|
2024-07-12 17:31:26 +00:00
|
|
|
seed: int
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# The timestep schedule for the denoising process.
|
2024-07-12 17:31:26 +00:00
|
|
|
timesteps: torch.Tensor
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# The first timestep in the schedule. This is used to determine the initial noise level, so
|
|
|
|
# should be populated if you want noise applied *even* if timesteps is empty.
|
2024-07-12 17:31:26 +00:00
|
|
|
init_timestep: torch.Tensor
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# Class of attention processor that is used.
|
2024-07-16 17:03:29 +00:00
|
|
|
attention_processor_cls: Type[Any]
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-16 16:30:29 +00:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class DenoiseContext:
|
2024-07-17 01:20:31 +00:00
|
|
|
"""Context with all variables in denoise"""
|
2024-07-16 19:52:44 +00:00
|
|
|
|
2024-07-17 01:20:31 +00:00
|
|
|
# Initial variables passed to denoise. Supposed to be unchanged.
|
2024-07-16 16:30:29 +00:00
|
|
|
inputs: DenoiseInputs
|
|
|
|
|
2024-07-17 01:20:31 +00:00
|
|
|
# Scheduler which used to apply noise predictions.
|
2024-07-12 17:31:26 +00:00
|
|
|
scheduler: SchedulerMixin
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# UNet model.
|
2024-07-12 17:31:26 +00:00
|
|
|
unet: Optional[UNet2DConditionModel] = None
|
|
|
|
|
2024-07-17 01:20:31 +00:00
|
|
|
# Current state of latent-space image in denoising process.
|
|
|
|
# None until `pre_denoise_loop` callback.
|
|
|
|
# Shape: [batch, channels, latent_height, latent_width]
|
2024-07-16 16:30:29 +00:00
|
|
|
latents: Optional[torch.Tensor] = None
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# Current denoising step index.
|
|
|
|
# None until `pre_step` callback.
|
2024-07-12 17:31:26 +00:00
|
|
|
step_index: Optional[int] = None
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# Current denoising step timestep.
|
|
|
|
# None until `pre_step` callback.
|
2024-07-12 17:31:26 +00:00
|
|
|
timestep: Optional[torch.Tensor] = None
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# Arguments which will be passed to UNet model.
|
|
|
|
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
2024-07-12 17:31:26 +00:00
|
|
|
unet_kwargs: Optional[UNetKwargs] = None
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
|
|
|
# Supposed to be used only in `post_step` callback, otherwise can be None.
|
2024-07-12 17:31:26 +00:00
|
|
|
step_output: Optional[SchedulerOutput] = None
|
|
|
|
|
2024-07-17 01:20:31 +00:00
|
|
|
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
|
|
|
# Available in events inside step(between `pre_step` and `post_stop`).
|
|
|
|
# Shape: [batch, channels, latent_height, latent_width]
|
2024-07-12 17:31:26 +00:00
|
|
|
latent_model_input: Optional[torch.Tensor] = None
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# [TMP] Defines on which conditionings current unet call will be runned.
|
|
|
|
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
|
|
|
conditioning_mode: Optional[ConditioningMode] = None
|
|
|
|
|
|
|
|
# [TMP] Noise predictions from negative conditioning.
|
|
|
|
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
|
|
|
# Shape: [batch, channels, latent_height, latent_width]
|
2024-07-12 17:31:26 +00:00
|
|
|
negative_noise_pred: Optional[torch.Tensor] = None
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# [TMP] Noise predictions from positive conditioning.
|
|
|
|
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
|
|
|
# Shape: [batch, channels, latent_height, latent_width]
|
2024-07-12 17:31:26 +00:00
|
|
|
positive_noise_pred: Optional[torch.Tensor] = None
|
2024-07-17 01:20:31 +00:00
|
|
|
|
|
|
|
# Combined noise prediction from passed conditionings.
|
|
|
|
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
|
|
|
# Shape: [batch, channels, latent_height, latent_width]
|
2024-07-12 17:31:26 +00:00
|
|
|
noise_pred: Optional[torch.Tensor] = None
|
|
|
|
|
2024-07-17 01:20:31 +00:00
|
|
|
# Dictionary for extensions to pass extra info about denoise process to other extensions.
|
2024-07-12 17:31:26 +00:00
|
|
|
extra: dict = field(default_factory=dict)
|