mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Comments, a bit refactor
This commit is contained in:
parent
79e35bd0d3
commit
2c2ec8f0bc
@ -8,7 +8,7 @@ from diffusers import UNet2DConditionModel
|
|||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -31,92 +31,101 @@ class UNetKwargs:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DenoiseInputs:
|
class DenoiseInputs:
|
||||||
"""Initial variables passed to denoise. Supposed to be unchanged.
|
"""Initial variables passed to denoise. Supposed to be unchanged."""
|
||||||
|
|
||||||
Variables:
|
|
||||||
orig_latents: The latent-space image to denoise.
|
|
||||||
Shape: [batch, channels, latent_height, latent_width]
|
|
||||||
- If we are inpainting, this is the initial latent image before noise has been added.
|
|
||||||
- If we are generating a new image, this should be initialized to zeros.
|
|
||||||
- In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
|
|
||||||
scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method.
|
|
||||||
conditioning_data: Text conditionging data.
|
|
||||||
noise: Noise used for two purposes:
|
|
||||||
Shape: [1 or batch, channels, latent_height, latent_width]
|
|
||||||
1. Used by the scheduler to noise the initial `latents` before denoising.
|
|
||||||
2. Used to noise the `masked_latents` when inpainting.
|
|
||||||
`noise` should be None if the `latents` tensor has already been noised.
|
|
||||||
seed: The seed used to generate the noise for the denoising process.
|
|
||||||
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
|
|
||||||
same noise used earlier in the pipeline. This should really be handled in a clearer way.
|
|
||||||
timesteps: The timestep schedule for the denoising process.
|
|
||||||
init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so
|
|
||||||
should be populated if you want noise applied *even* if timesteps is empty.
|
|
||||||
attention_processor_cls: Class of attention processor that is used.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
# The latent-space image to denoise.
|
||||||
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
|
# - If we are inpainting, this is the initial latent image before noise has been added.
|
||||||
|
# - If we are generating a new image, this should be initialized to zeros.
|
||||||
|
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
|
||||||
orig_latents: torch.Tensor
|
orig_latents: torch.Tensor
|
||||||
|
|
||||||
|
# kwargs forwarded to the scheduler.step() method.
|
||||||
scheduler_step_kwargs: dict[str, Any]
|
scheduler_step_kwargs: dict[str, Any]
|
||||||
|
|
||||||
|
# Text conditionging data.
|
||||||
conditioning_data: TextConditioningData
|
conditioning_data: TextConditioningData
|
||||||
|
|
||||||
|
# Noise used for two purposes:
|
||||||
|
# 1. Used by the scheduler to noise the initial `latents` before denoising.
|
||||||
|
# 2. Used to noise the `masked_latents` when inpainting.
|
||||||
|
# `noise` should be None if the `latents` tensor has already been noised.
|
||||||
|
# Shape: [1 or batch, channels, latent_height, latent_width]
|
||||||
noise: Optional[torch.Tensor]
|
noise: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# The seed used to generate the noise for the denoising process.
|
||||||
|
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
|
||||||
|
# same noise used earlier in the pipeline. This should really be handled in a clearer way.
|
||||||
seed: int
|
seed: int
|
||||||
|
|
||||||
|
# The timestep schedule for the denoising process.
|
||||||
timesteps: torch.Tensor
|
timesteps: torch.Tensor
|
||||||
|
|
||||||
|
# The first timestep in the schedule. This is used to determine the initial noise level, so
|
||||||
|
# should be populated if you want noise applied *even* if timesteps is empty.
|
||||||
init_timestep: torch.Tensor
|
init_timestep: torch.Tensor
|
||||||
|
|
||||||
|
# Class of attention processor that is used.
|
||||||
attention_processor_cls: Type[Any]
|
attention_processor_cls: Type[Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DenoiseContext:
|
class DenoiseContext:
|
||||||
"""Context with all variables in denoise
|
"""Context with all variables in denoise"""
|
||||||
|
|
||||||
Variables:
|
|
||||||
inputs: Initial variables passed to denoise. Supposed to be unchanged.
|
|
||||||
scheduler: Scheduler which used to apply noise predictions.
|
|
||||||
unet: UNet model.
|
|
||||||
latents: Current state of latent-space image in denoising process.
|
|
||||||
None until `pre_denoise_loop` callback.
|
|
||||||
Shape: [batch, channels, latent_height, latent_width]
|
|
||||||
step_index: Current denoising step index.
|
|
||||||
None until `pre_step` callback.
|
|
||||||
timestep: Current denoising step timestep.
|
|
||||||
None until `pre_step` callback.
|
|
||||||
unet_kwargs: Arguments which will be passed to U Net model.
|
|
||||||
Available in `pre_unet`/`post_unet` callbacks, otherwice will be None.
|
|
||||||
step_output: SchedulerOutput class returned from step function(normally, generated by scheduler).
|
|
||||||
Supposed to be used only in `post_step` callback, otherwice can be None.
|
|
||||||
latent_model_input: Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
|
||||||
Available in events inside step(between `pre_step` and `post_stop`).
|
|
||||||
Shape: [batch, channels, latent_height, latent_width]
|
|
||||||
conditioning_mode: [TMP] Defines on which conditionings current unet call will be runned.
|
|
||||||
Available in `pre_unet`/`post_unet` callbacks, otherwice will be None.
|
|
||||||
Can be "negative", "positive" or "both"
|
|
||||||
negative_noise_pred: [TMP] Noise predictions from negative conditioning.
|
|
||||||
Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None.
|
|
||||||
Shape: [batch, channels, latent_height, latent_width]
|
|
||||||
positive_noise_pred: [TMP] Noise predictions from positive conditioning.
|
|
||||||
Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None.
|
|
||||||
Shape: [batch, channels, latent_height, latent_width]
|
|
||||||
noise_pred: Combined noise prediction from passed conditionings.
|
|
||||||
Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None.
|
|
||||||
Shape: [batch, channels, latent_height, latent_width]
|
|
||||||
extra: Dictionary for extensions to pass extra info about denoise process to other extensions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
# Initial variables passed to denoise. Supposed to be unchanged.
|
||||||
inputs: DenoiseInputs
|
inputs: DenoiseInputs
|
||||||
|
|
||||||
|
# Scheduler which used to apply noise predictions.
|
||||||
scheduler: SchedulerMixin
|
scheduler: SchedulerMixin
|
||||||
|
|
||||||
|
# UNet model.
|
||||||
unet: Optional[UNet2DConditionModel] = None
|
unet: Optional[UNet2DConditionModel] = None
|
||||||
|
|
||||||
|
# Current state of latent-space image in denoising process.
|
||||||
|
# None until `pre_denoise_loop` callback.
|
||||||
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
latents: Optional[torch.Tensor] = None
|
latents: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Current denoising step index.
|
||||||
|
# None until `pre_step` callback.
|
||||||
step_index: Optional[int] = None
|
step_index: Optional[int] = None
|
||||||
|
|
||||||
|
# Current denoising step timestep.
|
||||||
|
# None until `pre_step` callback.
|
||||||
timestep: Optional[torch.Tensor] = None
|
timestep: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Arguments which will be passed to UNet model.
|
||||||
|
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||||
unet_kwargs: Optional[UNetKwargs] = None
|
unet_kwargs: Optional[UNetKwargs] = None
|
||||||
|
|
||||||
|
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
||||||
|
# Supposed to be used only in `post_step` callback, otherwise can be None.
|
||||||
step_output: Optional[SchedulerOutput] = None
|
step_output: Optional[SchedulerOutput] = None
|
||||||
|
|
||||||
|
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
||||||
|
# Available in events inside step(between `pre_step` and `post_stop`).
|
||||||
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
latent_model_input: Optional[torch.Tensor] = None
|
latent_model_input: Optional[torch.Tensor] = None
|
||||||
conditioning_mode: Optional[str] = None
|
|
||||||
|
# [TMP] Defines on which conditionings current unet call will be runned.
|
||||||
|
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||||
|
conditioning_mode: Optional[ConditioningMode] = None
|
||||||
|
|
||||||
|
# [TMP] Noise predictions from negative conditioning.
|
||||||
|
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||||
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
negative_noise_pred: Optional[torch.Tensor] = None
|
negative_noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# [TMP] Noise predictions from positive conditioning.
|
||||||
|
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||||
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
positive_noise_pred: Optional[torch.Tensor] = None
|
positive_noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Combined noise prediction from passed conditionings.
|
||||||
|
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||||
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
noise_pred: Optional[torch.Tensor] = None
|
noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Dictionary for extensions to pass extra info about denoise process to other extensions.
|
||||||
extra: dict = field(default_factory=dict)
|
extra: dict = field(default_factory=dict)
|
||||||
|
@ -137,6 +137,12 @@ class TextConditioningData:
|
|||||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
|
||||||
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
|
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
|
||||||
|
"""Fills unet arguments with data from provided conditionings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet_kwargs (UNetKwargs): Object which stores UNet model arguments.
|
||||||
|
conditioning_mode (ConditioningMode): Describes which conditionings should be used.
|
||||||
|
"""
|
||||||
_, _, h, w = unet_kwargs.sample.shape
|
_, _, h, w = unet_kwargs.sample.shape
|
||||||
device = unet_kwargs.sample.device
|
device = unet_kwargs.sample.device
|
||||||
dtype = unet_kwargs.sample.dtype
|
dtype = unet_kwargs.sample.dtype
|
||||||
@ -187,7 +193,7 @@ class TextConditioningData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
|
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor:
|
||||||
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
|
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -195,8 +201,13 @@ class TextConditioningData:
|
|||||||
cls,
|
cls,
|
||||||
cond: torch.Tensor,
|
cond: torch.Tensor,
|
||||||
target_len: int,
|
target_len: int,
|
||||||
encoder_attention_mask: Optional[torch.Tensor],
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
):
|
"""Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cond (torch.Tensor): Conditioning tensor which to pads by zeros.
|
||||||
|
target_len (int): To which length(tokens count) pad tensor.
|
||||||
|
"""
|
||||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||||
|
|
||||||
if cond.shape[1] < target_len:
|
if cond.shape[1] < target_len:
|
||||||
@ -212,21 +223,28 @@ class TextConditioningData:
|
|||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
if encoder_attention_mask is None:
|
return cond, conditioning_attention_mask
|
||||||
encoder_attention_mask = conditioning_attention_mask
|
|
||||||
else:
|
|
||||||
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
|
|
||||||
|
|
||||||
return cond, encoder_attention_mask
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]):
|
def _concat_conditionings_for_batch(
|
||||||
|
cls,
|
||||||
|
conditionings: List[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""Concatenate provided conditioning tensors to one batched tensor.
|
||||||
|
If tensors have different sizes then pad them by zeros and creates
|
||||||
|
encoder_attention_mask to exclude padding from attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate.
|
||||||
|
"""
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
max_len = max([c.shape[1] for c in conditionings])
|
max_len = max([c.shape[1] for c in conditionings])
|
||||||
if any(c.shape[1] != max_len for c in conditionings):
|
if any(c.shape[1] != max_len for c in conditionings):
|
||||||
|
encoder_attention_masks = [None] * len(conditionings)
|
||||||
for i in range(len(conditionings)):
|
for i in range(len(conditionings)):
|
||||||
conditionings[i], encoder_attention_mask = cls._pad_conditioning(
|
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(
|
||||||
conditionings[i], max_len, encoder_attention_mask
|
conditionings[i], max_len
|
||||||
)
|
)
|
||||||
|
encoder_attention_mask = torch.cat(encoder_attention_masks)
|
||||||
|
|
||||||
return torch.cat(conditionings), encoder_attention_mask
|
return torch.cat(conditionings), encoder_attention_mask
|
||||||
|
Loading…
Reference in New Issue
Block a user