Add CFG Rescale option for supporting zero-terminal SNR models (#4335)

* add support for CFG rescale

* fix typo

* move rescale position and tweak docs

* move input position

* implement suggestions from github and discord

* cleanup unused code

* add back dropped FieldDescription

* fix(ui): revert unrelated UI changes

* chore(nodes): bump denoise_latents version 1.4.0 -> 1.5.0

* feat(nodes): add cfg_rescale_multiplier to metadata node

* feat(ui): add cfg rescale multiplier to linear UI

- add param to state
- update graph builders
- add UI under advanced
- add metadata handling & recall
- regen types

* chore: black

* fix(backend): make `StableDiffusionGeneratorPipeline._rescale_cfg()` staticmethod

This doesn't need access to class.

* feat(backend): add docstring for `_rescale_cfg()` method

* feat(ui): update cfg rescale mult translation string

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Damian Stewart 2023-11-30 10:55:20 +01:00 committed by GitHub
parent 693c6cf5e4
commit 0beb08686c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 249 additions and 34 deletions

View File

@ -215,7 +215,7 @@ def get_scheduler(
title="Denoise Latents", title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents", category="latents",
version="1.4.0", version="1.5.0",
) )
class DenoiseLatentsInvocation(BaseInvocation): class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images""" """Denoises noisy latents to decodable images"""
@ -273,6 +273,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
input=Input.Connection, input=Input.Connection,
ui_order=7, ui_order=7,
) )
cfg_rescale_multiplier: float = InputField(
default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
latents: Optional[LatentsField] = InputField( latents: Optional[LatentsField] = InputField(
default=None, default=None,
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
@ -332,6 +335,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
text_embeddings=c, text_embeddings=c,
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
extra=extra_conditioning_info, extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings( postprocessing_settings=PostprocessingSettings(
threshold=0.0, # threshold, threshold=0.0, # threshold,

View File

@ -127,6 +127,9 @@ class CoreMetadataInvocation(BaseInvocation):
seed: Optional[int] = InputField(default=None, description="The seed used for noise generation") seed: Optional[int] = InputField(default=None, description="The seed used for noise generation")
rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation") rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation")
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter") cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
cfg_rescale_multiplier: Optional[float] = InputField(
default=None, description=FieldDescriptions.cfg_rescale_multiplier
)
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference") steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference") scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis") seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis")

View File

@ -2,6 +2,7 @@ class FieldDescriptions:
denoising_start = "When to start denoising, expressed a percentage of total steps" denoising_start = "When to start denoising, expressed a percentage of total steps"
denoising_end = "When to stop denoising, expressed a percentage of total steps" denoising_end = "When to stop denoising, expressed a percentage of total steps"
cfg_scale = "Classifier-Free Guidance scale" cfg_scale = "Classifier-Free Guidance scale"
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
scheduler = "Scheduler to use during inference" scheduler = "Scheduler to use during inference"
positive_cond = "Positive conditioning tensor" positive_cond = "Positive conditioning tensor"
negative_cond = "Negative conditioning tensor" negative_cond = "Negative conditioning tensor"

View File

@ -607,11 +607,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if isinstance(guidance_scale, list): if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[step_index] guidance_scale = guidance_scale[step_index]
noise_pred = self.invokeai_diffuser._combine( noise_pred = self.invokeai_diffuser._combine(uc_noise_pred, c_noise_pred, guidance_scale)
uc_noise_pred, guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier
c_noise_pred, if guidance_rescale_multiplier > 0:
guidance_scale, noise_pred = self._rescale_cfg(
) noise_pred,
c_noise_pred,
guidance_rescale_multiplier,
)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args) step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
@ -634,6 +637,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return step_output return step_output
@staticmethod
def _rescale_cfg(total_noise_pred, pos_noise_pred, multiplier=0.7):
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
return x_final
def _unet_forward( def _unet_forward(
self, self,
latents, latents,

View File

@ -67,13 +67,17 @@ class IPAdapterConditioningInfo:
class ConditioningData: class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
""" """
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
""" """
guidance_scale: Union[float, List[float]]
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
"""
guidance_rescale_multiplier: float = 0
extra: Optional[ExtraConditioningInfo] = None extra: Optional[ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict) scheduler_args: dict[str, Any] = field(default_factory=dict)
""" """

View File

@ -599,6 +599,7 @@
}, },
"metadata": { "metadata": {
"cfgScale": "CFG scale", "cfgScale": "CFG scale",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"createdBy": "Created By", "createdBy": "Created By",
"fit": "Image to image fit", "fit": "Image to image fit",
"generationMode": "Generation Mode", "generationMode": "Generation Mode",
@ -1032,6 +1033,8 @@
"setType": "Set cancel type" "setType": "Set cancel type"
}, },
"cfgScale": "CFG Scale", "cfgScale": "CFG Scale",
"cfgRescaleMultiplier": "CFG Rescale Multiplier",
"cfgRescale": "CFG Rescale",
"clipSkip": "CLIP Skip", "clipSkip": "CLIP Skip",
"clipSkipWithLayerCount": "CLIP Skip {{layerCount}}", "clipSkipWithLayerCount": "CLIP Skip {{layerCount}}",
"closeViewer": "Close Viewer", "closeViewer": "Close Viewer",
@ -1470,6 +1473,12 @@
"Controls how much your prompt influences the generation process." "Controls how much your prompt influences the generation process."
] ]
}, },
"paramCFGRescaleMultiplier": {
"heading": "CFG Rescale Multiplier",
"paragraphs": [
"Rescale multiplier for CFG guidance, used for models trained using zero-terminal SNR (ztsnr). Suggested value 0.7."
]
},
"paramDenoisingStrength": { "paramDenoisingStrength": {
"heading": "Denoising Strength", "heading": "Denoising Strength",
"paragraphs": [ "paragraphs": [

View File

@ -25,6 +25,7 @@ export type Feature =
| 'lora' | 'lora'
| 'noiseUseCPU' | 'noiseUseCPU'
| 'paramCFGScale' | 'paramCFGScale'
| 'paramCFGRescaleMultiplier'
| 'paramDenoisingStrength' | 'paramDenoisingStrength'
| 'paramIterations' | 'paramIterations'
| 'paramModel' | 'paramModel'

View File

@ -29,6 +29,7 @@ const ImageMetadataActions = (props: Props) => {
recallNegativePrompt, recallNegativePrompt,
recallSeed, recallSeed,
recallCfgScale, recallCfgScale,
recallCfgRescaleMultiplier,
recallModel, recallModel,
recallScheduler, recallScheduler,
recallVaeModel, recallVaeModel,
@ -85,6 +86,10 @@ const ImageMetadataActions = (props: Props) => {
recallCfgScale(metadata?.cfg_scale); recallCfgScale(metadata?.cfg_scale);
}, [metadata?.cfg_scale, recallCfgScale]); }, [metadata?.cfg_scale, recallCfgScale]);
const handleRecallCfgRescaleMultiplier = useCallback(() => {
recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier);
}, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]);
const handleRecallStrength = useCallback(() => { const handleRecallStrength = useCallback(() => {
recallStrength(metadata?.strength); recallStrength(metadata?.strength);
}, [metadata?.strength, recallStrength]); }, [metadata?.strength, recallStrength]);
@ -243,6 +248,14 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallCfgScale} onClick={handleRecallCfgScale}
/> />
)} )}
{metadata.cfg_rescale_multiplier !== undefined &&
metadata.cfg_rescale_multiplier !== null && (
<ImageMetadataItem
label={t('metadata.cfgRescaleMultiplier')}
value={metadata.cfg_rescale_multiplier}
onClick={handleRecallCfgRescaleMultiplier}
/>
)}
{metadata.strength && ( {metadata.strength && (
<ImageMetadataItem <ImageMetadataItem
label={t('metadata.strength')} label={t('metadata.strength')}

View File

@ -51,6 +51,7 @@ export const zCoreMetadata = z
seed: z.number().int().nullish().catch(null), seed: z.number().int().nullish().catch(null),
rand_device: z.string().nullish().catch(null), rand_device: z.string().nullish().catch(null),
cfg_scale: z.number().nullish().catch(null), cfg_scale: z.number().nullish().catch(null),
cfg_rescale_multiplier: z.number().nullish().catch(null),
steps: z.number().int().nullish().catch(null), steps: z.number().int().nullish().catch(null),
scheduler: z.string().nullish().catch(null), scheduler: z.string().nullish().catch(null),
clip_skip: z.number().int().nullish().catch(null), clip_skip: z.number().int().nullish().catch(null),

View File

@ -43,6 +43,7 @@ export const buildCanvasImageToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -316,6 +317,7 @@ export const buildCanvasImageToImageGraph = (
{ {
generation_mode: 'img2img', generation_mode: 'img2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
width: !isUsingScaledDimensions width: !isUsingScaledDimensions
? width ? width
: scaledBoundingBoxDimensions.width, : scaledBoundingBoxDimensions.width,

View File

@ -45,6 +45,7 @@ export const buildCanvasSDXLImageToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -327,6 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (
{ {
generation_mode: 'img2img', generation_mode: 'img2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
width: !isUsingScaledDimensions width: !isUsingScaledDimensions
? width ? width
: scaledBoundingBoxDimensions.width, : scaledBoundingBoxDimensions.width,

View File

@ -43,6 +43,7 @@ export const buildCanvasSDXLTextToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -306,6 +307,7 @@ export const buildCanvasSDXLTextToImageGraph = (
{ {
generation_mode: 'txt2img', generation_mode: 'txt2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
width: !isUsingScaledDimensions width: !isUsingScaledDimensions
? width ? width
: scaledBoundingBoxDimensions.width, : scaledBoundingBoxDimensions.width,

View File

@ -41,6 +41,7 @@ export const buildCanvasTextToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -294,6 +295,7 @@ export const buildCanvasTextToImageGraph = (
{ {
generation_mode: 'txt2img', generation_mode: 'txt2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
width: !isUsingScaledDimensions width: !isUsingScaledDimensions
? width ? width
: scaledBoundingBoxDimensions.width, : scaledBoundingBoxDimensions.width,

View File

@ -41,6 +41,7 @@ export const buildLinearImageToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -316,6 +317,7 @@ export const buildLinearImageToImageGraph = (
{ {
generation_mode: 'img2img', generation_mode: 'img2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
height, height,
width, width,
positive_prompt: positivePrompt, positive_prompt: positivePrompt,

View File

@ -43,6 +43,7 @@ export const buildLinearSDXLImageToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -336,6 +337,7 @@ export const buildLinearSDXLImageToImageGraph = (
{ {
generation_mode: 'sdxl_img2img', generation_mode: 'sdxl_img2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
height, height,
width, width,
positive_prompt: positivePrompt, positive_prompt: positivePrompt,

View File

@ -34,6 +34,7 @@ export const buildLinearSDXLTextToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -230,6 +231,7 @@ export const buildLinearSDXLTextToImageGraph = (
{ {
generation_mode: 'sdxl_txt2img', generation_mode: 'sdxl_txt2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
height, height,
width, width,
positive_prompt: positivePrompt, positive_prompt: positivePrompt,

View File

@ -38,6 +38,7 @@ export const buildLinearTextToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
steps, steps,
width, width,
@ -84,6 +85,7 @@ export const buildLinearTextToImageGraph = (
id: DENOISE_LATENTS, id: DENOISE_LATENTS,
is_intermediate, is_intermediate,
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
scheduler, scheduler,
steps, steps,
denoising_start: 0, denoising_start: 0,
@ -239,6 +241,7 @@ export const buildLinearTextToImageGraph = (
{ {
generation_mode: 'txt2img', generation_mode: 'txt2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
height, height,
width, width,
positive_prompt: positivePrompt, positive_prompt: positivePrompt,

View File

@ -9,21 +9,41 @@ import { useTranslation } from 'react-i18next';
import { ParamCpuNoiseToggle } from 'features/parameters/components/Parameters/Noise/ParamCpuNoise'; import { ParamCpuNoiseToggle } from 'features/parameters/components/Parameters/Noise/ParamCpuNoise';
import ParamSeamless from 'features/parameters/components/Parameters/Seamless/ParamSeamless'; import ParamSeamless from 'features/parameters/components/Parameters/Seamless/ParamSeamless';
import ParamClipSkip from './ParamClipSkip'; import ParamClipSkip from './ParamClipSkip';
import ParamCFGRescaleMultiplier from './ParamCFGRescaleMultiplier';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
(state: RootState) => { (state: RootState) => {
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } = const {
state.generation; clipSkip,
model,
seamlessXAxis,
seamlessYAxis,
shouldUseCpuNoise,
cfgRescaleMultiplier,
} = state.generation;
return { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise }; return {
clipSkip,
model,
seamlessXAxis,
seamlessYAxis,
shouldUseCpuNoise,
cfgRescaleMultiplier,
};
}, },
defaultSelectorOptions defaultSelectorOptions
); );
export default function ParamAdvancedCollapse() { export default function ParamAdvancedCollapse() {
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } = const {
useAppSelector(selector); clipSkip,
model,
seamlessXAxis,
seamlessYAxis,
shouldUseCpuNoise,
cfgRescaleMultiplier,
} = useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
const activeLabel = useMemo(() => { const activeLabel = useMemo(() => {
const activeLabel: string[] = []; const activeLabel: string[] = [];
@ -46,8 +66,20 @@ export default function ParamAdvancedCollapse() {
activeLabel.push(t('parameters.seamlessY')); activeLabel.push(t('parameters.seamlessY'));
} }
if (cfgRescaleMultiplier) {
activeLabel.push(t('parameters.cfgRescale'));
}
return activeLabel.join(', '); return activeLabel.join(', ');
}, [clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise, t]); }, [
cfgRescaleMultiplier,
clipSkip,
model,
seamlessXAxis,
seamlessYAxis,
shouldUseCpuNoise,
t,
]);
return ( return (
<IAICollapse label={t('common.advanced')} activeLabel={activeLabel}> <IAICollapse label={t('common.advanced')} activeLabel={activeLabel}>
@ -61,6 +93,8 @@ export default function ParamAdvancedCollapse() {
</> </>
)} )}
<ParamCpuNoiseToggle /> <ParamCpuNoiseToggle />
<Divider />
<ParamCFGRescaleMultiplier />
</Flex> </Flex>
</IAICollapse> </IAICollapse>
); );

View File

@ -0,0 +1,60 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAISlider from 'common/components/IAISlider';
import { setCfgRescaleMultiplier } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
[stateSelector],
({ generation, hotkeys }) => {
const { cfgRescaleMultiplier } = generation;
const { shift } = hotkeys;
return {
cfgRescaleMultiplier,
shift,
};
},
defaultSelectorOptions
);
const ParamCFGRescaleMultiplier = () => {
const { cfgRescaleMultiplier, shift } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => dispatch(setCfgRescaleMultiplier(v)),
[dispatch]
);
const handleReset = useCallback(
() => dispatch(setCfgRescaleMultiplier(0)),
[dispatch]
);
return (
<IAIInformationalPopover feature="paramCFGRescaleMultiplier">
<IAISlider
label={t('parameters.cfgRescaleMultiplier')}
step={shift ? 0.01 : 0.05}
min={0}
max={0.99}
onChange={handleChange}
handleReset={handleReset}
value={cfgRescaleMultiplier}
sliderNumberInputProps={{ max: 0.99 }}
withInput
withReset
withSliderMarks
isInteger={false}
/>
</IAIInformationalPopover>
);
};
export default memo(ParamCFGRescaleMultiplier);

View File

@ -57,6 +57,7 @@ import {
modelSelected, modelSelected,
} from 'features/parameters/store/actions'; } from 'features/parameters/store/actions';
import { import {
setCfgRescaleMultiplier,
setCfgScale, setCfgScale,
setHeight, setHeight,
setHrfEnabled, setHrfEnabled,
@ -94,6 +95,7 @@ import {
isParameterStrength, isParameterStrength,
isParameterVAEModel, isParameterVAEModel,
isParameterWidth, isParameterWidth,
isParameterCFGRescaleMultiplier,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
const selector = createSelector( const selector = createSelector(
@ -282,6 +284,21 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall CFG rescale multiplier with toast
*/
const recallCfgRescaleMultiplier = useCallback(
(cfgRescaleMultiplier: unknown) => {
if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
parameterNotSetToast();
return;
}
dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/** /**
* Recall model with toast * Recall model with toast
*/ */
@ -799,6 +816,7 @@ export const useRecallParameters = () => {
const { const {
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
height, height,
model, model,
positive_prompt, positive_prompt,
@ -831,6 +849,10 @@ export const useRecallParameters = () => {
dispatch(setCfgScale(cfg_scale)); dispatch(setCfgScale(cfg_scale));
} }
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
if (isParameterModel(model)) { if (isParameterModel(model)) {
dispatch(modelSelected(model)); dispatch(modelSelected(model));
} }
@ -985,6 +1007,7 @@ export const useRecallParameters = () => {
recallSDXLNegativeStylePrompt, recallSDXLNegativeStylePrompt,
recallSeed, recallSeed,
recallCfgScale, recallCfgScale,
recallCfgRescaleMultiplier,
recallModel, recallModel,
recallScheduler, recallScheduler,
recallVaeModel, recallVaeModel,

View File

@ -24,6 +24,7 @@ import {
ParameterVAEModel, ParameterVAEModel,
ParameterWidth, ParameterWidth,
zParameterModel, zParameterModel,
ParameterCFGRescaleMultiplier,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
export interface GenerationState { export interface GenerationState {
@ -31,6 +32,7 @@ export interface GenerationState {
hrfStrength: ParameterStrength; hrfStrength: ParameterStrength;
hrfMethod: ParameterHRFMethod; hrfMethod: ParameterHRFMethod;
cfgScale: ParameterCFGScale; cfgScale: ParameterCFGScale;
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
height: ParameterHeight; height: ParameterHeight;
img2imgStrength: ParameterStrength; img2imgStrength: ParameterStrength;
infillMethod: string; infillMethod: string;
@ -76,6 +78,7 @@ export const initialGenerationState: GenerationState = {
hrfEnabled: false, hrfEnabled: false,
hrfMethod: 'ESRGAN', hrfMethod: 'ESRGAN',
cfgScale: 7.5, cfgScale: 7.5,
cfgRescaleMultiplier: 0,
height: 512, height: 512,
img2imgStrength: 0.75, img2imgStrength: 0.75,
infillMethod: 'patchmatch', infillMethod: 'patchmatch',
@ -145,9 +148,15 @@ export const generationSlice = createSlice({
state.steps state.steps
); );
}, },
setCfgScale: (state, action: PayloadAction<number>) => { setCfgScale: (state, action: PayloadAction<ParameterCFGScale>) => {
state.cfgScale = action.payload; state.cfgScale = action.payload;
}, },
setCfgRescaleMultiplier: (
state,
action: PayloadAction<ParameterCFGRescaleMultiplier>
) => {
state.cfgRescaleMultiplier = action.payload;
},
setThreshold: (state, action: PayloadAction<number>) => { setThreshold: (state, action: PayloadAction<number>) => {
state.threshold = action.payload; state.threshold = action.payload;
}, },
@ -336,6 +345,7 @@ export const {
resetParametersState, resetParametersState,
resetSeed, resetSeed,
setCfgScale, setCfgScale,
setCfgRescaleMultiplier,
setWidth, setWidth,
setHeight, setHeight,
toggleSize, toggleSize,

View File

@ -77,6 +77,17 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
zParameterCFGScale.safeParse(val).success; zParameterCFGScale.safeParse(val).success;
// #endregion // #endregion
// #region CFG Rescale Multiplier
export const zParameterCFGRescaleMultiplier = z.number().gte(0).lt(1);
export type ParameterCFGRescaleMultiplier = z.infer<
typeof zParameterCFGRescaleMultiplier
>;
export const isParameterCFGRescaleMultiplier = (
val: unknown
): val is ParameterCFGRescaleMultiplier =>
zParameterCFGRescaleMultiplier.safeParse(val).success;
// #endregion
// #region Scheduler // #region Scheduler
export const zParameterScheduler = zSchedulerField; export const zParameterScheduler = zSchedulerField;
export type ParameterScheduler = z.infer<typeof zParameterScheduler>; export type ParameterScheduler = z.infer<typeof zParameterScheduler>;

File diff suppressed because one or more lines are too long