InvokeAI/invokeai/backend/stable_diffusion/denoise_context.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

61 lines
2.0 KiB
Python
Raw Normal View History

2024-07-12 17:31:26 +00:00
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch
from diffusers import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
@dataclass
class UNetKwargs:
sample: torch.Tensor
timestep: Union[torch.Tensor, float, int]
encoder_hidden_states: torch.Tensor
class_labels: Optional[torch.Tensor] = None
timestep_cond: Optional[torch.Tensor] = None
attention_mask: Optional[torch.Tensor] = None
cross_attention_kwargs: Optional[Dict[str, Any]] = None
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
mid_block_additional_residual: Optional[torch.Tensor] = None
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
encoder_attention_mask: Optional[torch.Tensor] = None
# return_dict: bool = True
@dataclass
class DenoiseContext:
latents: torch.Tensor
scheduler_step_kwargs: dict[str, Any]
conditioning_data: TextConditioningData
noise: Optional[torch.Tensor]
seed: int
timesteps: torch.Tensor
init_timestep: torch.Tensor
scheduler: SchedulerMixin
unet: Optional[UNet2DConditionModel] = None
orig_latents: Optional[torch.Tensor] = None
step_index: Optional[int] = None
timestep: Optional[torch.Tensor] = None
unet_kwargs: Optional[UNetKwargs] = None
step_output: Optional[SchedulerOutput] = None
latent_model_input: Optional[torch.Tensor] = None
conditioning_mode: Optional[str] = None
negative_noise_pred: Optional[torch.Tensor] = None
positive_noise_pred: Optional[torch.Tensor] = None
noise_pred: Optional[torch.Tensor] = None
extra: dict = field(default_factory=dict)
def __delattr__(self, name: str):
setattr(self, name, None)