diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py index bcebb2945e..2b43d3fb0f 100644 --- a/invokeai/backend/stable_diffusion/denoise_context.py +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -8,7 +8,7 @@ from diffusers import UNet2DConditionModel from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput if TYPE_CHECKING: - from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData @dataclass @@ -31,92 +31,101 @@ class UNetKwargs: @dataclass class DenoiseInputs: - """Initial variables passed to denoise. Supposed to be unchanged. - - Variables: - orig_latents: The latent-space image to denoise. - Shape: [batch, channels, latent_height, latent_width] - - If we are inpainting, this is the initial latent image before noise has been added. - - If we are generating a new image, this should be initialized to zeros. - - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner). - scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method. - conditioning_data: Text conditionging data. - noise: Noise used for two purposes: - Shape: [1 or batch, channels, latent_height, latent_width] - 1. Used by the scheduler to noise the initial `latents` before denoising. - 2. Used to noise the `masked_latents` when inpainting. - `noise` should be None if the `latents` tensor has already been noised. - seed: The seed used to generate the noise for the denoising process. - HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the - same noise used earlier in the pipeline. This should really be handled in a clearer way. - timesteps: The timestep schedule for the denoising process. - init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so - should be populated if you want noise applied *even* if timesteps is empty. - attention_processor_cls: Class of attention processor that is used. - """ + """Initial variables passed to denoise. Supposed to be unchanged.""" + # The latent-space image to denoise. + # Shape: [batch, channels, latent_height, latent_width] + # - If we are inpainting, this is the initial latent image before noise has been added. + # - If we are generating a new image, this should be initialized to zeros. + # - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner). orig_latents: torch.Tensor + + # kwargs forwarded to the scheduler.step() method. scheduler_step_kwargs: dict[str, Any] + + # Text conditionging data. conditioning_data: TextConditioningData + + # Noise used for two purposes: + # 1. Used by the scheduler to noise the initial `latents` before denoising. + # 2. Used to noise the `masked_latents` when inpainting. + # `noise` should be None if the `latents` tensor has already been noised. + # Shape: [1 or batch, channels, latent_height, latent_width] noise: Optional[torch.Tensor] + + # The seed used to generate the noise for the denoising process. + # HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the + # same noise used earlier in the pipeline. This should really be handled in a clearer way. seed: int + + # The timestep schedule for the denoising process. timesteps: torch.Tensor + + # The first timestep in the schedule. This is used to determine the initial noise level, so + # should be populated if you want noise applied *even* if timesteps is empty. init_timestep: torch.Tensor + + # Class of attention processor that is used. attention_processor_cls: Type[Any] @dataclass class DenoiseContext: - """Context with all variables in denoise - - Variables: - inputs: Initial variables passed to denoise. Supposed to be unchanged. - scheduler: Scheduler which used to apply noise predictions. - unet: UNet model. - latents: Current state of latent-space image in denoising process. - None until `pre_denoise_loop` callback. - Shape: [batch, channels, latent_height, latent_width] - step_index: Current denoising step index. - None until `pre_step` callback. - timestep: Current denoising step timestep. - None until `pre_step` callback. - unet_kwargs: Arguments which will be passed to U Net model. - Available in `pre_unet`/`post_unet` callbacks, otherwice will be None. - step_output: SchedulerOutput class returned from step function(normally, generated by scheduler). - Supposed to be used only in `post_step` callback, otherwice can be None. - latent_model_input: Scaled version of `latents`, which will be passed to unet_kwargs initialization. - Available in events inside step(between `pre_step` and `post_stop`). - Shape: [batch, channels, latent_height, latent_width] - conditioning_mode: [TMP] Defines on which conditionings current unet call will be runned. - Available in `pre_unet`/`post_unet` callbacks, otherwice will be None. - Can be "negative", "positive" or "both" - negative_noise_pred: [TMP] Noise predictions from negative conditioning. - Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None. - Shape: [batch, channels, latent_height, latent_width] - positive_noise_pred: [TMP] Noise predictions from positive conditioning. - Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None. - Shape: [batch, channels, latent_height, latent_width] - noise_pred: Combined noise prediction from passed conditionings. - Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None. - Shape: [batch, channels, latent_height, latent_width] - extra: Dictionary for extensions to pass extra info about denoise process to other extensions. - """ + """Context with all variables in denoise""" + # Initial variables passed to denoise. Supposed to be unchanged. inputs: DenoiseInputs + # Scheduler which used to apply noise predictions. scheduler: SchedulerMixin + + # UNet model. unet: Optional[UNet2DConditionModel] = None + # Current state of latent-space image in denoising process. + # None until `pre_denoise_loop` callback. + # Shape: [batch, channels, latent_height, latent_width] latents: Optional[torch.Tensor] = None + + # Current denoising step index. + # None until `pre_step` callback. step_index: Optional[int] = None + + # Current denoising step timestep. + # None until `pre_step` callback. timestep: Optional[torch.Tensor] = None + + # Arguments which will be passed to UNet model. + # Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. unet_kwargs: Optional[UNetKwargs] = None + + # SchedulerOutput class returned from step function(normally, generated by scheduler). + # Supposed to be used only in `post_step` callback, otherwise can be None. step_output: Optional[SchedulerOutput] = None + # Scaled version of `latents`, which will be passed to unet_kwargs initialization. + # Available in events inside step(between `pre_step` and `post_stop`). + # Shape: [batch, channels, latent_height, latent_width] latent_model_input: Optional[torch.Tensor] = None - conditioning_mode: Optional[str] = None + + # [TMP] Defines on which conditionings current unet call will be runned. + # Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. + conditioning_mode: Optional[ConditioningMode] = None + + # [TMP] Noise predictions from negative conditioning. + # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Shape: [batch, channels, latent_height, latent_width] negative_noise_pred: Optional[torch.Tensor] = None + + # [TMP] Noise predictions from positive conditioning. + # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Shape: [batch, channels, latent_height, latent_width] positive_noise_pred: Optional[torch.Tensor] = None + + # Combined noise prediction from passed conditionings. + # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Shape: [batch, channels, latent_height, latent_width] noise_pred: Optional[torch.Tensor] = None + # Dictionary for extensions to pass extra info about denoise process to other extensions. extra: dict = field(default_factory=dict) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 8a52310e6f..b017454a78 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -137,6 +137,12 @@ class TextConditioningData: return isinstance(self.cond_text, SDXLConditioningInfo) def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode): + """Fills unet arguments with data from provided conditionings. + + Args: + unet_kwargs (UNetKwargs): Object which stores UNet model arguments. + conditioning_mode (ConditioningMode): Describes which conditionings should be used. + """ _, _, h, w = unet_kwargs.sample.shape device = unet_kwargs.sample.device dtype = unet_kwargs.sample.dtype @@ -187,7 +193,7 @@ class TextConditioningData: ) @staticmethod - def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int): + def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor: return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim) @classmethod @@ -195,8 +201,13 @@ class TextConditioningData: cls, cond: torch.Tensor, target_len: int, - encoder_attention_mask: Optional[torch.Tensor], - ): + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes. + + Args: + cond (torch.Tensor): Conditioning tensor which to pads by zeros. + target_len (int): To which length(tokens count) pad tensor. + """ conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) if cond.shape[1] < target_len: @@ -212,21 +223,28 @@ class TextConditioningData: dim=1, ) - if encoder_attention_mask is None: - encoder_attention_mask = conditioning_attention_mask - else: - encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask]) - - return cond, encoder_attention_mask + return cond, conditioning_attention_mask @classmethod - def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]): + def _concat_conditionings_for_batch( + cls, + conditionings: List[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Concatenate provided conditioning tensors to one batched tensor. + If tensors have different sizes then pad them by zeros and creates + encoder_attention_mask to exclude padding from attention. + + Args: + conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate. + """ encoder_attention_mask = None max_len = max([c.shape[1] for c in conditionings]) if any(c.shape[1] != max_len for c in conditionings): + encoder_attention_masks = [None] * len(conditionings) for i in range(len(conditionings)): - conditionings[i], encoder_attention_mask = cls._pad_conditioning( - conditionings[i], max_len, encoder_attention_mask + conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning( + conditionings[i], max_len ) + encoder_attention_mask = torch.cat(encoder_attention_masks) return torch.cat(conditionings), encoder_attention_mask