feat(ui): add scale param to upscaling tab

This commit is contained in:
psychedelicious 2024-07-23 10:01:07 +10:00
parent 7cee4e42a7
commit f9d3966ea2
6 changed files with 74 additions and 72 deletions

View File

@ -1649,11 +1649,13 @@
"structure": "Structure", "structure": "Structure",
"toInstall": "to install", "toInstall": "to install",
"upscaleModel": "Upcale Model", "upscaleModel": "Upcale Model",
"scale": "Scale",
"visit": "Visit", "visit": "Visit",
"warningNoMainModel": "a model", "warningNoMainModel": "a model",
"warningNoTile": "a {{base_model}} tile controlnet required by this feature", "warningNoTile": "a {{base_model}} tile controlnet required by this feature",
"warningNoTileOrUpscaleModel": "an upscaler model and {{base_model}} tile controlnet required by this feature", "warningNoTileOrUpscaleModel": "an upscaler model and {{base_model}} tile controlnet required by this feature",
"warningNoUpscaleModel": "an upscaler model required by this feature", "warningNoUpscaleModel": "an upscaler model required by this feature",
"upscalingFromTo": "Upscaling from {{from}} to {{to}}"
}, },
"ui": { "ui": {
"tabs": { "tabs": {

View File

@ -2,7 +2,6 @@ import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { GraphType } from 'features/nodes/util/graph/generation/Graph'; import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ImageDTO } from 'services/api/types';
import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types'; import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -15,7 +14,6 @@ import {
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RESIZE,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SPANDREL, SPANDREL,
TILED_MULTI_DIFFUSION_DENOISE_LATENTS, TILED_MULTI_DIFFUSION_DENOISE_LATENTS,
@ -26,27 +24,17 @@ import { addLoRAs } from './generation/addLoRAs';
import { addSDXLLoRas } from './generation/addSDXLLoRAs'; import { addSDXLLoRas } from './generation/addSDXLLoRAs';
import { getBoardField, getSDXLStylePrompts } from './graphBuilderUtils'; import { getBoardField, getSDXLStylePrompts } from './graphBuilderUtils';
const UPSCALE_SCALE = 2;
export const getOutputImageSize = (initialImage: ImageDTO) => {
return {
width: ((initialImage.width * UPSCALE_SCALE) / 8) * 8,
height: ((initialImage.height * UPSCALE_SCALE) / 8) * 8,
};
};
export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promise<GraphType> => { export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promise<GraphType> => {
const { model, cfgScale: cfg_scale, scheduler, steps, vaePrecision, seed, vae } = state.generation; const { model, cfgScale: cfg_scale, scheduler, steps, vaePrecision, seed, vae } = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present; const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { upscaleModel, upscaleInitialImage, sharpness, structure, creativity, tileControlnetModel } = state.upscale; const { upscaleModel, upscaleInitialImage, sharpness, structure, creativity, tileControlnetModel, scale } =
state.upscale;
assert(model, 'No model found in state'); assert(model, 'No model found in state');
assert(upscaleModel, 'No upscale model found in state'); assert(upscaleModel, 'No upscale model found in state');
assert(upscaleInitialImage, 'No initial image found in state'); assert(upscaleInitialImage, 'No initial image found in state');
assert(tileControlnetModel, 'Tile controlnet is required'); assert(tileControlnetModel, 'Tile controlnet is required');
const { width: outputWidth, height: outputHeight } = getOutputImageSize(upscaleInitialImage);
const g = new Graph(); const g = new Graph();
const unsharpMaskNode1 = g.addNode({ const unsharpMaskNode1 = g.addNode({
@ -61,7 +49,8 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
id: SPANDREL, id: SPANDREL,
type: 'spandrel_image_to_image', type: 'spandrel_image_to_image',
image_to_image_model: upscaleModel, image_to_image_model: upscaleModel,
tile_size: 500, fit_to_multiple_of_8: true,
scale,
}); });
g.addEdge(unsharpMaskNode1, 'image', upscaleNode, 'image'); g.addEdge(unsharpMaskNode1, 'image', upscaleNode, 'image');
@ -75,24 +64,14 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
g.addEdge(upscaleNode, 'image', unsharpMaskNode2, 'image'); g.addEdge(upscaleNode, 'image', unsharpMaskNode2, 'image');
const resizeNode = g.addNode({
id: RESIZE,
type: 'img_resize',
width: outputWidth,
height: outputHeight,
resample_mode: 'lanczos',
});
g.addEdge(unsharpMaskNode2, 'image', resizeNode, 'image');
const noiseNode = g.addNode({ const noiseNode = g.addNode({
id: NOISE, id: NOISE,
type: 'noise', type: 'noise',
seed, seed,
}); });
g.addEdge(resizeNode, 'width', noiseNode, 'width'); g.addEdge(unsharpMaskNode2, 'width', noiseNode, 'width');
g.addEdge(resizeNode, 'height', noiseNode, 'height'); g.addEdge(unsharpMaskNode2, 'height', noiseNode, 'height');
const i2lNode = g.addNode({ const i2lNode = g.addNode({
id: IMAGE_TO_LATENTS, id: IMAGE_TO_LATENTS,
@ -101,7 +80,7 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
tiled: true, tiled: true,
}); });
g.addEdge(resizeNode, 'image', i2lNode, 'image'); g.addEdge(unsharpMaskNode2, 'image', i2lNode, 'image');
const l2iNode = g.addNode({ const l2iNode = g.addNode({
type: 'l2i', type: 'l2i',
@ -160,8 +139,6 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
g.upsertMetadata({ g.upsertMetadata({
cfg_scale, cfg_scale,
height: outputHeight,
width: outputWidth,
positive_prompt: positivePrompt, positive_prompt: positivePrompt,
negative_prompt: negativePrompt, negative_prompt: negativePrompt,
positive_style_prompt: positiveStylePrompt, positive_style_prompt: positiveStylePrompt,
@ -204,8 +181,6 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
g.upsertMetadata({ g.upsertMetadata({
cfg_scale, cfg_scale,
height: outputHeight,
width: outputWidth,
positive_prompt: positivePrompt, positive_prompt: positivePrompt,
negative_prompt: negativePrompt, negative_prompt: negativePrompt,
model: Graph.getModelMetadataField(modelConfig), model: Graph.getModelMetadataField(modelConfig),
@ -221,6 +196,8 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
} }
g.setMetadataReceivingNode(l2iNode); g.setMetadataReceivingNode(l2iNode);
g.addEdgeToMetadata(upscaleNode, 'width', 'width');
g.addEdgeToMetadata(upscaleNode, 'height', 'height');
let vaeNode; let vaeNode;
if (vae) { if (vae) {
@ -252,7 +229,7 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
end_step_percent: (structure + 10) * 0.025 + 0.3, end_step_percent: (structure + 10) * 0.025 + 0.3,
}); });
g.addEdge(resizeNode, 'image', controlnetNode1, 'image'); g.addEdge(unsharpMaskNode2, 'image', controlnetNode1, 'image');
const controlnetNode2 = g.addNode({ const controlnetNode2 = g.addNode({
id: 'controlnet_2', id: 'controlnet_2',
@ -265,7 +242,7 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
end_step_percent: 0.85, end_step_percent: 0.85,
}); });
g.addEdge(resizeNode, 'image', controlnetNode2, 'image'); g.addEdge(unsharpMaskNode2, 'image', controlnetNode2, 'image');
const collectNode = g.addNode({ const collectNode = g.addNode({
id: CONTROL_NET_COLLECT, id: CONTROL_NET_COLLECT,

View File

@ -12,6 +12,7 @@ interface UpscaleState {
structure: number; structure: number;
creativity: number; creativity: number;
tileControlnetModel: ControlNetModelConfig | null; tileControlnetModel: ControlNetModelConfig | null;
scale: number;
} }
const initialUpscaleState: UpscaleState = { const initialUpscaleState: UpscaleState = {
@ -22,6 +23,7 @@ const initialUpscaleState: UpscaleState = {
structure: 0, structure: 0,
creativity: 0, creativity: 0,
tileControlnetModel: null, tileControlnetModel: null,
scale: 4,
}; };
export const upscaleSlice = createSlice({ export const upscaleSlice = createSlice({
@ -46,6 +48,9 @@ export const upscaleSlice = createSlice({
tileControlnetModelChanged: (state, action: PayloadAction<ControlNetModelConfig | null>) => { tileControlnetModelChanged: (state, action: PayloadAction<ControlNetModelConfig | null>) => {
state.tileControlnetModel = action.payload; state.tileControlnetModel = action.payload;
}, },
scaleChanged: (state, action: PayloadAction<number>) => {
state.scale = action.payload;
},
}, },
}); });
@ -56,6 +61,7 @@ export const {
creativityChanged, creativityChanged,
sharpnessChanged, sharpnessChanged,
tileControlnetModelChanged, tileControlnetModelChanged,
scaleChanged,
} = upscaleSlice.actions; } = upscaleSlice.actions;
export const selectUpscalelice = (state: RootState) => state.upscale; export const selectUpscalelice = (state: RootState) => state.upscale;

View File

@ -0,0 +1,39 @@
import { CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { scaleChanged } from 'features/parameters/store/upscaleSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const marks = [2, 4, 8, 16];
const formatValue = (val: number) => `${val}x`;
export const UpscaleScaleSlider = memo(() => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const scale = useAppSelector((s) => s.upscale.scale);
const onChange = useCallback(
(val: number) => {
dispatch(scaleChanged(val));
},
[dispatch]
);
return (
<FormControl orientation="vertical" gap={0}>
<FormLabel m={0}>{t('upscaling.scale')}</FormLabel>
<CompositeSlider
min={2}
max={16}
value={scale}
onChange={onChange}
marks={marks}
formatValue={formatValue}
withThumbTooltip
/>
</FormControl>
);
});
UpscaleScaleSlider.displayName = 'UpscaleScaleSlider';

View File

@ -6,6 +6,7 @@ import ParamSharpness from 'features/parameters/components/Upscale/ParamSharpnes
import ParamSpandrelModel from 'features/parameters/components/Upscale/ParamSpandrelModel'; import ParamSpandrelModel from 'features/parameters/components/Upscale/ParamSpandrelModel';
import ParamStructure from 'features/parameters/components/Upscale/ParamStructure'; import ParamStructure from 'features/parameters/components/Upscale/ParamStructure';
import { selectUpscalelice } from 'features/parameters/store/upscaleSlice'; import { selectUpscalelice } from 'features/parameters/store/upscaleSlice';
import { UpscaleScaleSlider } from 'features/settingsAccordions/components/UpscaleSettingsAccordion/UpscaleScaleSlider';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle'; import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { memo } from 'react'; import { memo } from 'react';
@ -13,13 +14,22 @@ import { useTranslation } from 'react-i18next';
import { MultidiffusionWarning } from './MultidiffusionWarning'; import { MultidiffusionWarning } from './MultidiffusionWarning';
import { UpscaleInitialImage } from './UpscaleInitialImage'; import { UpscaleInitialImage } from './UpscaleInitialImage';
import { UpscaleSizeDetails } from './UpscaleSizeDetails';
const selector = createMemoizedSelector([selectUpscalelice], (upscale) => { const selector = createMemoizedSelector([selectUpscalelice], (upscaleSlice) => {
const { upscaleModel, upscaleInitialImage, scale } = upscaleSlice;
const badges: string[] = []; const badges: string[] = [];
if (upscale.upscaleModel) { if (upscaleModel) {
badges.push(upscale.upscaleModel.name); badges.push(upscaleModel.name);
}
if (upscaleInitialImage) {
// Output height and width are scaled and rounded down to the nearest multiple of 8
const outputWidth = Math.floor((upscaleInitialImage.width * scale) / 8) * 8;
const outputHeight = Math.floor((upscaleInitialImage.height * scale) / 8) * 8;
badges.push(`${outputWidth}×${outputHeight}`);
} }
return { badges }; return { badges };
@ -43,9 +53,9 @@ export const UpscaleSettingsAccordion = memo(() => {
<Flex pt={4} px={4} w="full" h="full" flexDir="column" data-testid="image-settings-accordion"> <Flex pt={4} px={4} w="full" h="full" flexDir="column" data-testid="image-settings-accordion">
<Flex gap={4}> <Flex gap={4}>
<UpscaleInitialImage /> <UpscaleInitialImage />
<Flex direction="column" w="full" alignItems="center" gap={4}> <Flex direction="column" w="full" alignItems="center" gap={2}>
<ParamSpandrelModel /> <ParamSpandrelModel />
<UpscaleSizeDetails /> <UpscaleScaleSlider />
<MultidiffusionWarning /> <MultidiffusionWarning />
</Flex> </Flex>
</Flex> </Flex>

View File

@ -1,32 +0,0 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { getOutputImageSize } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
export const UpscaleSizeDetails = () => {
const { t } = useTranslation();
const { upscaleInitialImage } = useAppSelector((s) => s.upscale);
const outputSizeText = useMemo(() => {
if (upscaleInitialImage) {
const { width, height } = getOutputImageSize(upscaleInitialImage);
return `${t('upscaling.outputImageSize')}: ${width}×${height}`;
}
}, [upscaleInitialImage, t]);
if (!outputSizeText || !upscaleInitialImage) {
return <></>;
}
return (
<Flex direction="column">
<Text variant="subtext" fontWeight="bold">
{t('upscaling.currentImageSize')}: {upscaleInitialImage.width}×{upscaleInitialImage.height}
</Text>
<Text variant="subtext" fontWeight="bold">
{outputSizeText}
</Text>
</Flex>
);
};