feat (ui, generation): High Resolution Fix- added automatic resolution toggle and replaced latent upscale with two improved methods (#4905)

* working

* added selector for method

* refactoring graph

* added ersgan method

* fixing yarn build

* add tooltips

* a conjuction

* rephrase

* removed manual sliders, set HRF to calculate dimensions automatically to match 512^2 pixels

* working

* working

* working

* fixed tooltip

* add hrf to use all parameters

* adding hrf method to parameters

* working on parameter recall

* working on parameter recall

* cleaning

* fix(ui): fix unnecessary casts in addHrfToGraph

* chore(ui): use camelCase in addHrfToGraph

* fix(ui): do not add HRF metadata unless HRF is added to graph

* fix(ui): remove unused imports in addHrfToGraph

* feat(ui): do not hide HRF params when disabled, only disable them

* fix(ui): remove unused vars in addHrfToGraph

* feat(ui): default HRF str to 0.35, method ESRGAN

* fix(ui): use isValidBoolean to check hrfEnabled param

* fix(nodes): update CoreMetadataInvocation fields for HRF

* feat(ui): set hrf strength default to 0.45

* fix(ui): set default hrf strength in configSlice

* feat(ui): use translations for HRF features

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Paul Curry 2023-11-10 16:11:46 -08:00 committed by GitHub
parent 9ccfa34e04
commit 1c7ea57492
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 616 additions and 392 deletions

View File

@ -160,13 +160,14 @@ class CoreMetadataInvocation(BaseInvocation):
) )
# High resolution fix metadata. # High resolution fix metadata.
hrf_width: Optional[int] = InputField( hrf_enabled: Optional[float] = InputField(
default=None, default=None,
description="The high resolution fix height and width multipler.", description="Whether or not high resolution fix was enabled.",
) )
hrf_height: Optional[int] = InputField( # TODO: should this be stricter or do we just let the UI handle it?
hrf_method: Optional[str] = InputField(
default=None, default=None,
description="The high resolution fix height and width multipler.", description="The high resolution fix upscale method.",
) )
hrf_strength: Optional[float] = InputField( hrf_strength: Optional[float] = InputField(
default=None, default=None,

View File

@ -221,6 +221,19 @@
"resetIPAdapterImage": "Reset IP Adapter Image", "resetIPAdapterImage": "Reset IP Adapter Image",
"ipAdapterImageFallback": "No IP Adapter Image Selected" "ipAdapterImageFallback": "No IP Adapter Image Selected"
}, },
"hrf": {
"hrf": "High Resolution Fix",
"enableHrf": "Enable High Resolution Fix",
"enableHrfTooltip": "Generate with a lower initial resolution, upscale to the base resolution, then run Image-to-Image.",
"upscaleMethod": "Upscale Method",
"hrfStrength": "High Resolution Fix Strength",
"strengthTooltip": "Lower values result in fewer details, which may reduce potential artifacts.",
"metadata": {
"enabled": "High Resolution Fix Enabled",
"strength": "High Resolution Fix Strength",
"method": "High Resolution Fix Method"
}
},
"embedding": { "embedding": {
"addEmbedding": "Add Embedding", "addEmbedding": "Add Embedding",
"incompatibleModel": "Incompatible base model:", "incompatibleModel": "Incompatible base model:",
@ -1258,15 +1271,11 @@
}, },
"compositingBlur": { "compositingBlur": {
"heading": "Blur", "heading": "Blur",
"paragraphs": [ "paragraphs": ["The blur radius of the mask."]
"The blur radius of the mask."
]
}, },
"compositingBlurMethod": { "compositingBlurMethod": {
"heading": "Blur Method", "heading": "Blur Method",
"paragraphs": [ "paragraphs": ["The method of blur applied to the masked area."]
"The method of blur applied to the masked area."
]
}, },
"compositingCoherencePass": { "compositingCoherencePass": {
"heading": "Coherence Pass", "heading": "Coherence Pass",
@ -1276,9 +1285,7 @@
}, },
"compositingCoherenceMode": { "compositingCoherenceMode": {
"heading": "Mode", "heading": "Mode",
"paragraphs": [ "paragraphs": ["The mode of the Coherence Pass."]
"The mode of the Coherence Pass."
]
}, },
"compositingCoherenceSteps": { "compositingCoherenceSteps": {
"heading": "Steps", "heading": "Steps",
@ -1296,9 +1303,7 @@
}, },
"compositingMaskAdjustments": { "compositingMaskAdjustments": {
"heading": "Mask Adjustments", "heading": "Mask Adjustments",
"paragraphs": [ "paragraphs": ["Adjust the mask."]
"Adjust the mask."
]
}, },
"controlNetBeginEnd": { "controlNetBeginEnd": {
"heading": "Begin / End Step Percentage", "heading": "Begin / End Step Percentage",
@ -1356,9 +1361,7 @@
}, },
"infillMethod": { "infillMethod": {
"heading": "Infill Method", "heading": "Infill Method",
"paragraphs": [ "paragraphs": ["Method to infill the selected area."]
"Method to infill the selected area."
]
}, },
"lora": { "lora": {
"heading": "LoRA Weight", "heading": "LoRA Weight",

View File

@ -35,6 +35,9 @@ const ImageMetadataActions = (props: Props) => {
recallWidth, recallWidth,
recallHeight, recallHeight,
recallStrength, recallStrength,
recallHrfEnabled,
recallHrfStrength,
recallHrfMethod,
recallLoRA, recallLoRA,
recallControlNet, recallControlNet,
recallIPAdapter, recallIPAdapter,
@ -81,6 +84,18 @@ const ImageMetadataActions = (props: Props) => {
recallStrength(metadata?.strength); recallStrength(metadata?.strength);
}, [metadata?.strength, recallStrength]); }, [metadata?.strength, recallStrength]);
const handleRecallHrfEnabled = useCallback(() => {
recallHrfEnabled(metadata?.hrf_enabled);
}, [metadata?.hrf_enabled, recallHrfEnabled]);
const handleRecallHrfStrength = useCallback(() => {
recallHrfStrength(metadata?.hrf_strength);
}, [metadata?.hrf_strength, recallHrfStrength]);
const handleRecallHrfMethod = useCallback(() => {
recallHrfMethod(metadata?.hrf_method);
}, [metadata?.hrf_method, recallHrfMethod]);
const handleRecallLoRA = useCallback( const handleRecallLoRA = useCallback(
(lora: LoRAMetadataItem) => { (lora: LoRAMetadataItem) => {
recallLoRA(lora); recallLoRA(lora);
@ -225,6 +240,27 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallStrength} onClick={handleRecallStrength}
/> />
)} )}
{metadata.hrf_enabled && (
<ImageMetadataItem
label={t('hrf.metadata.enabled')}
value={metadata.hrf_enabled}
onClick={handleRecallHrfEnabled}
/>
)}
{metadata.hrf_enabled && metadata.hrf_strength && (
<ImageMetadataItem
label={t('hrf.metadata.strength')}
value={metadata.hrf_strength}
onClick={handleRecallHrfStrength}
/>
)}
{metadata.hrf_enabled && metadata.hrf_method && (
<ImageMetadataItem
label={t('hrf.metadata.method')}
value={metadata.hrf_method}
onClick={handleRecallHrfMethod}
/>
)}
{metadata.loras && {metadata.loras &&
metadata.loras.map((lora, index) => { metadata.loras.map((lora, index) => {
if (isValidLoRAModel(lora.lora)) { if (isValidLoRAModel(lora.lora)) {

View File

@ -1424,6 +1424,9 @@ export const zCoreMetadata = z
loras: z.array(zLoRAMetadataItem).nullish().catch(null), loras: z.array(zLoRAMetadataItem).nullish().catch(null),
vae: zVaeModelField.nullish().catch(null), vae: zVaeModelField.nullish().catch(null),
strength: z.number().nullish().catch(null), strength: z.number().nullish().catch(null),
hrf_enabled: z.boolean().nullish().catch(null),
hrf_strength: z.number().nullish().catch(null),
hrf_method: z.string().nullish().catch(null),
init_image: z.string().nullish().catch(null), init_image: z.string().nullish().catch(null),
positive_style_prompt: z.string().nullish().catch(null), positive_style_prompt: z.string().nullish().catch(null),
negative_style_prompt: z.string().nullish().catch(null), negative_style_prompt: z.string().nullish().catch(null),

View File

@ -1,22 +1,26 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { import {
DenoiseLatentsInvocation, DenoiseLatentsInvocation,
ESRGANInvocation,
Edge, Edge,
LatentsToImageInvocation, LatentsToImageInvocation,
NoiseInvocation, NoiseInvocation,
ResizeLatentsInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { import {
DENOISE_LATENTS, DENOISE_LATENTS,
DENOISE_LATENTS_HRF, DENOISE_LATENTS_HRF,
ESRGAN_HRF,
IMAGE_TO_LATENTS_HRF,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
LATENTS_TO_IMAGE_HRF, LATENTS_TO_IMAGE_HRF_HR,
LATENTS_TO_IMAGE_HRF_LR,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
NOISE, NOISE,
NOISE_HRF, NOISE_HRF,
RESCALE_LATENTS, RESIZE_HRF,
VAE_LOADER, VAE_LOADER,
} from './constants'; } from './constants';
import { upsertMetadata } from './metadata'; import { upsertMetadata } from './metadata';
@ -56,6 +60,52 @@ function copyConnectionsToDenoiseLatentsHrf(graph: NonNullableGraph): void {
graph.edges = graph.edges.concat(newEdges); graph.edges = graph.edges.concat(newEdges);
} }
/**
* Calculates the new resolution for high-resolution features (HRF) based on base model type.
* Adjusts the width and height to maintain the aspect ratio and constrains them by the model's dimension limits,
* rounding down to the nearest multiple of 8.
*
* @param {string} baseModel The base model type, which determines the base dimension used in calculations.
* @param {number} width The current width to be adjusted for HRF.
* @param {number} height The current height to be adjusted for HRF.
* @return {{newWidth: number, newHeight: number}} The new width and height, adjusted and rounded as needed.
*/
function calculateHrfRes(
baseModel: string,
width: number,
height: number
): { newWidth: number; newHeight: number } {
const aspect = width / height;
let dimension;
if (baseModel == 'sdxl') {
dimension = 1024;
} else {
dimension = 512;
}
const minDimension = Math.floor(dimension * 0.5);
const modelArea = dimension * dimension; // Assuming square images for model_area
let initWidth;
let initHeight;
if (aspect > 1.0) {
initHeight = Math.max(minDimension, Math.sqrt(modelArea / aspect));
initWidth = initHeight * aspect;
} else {
initWidth = Math.max(minDimension, Math.sqrt(modelArea * aspect));
initHeight = initWidth / aspect;
}
// Cap initial height and width to final height and width.
initWidth = Math.min(width, initWidth);
initHeight = Math.min(height, initHeight);
const newWidth = roundToMultiple(Math.floor(initWidth), 8);
const newHeight = roundToMultiple(Math.floor(initHeight), 8);
return { newWidth, newHeight };
}
// Adds the high-res fix feature to the given graph. // Adds the high-res fix feature to the given graph.
export const addHrfToGraph = ( export const addHrfToGraph = (
state: RootState, state: RootState,
@ -71,151 +121,61 @@ export const addHrfToGraph = (
} }
const log = logger('txt2img'); const log = logger('txt2img');
const { vae, hrfWidth, hrfHeight, hrfStrength } = state.generation; const { vae, hrfStrength, hrfEnabled, hrfMethod } = state.generation;
const isAutoVae = !vae; const isAutoVae = !vae;
const width = state.generation.width;
const height = state.generation.height;
const baseModel = state.generation.model
? state.generation.model.base_model
: 'sd1';
const { newWidth: hrfWidth, newHeight: hrfHeight } = calculateHrfRes(
baseModel,
width,
height
);
// Pre-existing (original) graph nodes. // Pre-existing (original) graph nodes.
const originalDenoiseLatentsNode = graph.nodes[DENOISE_LATENTS] as const originalDenoiseLatentsNode = graph.nodes[DENOISE_LATENTS] as
| DenoiseLatentsInvocation | DenoiseLatentsInvocation
| undefined; | undefined;
const originalNoiseNode = graph.nodes[NOISE] as NoiseInvocation | undefined; const originalNoiseNode = graph.nodes[NOISE] as NoiseInvocation | undefined;
// Original latents to image should pick this up.
const originalLatentsToImageNode = graph.nodes[LATENTS_TO_IMAGE] as const originalLatentsToImageNode = graph.nodes[LATENTS_TO_IMAGE] as
| LatentsToImageInvocation | LatentsToImageInvocation
| undefined; | undefined;
// Check if originalDenoiseLatentsNode is undefined and log an error
if (!originalDenoiseLatentsNode) { if (!originalDenoiseLatentsNode) {
log.error('originalDenoiseLatentsNode is undefined'); log.error('originalDenoiseLatentsNode is undefined');
return; return;
} }
// Check if originalNoiseNode is undefined and log an error
if (!originalNoiseNode) { if (!originalNoiseNode) {
log.error('originalNoiseNode is undefined'); log.error('originalNoiseNode is undefined');
return; return;
} }
// Check if originalLatentsToImageNode is undefined and log an error
if (!originalLatentsToImageNode) { if (!originalLatentsToImageNode) {
log.error('originalLatentsToImageNode is undefined'); log.error('originalLatentsToImageNode is undefined');
return; return;
} }
// Change height and width of original noise node to initial resolution. // Change height and width of original noise node to initial resolution.
if (originalNoiseNode) { if (originalNoiseNode) {
originalNoiseNode.width = hrfWidth; originalNoiseNode.width = hrfWidth;
originalNoiseNode.height = hrfHeight; originalNoiseNode.height = hrfHeight;
} }
// Define new nodes. // Define new nodes and their connections, roughly in order of operations.
// Denoise latents node to be run on upscaled latents. graph.nodes[LATENTS_TO_IMAGE_HRF_LR] = {
const denoiseLatentsHrfNode: DenoiseLatentsInvocation = {
type: 'denoise_latents',
id: DENOISE_LATENTS_HRF,
is_intermediate: originalDenoiseLatentsNode?.is_intermediate,
cfg_scale: originalDenoiseLatentsNode?.cfg_scale,
scheduler: originalDenoiseLatentsNode?.scheduler,
steps: originalDenoiseLatentsNode?.steps,
denoising_start: 1 - hrfStrength,
denoising_end: 1,
};
// New base resolution noise node.
const hrfNoiseNode: NoiseInvocation = {
type: 'noise',
id: NOISE_HRF,
seed: originalNoiseNode?.seed,
use_cpu: originalNoiseNode?.use_cpu,
is_intermediate: originalNoiseNode?.is_intermediate,
};
const rescaleLatentsNode: ResizeLatentsInvocation = {
id: RESCALE_LATENTS,
type: 'lresize',
width: state.generation.width,
height: state.generation.height,
};
// New node to convert latents to image.
const latentsToImageHrfNode: LatentsToImageInvocation | undefined =
originalLatentsToImageNode
? {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE_HRF, id: LATENTS_TO_IMAGE_HRF_LR,
fp32: originalLatentsToImageNode?.fp32, fp32: originalLatentsToImageNode?.fp32,
is_intermediate: originalLatentsToImageNode?.is_intermediate, is_intermediate: true,
} };
: undefined;
// Add new nodes to graph.
graph.nodes[LATENTS_TO_IMAGE_HRF] =
latentsToImageHrfNode as LatentsToImageInvocation;
graph.nodes[DENOISE_LATENTS_HRF] =
denoiseLatentsHrfNode as DenoiseLatentsInvocation;
graph.nodes[NOISE_HRF] = hrfNoiseNode as NoiseInvocation;
graph.nodes[RESCALE_LATENTS] = rescaleLatentsNode as ResizeLatentsInvocation;
// Connect nodes.
graph.edges.push( graph.edges.push(
{ {
// Set up rescale latents.
source: { source: {
node_id: DENOISE_LATENTS, node_id: DENOISE_LATENTS,
field: 'latents', field: 'latents',
}, },
destination: { destination: {
node_id: RESCALE_LATENTS, node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'latents',
},
},
// Set up new noise node
{
source: {
node_id: RESCALE_LATENTS,
field: 'height',
},
destination: {
node_id: NOISE_HRF,
field: 'height',
},
},
{
source: {
node_id: RESCALE_LATENTS,
field: 'width',
},
destination: {
node_id: NOISE_HRF,
field: 'width',
},
},
// Set up new denoise node.
{
source: {
node_id: RESCALE_LATENTS,
field: 'latents',
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: 'latents',
},
},
{
source: {
node_id: NOISE_HRF,
field: 'noise',
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: 'noise',
},
},
// Set up new latents to image node.
{
source: {
node_id: DENOISE_LATENTS_HRF,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF,
field: 'latents', field: 'latents',
}, },
}, },
@ -225,17 +185,188 @@ export const addHrfToGraph = (
field: 'vae', field: 'vae',
}, },
destination: { destination: {
node_id: LATENTS_TO_IMAGE_HRF, node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'vae', field: 'vae',
}, },
} }
); );
upsertMetadata(graph, { graph.nodes[RESIZE_HRF] = {
hrf_height: hrfHeight, id: RESIZE_HRF,
hrf_width: hrfWidth, type: 'img_resize',
hrf_strength: hrfStrength, is_intermediate: true,
width: width,
height: height,
};
if (hrfMethod == 'ESRGAN') {
let model_name: ESRGANInvocation['model_name'] = 'RealESRGAN_x2plus.pth';
if ((width * height) / (hrfWidth * hrfHeight) > 2) {
model_name = 'RealESRGAN_x4plus.pth';
}
graph.nodes[ESRGAN_HRF] = {
id: ESRGAN_HRF,
type: 'esrgan',
model_name,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'image',
},
destination: {
node_id: ESRGAN_HRF,
field: 'image',
},
},
{
source: {
node_id: ESRGAN_HRF,
field: 'image',
},
destination: {
node_id: RESIZE_HRF,
field: 'image',
},
}
);
} else {
graph.edges.push({
source: {
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'image',
},
destination: {
node_id: RESIZE_HRF,
field: 'image',
},
}); });
}
graph.nodes[NOISE_HRF] = {
type: 'noise',
id: NOISE_HRF,
seed: originalNoiseNode?.seed,
use_cpu: originalNoiseNode?.use_cpu,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: RESIZE_HRF,
field: 'height',
},
destination: {
node_id: NOISE_HRF,
field: 'height',
},
},
{
source: {
node_id: RESIZE_HRF,
field: 'width',
},
destination: {
node_id: NOISE_HRF,
field: 'width',
},
}
);
graph.nodes[IMAGE_TO_LATENTS_HRF] = {
type: 'i2l',
id: IMAGE_TO_LATENTS_HRF,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS_HRF,
field: 'vae',
},
},
{
source: {
node_id: RESIZE_HRF,
field: 'image',
},
destination: {
node_id: IMAGE_TO_LATENTS_HRF,
field: 'image',
},
}
);
graph.nodes[DENOISE_LATENTS_HRF] = {
type: 'denoise_latents',
id: DENOISE_LATENTS_HRF,
is_intermediate: true,
cfg_scale: originalDenoiseLatentsNode?.cfg_scale,
scheduler: originalDenoiseLatentsNode?.scheduler,
steps: originalDenoiseLatentsNode?.steps,
denoising_start: 1 - state.generation.hrfStrength,
denoising_end: 1,
};
graph.edges.push(
{
source: {
node_id: IMAGE_TO_LATENTS_HRF,
field: 'latents',
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: 'latents',
},
},
{
source: {
node_id: NOISE_HRF,
field: 'noise',
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: 'noise',
},
}
);
copyConnectionsToDenoiseLatentsHrf(graph); copyConnectionsToDenoiseLatentsHrf(graph);
graph.nodes[LATENTS_TO_IMAGE_HRF_HR] = {
type: 'l2i',
id: LATENTS_TO_IMAGE_HRF_HR,
fp32: originalLatentsToImageNode?.fp32,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF_HR,
field: 'vae',
},
},
{
source: {
node_id: DENOISE_LATENTS_HRF,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF_HR,
field: 'latents',
},
}
);
upsertMetadata(graph, {
hrf_strength: hrfStrength,
hrf_enabled: hrfEnabled,
hrf_method: hrfMethod,
});
}; };

View File

@ -5,7 +5,7 @@ import { SaveImageInvocation } from 'services/api/types';
import { import {
CANVAS_OUTPUT, CANVAS_OUTPUT,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
LATENTS_TO_IMAGE_HRF, LATENTS_TO_IMAGE_HRF_HR,
NSFW_CHECKER, NSFW_CHECKER,
SAVE_IMAGE, SAVE_IMAGE,
WATERMARKER, WATERMARKER,
@ -62,10 +62,10 @@ export const addSaveImageNode = (
}, },
destination, destination,
}); });
} else if (LATENTS_TO_IMAGE_HRF in graph.nodes) { } else if (LATENTS_TO_IMAGE_HRF_HR in graph.nodes) {
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: LATENTS_TO_IMAGE_HRF, node_id: LATENTS_TO_IMAGE_HRF_HR,
field: 'image', field: 'image',
}, },
destination, destination,

View File

@ -4,7 +4,11 @@ export const NEGATIVE_CONDITIONING = 'negative_conditioning';
export const DENOISE_LATENTS = 'denoise_latents'; export const DENOISE_LATENTS = 'denoise_latents';
export const DENOISE_LATENTS_HRF = 'denoise_latents_hrf'; export const DENOISE_LATENTS_HRF = 'denoise_latents_hrf';
export const LATENTS_TO_IMAGE = 'latents_to_image'; export const LATENTS_TO_IMAGE = 'latents_to_image';
export const LATENTS_TO_IMAGE_HRF = 'latents_to_image_hrf'; export const LATENTS_TO_IMAGE_HRF_HR = 'latents_to_image_hrf_hr';
export const LATENTS_TO_IMAGE_HRF_LR = 'latents_to_image_hrf_lr';
export const IMAGE_TO_LATENTS_HRF = 'image_to_latents_hrf';
export const RESIZE_HRF = 'resize_hrf';
export const ESRGAN_HRF = 'esrgan_hrf';
export const SAVE_IMAGE = 'save_image'; export const SAVE_IMAGE = 'save_image';
export const NSFW_CHECKER = 'nsfw_checker'; export const NSFW_CHECKER = 'nsfw_checker';
export const WATERMARKER = 'invisible_watermark'; export const WATERMARKER = 'invisible_watermark';
@ -21,7 +25,6 @@ export const CLIP_SKIP = 'clip_skip';
export const IMAGE_TO_LATENTS = 'image_to_latents'; export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image'; export const RESIZE = 'resize_image';
export const RESCALE_LATENTS = 'rescale_latents';
export const IMG2IMG_RESIZE = 'img2img_resize'; export const IMG2IMG_RESIZE = 'img2img_resize';
export const CANVAS_OUTPUT = 'canvas_output'; export const CANVAS_OUTPUT = 'canvas_output';
export const INPAINT_IMAGE = 'inpaint_image'; export const INPAINT_IMAGE = 'inpaint_image';

View File

@ -7,10 +7,9 @@ import IAICollapse from 'common/components/IAICollapse';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import ParamHrfHeight from './ParamHrfHeight';
import ParamHrfStrength from './ParamHrfStrength'; import ParamHrfStrength from './ParamHrfStrength';
import ParamHrfToggle from './ParamHrfToggle'; import ParamHrfToggle from './ParamHrfToggle';
import ParamHrfWidth from './ParamHrfWidth'; import ParamHrfMethod from './ParamHrfMethod';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
@ -37,28 +36,11 @@ export default function ParamHrfCollapse() {
} }
return ( return (
<IAICollapse label="High Resolution Fix" activeLabel={activeLabel}> <IAICollapse label={t('hrf.hrf')} activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}> <Flex sx={{ flexDir: 'column', gap: 2 }}>
<ParamHrfToggle /> <ParamHrfToggle />
{hrfEnabled && ( <ParamHrfStrength />
<Flex <ParamHrfMethod />
sx={{
gap: 2,
p: 4,
borderRadius: 4,
flexDirection: 'column',
w: 'full',
bg: 'base.100',
_dark: {
bg: 'base.750',
},
}}
>
<ParamHrfWidth />
<ParamHrfHeight />
</Flex>
)}
{hrfEnabled && <ParamHrfStrength />}
</Flex> </Flex>
</IAICollapse> </IAICollapse>
); );

View File

@ -1,87 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import {
setHrfHeight,
setHrfWidth,
} from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
function findPrevMultipleOfEight(n: number): number {
return Math.floor((n - 1) / 8) * 8;
}
const selector = createSelector(
[stateSelector],
({ generation, hotkeys, config }) => {
const { min, fineStep, coarseStep } = config.sd.height;
const { model, height, hrfHeight, aspectRatio, hrfEnabled } = generation;
const step = hotkeys.shift ? fineStep : coarseStep;
return {
model,
height,
hrfHeight,
min,
step,
aspectRatio,
hrfEnabled,
};
},
defaultSelectorOptions
);
type ParamHeightProps = Omit<
IAIFullSliderProps,
'label' | 'value' | 'onChange'
>;
const ParamHrfHeight = (props: ParamHeightProps) => {
const { height, hrfHeight, min, step, aspectRatio, hrfEnabled } =
useAppSelector(selector);
const dispatch = useAppDispatch();
const maxHrfHeight = Math.max(findPrevMultipleOfEight(height), min);
const handleChange = useCallback(
(v: number) => {
dispatch(setHrfHeight(v));
if (aspectRatio) {
const newWidth = roundToMultiple(v * aspectRatio, 8);
dispatch(setHrfWidth(newWidth));
}
},
[dispatch, aspectRatio]
);
const handleReset = useCallback(() => {
dispatch(setHrfHeight(maxHrfHeight));
if (aspectRatio) {
const newWidth = roundToMultiple(maxHrfHeight * aspectRatio, 8);
dispatch(setHrfWidth(newWidth));
}
}, [dispatch, maxHrfHeight, aspectRatio]);
return (
<IAISlider
label="Initial Height"
value={hrfHeight}
min={min}
step={step}
max={maxHrfHeight}
onChange={handleChange}
handleReset={handleReset}
withInput
withReset
withSliderMarks
sliderNumberInputProps={{ max: maxHrfHeight }}
isDisabled={!hrfEnabled}
{...props}
/>
);
};
export default memo(ParamHrfHeight);

View File

@ -0,0 +1,49 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { setHrfMethod } from 'features/parameters/store/generationSlice';
import { HrfMethodParam } from 'features/parameters/types/parameterSchemas';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
stateSelector,
({ generation }) => {
const { hrfMethod, hrfEnabled } = generation;
return { hrfMethod, hrfEnabled };
},
defaultSelectorOptions
);
const DATA = ['ESRGAN', 'bilinear'];
// Dropdown selection for the type of high resolution fix method to use.
const ParamHrfMethodSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { hrfMethod, hrfEnabled } = useAppSelector(selector);
const handleChange = useCallback(
(v: HrfMethodParam | null) => {
if (!v) {
return;
}
dispatch(setHrfMethod(v));
},
[dispatch]
);
return (
<IAIMantineSelect
label={t('hrf.upscaleMethod')}
value={hrfMethod}
data={DATA}
onChange={handleChange}
disabled={!hrfEnabled}
/>
);
};
export default memo(ParamHrfMethodSelect);

View File

@ -5,6 +5,8 @@ import { memo, useCallback } from 'react';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { setHrfStrength } from 'features/parameters/store/generationSlice'; import { setHrfStrength } from 'features/parameters/store/generationSlice';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { Tooltip } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
@ -31,6 +33,7 @@ const ParamHrfStrength = () => {
const { hrfStrength, initial, min, sliderMax, step, hrfEnabled } = const { hrfStrength, initial, min, sliderMax, step, hrfEnabled } =
useAppSelector(selector); useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleHrfStrengthReset = useCallback(() => { const handleHrfStrengthReset = useCallback(() => {
dispatch(setHrfStrength(initial)); dispatch(setHrfStrength(initial));
@ -44,9 +47,9 @@ const ParamHrfStrength = () => {
); );
return ( return (
<Tooltip label={t('hrf.strengthTooltip')} placement="right" hasArrow>
<IAISlider <IAISlider
label="Denoising Strength" label={t('parameters.denoisingStrength')}
aria-label="High Resolution Denoising Strength"
min={min} min={min}
max={sliderMax} max={sliderMax}
step={step} step={step}
@ -58,6 +61,7 @@ const ParamHrfStrength = () => {
handleReset={handleHrfStrengthReset} handleReset={handleHrfStrengthReset}
isDisabled={!hrfEnabled} isDisabled={!hrfEnabled}
/> />
</Tooltip>
); );
}; };

View File

@ -3,9 +3,11 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { setHrfEnabled } from 'features/parameters/store/generationSlice'; import { setHrfEnabled } from 'features/parameters/store/generationSlice';
import { ChangeEvent, useCallback } from 'react'; import { ChangeEvent, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export default function ParamHrfToggle() { export default function ParamHrfToggle() {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation();
const hrfEnabled = useAppSelector( const hrfEnabled = useAppSelector(
(state: RootState) => state.generation.hrfEnabled (state: RootState) => state.generation.hrfEnabled
@ -19,9 +21,10 @@ export default function ParamHrfToggle() {
return ( return (
<IAISwitch <IAISwitch
label="Enable High Resolution Fix" label={t('hrf.enableHrf')}
isChecked={hrfEnabled} isChecked={hrfEnabled}
onChange={handleHrfEnabled} onChange={handleHrfEnabled}
tooltip={t('hrf.enableHrfTooltip')}
/> />
); );
} }

View File

@ -1,84 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import {
setHrfHeight,
setHrfWidth,
} from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
function findPrevMultipleOfEight(n: number): number {
return Math.floor((n - 1) / 8) * 8;
}
const selector = createSelector(
[stateSelector],
({ generation, hotkeys, config }) => {
const { min, fineStep, coarseStep } = config.sd.width;
const { model, width, hrfWidth, aspectRatio, hrfEnabled } = generation;
const step = hotkeys.shift ? fineStep : coarseStep;
return {
model,
width,
hrfWidth,
min,
step,
aspectRatio,
hrfEnabled,
};
},
defaultSelectorOptions
);
type ParamWidthProps = Omit<IAIFullSliderProps, 'label' | 'value' | 'onChange'>;
const ParamHrfWidth = (props: ParamWidthProps) => {
const { width, hrfWidth, min, step, aspectRatio, hrfEnabled } =
useAppSelector(selector);
const dispatch = useAppDispatch();
const maxHrfWidth = Math.max(findPrevMultipleOfEight(width), min);
const handleChange = useCallback(
(v: number) => {
dispatch(setHrfWidth(v));
if (aspectRatio) {
const newHeight = roundToMultiple(v / aspectRatio, 8);
dispatch(setHrfHeight(newHeight));
}
},
[dispatch, aspectRatio]
);
const handleReset = useCallback(() => {
dispatch(setHrfWidth(maxHrfWidth));
if (aspectRatio) {
const newHeight = roundToMultiple(maxHrfWidth / aspectRatio, 8);
dispatch(setHrfHeight(newHeight));
}
}, [dispatch, maxHrfWidth, aspectRatio]);
return (
<IAISlider
label="Initial Width"
value={hrfWidth}
min={min}
step={step}
max={maxHrfWidth}
onChange={handleChange}
handleReset={handleReset}
withInput
withReset
withSliderMarks
sliderNumberInputProps={{ max: maxHrfWidth }}
isDisabled={!hrfEnabled}
{...props}
/>
);
};
export default memo(ParamHrfWidth);

View File

@ -55,6 +55,9 @@ import { initialImageSelected, modelSelected } from '../store/actions';
import { import {
setCfgScale, setCfgScale,
setHeight, setHeight,
setHrfEnabled,
setHrfMethod,
setHrfStrength,
setImg2imgStrength, setImg2imgStrength,
setNegativePrompt, setNegativePrompt,
setPositivePrompt, setPositivePrompt,
@ -67,6 +70,7 @@ import {
isValidCfgScale, isValidCfgScale,
isValidControlNetModel, isValidControlNetModel,
isValidHeight, isValidHeight,
isValidHrfMethod,
isValidIPAdapterModel, isValidIPAdapterModel,
isValidLoRAModel, isValidLoRAModel,
isValidMainModel, isValidMainModel,
@ -83,6 +87,7 @@ import {
isValidSteps, isValidSteps,
isValidStrength, isValidStrength,
isValidWidth, isValidWidth,
isValidBoolean,
} from '../types/parameterSchemas'; } from '../types/parameterSchemas';
const selector = createSelector( const selector = createSelector(
@ -361,6 +366,51 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall high resolution enabled with toast
*/
const recallHrfEnabled = useCallback(
(hrfEnabled: unknown) => {
if (!isValidBoolean(hrfEnabled)) {
parameterNotSetToast();
return;
}
dispatch(setHrfEnabled(hrfEnabled));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall high resolution strength with toast
*/
const recallHrfStrength = useCallback(
(hrfStrength: unknown) => {
if (!isValidStrength(hrfStrength)) {
parameterNotSetToast();
return;
}
dispatch(setHrfStrength(hrfStrength));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall high resolution method with toast
*/
const recallHrfMethod = useCallback(
(hrfMethod: unknown) => {
if (!isValidHrfMethod(hrfMethod)) {
parameterNotSetToast();
return;
}
dispatch(setHrfMethod(hrfMethod));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/** /**
* Recall LoRA with toast * Recall LoRA with toast
*/ */
@ -711,6 +761,9 @@ export const useRecallParameters = () => {
steps, steps,
width, width,
strength, strength,
hrf_enabled,
hrf_strength,
hrf_method,
positive_style_prompt, positive_style_prompt,
negative_style_prompt, negative_style_prompt,
refiner_model, refiner_model,
@ -729,34 +782,55 @@ export const useRecallParameters = () => {
if (isValidCfgScale(cfg_scale)) { if (isValidCfgScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale)); dispatch(setCfgScale(cfg_scale));
} }
if (isValidMainModel(model)) { if (isValidMainModel(model)) {
dispatch(modelSelected(model)); dispatch(modelSelected(model));
} }
if (isValidPositivePrompt(positive_prompt)) { if (isValidPositivePrompt(positive_prompt)) {
dispatch(setPositivePrompt(positive_prompt)); dispatch(setPositivePrompt(positive_prompt));
} }
if (isValidNegativePrompt(negative_prompt)) { if (isValidNegativePrompt(negative_prompt)) {
dispatch(setNegativePrompt(negative_prompt)); dispatch(setNegativePrompt(negative_prompt));
} }
if (isValidScheduler(scheduler)) { if (isValidScheduler(scheduler)) {
dispatch(setScheduler(scheduler)); dispatch(setScheduler(scheduler));
} }
if (isValidSeed(seed)) { if (isValidSeed(seed)) {
dispatch(setSeed(seed)); dispatch(setSeed(seed));
} }
if (isValidSteps(steps)) { if (isValidSteps(steps)) {
dispatch(setSteps(steps)); dispatch(setSteps(steps));
} }
if (isValidWidth(width)) { if (isValidWidth(width)) {
dispatch(setWidth(width)); dispatch(setWidth(width));
} }
if (isValidHeight(height)) { if (isValidHeight(height)) {
dispatch(setHeight(height)); dispatch(setHeight(height));
} }
if (isValidStrength(strength)) { if (isValidStrength(strength)) {
dispatch(setImg2imgStrength(strength)); dispatch(setImg2imgStrength(strength));
} }
if (isValidBoolean(hrf_enabled)) {
dispatch(setHrfEnabled(hrf_enabled));
}
if (isValidStrength(hrf_strength)) {
dispatch(setHrfStrength(hrf_strength));
}
if (isValidHrfMethod(hrf_method)) {
dispatch(setHrfMethod(hrf_method));
}
if (isValidSDXLPositiveStylePrompt(positive_style_prompt)) { if (isValidSDXLPositiveStylePrompt(positive_style_prompt)) {
dispatch(setPositiveStylePromptSDXL(positive_style_prompt)); dispatch(setPositiveStylePromptSDXL(positive_style_prompt));
} }
@ -862,6 +936,9 @@ export const useRecallParameters = () => {
recallWidth, recallWidth,
recallHeight, recallHeight,
recallStrength, recallStrength,
recallHrfEnabled,
recallHrfStrength,
recallHrfMethod,
recallLoRA, recallLoRA,
recallControlNet, recallControlNet,
recallIPAdapter, recallIPAdapter,

View File

@ -11,6 +11,7 @@ import {
CanvasCoherenceModeParam, CanvasCoherenceModeParam,
CfgScaleParam, CfgScaleParam,
HeightParam, HeightParam,
HrfMethodParam,
MainModelParam, MainModelParam,
MaskBlurMethodParam, MaskBlurMethodParam,
NegativePromptParam, NegativePromptParam,
@ -27,10 +28,9 @@ import {
} from '../types/parameterSchemas'; } from '../types/parameterSchemas';
export interface GenerationState { export interface GenerationState {
hrfHeight: HeightParam;
hrfWidth: WidthParam;
hrfEnabled: boolean; hrfEnabled: boolean;
hrfStrength: StrengthParam; hrfStrength: StrengthParam;
hrfMethod: HrfMethodParam;
cfgScale: CfgScaleParam; cfgScale: CfgScaleParam;
height: HeightParam; height: HeightParam;
img2imgStrength: StrengthParam; img2imgStrength: StrengthParam;
@ -73,10 +73,9 @@ export interface GenerationState {
} }
export const initialGenerationState: GenerationState = { export const initialGenerationState: GenerationState = {
hrfHeight: 64, hrfStrength: 0.45,
hrfWidth: 64,
hrfStrength: 0.75,
hrfEnabled: false, hrfEnabled: false,
hrfMethod: 'ESRGAN',
cfgScale: 7.5, cfgScale: 7.5,
height: 512, height: 512,
img2imgStrength: 0.75, img2imgStrength: 0.75,
@ -279,18 +278,15 @@ export const generationSlice = createSlice({
setClipSkip: (state, action: PayloadAction<number>) => { setClipSkip: (state, action: PayloadAction<number>) => {
state.clipSkip = action.payload; state.clipSkip = action.payload;
}, },
setHrfHeight: (state, action: PayloadAction<number>) => {
state.hrfHeight = action.payload;
},
setHrfWidth: (state, action: PayloadAction<number>) => {
state.hrfWidth = action.payload;
},
setHrfStrength: (state, action: PayloadAction<number>) => { setHrfStrength: (state, action: PayloadAction<number>) => {
state.hrfStrength = action.payload; state.hrfStrength = action.payload;
}, },
setHrfEnabled: (state, action: PayloadAction<boolean>) => { setHrfEnabled: (state, action: PayloadAction<boolean>) => {
state.hrfEnabled = action.payload; state.hrfEnabled = action.payload;
}, },
setHrfMethod: (state, action: PayloadAction<HrfMethodParam>) => {
state.hrfMethod = action.payload;
},
shouldUseCpuNoiseChanged: (state, action: PayloadAction<boolean>) => { shouldUseCpuNoiseChanged: (state, action: PayloadAction<boolean>) => {
state.shouldUseCpuNoise = action.payload; state.shouldUseCpuNoise = action.payload;
}, },
@ -375,10 +371,9 @@ export const {
setSeamlessXAxis, setSeamlessXAxis,
setSeamlessYAxis, setSeamlessYAxis,
setClipSkip, setClipSkip,
setHrfHeight,
setHrfWidth,
setHrfStrength,
setHrfEnabled, setHrfEnabled,
setHrfStrength,
setHrfMethod,
shouldUseCpuNoiseChanged, shouldUseCpuNoiseChanged,
setAspectRatio, setAspectRatio,
setShouldLockAspectRatio, setShouldLockAspectRatio,

View File

@ -400,6 +400,20 @@ export type PrecisionParam = z.infer<typeof zPrecision>;
export const isValidPrecision = (val: unknown): val is PrecisionParam => export const isValidPrecision = (val: unknown): val is PrecisionParam =>
zPrecision.safeParse(val).success; zPrecision.safeParse(val).success;
/**
* Zod schema for a high resolution fix method parameter.
*/
export const zHrfMethod = z.enum(['ESRGAN', 'bilinear']);
/**
* Type alias for high resolution fix method parameter, inferred from its zod schema
*/
export type HrfMethodParam = z.infer<typeof zHrfMethod>;
/**
* Validates/type-guards a value as a high resolution fix method parameter
*/
export const isValidHrfMethod = (val: unknown): val is HrfMethodParam =>
zHrfMethod.safeParse(val).success;
/** /**
* Zod schema for SDXL refiner positive aesthetic score parameter * Zod schema for SDXL refiner positive aesthetic score parameter
*/ */
@ -482,6 +496,17 @@ export const isValidCoherenceModeParam = (
): val is CanvasCoherenceModeParam => ): val is CanvasCoherenceModeParam =>
zCanvasCoherenceMode.safeParse(val).success; zCanvasCoherenceMode.safeParse(val).success;
/**
* Zod schema for a boolean.
*/
export const zBoolean = z.boolean();
/**
* Validates/type-guards a value as a boolean parameter
*/
export const isValidBoolean = (val: unknown): val is boolean =>
zBoolean.safeParse(val).success && val !== null && val !== undefined;
// /** // /**
// * Zod schema for BaseModelType // * Zod schema for BaseModelType
// */ // */

View File

@ -70,7 +70,7 @@ export const initialConfigState: AppConfig = {
coarseStep: 0.05, coarseStep: 0.05,
}, },
hrfStrength: { hrfStrength: {
initial: 0.7, initial: 0.45,
min: 0, min: 0,
sliderMax: 1, sliderMax: 1,
inputMax: 1, inputMax: 1,

File diff suppressed because one or more lines are too long

View File

@ -127,7 +127,6 @@ export type CompelInvocation = s['CompelInvocation'];
export type DynamicPromptInvocation = s['DynamicPromptInvocation']; export type DynamicPromptInvocation = s['DynamicPromptInvocation'];
export type NoiseInvocation = s['NoiseInvocation']; export type NoiseInvocation = s['NoiseInvocation'];
export type DenoiseLatentsInvocation = s['DenoiseLatentsInvocation']; export type DenoiseLatentsInvocation = s['DenoiseLatentsInvocation'];
export type ResizeLatentsInvocation = s['ResizeLatentsInvocation'];
export type ONNXTextToLatentsInvocation = s['ONNXTextToLatentsInvocation']; export type ONNXTextToLatentsInvocation = s['ONNXTextToLatentsInvocation'];
export type SDXLLoraLoaderInvocation = s['SDXLLoraLoaderInvocation']; export type SDXLLoraLoaderInvocation = s['SDXLLoraLoaderInvocation'];
export type ImageToLatentsInvocation = s['ImageToLatentsInvocation']; export type ImageToLatentsInvocation = s['ImageToLatentsInvocation'];