Add CFG Rescale option for supporting zero-terminal SNR models (#4335)

* add support for CFG rescale

* fix typo

* move rescale position and tweak docs

* move input position

* implement suggestions from github and discord

* cleanup unused code

* add back dropped FieldDescription

* fix(ui): revert unrelated UI changes

* chore(nodes): bump denoise_latents version 1.4.0 -> 1.5.0

* feat(nodes): add cfg_rescale_multiplier to metadata node

* feat(ui): add cfg rescale multiplier to linear UI

- add param to state
- update graph builders
- add UI under advanced
- add metadata handling & recall
- regen types

* chore: black

* fix(backend): make `StableDiffusionGeneratorPipeline._rescale_cfg()` staticmethod

This doesn't need access to class.

* feat(backend): add docstring for `_rescale_cfg()` method

* feat(ui): update cfg rescale mult translation string

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Damian Stewart 2023-11-30 10:55:20 +01:00 committed by GitHub
parent 693c6cf5e4
commit 0beb08686c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 249 additions and 34 deletions

View File

@ -215,7 +215,7 @@ def get_scheduler(
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.4.0",
version="1.5.0",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
@ -273,6 +273,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
input=Input.Connection,
ui_order=7,
)
cfg_rescale_multiplier: float = InputField(
default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
latents: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.latents,
@ -332,6 +335,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
unconditioned_embeddings=uc,
text_embeddings=c,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0, # threshold,

View File

@ -127,6 +127,9 @@ class CoreMetadataInvocation(BaseInvocation):
seed: Optional[int] = InputField(default=None, description="The seed used for noise generation")
rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation")
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
cfg_rescale_multiplier: Optional[float] = InputField(
default=None, description=FieldDescriptions.cfg_rescale_multiplier
)
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis")

View File

@ -2,6 +2,7 @@ class FieldDescriptions:
denoising_start = "When to start denoising, expressed a percentage of total steps"
denoising_end = "When to stop denoising, expressed a percentage of total steps"
cfg_scale = "Classifier-Free Guidance scale"
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
scheduler = "Scheduler to use during inference"
positive_cond = "Positive conditioning tensor"
negative_cond = "Negative conditioning tensor"

View File

@ -607,11 +607,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[step_index]
noise_pred = self.invokeai_diffuser._combine(
uc_noise_pred,
c_noise_pred,
guidance_scale,
)
noise_pred = self.invokeai_diffuser._combine(uc_noise_pred, c_noise_pred, guidance_scale)
guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier
if guidance_rescale_multiplier > 0:
noise_pred = self._rescale_cfg(
noise_pred,
c_noise_pred,
guidance_rescale_multiplier,
)
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
@ -634,6 +637,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return step_output
@staticmethod
def _rescale_cfg(total_noise_pred, pos_noise_pred, multiplier=0.7):
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
return x_final
def _unet_forward(
self,
latents,

View File

@ -67,13 +67,17 @@ class IPAdapterConditioningInfo:
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
guidance_scale: Union[float, List[float]]
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
"""
guidance_rescale_multiplier: float = 0
extra: Optional[ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""

View File

@ -599,6 +599,7 @@
},
"metadata": {
"cfgScale": "CFG scale",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"createdBy": "Created By",
"fit": "Image to image fit",
"generationMode": "Generation Mode",
@ -1032,6 +1033,8 @@
"setType": "Set cancel type"
},
"cfgScale": "CFG Scale",
"cfgRescaleMultiplier": "CFG Rescale Multiplier",
"cfgRescale": "CFG Rescale",
"clipSkip": "CLIP Skip",
"clipSkipWithLayerCount": "CLIP Skip {{layerCount}}",
"closeViewer": "Close Viewer",
@ -1470,6 +1473,12 @@
"Controls how much your prompt influences the generation process."
]
},
"paramCFGRescaleMultiplier": {
"heading": "CFG Rescale Multiplier",
"paragraphs": [
"Rescale multiplier for CFG guidance, used for models trained using zero-terminal SNR (ztsnr). Suggested value 0.7."
]
},
"paramDenoisingStrength": {
"heading": "Denoising Strength",
"paragraphs": [

View File

@ -25,6 +25,7 @@ export type Feature =
| 'lora'
| 'noiseUseCPU'
| 'paramCFGScale'
| 'paramCFGRescaleMultiplier'
| 'paramDenoisingStrength'
| 'paramIterations'
| 'paramModel'

View File

@ -29,6 +29,7 @@ const ImageMetadataActions = (props: Props) => {
recallNegativePrompt,
recallSeed,
recallCfgScale,
recallCfgRescaleMultiplier,
recallModel,
recallScheduler,
recallVaeModel,
@ -85,6 +86,10 @@ const ImageMetadataActions = (props: Props) => {
recallCfgScale(metadata?.cfg_scale);
}, [metadata?.cfg_scale, recallCfgScale]);
const handleRecallCfgRescaleMultiplier = useCallback(() => {
recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier);
}, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]);
const handleRecallStrength = useCallback(() => {
recallStrength(metadata?.strength);
}, [metadata?.strength, recallStrength]);
@ -243,6 +248,14 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallCfgScale}
/>
)}
{metadata.cfg_rescale_multiplier !== undefined &&
metadata.cfg_rescale_multiplier !== null && (
<ImageMetadataItem
label={t('metadata.cfgRescaleMultiplier')}
value={metadata.cfg_rescale_multiplier}
onClick={handleRecallCfgRescaleMultiplier}
/>
)}
{metadata.strength && (
<ImageMetadataItem
label={t('metadata.strength')}

View File

@ -51,6 +51,7 @@ export const zCoreMetadata = z
seed: z.number().int().nullish().catch(null),
rand_device: z.string().nullish().catch(null),
cfg_scale: z.number().nullish().catch(null),
cfg_rescale_multiplier: z.number().nullish().catch(null),
steps: z.number().int().nullish().catch(null),
scheduler: z.string().nullish().catch(null),
clip_skip: z.number().int().nullish().catch(null),

View File

@ -43,6 +43,7 @@ export const buildCanvasImageToImageGraph = (
negativePrompt,
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
seed,
steps,
@ -316,6 +317,7 @@ export const buildCanvasImageToImageGraph = (
{
generation_mode: 'img2img',
cfg_scale,
cfg_rescale_multiplier,
width: !isUsingScaledDimensions
? width
: scaledBoundingBoxDimensions.width,

View File

@ -45,6 +45,7 @@ export const buildCanvasSDXLImageToImageGraph = (
negativePrompt,
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
seed,
steps,
@ -327,6 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (
{
generation_mode: 'img2img',
cfg_scale,
cfg_rescale_multiplier,
width: !isUsingScaledDimensions
? width
: scaledBoundingBoxDimensions.width,

View File

@ -43,6 +43,7 @@ export const buildCanvasSDXLTextToImageGraph = (
negativePrompt,
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
seed,
steps,
@ -306,6 +307,7 @@ export const buildCanvasSDXLTextToImageGraph = (
{
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
width: !isUsingScaledDimensions
? width
: scaledBoundingBoxDimensions.width,

View File

@ -41,6 +41,7 @@ export const buildCanvasTextToImageGraph = (
negativePrompt,
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
seed,
steps,
@ -294,6 +295,7 @@ export const buildCanvasTextToImageGraph = (
{
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
width: !isUsingScaledDimensions
? width
: scaledBoundingBoxDimensions.width,

View File

@ -41,6 +41,7 @@ export const buildLinearImageToImageGraph = (
negativePrompt,
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
seed,
steps,
@ -316,6 +317,7 @@ export const buildLinearImageToImageGraph = (
{
generation_mode: 'img2img',
cfg_scale,
cfg_rescale_multiplier,
height,
width,
positive_prompt: positivePrompt,

View File

@ -43,6 +43,7 @@ export const buildLinearSDXLImageToImageGraph = (
negativePrompt,
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
seed,
steps,
@ -336,6 +337,7 @@ export const buildLinearSDXLImageToImageGraph = (
{
generation_mode: 'sdxl_img2img',
cfg_scale,
cfg_rescale_multiplier,
height,
width,
positive_prompt: positivePrompt,

View File

@ -34,6 +34,7 @@ export const buildLinearSDXLTextToImageGraph = (
negativePrompt,
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
seed,
steps,
@ -230,6 +231,7 @@ export const buildLinearSDXLTextToImageGraph = (
{
generation_mode: 'sdxl_txt2img',
cfg_scale,
cfg_rescale_multiplier,
height,
width,
positive_prompt: positivePrompt,

View File

@ -38,6 +38,7 @@ export const buildLinearTextToImageGraph = (
negativePrompt,
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
steps,
width,
@ -84,6 +85,7 @@ export const buildLinearTextToImageGraph = (
id: DENOISE_LATENTS,
is_intermediate,
cfg_scale,
cfg_rescale_multiplier,
scheduler,
steps,
denoising_start: 0,
@ -239,6 +241,7 @@ export const buildLinearTextToImageGraph = (
{
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
height,
width,
positive_prompt: positivePrompt,

View File

@ -9,21 +9,41 @@ import { useTranslation } from 'react-i18next';
import { ParamCpuNoiseToggle } from 'features/parameters/components/Parameters/Noise/ParamCpuNoise';
import ParamSeamless from 'features/parameters/components/Parameters/Seamless/ParamSeamless';
import ParamClipSkip from './ParamClipSkip';
import ParamCFGRescaleMultiplier from './ParamCFGRescaleMultiplier';
const selector = createSelector(
stateSelector,
(state: RootState) => {
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
state.generation;
const {
clipSkip,
model,
seamlessXAxis,
seamlessYAxis,
shouldUseCpuNoise,
cfgRescaleMultiplier,
} = state.generation;
return { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise };
return {
clipSkip,
model,
seamlessXAxis,
seamlessYAxis,
shouldUseCpuNoise,
cfgRescaleMultiplier,
};
},
defaultSelectorOptions
);
export default function ParamAdvancedCollapse() {
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
useAppSelector(selector);
const {
clipSkip,
model,
seamlessXAxis,
seamlessYAxis,
shouldUseCpuNoise,
cfgRescaleMultiplier,
} = useAppSelector(selector);
const { t } = useTranslation();
const activeLabel = useMemo(() => {
const activeLabel: string[] = [];
@ -46,8 +66,20 @@ export default function ParamAdvancedCollapse() {
activeLabel.push(t('parameters.seamlessY'));
}
if (cfgRescaleMultiplier) {
activeLabel.push(t('parameters.cfgRescale'));
}
return activeLabel.join(', ');
}, [clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise, t]);
}, [
cfgRescaleMultiplier,
clipSkip,
model,
seamlessXAxis,
seamlessYAxis,
shouldUseCpuNoise,
t,
]);
return (
<IAICollapse label={t('common.advanced')} activeLabel={activeLabel}>
@ -61,6 +93,8 @@ export default function ParamAdvancedCollapse() {
</>
)}
<ParamCpuNoiseToggle />
<Divider />
<ParamCFGRescaleMultiplier />
</Flex>
</IAICollapse>
);

View File

@ -0,0 +1,60 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAISlider from 'common/components/IAISlider';
import { setCfgRescaleMultiplier } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
[stateSelector],
({ generation, hotkeys }) => {
const { cfgRescaleMultiplier } = generation;
const { shift } = hotkeys;
return {
cfgRescaleMultiplier,
shift,
};
},
defaultSelectorOptions
);
const ParamCFGRescaleMultiplier = () => {
const { cfgRescaleMultiplier, shift } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => dispatch(setCfgRescaleMultiplier(v)),
[dispatch]
);
const handleReset = useCallback(
() => dispatch(setCfgRescaleMultiplier(0)),
[dispatch]
);
return (
<IAIInformationalPopover feature="paramCFGRescaleMultiplier">
<IAISlider
label={t('parameters.cfgRescaleMultiplier')}
step={shift ? 0.01 : 0.05}
min={0}
max={0.99}
onChange={handleChange}
handleReset={handleReset}
value={cfgRescaleMultiplier}
sliderNumberInputProps={{ max: 0.99 }}
withInput
withReset
withSliderMarks
isInteger={false}
/>
</IAIInformationalPopover>
);
};
export default memo(ParamCFGRescaleMultiplier);

View File

@ -57,6 +57,7 @@ import {
modelSelected,
} from 'features/parameters/store/actions';
import {
setCfgRescaleMultiplier,
setCfgScale,
setHeight,
setHrfEnabled,
@ -94,6 +95,7 @@ import {
isParameterStrength,
isParameterVAEModel,
isParameterWidth,
isParameterCFGRescaleMultiplier,
} from 'features/parameters/types/parameterSchemas';
const selector = createSelector(
@ -282,6 +284,21 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall CFG rescale multiplier with toast
*/
const recallCfgRescaleMultiplier = useCallback(
(cfgRescaleMultiplier: unknown) => {
if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
parameterNotSetToast();
return;
}
dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall model with toast
*/
@ -799,6 +816,7 @@ export const useRecallParameters = () => {
const {
cfg_scale,
cfg_rescale_multiplier,
height,
model,
positive_prompt,
@ -831,6 +849,10 @@ export const useRecallParameters = () => {
dispatch(setCfgScale(cfg_scale));
}
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
if (isParameterModel(model)) {
dispatch(modelSelected(model));
}
@ -985,6 +1007,7 @@ export const useRecallParameters = () => {
recallSDXLNegativeStylePrompt,
recallSeed,
recallCfgScale,
recallCfgRescaleMultiplier,
recallModel,
recallScheduler,
recallVaeModel,

View File

@ -24,6 +24,7 @@ import {
ParameterVAEModel,
ParameterWidth,
zParameterModel,
ParameterCFGRescaleMultiplier,
} from 'features/parameters/types/parameterSchemas';
export interface GenerationState {
@ -31,6 +32,7 @@ export interface GenerationState {
hrfStrength: ParameterStrength;
hrfMethod: ParameterHRFMethod;
cfgScale: ParameterCFGScale;
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
height: ParameterHeight;
img2imgStrength: ParameterStrength;
infillMethod: string;
@ -76,6 +78,7 @@ export const initialGenerationState: GenerationState = {
hrfEnabled: false,
hrfMethod: 'ESRGAN',
cfgScale: 7.5,
cfgRescaleMultiplier: 0,
height: 512,
img2imgStrength: 0.75,
infillMethod: 'patchmatch',
@ -145,9 +148,15 @@ export const generationSlice = createSlice({
state.steps
);
},
setCfgScale: (state, action: PayloadAction<number>) => {
setCfgScale: (state, action: PayloadAction<ParameterCFGScale>) => {
state.cfgScale = action.payload;
},
setCfgRescaleMultiplier: (
state,
action: PayloadAction<ParameterCFGRescaleMultiplier>
) => {
state.cfgRescaleMultiplier = action.payload;
},
setThreshold: (state, action: PayloadAction<number>) => {
state.threshold = action.payload;
},
@ -336,6 +345,7 @@ export const {
resetParametersState,
resetSeed,
setCfgScale,
setCfgRescaleMultiplier,
setWidth,
setHeight,
toggleSize,

View File

@ -77,6 +77,17 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
zParameterCFGScale.safeParse(val).success;
// #endregion
// #region CFG Rescale Multiplier
export const zParameterCFGRescaleMultiplier = z.number().gte(0).lt(1);
export type ParameterCFGRescaleMultiplier = z.infer<
typeof zParameterCFGRescaleMultiplier
>;
export const isParameterCFGRescaleMultiplier = (
val: unknown
): val is ParameterCFGRescaleMultiplier =>
zParameterCFGRescaleMultiplier.safeParse(val).success;
// #endregion
// #region Scheduler
export const zParameterScheduler = zSchedulerField;
export type ParameterScheduler = z.infer<typeof zParameterScheduler>;

File diff suppressed because one or more lines are too long