mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
from dataclasses import dataclass, field
|
||
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
from diffusers import UNet2DConditionModel
|
||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class UNetKwargs:
|
||
|
sample: torch.Tensor
|
||
|
timestep: Union[torch.Tensor, float, int]
|
||
|
encoder_hidden_states: torch.Tensor
|
||
|
|
||
|
class_labels: Optional[torch.Tensor] = None
|
||
|
timestep_cond: Optional[torch.Tensor] = None
|
||
|
attention_mask: Optional[torch.Tensor] = None
|
||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None
|
||
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||
|
mid_block_additional_residual: Optional[torch.Tensor] = None
|
||
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||
|
encoder_attention_mask: Optional[torch.Tensor] = None
|
||
|
# return_dict: bool = True
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class DenoiseContext:
|
||
|
latents: torch.Tensor
|
||
|
scheduler_step_kwargs: dict[str, Any]
|
||
|
conditioning_data: TextConditioningData
|
||
|
noise: Optional[torch.Tensor]
|
||
|
seed: int
|
||
|
timesteps: torch.Tensor
|
||
|
init_timestep: torch.Tensor
|
||
|
|
||
|
scheduler: SchedulerMixin
|
||
|
unet: Optional[UNet2DConditionModel] = None
|
||
|
|
||
|
orig_latents: Optional[torch.Tensor] = None
|
||
|
step_index: Optional[int] = None
|
||
|
timestep: Optional[torch.Tensor] = None
|
||
|
unet_kwargs: Optional[UNetKwargs] = None
|
||
|
step_output: Optional[SchedulerOutput] = None
|
||
|
|
||
|
latent_model_input: Optional[torch.Tensor] = None
|
||
|
conditioning_mode: Optional[str] = None
|
||
|
negative_noise_pred: Optional[torch.Tensor] = None
|
||
|
positive_noise_pred: Optional[torch.Tensor] = None
|
||
|
noise_pred: Optional[torch.Tensor] = None
|
||
|
|
||
|
extra: dict = field(default_factory=dict)
|
||
|
|
||
|
def __delattr__(self, name: str):
|
||
|
setattr(self, name, None)
|