Add CFG Rescale option for supporting zero-terminal SNR models (#4335)

* add support for CFG rescale

* fix typo

* move rescale position and tweak docs

* move input position

* implement suggestions from github and discord

* cleanup unused code

* add back dropped FieldDescription

* fix(ui): revert unrelated UI changes

* chore(nodes): bump denoise_latents version 1.4.0 -> 1.5.0

* feat(nodes): add cfg_rescale_multiplier to metadata node

* feat(ui): add cfg rescale multiplier to linear UI

- add param to state
- update graph builders
- add UI under advanced
- add metadata handling & recall
- regen types

* chore: black

* fix(backend): make `StableDiffusionGeneratorPipeline._rescale_cfg()` staticmethod

This doesn't need access to class.

* feat(backend): add docstring for `_rescale_cfg()` method

* feat(ui): update cfg rescale mult translation string

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Damian Stewart
2023-11-30 10:55:20 +01:00
committed by GitHub
parent 693c6cf5e4
commit 0beb08686c
23 changed files with 249 additions and 34 deletions

View File

@ -607,11 +607,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[step_index]
noise_pred = self.invokeai_diffuser._combine(
uc_noise_pred,
c_noise_pred,
guidance_scale,
)
noise_pred = self.invokeai_diffuser._combine(uc_noise_pred, c_noise_pred, guidance_scale)
guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier
if guidance_rescale_multiplier > 0:
noise_pred = self._rescale_cfg(
noise_pred,
c_noise_pred,
guidance_rescale_multiplier,
)
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
@ -634,6 +637,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return step_output
@staticmethod
def _rescale_cfg(total_noise_pred, pos_noise_pred, multiplier=0.7):
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
return x_final
def _unet_forward(
self,
latents,

View File

@ -67,13 +67,17 @@ class IPAdapterConditioningInfo:
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
guidance_scale: Union[float, List[float]]
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
"""
guidance_rescale_multiplier: float = 0
extra: Optional[ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""