Suggested changes

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2024-07-22 22:17:29 +03:00
parent 9a1420280e
commit 1b359b55cb
2 changed files with 13 additions and 13 deletions

View File

@ -83,47 +83,47 @@ class DenoiseContext:
unet: Optional[UNet2DConditionModel] = None unet: Optional[UNet2DConditionModel] = None
# Current state of latent-space image in denoising process. # Current state of latent-space image in denoising process.
# None until `pre_denoise_loop` callback. # None until `PRE_DENOISE_LOOP` callback.
# Shape: [batch, channels, latent_height, latent_width] # Shape: [batch, channels, latent_height, latent_width]
latents: Optional[torch.Tensor] = None latents: Optional[torch.Tensor] = None
# Current denoising step index. # Current denoising step index.
# None until `pre_step` callback. # None until `PRE_STEP` callback.
step_index: Optional[int] = None step_index: Optional[int] = None
# Current denoising step timestep. # Current denoising step timestep.
# None until `pre_step` callback. # None until `PRE_STEP` callback.
timestep: Optional[torch.Tensor] = None timestep: Optional[torch.Tensor] = None
# Arguments which will be passed to UNet model. # Arguments which will be passed to UNet model.
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. # Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
unet_kwargs: Optional[UNetKwargs] = None unet_kwargs: Optional[UNetKwargs] = None
# SchedulerOutput class returned from step function(normally, generated by scheduler). # SchedulerOutput class returned from step function(normally, generated by scheduler).
# Supposed to be used only in `post_step` callback, otherwise can be None. # Supposed to be used only in `POST_STEP` callback, otherwise can be None.
step_output: Optional[SchedulerOutput] = None step_output: Optional[SchedulerOutput] = None
# Scaled version of `latents`, which will be passed to unet_kwargs initialization. # Scaled version of `latents`, which will be passed to unet_kwargs initialization.
# Available in events inside step(between `pre_step` and `post_stop`). # Available in events inside step(between `PRE_STEP` and `POST_STEP`).
# Shape: [batch, channels, latent_height, latent_width] # Shape: [batch, channels, latent_height, latent_width]
latent_model_input: Optional[torch.Tensor] = None latent_model_input: Optional[torch.Tensor] = None
# [TMP] Defines on which conditionings current unet call will be runned. # [TMP] Defines on which conditionings current unet call will be runned.
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. # Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
conditioning_mode: Optional[ConditioningMode] = None conditioning_mode: Optional[ConditioningMode] = None
# [TMP] Noise predictions from negative conditioning. # [TMP] Noise predictions from negative conditioning.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. # Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width] # Shape: [batch, channels, latent_height, latent_width]
negative_noise_pred: Optional[torch.Tensor] = None negative_noise_pred: Optional[torch.Tensor] = None
# [TMP] Noise predictions from positive conditioning. # [TMP] Noise predictions from positive conditioning.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. # Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width] # Shape: [batch, channels, latent_height, latent_width]
positive_noise_pred: Optional[torch.Tensor] = None positive_noise_pred: Optional[torch.Tensor] = None
# Combined noise prediction from passed conditionings. # Combined noise prediction from passed conditionings.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. # Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width] # Shape: [batch, channels, latent_height, latent_width]
noise_pred: Optional[torch.Tensor] = None noise_pred: Optional[torch.Tensor] = None

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
class RescaleCFGExt(ExtensionBase): class RescaleCFGExt(ExtensionBase):
def __init__(self, rescale_multiplier: float): def __init__(self, rescale_multiplier: float):
super().__init__() super().__init__()
self.rescale_multiplier = rescale_multiplier self._rescale_multiplier = rescale_multiplier
@staticmethod @staticmethod
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7): def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
@ -28,9 +28,9 @@ class RescaleCFGExt(ExtensionBase):
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS) @callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
def rescale_noise_pred(self, ctx: DenoiseContext): def rescale_noise_pred(self, ctx: DenoiseContext):
if self.rescale_multiplier > 0: if self._rescale_multiplier > 0:
ctx.noise_pred = self._rescale_cfg( ctx.noise_pred = self._rescale_cfg(
ctx.noise_pred, ctx.noise_pred,
ctx.positive_noise_pred, ctx.positive_noise_pred,
self.rescale_multiplier, self._rescale_multiplier,
) )