mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add CFG Rescale option for supporting zero-terminal SNR models (#4335)
* add support for CFG rescale * fix typo * move rescale position and tweak docs * move input position * implement suggestions from github and discord * cleanup unused code * add back dropped FieldDescription * fix(ui): revert unrelated UI changes * chore(nodes): bump denoise_latents version 1.4.0 -> 1.5.0 * feat(nodes): add cfg_rescale_multiplier to metadata node * feat(ui): add cfg rescale multiplier to linear UI - add param to state - update graph builders - add UI under advanced - add metadata handling & recall - regen types * chore: black * fix(backend): make `StableDiffusionGeneratorPipeline._rescale_cfg()` staticmethod This doesn't need access to class. * feat(backend): add docstring for `_rescale_cfg()` method * feat(ui): update cfg rescale mult translation string --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
693c6cf5e4
commit
0beb08686c
@ -215,7 +215,7 @@ def get_scheduler(
|
|||||||
title="Denoise Latents",
|
title="Denoise Latents",
|
||||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.4.0",
|
version="1.5.0",
|
||||||
)
|
)
|
||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
@ -273,6 +273,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
ui_order=7,
|
ui_order=7,
|
||||||
)
|
)
|
||||||
|
cfg_rescale_multiplier: float = InputField(
|
||||||
|
default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||||
|
)
|
||||||
latents: Optional[LatentsField] = InputField(
|
latents: Optional[LatentsField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
@ -332,6 +335,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
unconditioned_embeddings=uc,
|
unconditioned_embeddings=uc,
|
||||||
text_embeddings=c,
|
text_embeddings=c,
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
extra=extra_conditioning_info,
|
extra=extra_conditioning_info,
|
||||||
postprocessing_settings=PostprocessingSettings(
|
postprocessing_settings=PostprocessingSettings(
|
||||||
threshold=0.0, # threshold,
|
threshold=0.0, # threshold,
|
||||||
|
@ -127,6 +127,9 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
seed: Optional[int] = InputField(default=None, description="The seed used for noise generation")
|
seed: Optional[int] = InputField(default=None, description="The seed used for noise generation")
|
||||||
rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation")
|
rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation")
|
||||||
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
|
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
|
||||||
|
cfg_rescale_multiplier: Optional[float] = InputField(
|
||||||
|
default=None, description=FieldDescriptions.cfg_rescale_multiplier
|
||||||
|
)
|
||||||
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
|
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
|
||||||
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
|
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
|
||||||
seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis")
|
seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis")
|
||||||
|
@ -2,6 +2,7 @@ class FieldDescriptions:
|
|||||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||||
cfg_scale = "Classifier-Free Guidance scale"
|
cfg_scale = "Classifier-Free Guidance scale"
|
||||||
|
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
|
||||||
scheduler = "Scheduler to use during inference"
|
scheduler = "Scheduler to use during inference"
|
||||||
positive_cond = "Positive conditioning tensor"
|
positive_cond = "Positive conditioning tensor"
|
||||||
negative_cond = "Negative conditioning tensor"
|
negative_cond = "Negative conditioning tensor"
|
||||||
|
@ -607,11 +607,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if isinstance(guidance_scale, list):
|
if isinstance(guidance_scale, list):
|
||||||
guidance_scale = guidance_scale[step_index]
|
guidance_scale = guidance_scale[step_index]
|
||||||
|
|
||||||
noise_pred = self.invokeai_diffuser._combine(
|
noise_pred = self.invokeai_diffuser._combine(uc_noise_pred, c_noise_pred, guidance_scale)
|
||||||
uc_noise_pred,
|
guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier
|
||||||
c_noise_pred,
|
if guidance_rescale_multiplier > 0:
|
||||||
guidance_scale,
|
noise_pred = self._rescale_cfg(
|
||||||
)
|
noise_pred,
|
||||||
|
c_noise_pred,
|
||||||
|
guidance_rescale_multiplier,
|
||||||
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
||||||
@ -634,6 +637,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
return step_output
|
return step_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rescale_cfg(total_noise_pred, pos_noise_pred, multiplier=0.7):
|
||||||
|
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
|
||||||
|
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||||
|
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||||
|
|
||||||
|
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
|
||||||
|
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
|
||||||
|
return x_final
|
||||||
|
|
||||||
def _unet_forward(
|
def _unet_forward(
|
||||||
self,
|
self,
|
||||||
latents,
|
latents,
|
||||||
|
@ -67,13 +67,17 @@ class IPAdapterConditioningInfo:
|
|||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
unconditioned_embeddings: BasicConditioningInfo
|
||||||
text_embeddings: BasicConditioningInfo
|
text_embeddings: BasicConditioningInfo
|
||||||
guidance_scale: Union[float, List[float]]
|
|
||||||
"""
|
"""
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
"""
|
"""
|
||||||
|
guidance_scale: Union[float, List[float]]
|
||||||
|
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
||||||
|
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
||||||
|
"""
|
||||||
|
guidance_rescale_multiplier: float = 0
|
||||||
extra: Optional[ExtraConditioningInfo] = None
|
extra: Optional[ExtraConditioningInfo] = None
|
||||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||||
"""
|
"""
|
||||||
|
@ -599,6 +599,7 @@
|
|||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"cfgScale": "CFG scale",
|
"cfgScale": "CFG scale",
|
||||||
|
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||||
"createdBy": "Created By",
|
"createdBy": "Created By",
|
||||||
"fit": "Image to image fit",
|
"fit": "Image to image fit",
|
||||||
"generationMode": "Generation Mode",
|
"generationMode": "Generation Mode",
|
||||||
@ -1032,6 +1033,8 @@
|
|||||||
"setType": "Set cancel type"
|
"setType": "Set cancel type"
|
||||||
},
|
},
|
||||||
"cfgScale": "CFG Scale",
|
"cfgScale": "CFG Scale",
|
||||||
|
"cfgRescaleMultiplier": "CFG Rescale Multiplier",
|
||||||
|
"cfgRescale": "CFG Rescale",
|
||||||
"clipSkip": "CLIP Skip",
|
"clipSkip": "CLIP Skip",
|
||||||
"clipSkipWithLayerCount": "CLIP Skip {{layerCount}}",
|
"clipSkipWithLayerCount": "CLIP Skip {{layerCount}}",
|
||||||
"closeViewer": "Close Viewer",
|
"closeViewer": "Close Viewer",
|
||||||
@ -1470,6 +1473,12 @@
|
|||||||
"Controls how much your prompt influences the generation process."
|
"Controls how much your prompt influences the generation process."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"paramCFGRescaleMultiplier": {
|
||||||
|
"heading": "CFG Rescale Multiplier",
|
||||||
|
"paragraphs": [
|
||||||
|
"Rescale multiplier for CFG guidance, used for models trained using zero-terminal SNR (ztsnr). Suggested value 0.7."
|
||||||
|
]
|
||||||
|
},
|
||||||
"paramDenoisingStrength": {
|
"paramDenoisingStrength": {
|
||||||
"heading": "Denoising Strength",
|
"heading": "Denoising Strength",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
|
@ -25,6 +25,7 @@ export type Feature =
|
|||||||
| 'lora'
|
| 'lora'
|
||||||
| 'noiseUseCPU'
|
| 'noiseUseCPU'
|
||||||
| 'paramCFGScale'
|
| 'paramCFGScale'
|
||||||
|
| 'paramCFGRescaleMultiplier'
|
||||||
| 'paramDenoisingStrength'
|
| 'paramDenoisingStrength'
|
||||||
| 'paramIterations'
|
| 'paramIterations'
|
||||||
| 'paramModel'
|
| 'paramModel'
|
||||||
|
@ -29,6 +29,7 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
recallNegativePrompt,
|
recallNegativePrompt,
|
||||||
recallSeed,
|
recallSeed,
|
||||||
recallCfgScale,
|
recallCfgScale,
|
||||||
|
recallCfgRescaleMultiplier,
|
||||||
recallModel,
|
recallModel,
|
||||||
recallScheduler,
|
recallScheduler,
|
||||||
recallVaeModel,
|
recallVaeModel,
|
||||||
@ -85,6 +86,10 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
recallCfgScale(metadata?.cfg_scale);
|
recallCfgScale(metadata?.cfg_scale);
|
||||||
}, [metadata?.cfg_scale, recallCfgScale]);
|
}, [metadata?.cfg_scale, recallCfgScale]);
|
||||||
|
|
||||||
|
const handleRecallCfgRescaleMultiplier = useCallback(() => {
|
||||||
|
recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier);
|
||||||
|
}, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]);
|
||||||
|
|
||||||
const handleRecallStrength = useCallback(() => {
|
const handleRecallStrength = useCallback(() => {
|
||||||
recallStrength(metadata?.strength);
|
recallStrength(metadata?.strength);
|
||||||
}, [metadata?.strength, recallStrength]);
|
}, [metadata?.strength, recallStrength]);
|
||||||
@ -243,6 +248,14 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
onClick={handleRecallCfgScale}
|
onClick={handleRecallCfgScale}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
{metadata.cfg_rescale_multiplier !== undefined &&
|
||||||
|
metadata.cfg_rescale_multiplier !== null && (
|
||||||
|
<ImageMetadataItem
|
||||||
|
label={t('metadata.cfgRescaleMultiplier')}
|
||||||
|
value={metadata.cfg_rescale_multiplier}
|
||||||
|
onClick={handleRecallCfgRescaleMultiplier}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
{metadata.strength && (
|
{metadata.strength && (
|
||||||
<ImageMetadataItem
|
<ImageMetadataItem
|
||||||
label={t('metadata.strength')}
|
label={t('metadata.strength')}
|
||||||
|
@ -51,6 +51,7 @@ export const zCoreMetadata = z
|
|||||||
seed: z.number().int().nullish().catch(null),
|
seed: z.number().int().nullish().catch(null),
|
||||||
rand_device: z.string().nullish().catch(null),
|
rand_device: z.string().nullish().catch(null),
|
||||||
cfg_scale: z.number().nullish().catch(null),
|
cfg_scale: z.number().nullish().catch(null),
|
||||||
|
cfg_rescale_multiplier: z.number().nullish().catch(null),
|
||||||
steps: z.number().int().nullish().catch(null),
|
steps: z.number().int().nullish().catch(null),
|
||||||
scheduler: z.string().nullish().catch(null),
|
scheduler: z.string().nullish().catch(null),
|
||||||
clip_skip: z.number().int().nullish().catch(null),
|
clip_skip: z.number().int().nullish().catch(null),
|
||||||
|
@ -43,6 +43,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -316,6 +317,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'img2img',
|
generation_mode: 'img2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
width: !isUsingScaledDimensions
|
width: !isUsingScaledDimensions
|
||||||
? width
|
? width
|
||||||
: scaledBoundingBoxDimensions.width,
|
: scaledBoundingBoxDimensions.width,
|
||||||
|
@ -45,6 +45,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -327,6 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'img2img',
|
generation_mode: 'img2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
width: !isUsingScaledDimensions
|
width: !isUsingScaledDimensions
|
||||||
? width
|
? width
|
||||||
: scaledBoundingBoxDimensions.width,
|
: scaledBoundingBoxDimensions.width,
|
||||||
|
@ -43,6 +43,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -306,6 +307,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
width: !isUsingScaledDimensions
|
width: !isUsingScaledDimensions
|
||||||
? width
|
? width
|
||||||
: scaledBoundingBoxDimensions.width,
|
: scaledBoundingBoxDimensions.width,
|
||||||
|
@ -41,6 +41,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -294,6 +295,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
width: !isUsingScaledDimensions
|
width: !isUsingScaledDimensions
|
||||||
? width
|
? width
|
||||||
: scaledBoundingBoxDimensions.width,
|
: scaledBoundingBoxDimensions.width,
|
||||||
|
@ -41,6 +41,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -316,6 +317,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'img2img',
|
generation_mode: 'img2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
|
@ -43,6 +43,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -336,6 +337,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'sdxl_img2img',
|
generation_mode: 'sdxl_img2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
|
@ -34,6 +34,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -230,6 +231,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'sdxl_txt2img',
|
generation_mode: 'sdxl_txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
|
@ -38,6 +38,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
width,
|
width,
|
||||||
@ -84,6 +85,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
id: DENOISE_LATENTS,
|
id: DENOISE_LATENTS,
|
||||||
is_intermediate,
|
is_intermediate,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
denoising_start: 0,
|
denoising_start: 0,
|
||||||
@ -239,6 +241,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
|
@ -9,21 +9,41 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import { ParamCpuNoiseToggle } from 'features/parameters/components/Parameters/Noise/ParamCpuNoise';
|
import { ParamCpuNoiseToggle } from 'features/parameters/components/Parameters/Noise/ParamCpuNoise';
|
||||||
import ParamSeamless from 'features/parameters/components/Parameters/Seamless/ParamSeamless';
|
import ParamSeamless from 'features/parameters/components/Parameters/Seamless/ParamSeamless';
|
||||||
import ParamClipSkip from './ParamClipSkip';
|
import ParamClipSkip from './ParamClipSkip';
|
||||||
|
import ParamCFGRescaleMultiplier from './ParamCFGRescaleMultiplier';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
(state: RootState) => {
|
(state: RootState) => {
|
||||||
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
|
const {
|
||||||
state.generation;
|
clipSkip,
|
||||||
|
model,
|
||||||
|
seamlessXAxis,
|
||||||
|
seamlessYAxis,
|
||||||
|
shouldUseCpuNoise,
|
||||||
|
cfgRescaleMultiplier,
|
||||||
|
} = state.generation;
|
||||||
|
|
||||||
return { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise };
|
return {
|
||||||
|
clipSkip,
|
||||||
|
model,
|
||||||
|
seamlessXAxis,
|
||||||
|
seamlessYAxis,
|
||||||
|
shouldUseCpuNoise,
|
||||||
|
cfgRescaleMultiplier,
|
||||||
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
export default function ParamAdvancedCollapse() {
|
export default function ParamAdvancedCollapse() {
|
||||||
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
|
const {
|
||||||
useAppSelector(selector);
|
clipSkip,
|
||||||
|
model,
|
||||||
|
seamlessXAxis,
|
||||||
|
seamlessYAxis,
|
||||||
|
shouldUseCpuNoise,
|
||||||
|
cfgRescaleMultiplier,
|
||||||
|
} = useAppSelector(selector);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const activeLabel = useMemo(() => {
|
const activeLabel = useMemo(() => {
|
||||||
const activeLabel: string[] = [];
|
const activeLabel: string[] = [];
|
||||||
@ -46,8 +66,20 @@ export default function ParamAdvancedCollapse() {
|
|||||||
activeLabel.push(t('parameters.seamlessY'));
|
activeLabel.push(t('parameters.seamlessY'));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cfgRescaleMultiplier) {
|
||||||
|
activeLabel.push(t('parameters.cfgRescale'));
|
||||||
|
}
|
||||||
|
|
||||||
return activeLabel.join(', ');
|
return activeLabel.join(', ');
|
||||||
}, [clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise, t]);
|
}, [
|
||||||
|
cfgRescaleMultiplier,
|
||||||
|
clipSkip,
|
||||||
|
model,
|
||||||
|
seamlessXAxis,
|
||||||
|
seamlessYAxis,
|
||||||
|
shouldUseCpuNoise,
|
||||||
|
t,
|
||||||
|
]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICollapse label={t('common.advanced')} activeLabel={activeLabel}>
|
<IAICollapse label={t('common.advanced')} activeLabel={activeLabel}>
|
||||||
@ -61,6 +93,8 @@ export default function ParamAdvancedCollapse() {
|
|||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
<ParamCpuNoiseToggle />
|
<ParamCpuNoiseToggle />
|
||||||
|
<Divider />
|
||||||
|
<ParamCFGRescaleMultiplier />
|
||||||
</Flex>
|
</Flex>
|
||||||
</IAICollapse>
|
</IAICollapse>
|
||||||
);
|
);
|
||||||
|
@ -0,0 +1,60 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { setCfgRescaleMultiplier } from 'features/parameters/store/generationSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[stateSelector],
|
||||||
|
({ generation, hotkeys }) => {
|
||||||
|
const { cfgRescaleMultiplier } = generation;
|
||||||
|
const { shift } = hotkeys;
|
||||||
|
|
||||||
|
return {
|
||||||
|
cfgRescaleMultiplier,
|
||||||
|
shift,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamCFGRescaleMultiplier = () => {
|
||||||
|
const { cfgRescaleMultiplier, shift } = useAppSelector(selector);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: number) => dispatch(setCfgRescaleMultiplier(v)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleReset = useCallback(
|
||||||
|
() => dispatch(setCfgRescaleMultiplier(0)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIInformationalPopover feature="paramCFGRescaleMultiplier">
|
||||||
|
<IAISlider
|
||||||
|
label={t('parameters.cfgRescaleMultiplier')}
|
||||||
|
step={shift ? 0.01 : 0.05}
|
||||||
|
min={0}
|
||||||
|
max={0.99}
|
||||||
|
onChange={handleChange}
|
||||||
|
handleReset={handleReset}
|
||||||
|
value={cfgRescaleMultiplier}
|
||||||
|
sliderNumberInputProps={{ max: 0.99 }}
|
||||||
|
withInput
|
||||||
|
withReset
|
||||||
|
withSliderMarks
|
||||||
|
isInteger={false}
|
||||||
|
/>
|
||||||
|
</IAIInformationalPopover>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamCFGRescaleMultiplier);
|
@ -57,6 +57,7 @@ import {
|
|||||||
modelSelected,
|
modelSelected,
|
||||||
} from 'features/parameters/store/actions';
|
} from 'features/parameters/store/actions';
|
||||||
import {
|
import {
|
||||||
|
setCfgRescaleMultiplier,
|
||||||
setCfgScale,
|
setCfgScale,
|
||||||
setHeight,
|
setHeight,
|
||||||
setHrfEnabled,
|
setHrfEnabled,
|
||||||
@ -94,6 +95,7 @@ import {
|
|||||||
isParameterStrength,
|
isParameterStrength,
|
||||||
isParameterVAEModel,
|
isParameterVAEModel,
|
||||||
isParameterWidth,
|
isParameterWidth,
|
||||||
|
isParameterCFGRescaleMultiplier,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -282,6 +284,21 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Recall CFG rescale multiplier with toast
|
||||||
|
*/
|
||||||
|
const recallCfgRescaleMultiplier = useCallback(
|
||||||
|
(cfgRescaleMultiplier: unknown) => {
|
||||||
|
if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
|
||||||
|
parameterNotSetToast();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier));
|
||||||
|
parameterSetToast();
|
||||||
|
},
|
||||||
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
|
);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Recall model with toast
|
* Recall model with toast
|
||||||
*/
|
*/
|
||||||
@ -799,6 +816,7 @@ export const useRecallParameters = () => {
|
|||||||
|
|
||||||
const {
|
const {
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
height,
|
height,
|
||||||
model,
|
model,
|
||||||
positive_prompt,
|
positive_prompt,
|
||||||
@ -831,6 +849,10 @@ export const useRecallParameters = () => {
|
|||||||
dispatch(setCfgScale(cfg_scale));
|
dispatch(setCfgScale(cfg_scale));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||||
|
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||||
|
}
|
||||||
|
|
||||||
if (isParameterModel(model)) {
|
if (isParameterModel(model)) {
|
||||||
dispatch(modelSelected(model));
|
dispatch(modelSelected(model));
|
||||||
}
|
}
|
||||||
@ -985,6 +1007,7 @@ export const useRecallParameters = () => {
|
|||||||
recallSDXLNegativeStylePrompt,
|
recallSDXLNegativeStylePrompt,
|
||||||
recallSeed,
|
recallSeed,
|
||||||
recallCfgScale,
|
recallCfgScale,
|
||||||
|
recallCfgRescaleMultiplier,
|
||||||
recallModel,
|
recallModel,
|
||||||
recallScheduler,
|
recallScheduler,
|
||||||
recallVaeModel,
|
recallVaeModel,
|
||||||
|
@ -24,6 +24,7 @@ import {
|
|||||||
ParameterVAEModel,
|
ParameterVAEModel,
|
||||||
ParameterWidth,
|
ParameterWidth,
|
||||||
zParameterModel,
|
zParameterModel,
|
||||||
|
ParameterCFGRescaleMultiplier,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
|
||||||
export interface GenerationState {
|
export interface GenerationState {
|
||||||
@ -31,6 +32,7 @@ export interface GenerationState {
|
|||||||
hrfStrength: ParameterStrength;
|
hrfStrength: ParameterStrength;
|
||||||
hrfMethod: ParameterHRFMethod;
|
hrfMethod: ParameterHRFMethod;
|
||||||
cfgScale: ParameterCFGScale;
|
cfgScale: ParameterCFGScale;
|
||||||
|
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
|
||||||
height: ParameterHeight;
|
height: ParameterHeight;
|
||||||
img2imgStrength: ParameterStrength;
|
img2imgStrength: ParameterStrength;
|
||||||
infillMethod: string;
|
infillMethod: string;
|
||||||
@ -76,6 +78,7 @@ export const initialGenerationState: GenerationState = {
|
|||||||
hrfEnabled: false,
|
hrfEnabled: false,
|
||||||
hrfMethod: 'ESRGAN',
|
hrfMethod: 'ESRGAN',
|
||||||
cfgScale: 7.5,
|
cfgScale: 7.5,
|
||||||
|
cfgRescaleMultiplier: 0,
|
||||||
height: 512,
|
height: 512,
|
||||||
img2imgStrength: 0.75,
|
img2imgStrength: 0.75,
|
||||||
infillMethod: 'patchmatch',
|
infillMethod: 'patchmatch',
|
||||||
@ -145,9 +148,15 @@ export const generationSlice = createSlice({
|
|||||||
state.steps
|
state.steps
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
setCfgScale: (state, action: PayloadAction<number>) => {
|
setCfgScale: (state, action: PayloadAction<ParameterCFGScale>) => {
|
||||||
state.cfgScale = action.payload;
|
state.cfgScale = action.payload;
|
||||||
},
|
},
|
||||||
|
setCfgRescaleMultiplier: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<ParameterCFGRescaleMultiplier>
|
||||||
|
) => {
|
||||||
|
state.cfgRescaleMultiplier = action.payload;
|
||||||
|
},
|
||||||
setThreshold: (state, action: PayloadAction<number>) => {
|
setThreshold: (state, action: PayloadAction<number>) => {
|
||||||
state.threshold = action.payload;
|
state.threshold = action.payload;
|
||||||
},
|
},
|
||||||
@ -336,6 +345,7 @@ export const {
|
|||||||
resetParametersState,
|
resetParametersState,
|
||||||
resetSeed,
|
resetSeed,
|
||||||
setCfgScale,
|
setCfgScale,
|
||||||
|
setCfgRescaleMultiplier,
|
||||||
setWidth,
|
setWidth,
|
||||||
setHeight,
|
setHeight,
|
||||||
toggleSize,
|
toggleSize,
|
||||||
|
@ -77,6 +77,17 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
|
|||||||
zParameterCFGScale.safeParse(val).success;
|
zParameterCFGScale.safeParse(val).success;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region CFG Rescale Multiplier
|
||||||
|
export const zParameterCFGRescaleMultiplier = z.number().gte(0).lt(1);
|
||||||
|
export type ParameterCFGRescaleMultiplier = z.infer<
|
||||||
|
typeof zParameterCFGRescaleMultiplier
|
||||||
|
>;
|
||||||
|
export const isParameterCFGRescaleMultiplier = (
|
||||||
|
val: unknown
|
||||||
|
): val is ParameterCFGRescaleMultiplier =>
|
||||||
|
zParameterCFGRescaleMultiplier.safeParse(val).success;
|
||||||
|
// #endregion
|
||||||
|
|
||||||
// #region Scheduler
|
// #region Scheduler
|
||||||
export const zParameterScheduler = zSchedulerField;
|
export const zParameterScheduler = zSchedulerField;
|
||||||
export type ParameterScheduler = z.infer<typeof zParameterScheduler>;
|
export type ParameterScheduler = z.infer<typeof zParameterScheduler>;
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user