mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add CFG Rescale option for supporting zero-terminal SNR models (#4335)
* add support for CFG rescale * fix typo * move rescale position and tweak docs * move input position * implement suggestions from github and discord * cleanup unused code * add back dropped FieldDescription * fix(ui): revert unrelated UI changes * chore(nodes): bump denoise_latents version 1.4.0 -> 1.5.0 * feat(nodes): add cfg_rescale_multiplier to metadata node * feat(ui): add cfg rescale multiplier to linear UI - add param to state - update graph builders - add UI under advanced - add metadata handling & recall - regen types * chore: black * fix(backend): make `StableDiffusionGeneratorPipeline._rescale_cfg()` staticmethod This doesn't need access to class. * feat(backend): add docstring for `_rescale_cfg()` method * feat(ui): update cfg rescale mult translation string --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
@ -599,6 +599,7 @@
|
||||
},
|
||||
"metadata": {
|
||||
"cfgScale": "CFG scale",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"createdBy": "Created By",
|
||||
"fit": "Image to image fit",
|
||||
"generationMode": "Generation Mode",
|
||||
@ -1032,6 +1033,8 @@
|
||||
"setType": "Set cancel type"
|
||||
},
|
||||
"cfgScale": "CFG Scale",
|
||||
"cfgRescaleMultiplier": "CFG Rescale Multiplier",
|
||||
"cfgRescale": "CFG Rescale",
|
||||
"clipSkip": "CLIP Skip",
|
||||
"clipSkipWithLayerCount": "CLIP Skip {{layerCount}}",
|
||||
"closeViewer": "Close Viewer",
|
||||
@ -1470,6 +1473,12 @@
|
||||
"Controls how much your prompt influences the generation process."
|
||||
]
|
||||
},
|
||||
"paramCFGRescaleMultiplier": {
|
||||
"heading": "CFG Rescale Multiplier",
|
||||
"paragraphs": [
|
||||
"Rescale multiplier for CFG guidance, used for models trained using zero-terminal SNR (ztsnr). Suggested value 0.7."
|
||||
]
|
||||
},
|
||||
"paramDenoisingStrength": {
|
||||
"heading": "Denoising Strength",
|
||||
"paragraphs": [
|
||||
|
@ -25,6 +25,7 @@ export type Feature =
|
||||
| 'lora'
|
||||
| 'noiseUseCPU'
|
||||
| 'paramCFGScale'
|
||||
| 'paramCFGRescaleMultiplier'
|
||||
| 'paramDenoisingStrength'
|
||||
| 'paramIterations'
|
||||
| 'paramModel'
|
||||
|
@ -29,6 +29,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
recallNegativePrompt,
|
||||
recallSeed,
|
||||
recallCfgScale,
|
||||
recallCfgRescaleMultiplier,
|
||||
recallModel,
|
||||
recallScheduler,
|
||||
recallVaeModel,
|
||||
@ -85,6 +86,10 @@ const ImageMetadataActions = (props: Props) => {
|
||||
recallCfgScale(metadata?.cfg_scale);
|
||||
}, [metadata?.cfg_scale, recallCfgScale]);
|
||||
|
||||
const handleRecallCfgRescaleMultiplier = useCallback(() => {
|
||||
recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier);
|
||||
}, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]);
|
||||
|
||||
const handleRecallStrength = useCallback(() => {
|
||||
recallStrength(metadata?.strength);
|
||||
}, [metadata?.strength, recallStrength]);
|
||||
@ -243,6 +248,14 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={handleRecallCfgScale}
|
||||
/>
|
||||
)}
|
||||
{metadata.cfg_rescale_multiplier !== undefined &&
|
||||
metadata.cfg_rescale_multiplier !== null && (
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.cfgRescaleMultiplier')}
|
||||
value={metadata.cfg_rescale_multiplier}
|
||||
onClick={handleRecallCfgRescaleMultiplier}
|
||||
/>
|
||||
)}
|
||||
{metadata.strength && (
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.strength')}
|
||||
|
@ -51,6 +51,7 @@ export const zCoreMetadata = z
|
||||
seed: z.number().int().nullish().catch(null),
|
||||
rand_device: z.string().nullish().catch(null),
|
||||
cfg_scale: z.number().nullish().catch(null),
|
||||
cfg_rescale_multiplier: z.number().nullish().catch(null),
|
||||
steps: z.number().int().nullish().catch(null),
|
||||
scheduler: z.string().nullish().catch(null),
|
||||
clip_skip: z.number().int().nullish().catch(null),
|
||||
|
@ -43,6 +43,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -316,6 +317,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
{
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
width: !isUsingScaledDimensions
|
||||
? width
|
||||
: scaledBoundingBoxDimensions.width,
|
||||
|
@ -45,6 +45,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -327,6 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
{
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
width: !isUsingScaledDimensions
|
||||
? width
|
||||
: scaledBoundingBoxDimensions.width,
|
||||
|
@ -43,6 +43,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -306,6 +307,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
{
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
width: !isUsingScaledDimensions
|
||||
? width
|
||||
: scaledBoundingBoxDimensions.width,
|
||||
|
@ -41,6 +41,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -294,6 +295,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
{
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
width: !isUsingScaledDimensions
|
||||
? width
|
||||
: scaledBoundingBoxDimensions.width,
|
||||
|
@ -41,6 +41,7 @@ export const buildLinearImageToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -316,6 +317,7 @@ export const buildLinearImageToImageGraph = (
|
||||
{
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
|
@ -43,6 +43,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -336,6 +337,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
{
|
||||
generation_mode: 'sdxl_img2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
|
@ -34,6 +34,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -230,6 +231,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
{
|
||||
generation_mode: 'sdxl_txt2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
|
@ -38,6 +38,7 @@ export const buildLinearTextToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
steps,
|
||||
width,
|
||||
@ -84,6 +85,7 @@ export const buildLinearTextToImageGraph = (
|
||||
id: DENOISE_LATENTS,
|
||||
is_intermediate,
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
steps,
|
||||
denoising_start: 0,
|
||||
@ -239,6 +241,7 @@ export const buildLinearTextToImageGraph = (
|
||||
{
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
|
@ -9,21 +9,41 @@ import { useTranslation } from 'react-i18next';
|
||||
import { ParamCpuNoiseToggle } from 'features/parameters/components/Parameters/Noise/ParamCpuNoise';
|
||||
import ParamSeamless from 'features/parameters/components/Parameters/Seamless/ParamSeamless';
|
||||
import ParamClipSkip from './ParamClipSkip';
|
||||
import ParamCFGRescaleMultiplier from './ParamCFGRescaleMultiplier';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state: RootState) => {
|
||||
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
|
||||
state.generation;
|
||||
const {
|
||||
clipSkip,
|
||||
model,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
shouldUseCpuNoise,
|
||||
cfgRescaleMultiplier,
|
||||
} = state.generation;
|
||||
|
||||
return { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise };
|
||||
return {
|
||||
clipSkip,
|
||||
model,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
shouldUseCpuNoise,
|
||||
cfgRescaleMultiplier,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
export default function ParamAdvancedCollapse() {
|
||||
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
|
||||
useAppSelector(selector);
|
||||
const {
|
||||
clipSkip,
|
||||
model,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
shouldUseCpuNoise,
|
||||
cfgRescaleMultiplier,
|
||||
} = useAppSelector(selector);
|
||||
const { t } = useTranslation();
|
||||
const activeLabel = useMemo(() => {
|
||||
const activeLabel: string[] = [];
|
||||
@ -46,8 +66,20 @@ export default function ParamAdvancedCollapse() {
|
||||
activeLabel.push(t('parameters.seamlessY'));
|
||||
}
|
||||
|
||||
if (cfgRescaleMultiplier) {
|
||||
activeLabel.push(t('parameters.cfgRescale'));
|
||||
}
|
||||
|
||||
return activeLabel.join(', ');
|
||||
}, [clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise, t]);
|
||||
}, [
|
||||
cfgRescaleMultiplier,
|
||||
clipSkip,
|
||||
model,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
shouldUseCpuNoise,
|
||||
t,
|
||||
]);
|
||||
|
||||
return (
|
||||
<IAICollapse label={t('common.advanced')} activeLabel={activeLabel}>
|
||||
@ -61,6 +93,8 @@ export default function ParamAdvancedCollapse() {
|
||||
</>
|
||||
)}
|
||||
<ParamCpuNoiseToggle />
|
||||
<Divider />
|
||||
<ParamCFGRescaleMultiplier />
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
|
@ -0,0 +1,60 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setCfgRescaleMultiplier } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ generation, hotkeys }) => {
|
||||
const { cfgRescaleMultiplier } = generation;
|
||||
const { shift } = hotkeys;
|
||||
|
||||
return {
|
||||
cfgRescaleMultiplier,
|
||||
shift,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamCFGRescaleMultiplier = () => {
|
||||
const { cfgRescaleMultiplier, shift } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => dispatch(setCfgRescaleMultiplier(v)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(
|
||||
() => dispatch(setCfgRescaleMultiplier(0)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIInformationalPopover feature="paramCFGRescaleMultiplier">
|
||||
<IAISlider
|
||||
label={t('parameters.cfgRescaleMultiplier')}
|
||||
step={shift ? 0.01 : 0.05}
|
||||
min={0}
|
||||
max={0.99}
|
||||
onChange={handleChange}
|
||||
handleReset={handleReset}
|
||||
value={cfgRescaleMultiplier}
|
||||
sliderNumberInputProps={{ max: 0.99 }}
|
||||
withInput
|
||||
withReset
|
||||
withSliderMarks
|
||||
isInteger={false}
|
||||
/>
|
||||
</IAIInformationalPopover>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamCFGRescaleMultiplier);
|
@ -57,6 +57,7 @@ import {
|
||||
modelSelected,
|
||||
} from 'features/parameters/store/actions';
|
||||
import {
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setHeight,
|
||||
setHrfEnabled,
|
||||
@ -94,6 +95,7 @@ import {
|
||||
isParameterStrength,
|
||||
isParameterVAEModel,
|
||||
isParameterWidth,
|
||||
isParameterCFGRescaleMultiplier,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -282,6 +284,21 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall CFG rescale multiplier with toast
|
||||
*/
|
||||
const recallCfgRescaleMultiplier = useCallback(
|
||||
(cfgRescaleMultiplier: unknown) => {
|
||||
if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier));
|
||||
parameterSetToast();
|
||||
},
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall model with toast
|
||||
*/
|
||||
@ -799,6 +816,7 @@ export const useRecallParameters = () => {
|
||||
|
||||
const {
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
model,
|
||||
positive_prompt,
|
||||
@ -831,6 +849,10 @@ export const useRecallParameters = () => {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
}
|
||||
|
||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||
}
|
||||
|
||||
if (isParameterModel(model)) {
|
||||
dispatch(modelSelected(model));
|
||||
}
|
||||
@ -985,6 +1007,7 @@ export const useRecallParameters = () => {
|
||||
recallSDXLNegativeStylePrompt,
|
||||
recallSeed,
|
||||
recallCfgScale,
|
||||
recallCfgRescaleMultiplier,
|
||||
recallModel,
|
||||
recallScheduler,
|
||||
recallVaeModel,
|
||||
|
@ -24,6 +24,7 @@ import {
|
||||
ParameterVAEModel,
|
||||
ParameterWidth,
|
||||
zParameterModel,
|
||||
ParameterCFGRescaleMultiplier,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
export interface GenerationState {
|
||||
@ -31,6 +32,7 @@ export interface GenerationState {
|
||||
hrfStrength: ParameterStrength;
|
||||
hrfMethod: ParameterHRFMethod;
|
||||
cfgScale: ParameterCFGScale;
|
||||
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
|
||||
height: ParameterHeight;
|
||||
img2imgStrength: ParameterStrength;
|
||||
infillMethod: string;
|
||||
@ -76,6 +78,7 @@ export const initialGenerationState: GenerationState = {
|
||||
hrfEnabled: false,
|
||||
hrfMethod: 'ESRGAN',
|
||||
cfgScale: 7.5,
|
||||
cfgRescaleMultiplier: 0,
|
||||
height: 512,
|
||||
img2imgStrength: 0.75,
|
||||
infillMethod: 'patchmatch',
|
||||
@ -145,9 +148,15 @@ export const generationSlice = createSlice({
|
||||
state.steps
|
||||
);
|
||||
},
|
||||
setCfgScale: (state, action: PayloadAction<number>) => {
|
||||
setCfgScale: (state, action: PayloadAction<ParameterCFGScale>) => {
|
||||
state.cfgScale = action.payload;
|
||||
},
|
||||
setCfgRescaleMultiplier: (
|
||||
state,
|
||||
action: PayloadAction<ParameterCFGRescaleMultiplier>
|
||||
) => {
|
||||
state.cfgRescaleMultiplier = action.payload;
|
||||
},
|
||||
setThreshold: (state, action: PayloadAction<number>) => {
|
||||
state.threshold = action.payload;
|
||||
},
|
||||
@ -336,6 +345,7 @@ export const {
|
||||
resetParametersState,
|
||||
resetSeed,
|
||||
setCfgScale,
|
||||
setCfgRescaleMultiplier,
|
||||
setWidth,
|
||||
setHeight,
|
||||
toggleSize,
|
||||
|
@ -77,6 +77,17 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
|
||||
zParameterCFGScale.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region CFG Rescale Multiplier
|
||||
export const zParameterCFGRescaleMultiplier = z.number().gte(0).lt(1);
|
||||
export type ParameterCFGRescaleMultiplier = z.infer<
|
||||
typeof zParameterCFGRescaleMultiplier
|
||||
>;
|
||||
export const isParameterCFGRescaleMultiplier = (
|
||||
val: unknown
|
||||
): val is ParameterCFGRescaleMultiplier =>
|
||||
zParameterCFGRescaleMultiplier.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Scheduler
|
||||
export const zParameterScheduler = zSchedulerField;
|
||||
export type ParameterScheduler = z.infer<typeof zParameterScheduler>;
|
||||
|
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user