mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add CFG Rescale option for supporting zero-terminal SNR models (#4335)
* add support for CFG rescale * fix typo * move rescale position and tweak docs * move input position * implement suggestions from github and discord * cleanup unused code * add back dropped FieldDescription * fix(ui): revert unrelated UI changes * chore(nodes): bump denoise_latents version 1.4.0 -> 1.5.0 * feat(nodes): add cfg_rescale_multiplier to metadata node * feat(ui): add cfg rescale multiplier to linear UI - add param to state - update graph builders - add UI under advanced - add metadata handling & recall - regen types * chore: black * fix(backend): make `StableDiffusionGeneratorPipeline._rescale_cfg()` staticmethod This doesn't need access to class. * feat(backend): add docstring for `_rescale_cfg()` method * feat(ui): update cfg rescale mult translation string --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
693c6cf5e4
commit
0beb08686c
@ -215,7 +215,7 @@ def get_scheduler(
|
||||
title="Denoise Latents",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.4.0",
|
||||
version="1.5.0",
|
||||
)
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
@ -273,6 +273,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
ui_order=7,
|
||||
)
|
||||
cfg_rescale_multiplier: float = InputField(
|
||||
default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
@ -332,6 +335,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unconditioned_embeddings=uc,
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0, # threshold,
|
||||
|
@ -127,6 +127,9 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
seed: Optional[int] = InputField(default=None, description="The seed used for noise generation")
|
||||
rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation")
|
||||
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
|
||||
cfg_rescale_multiplier: Optional[float] = InputField(
|
||||
default=None, description=FieldDescriptions.cfg_rescale_multiplier
|
||||
)
|
||||
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
|
||||
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
|
||||
seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis")
|
||||
|
@ -2,6 +2,7 @@ class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
cfg_scale = "Classifier-Free Guidance scale"
|
||||
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
|
||||
scheduler = "Scheduler to use during inference"
|
||||
positive_cond = "Positive conditioning tensor"
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
|
@ -607,11 +607,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = guidance_scale[step_index]
|
||||
|
||||
noise_pred = self.invokeai_diffuser._combine(
|
||||
uc_noise_pred,
|
||||
c_noise_pred,
|
||||
guidance_scale,
|
||||
)
|
||||
noise_pred = self.invokeai_diffuser._combine(uc_noise_pred, c_noise_pred, guidance_scale)
|
||||
guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier
|
||||
if guidance_rescale_multiplier > 0:
|
||||
noise_pred = self._rescale_cfg(
|
||||
noise_pred,
|
||||
c_noise_pred,
|
||||
guidance_rescale_multiplier,
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
||||
@ -634,6 +637,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
return step_output
|
||||
|
||||
@staticmethod
|
||||
def _rescale_cfg(total_noise_pred, pos_noise_pred, multiplier=0.7):
|
||||
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
|
||||
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||
|
||||
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
|
||||
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
|
||||
return x_final
|
||||
|
||||
def _unet_forward(
|
||||
self,
|
||||
latents,
|
||||
|
@ -67,13 +67,17 @@ class IPAdapterConditioningInfo:
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
text_embeddings: BasicConditioningInfo
|
||||
guidance_scale: Union[float, List[float]]
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
guidance_scale: Union[float, List[float]]
|
||||
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
||||
"""
|
||||
guidance_rescale_multiplier: float = 0
|
||||
extra: Optional[ExtraConditioningInfo] = None
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""
|
||||
|
@ -599,6 +599,7 @@
|
||||
},
|
||||
"metadata": {
|
||||
"cfgScale": "CFG scale",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"createdBy": "Created By",
|
||||
"fit": "Image to image fit",
|
||||
"generationMode": "Generation Mode",
|
||||
@ -1032,6 +1033,8 @@
|
||||
"setType": "Set cancel type"
|
||||
},
|
||||
"cfgScale": "CFG Scale",
|
||||
"cfgRescaleMultiplier": "CFG Rescale Multiplier",
|
||||
"cfgRescale": "CFG Rescale",
|
||||
"clipSkip": "CLIP Skip",
|
||||
"clipSkipWithLayerCount": "CLIP Skip {{layerCount}}",
|
||||
"closeViewer": "Close Viewer",
|
||||
@ -1470,6 +1473,12 @@
|
||||
"Controls how much your prompt influences the generation process."
|
||||
]
|
||||
},
|
||||
"paramCFGRescaleMultiplier": {
|
||||
"heading": "CFG Rescale Multiplier",
|
||||
"paragraphs": [
|
||||
"Rescale multiplier for CFG guidance, used for models trained using zero-terminal SNR (ztsnr). Suggested value 0.7."
|
||||
]
|
||||
},
|
||||
"paramDenoisingStrength": {
|
||||
"heading": "Denoising Strength",
|
||||
"paragraphs": [
|
||||
|
@ -25,6 +25,7 @@ export type Feature =
|
||||
| 'lora'
|
||||
| 'noiseUseCPU'
|
||||
| 'paramCFGScale'
|
||||
| 'paramCFGRescaleMultiplier'
|
||||
| 'paramDenoisingStrength'
|
||||
| 'paramIterations'
|
||||
| 'paramModel'
|
||||
|
@ -29,6 +29,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
recallNegativePrompt,
|
||||
recallSeed,
|
||||
recallCfgScale,
|
||||
recallCfgRescaleMultiplier,
|
||||
recallModel,
|
||||
recallScheduler,
|
||||
recallVaeModel,
|
||||
@ -85,6 +86,10 @@ const ImageMetadataActions = (props: Props) => {
|
||||
recallCfgScale(metadata?.cfg_scale);
|
||||
}, [metadata?.cfg_scale, recallCfgScale]);
|
||||
|
||||
const handleRecallCfgRescaleMultiplier = useCallback(() => {
|
||||
recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier);
|
||||
}, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]);
|
||||
|
||||
const handleRecallStrength = useCallback(() => {
|
||||
recallStrength(metadata?.strength);
|
||||
}, [metadata?.strength, recallStrength]);
|
||||
@ -243,6 +248,14 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={handleRecallCfgScale}
|
||||
/>
|
||||
)}
|
||||
{metadata.cfg_rescale_multiplier !== undefined &&
|
||||
metadata.cfg_rescale_multiplier !== null && (
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.cfgRescaleMultiplier')}
|
||||
value={metadata.cfg_rescale_multiplier}
|
||||
onClick={handleRecallCfgRescaleMultiplier}
|
||||
/>
|
||||
)}
|
||||
{metadata.strength && (
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.strength')}
|
||||
|
@ -51,6 +51,7 @@ export const zCoreMetadata = z
|
||||
seed: z.number().int().nullish().catch(null),
|
||||
rand_device: z.string().nullish().catch(null),
|
||||
cfg_scale: z.number().nullish().catch(null),
|
||||
cfg_rescale_multiplier: z.number().nullish().catch(null),
|
||||
steps: z.number().int().nullish().catch(null),
|
||||
scheduler: z.string().nullish().catch(null),
|
||||
clip_skip: z.number().int().nullish().catch(null),
|
||||
|
@ -43,6 +43,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -316,6 +317,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
{
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
width: !isUsingScaledDimensions
|
||||
? width
|
||||
: scaledBoundingBoxDimensions.width,
|
||||
|
@ -45,6 +45,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -327,6 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
{
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
width: !isUsingScaledDimensions
|
||||
? width
|
||||
: scaledBoundingBoxDimensions.width,
|
||||
|
@ -43,6 +43,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -306,6 +307,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
{
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
width: !isUsingScaledDimensions
|
||||
? width
|
||||
: scaledBoundingBoxDimensions.width,
|
||||
|
@ -41,6 +41,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -294,6 +295,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
{
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
width: !isUsingScaledDimensions
|
||||
? width
|
||||
: scaledBoundingBoxDimensions.width,
|
||||
|
@ -41,6 +41,7 @@ export const buildLinearImageToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -316,6 +317,7 @@ export const buildLinearImageToImageGraph = (
|
||||
{
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
|
@ -43,6 +43,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -336,6 +337,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
{
|
||||
generation_mode: 'sdxl_img2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
|
@ -34,6 +34,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@ -230,6 +231,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
{
|
||||
generation_mode: 'sdxl_txt2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
|
@ -38,6 +38,7 @@ export const buildLinearTextToImageGraph = (
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
steps,
|
||||
width,
|
||||
@ -84,6 +85,7 @@ export const buildLinearTextToImageGraph = (
|
||||
id: DENOISE_LATENTS,
|
||||
is_intermediate,
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
steps,
|
||||
denoising_start: 0,
|
||||
@ -239,6 +241,7 @@ export const buildLinearTextToImageGraph = (
|
||||
{
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
|
@ -9,21 +9,41 @@ import { useTranslation } from 'react-i18next';
|
||||
import { ParamCpuNoiseToggle } from 'features/parameters/components/Parameters/Noise/ParamCpuNoise';
|
||||
import ParamSeamless from 'features/parameters/components/Parameters/Seamless/ParamSeamless';
|
||||
import ParamClipSkip from './ParamClipSkip';
|
||||
import ParamCFGRescaleMultiplier from './ParamCFGRescaleMultiplier';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state: RootState) => {
|
||||
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
|
||||
state.generation;
|
||||
const {
|
||||
clipSkip,
|
||||
model,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
shouldUseCpuNoise,
|
||||
cfgRescaleMultiplier,
|
||||
} = state.generation;
|
||||
|
||||
return { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise };
|
||||
return {
|
||||
clipSkip,
|
||||
model,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
shouldUseCpuNoise,
|
||||
cfgRescaleMultiplier,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
export default function ParamAdvancedCollapse() {
|
||||
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
|
||||
useAppSelector(selector);
|
||||
const {
|
||||
clipSkip,
|
||||
model,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
shouldUseCpuNoise,
|
||||
cfgRescaleMultiplier,
|
||||
} = useAppSelector(selector);
|
||||
const { t } = useTranslation();
|
||||
const activeLabel = useMemo(() => {
|
||||
const activeLabel: string[] = [];
|
||||
@ -46,8 +66,20 @@ export default function ParamAdvancedCollapse() {
|
||||
activeLabel.push(t('parameters.seamlessY'));
|
||||
}
|
||||
|
||||
if (cfgRescaleMultiplier) {
|
||||
activeLabel.push(t('parameters.cfgRescale'));
|
||||
}
|
||||
|
||||
return activeLabel.join(', ');
|
||||
}, [clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise, t]);
|
||||
}, [
|
||||
cfgRescaleMultiplier,
|
||||
clipSkip,
|
||||
model,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
shouldUseCpuNoise,
|
||||
t,
|
||||
]);
|
||||
|
||||
return (
|
||||
<IAICollapse label={t('common.advanced')} activeLabel={activeLabel}>
|
||||
@ -61,6 +93,8 @@ export default function ParamAdvancedCollapse() {
|
||||
</>
|
||||
)}
|
||||
<ParamCpuNoiseToggle />
|
||||
<Divider />
|
||||
<ParamCFGRescaleMultiplier />
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
|
@ -0,0 +1,60 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setCfgRescaleMultiplier } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ generation, hotkeys }) => {
|
||||
const { cfgRescaleMultiplier } = generation;
|
||||
const { shift } = hotkeys;
|
||||
|
||||
return {
|
||||
cfgRescaleMultiplier,
|
||||
shift,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamCFGRescaleMultiplier = () => {
|
||||
const { cfgRescaleMultiplier, shift } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => dispatch(setCfgRescaleMultiplier(v)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(
|
||||
() => dispatch(setCfgRescaleMultiplier(0)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIInformationalPopover feature="paramCFGRescaleMultiplier">
|
||||
<IAISlider
|
||||
label={t('parameters.cfgRescaleMultiplier')}
|
||||
step={shift ? 0.01 : 0.05}
|
||||
min={0}
|
||||
max={0.99}
|
||||
onChange={handleChange}
|
||||
handleReset={handleReset}
|
||||
value={cfgRescaleMultiplier}
|
||||
sliderNumberInputProps={{ max: 0.99 }}
|
||||
withInput
|
||||
withReset
|
||||
withSliderMarks
|
||||
isInteger={false}
|
||||
/>
|
||||
</IAIInformationalPopover>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamCFGRescaleMultiplier);
|
@ -57,6 +57,7 @@ import {
|
||||
modelSelected,
|
||||
} from 'features/parameters/store/actions';
|
||||
import {
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setHeight,
|
||||
setHrfEnabled,
|
||||
@ -94,6 +95,7 @@ import {
|
||||
isParameterStrength,
|
||||
isParameterVAEModel,
|
||||
isParameterWidth,
|
||||
isParameterCFGRescaleMultiplier,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -282,6 +284,21 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall CFG rescale multiplier with toast
|
||||
*/
|
||||
const recallCfgRescaleMultiplier = useCallback(
|
||||
(cfgRescaleMultiplier: unknown) => {
|
||||
if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier));
|
||||
parameterSetToast();
|
||||
},
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall model with toast
|
||||
*/
|
||||
@ -799,6 +816,7 @@ export const useRecallParameters = () => {
|
||||
|
||||
const {
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
model,
|
||||
positive_prompt,
|
||||
@ -831,6 +849,10 @@ export const useRecallParameters = () => {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
}
|
||||
|
||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||
}
|
||||
|
||||
if (isParameterModel(model)) {
|
||||
dispatch(modelSelected(model));
|
||||
}
|
||||
@ -985,6 +1007,7 @@ export const useRecallParameters = () => {
|
||||
recallSDXLNegativeStylePrompt,
|
||||
recallSeed,
|
||||
recallCfgScale,
|
||||
recallCfgRescaleMultiplier,
|
||||
recallModel,
|
||||
recallScheduler,
|
||||
recallVaeModel,
|
||||
|
@ -24,6 +24,7 @@ import {
|
||||
ParameterVAEModel,
|
||||
ParameterWidth,
|
||||
zParameterModel,
|
||||
ParameterCFGRescaleMultiplier,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
export interface GenerationState {
|
||||
@ -31,6 +32,7 @@ export interface GenerationState {
|
||||
hrfStrength: ParameterStrength;
|
||||
hrfMethod: ParameterHRFMethod;
|
||||
cfgScale: ParameterCFGScale;
|
||||
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
|
||||
height: ParameterHeight;
|
||||
img2imgStrength: ParameterStrength;
|
||||
infillMethod: string;
|
||||
@ -76,6 +78,7 @@ export const initialGenerationState: GenerationState = {
|
||||
hrfEnabled: false,
|
||||
hrfMethod: 'ESRGAN',
|
||||
cfgScale: 7.5,
|
||||
cfgRescaleMultiplier: 0,
|
||||
height: 512,
|
||||
img2imgStrength: 0.75,
|
||||
infillMethod: 'patchmatch',
|
||||
@ -145,9 +148,15 @@ export const generationSlice = createSlice({
|
||||
state.steps
|
||||
);
|
||||
},
|
||||
setCfgScale: (state, action: PayloadAction<number>) => {
|
||||
setCfgScale: (state, action: PayloadAction<ParameterCFGScale>) => {
|
||||
state.cfgScale = action.payload;
|
||||
},
|
||||
setCfgRescaleMultiplier: (
|
||||
state,
|
||||
action: PayloadAction<ParameterCFGRescaleMultiplier>
|
||||
) => {
|
||||
state.cfgRescaleMultiplier = action.payload;
|
||||
},
|
||||
setThreshold: (state, action: PayloadAction<number>) => {
|
||||
state.threshold = action.payload;
|
||||
},
|
||||
@ -336,6 +345,7 @@ export const {
|
||||
resetParametersState,
|
||||
resetSeed,
|
||||
setCfgScale,
|
||||
setCfgRescaleMultiplier,
|
||||
setWidth,
|
||||
setHeight,
|
||||
toggleSize,
|
||||
|
@ -77,6 +77,17 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
|
||||
zParameterCFGScale.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region CFG Rescale Multiplier
|
||||
export const zParameterCFGRescaleMultiplier = z.number().gte(0).lt(1);
|
||||
export type ParameterCFGRescaleMultiplier = z.infer<
|
||||
typeof zParameterCFGRescaleMultiplier
|
||||
>;
|
||||
export const isParameterCFGRescaleMultiplier = (
|
||||
val: unknown
|
||||
): val is ParameterCFGRescaleMultiplier =>
|
||||
zParameterCFGRescaleMultiplier.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Scheduler
|
||||
export const zParameterScheduler = zSchedulerField;
|
||||
export type ParameterScheduler = z.infer<typeof zParameterScheduler>;
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user