mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Base code from draft PR
This commit is contained in:
60
invokeai/backend/stable_diffusion/denoise_context.py
Normal file
60
invokeai/backend/stable_diffusion/denoise_context.py
Normal file
@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNetKwargs:
|
||||
sample: torch.Tensor
|
||||
timestep: Union[torch.Tensor, float, int]
|
||||
encoder_hidden_states: torch.Tensor
|
||||
|
||||
class_labels: Optional[torch.Tensor] = None
|
||||
timestep_cond: Optional[torch.Tensor] = None
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None
|
||||
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None
|
||||
# return_dict: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DenoiseContext:
|
||||
latents: torch.Tensor
|
||||
scheduler_step_kwargs: dict[str, Any]
|
||||
conditioning_data: TextConditioningData
|
||||
noise: Optional[torch.Tensor]
|
||||
seed: int
|
||||
timesteps: torch.Tensor
|
||||
init_timestep: torch.Tensor
|
||||
|
||||
scheduler: SchedulerMixin
|
||||
unet: Optional[UNet2DConditionModel] = None
|
||||
|
||||
orig_latents: Optional[torch.Tensor] = None
|
||||
step_index: Optional[int] = None
|
||||
timestep: Optional[torch.Tensor] = None
|
||||
unet_kwargs: Optional[UNetKwargs] = None
|
||||
step_output: Optional[SchedulerOutput] = None
|
||||
|
||||
latent_model_input: Optional[torch.Tensor] = None
|
||||
conditioning_mode: Optional[str] = None
|
||||
negative_noise_pred: Optional[torch.Tensor] = None
|
||||
positive_noise_pred: Optional[torch.Tensor] = None
|
||||
noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
extra: dict = field(default_factory=dict)
|
||||
|
||||
def __delattr__(self, name: str):
|
||||
setattr(self, name, None)
|
Reference in New Issue
Block a user