from __future__ import annotations from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union import torch from diffusers import UNet2DConditionModel from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput if TYPE_CHECKING: from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData @dataclass class UNetKwargs: sample: torch.Tensor timestep: Union[torch.Tensor, float, int] encoder_hidden_states: torch.Tensor class_labels: Optional[torch.Tensor] = None timestep_cond: Optional[torch.Tensor] = None attention_mask: Optional[torch.Tensor] = None cross_attention_kwargs: Optional[Dict[str, Any]] = None added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None mid_block_additional_residual: Optional[torch.Tensor] = None down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None encoder_attention_mask: Optional[torch.Tensor] = None # return_dict: bool = True @dataclass class DenoiseInputs: """Initial variables passed to denoise. Supposed to be unchanged. Variables: orig_latents: The latent-space image to denoise. Shape: [batch, channels, latent_height, latent_width] - If we are inpainting, this is the initial latent image before noise has been added. - If we are generating a new image, this should be initialized to zeros. - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner). scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method. conditioning_data: Text conditionging data. noise: Noise used for two purposes: Shape: [1 or batch, channels, latent_height, latent_width] 1. Used by the scheduler to noise the initial `latents` before denoising. 2. Used to noise the `masked_latents` when inpainting. `noise` should be None if the `latents` tensor has already been noised. seed: The seed used to generate the noise for the denoising process. HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the same noise used earlier in the pipeline. This should really be handled in a clearer way. timesteps: The timestep schedule for the denoising process. init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so should be populated if you want noise applied *even* if timesteps is empty. attention_processor_cls: Class of attention processor that is used. """ orig_latents: torch.Tensor scheduler_step_kwargs: dict[str, Any] conditioning_data: TextConditioningData noise: Optional[torch.Tensor] seed: int timesteps: torch.Tensor init_timestep: torch.Tensor attention_processor_cls: Type[Any] @dataclass class DenoiseContext: """Context with all variables in denoise Variables: inputs: Initial variables passed to denoise. Supposed to be unchanged. scheduler: Scheduler which used to apply noise predictions. unet: UNet model. latents: Current state of latent-space image in denoising process. None until `pre_denoise_loop` callback. Shape: [batch, channels, latent_height, latent_width] step_index: Current denoising step index. None until `pre_step` callback. timestep: Current denoising step timestep. None until `pre_step` callback. unet_kwargs: Arguments which will be passed to U Net model. Available in `pre_unet`/`post_unet` callbacks, otherwice will be None. step_output: SchedulerOutput class returned from step function(normally, generated by scheduler). Supposed to be used only in `post_step` callback, otherwice can be None. latent_model_input: Scaled version of `latents`, which will be passed to unet_kwargs initialization. Available in events inside step(between `pre_step` and `post_stop`). Shape: [batch, channels, latent_height, latent_width] conditioning_mode: [TMP] Defines on which conditionings current unet call will be runned. Available in `pre_unet`/`post_unet` callbacks, otherwice will be None. Can be "negative", "positive" or "both" negative_noise_pred: [TMP] Noise predictions from negative conditioning. Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None. Shape: [batch, channels, latent_height, latent_width] positive_noise_pred: [TMP] Noise predictions from positive conditioning. Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None. Shape: [batch, channels, latent_height, latent_width] noise_pred: Combined noise prediction from passed conditionings. Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None. Shape: [batch, channels, latent_height, latent_width] extra: Dictionary for extensions to pass extra info about denoise process to other extensions. """ inputs: DenoiseInputs scheduler: SchedulerMixin unet: Optional[UNet2DConditionModel] = None latents: Optional[torch.Tensor] = None step_index: Optional[int] = None timestep: Optional[torch.Tensor] = None unet_kwargs: Optional[UNetKwargs] = None step_output: Optional[SchedulerOutput] = None latent_model_input: Optional[torch.Tensor] = None conditioning_mode: Optional[str] = None negative_noise_pred: Optional[torch.Tensor] = None positive_noise_pred: Optional[torch.Tensor] = None noise_pred: Optional[torch.Tensor] = None extra: dict = field(default_factory=dict)