mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
parent
9a1420280e
commit
1b359b55cb
@ -83,47 +83,47 @@ class DenoiseContext:
|
|||||||
unet: Optional[UNet2DConditionModel] = None
|
unet: Optional[UNet2DConditionModel] = None
|
||||||
|
|
||||||
# Current state of latent-space image in denoising process.
|
# Current state of latent-space image in denoising process.
|
||||||
# None until `pre_denoise_loop` callback.
|
# None until `PRE_DENOISE_LOOP` callback.
|
||||||
# Shape: [batch, channels, latent_height, latent_width]
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
latents: Optional[torch.Tensor] = None
|
latents: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Current denoising step index.
|
# Current denoising step index.
|
||||||
# None until `pre_step` callback.
|
# None until `PRE_STEP` callback.
|
||||||
step_index: Optional[int] = None
|
step_index: Optional[int] = None
|
||||||
|
|
||||||
# Current denoising step timestep.
|
# Current denoising step timestep.
|
||||||
# None until `pre_step` callback.
|
# None until `PRE_STEP` callback.
|
||||||
timestep: Optional[torch.Tensor] = None
|
timestep: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Arguments which will be passed to UNet model.
|
# Arguments which will be passed to UNet model.
|
||||||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
|
||||||
unet_kwargs: Optional[UNetKwargs] = None
|
unet_kwargs: Optional[UNetKwargs] = None
|
||||||
|
|
||||||
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
||||||
# Supposed to be used only in `post_step` callback, otherwise can be None.
|
# Supposed to be used only in `POST_STEP` callback, otherwise can be None.
|
||||||
step_output: Optional[SchedulerOutput] = None
|
step_output: Optional[SchedulerOutput] = None
|
||||||
|
|
||||||
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
||||||
# Available in events inside step(between `pre_step` and `post_stop`).
|
# Available in events inside step(between `PRE_STEP` and `POST_STEP`).
|
||||||
# Shape: [batch, channels, latent_height, latent_width]
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
latent_model_input: Optional[torch.Tensor] = None
|
latent_model_input: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# [TMP] Defines on which conditionings current unet call will be runned.
|
# [TMP] Defines on which conditionings current unet call will be runned.
|
||||||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
|
||||||
conditioning_mode: Optional[ConditioningMode] = None
|
conditioning_mode: Optional[ConditioningMode] = None
|
||||||
|
|
||||||
# [TMP] Noise predictions from negative conditioning.
|
# [TMP] Noise predictions from negative conditioning.
|
||||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||||
# Shape: [batch, channels, latent_height, latent_width]
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
negative_noise_pred: Optional[torch.Tensor] = None
|
negative_noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# [TMP] Noise predictions from positive conditioning.
|
# [TMP] Noise predictions from positive conditioning.
|
||||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||||
# Shape: [batch, channels, latent_height, latent_width]
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
positive_noise_pred: Optional[torch.Tensor] = None
|
positive_noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Combined noise prediction from passed conditionings.
|
# Combined noise prediction from passed conditionings.
|
||||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||||
# Shape: [batch, channels, latent_height, latent_width]
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
noise_pred: Optional[torch.Tensor] = None
|
noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||||||
class RescaleCFGExt(ExtensionBase):
|
class RescaleCFGExt(ExtensionBase):
|
||||||
def __init__(self, rescale_multiplier: float):
|
def __init__(self, rescale_multiplier: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rescale_multiplier = rescale_multiplier
|
self._rescale_multiplier = rescale_multiplier
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
|
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
|
||||||
@ -28,9 +28,9 @@ class RescaleCFGExt(ExtensionBase):
|
|||||||
|
|
||||||
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
|
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
|
||||||
def rescale_noise_pred(self, ctx: DenoiseContext):
|
def rescale_noise_pred(self, ctx: DenoiseContext):
|
||||||
if self.rescale_multiplier > 0:
|
if self._rescale_multiplier > 0:
|
||||||
ctx.noise_pred = self._rescale_cfg(
|
ctx.noise_pred = self._rescale_cfg(
|
||||||
ctx.noise_pred,
|
ctx.noise_pred,
|
||||||
ctx.positive_noise_pred,
|
ctx.positive_noise_pred,
|
||||||
self.rescale_multiplier,
|
self._rescale_multiplier,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user