Update Simple Upscale Button to work with spandrel models (#6649)

## Summary
Update Simple Upscale Button to work with spandrel models, add
UpscaleWarning when models aren't available, clean up ESRGAN logic
<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
This commit is contained in:
chainchompa 2024-07-23 13:33:01 -04:00 committed by GitHub
commit 075e0405f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 83 additions and 180 deletions

View File

@ -10,7 +10,7 @@ import { heightChanged, widthChanged } from 'features/controlLayers/store/contro
import { loraRemoved } from 'features/lora/store/loraSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
import { upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { simpleUpscaleModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
@ -186,21 +186,23 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log)
};
const handleSpandrelImageToImageModels: ModelHandler = (models, state, dispatch, _log) => {
const currentUpscaleModel = state.upscale.upscaleModel;
const { upscaleModel: currentUpscaleModel, simpleUpscaleModel: currentSimpleUpscaleModel } = state.upscale;
const upscaleModels = models.filter(isSpandrelImageToImageModelConfig);
const firstModel = upscaleModels[0] || null;
if (currentUpscaleModel) {
const isCurrentUpscaleModelAvailable = upscaleModels.some((m) => m.key === currentUpscaleModel.key);
if (isCurrentUpscaleModelAvailable) {
return;
}
}
const isCurrentUpscaleModelAvailable = currentUpscaleModel
? upscaleModels.some((m) => m.key === currentUpscaleModel.key)
: false;
const firstModel = upscaleModels[0];
if (firstModel) {
if (!isCurrentUpscaleModelAvailable) {
dispatch(upscaleModelChanged(firstModel));
return;
}
dispatch(upscaleModelChanged(null));
const isCurrentSimpleUpscaleModelAvailable = currentSimpleUpscaleModel
? upscaleModels.some((m) => m.key === currentSimpleUpscaleModel.key)
: false;
if (!isCurrentSimpleUpscaleModelAvailable) {
dispatch(simpleUpscaleModelChanged(firstModel));
}
};

View File

@ -18,7 +18,6 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening
const log = logger('session');
const { imageDTO } = action.payload;
const { image_name } = imageDTO;
const state = getState();
const { isAllowedToUpscale, detailTKey } = createIsAllowedToUpscaleSelector(imageDTO)(state);
@ -40,8 +39,8 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: buildAdHocUpscaleGraph({
image_name,
graph: await buildAdHocUpscaleGraph({
image: imageDTO,
state,
}),
runs: 1,

View File

@ -25,7 +25,6 @@ import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/no
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
import { queueSlice } from 'features/queue/store/queueSlice';
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
@ -53,7 +52,6 @@ const allReducers = {
[gallerySlice.name]: gallerySlice.reducer,
[generationSlice.name]: generationSlice.reducer,
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
[postprocessingSlice.name]: postprocessingSlice.reducer,
[systemSlice.name]: systemSlice.reducer,
[configSlice.name]: configSlice.reducer,
[uiSlice.name]: uiSlice.reducer,
@ -104,7 +102,6 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[galleryPersistConfig.name]: galleryPersistConfig,
[generationPersistConfig.name]: generationPersistConfig,
[nodesPersistConfig.name]: nodesPersistConfig,
[postprocessingPersistConfig.name]: postprocessingPersistConfig,
[systemPersistConfig.name]: systemPersistConfig,
[workflowPersistConfig.name]: workflowPersistConfig,
[uiPersistConfig.name]: uiPersistConfig,

View File

@ -1,38 +1,46 @@
import type { RootState } from 'app/store/store';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Graph, Invocation, NonNullableGraph } from 'services/api/types';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type ImageDTO,
type Invocation,
isSpandrelImageToImageModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import { assert } from 'tsafe';
import { addCoreMetadataNode, upsertMetadata } from './canvas/metadata';
import { ESRGAN } from './constants';
import { addCoreMetadataNode, getModelMetadataField, upsertMetadata } from './canvas/metadata';
import { SPANDREL } from './constants';
type Arg = {
image_name: string;
image: ImageDTO;
state: RootState;
};
export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => {
const { esrganModelName } = state.postprocessing;
export const buildAdHocUpscaleGraph = async ({ image, state }: Arg): Promise<NonNullableGraph> => {
const { simpleUpscaleModel } = state.upscale;
const realesrganNode: Invocation<'esrgan'> = {
id: ESRGAN,
type: 'esrgan',
image: { image_name },
model_name: esrganModelName,
is_intermediate: false,
board: getBoardField(state),
assert(simpleUpscaleModel, 'No upscale model found in state');
const upscaleNode: Invocation<'spandrel_image_to_image'> = {
id: SPANDREL,
type: 'spandrel_image_to_image',
image_to_image_model: simpleUpscaleModel,
tile_size: 500,
image,
};
const graph: NonNullableGraph = {
id: `adhoc-esrgan-graph`,
id: `adhoc-upscale-graph`,
nodes: {
[ESRGAN]: realesrganNode,
[SPANDREL]: upscaleNode,
},
edges: [],
};
const modelConfig = await fetchModelConfigWithTypeGuard(simpleUpscaleModel.key, isSpandrelImageToImageModelConfig);
addCoreMetadataNode(graph, {}, ESRGAN);
addCoreMetadataNode(graph, {}, SPANDREL);
upsertMetadata(graph, {
esrgan_model: esrganModelName,
upscale_model: getModelMetadataField(modelConfig),
});
return graph;

View File

@ -36,7 +36,6 @@ export const CONTROL_NET_COLLECT = 'control_net_collect';
export const IP_ADAPTER_COLLECT = 'ip_adapter_collect';
export const T2I_ADAPTER_COLLECT = 't2i_adapter_collect';
export const METADATA = 'core_metadata';
export const ESRGAN = 'esrgan';
export const SPANDREL = 'spandrel';
export const SDXL_MODEL_LOADER = 'sdxl_model_loader';
export const SDXL_DENOISE_LATENTS = 'sdxl_denoise_latents';

View File

@ -1,72 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import { esrganModelNameChanged, isParamESRGANModelName } from 'features/parameters/store/postprocessingSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
const options: GroupBase<ComboboxOption>[] = [
{
label: 'x2 Upscalers',
options: [
{
label: 'RealESRGAN x2 Plus',
value: 'RealESRGAN_x2plus.pth',
description: 'Attempts to retain sharpness, low smoothing',
},
],
},
{
label: 'x4 Upscalers',
options: [
{
label: 'RealESRGAN x4 Plus',
value: 'RealESRGAN_x4plus.pth',
description: 'Best for photos and highly detailed images, medium smoothing',
},
{
label: 'RealESRGAN x4 Plus (anime 6B)',
value: 'RealESRGAN_x4plus_anime_6B.pth',
description: 'Best for anime/manga, high smoothing',
},
{
label: 'ESRGAN SRx4',
value: 'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth',
description: 'Retains sharpness, low smoothing',
},
],
},
];
const ParamESRGANModel = () => {
const { t } = useTranslation();
const esrganModelName = useAppSelector((s) => s.postprocessing.esrganModelName);
const dispatch = useAppDispatch();
const onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isParamESRGANModelName(v?.value)) {
return;
}
dispatch(esrganModelNameChanged(v.value));
},
[dispatch]
);
const value = useMemo(
() => options.flatMap((o) => o.options).find((m) => m.value === esrganModelName),
[esrganModelName]
);
return (
<FormControl orientation="vertical">
<FormLabel>{t('models.esrganModel')}</FormLabel>
<Combobox value={value} onChange={onChange} options={options} />
</FormControl>
);
};
export default memo(ParamESRGANModel);

View File

@ -1,17 +1,21 @@
import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useModelCombobox } from 'common/hooks/useModelCombobox';
import { upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { simpleUpscaleModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
const ParamSpandrelModel = () => {
interface Props {
isMultidiffusion: boolean;
}
const ParamSpandrelModel = ({ isMultidiffusion }: Props) => {
const { t } = useTranslation();
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
const model = useAppSelector((s) => s.upscale.upscaleModel);
const model = useAppSelector((s) => (isMultidiffusion ? s.upscale.upscaleModel : s.upscale.simpleUpscaleModel));
const dispatch = useAppDispatch();
const tooltipLabel = useMemo(() => {
@ -23,9 +27,13 @@ const ParamSpandrelModel = () => {
const _onChange = useCallback(
(v: SpandrelImageToImageModelConfig | null) => {
dispatch(upscaleModelChanged(v));
if (isMultidiffusion) {
dispatch(upscaleModelChanged(v));
} else {
dispatch(simpleUpscaleModelChanged(v));
}
},
[dispatch]
[isMultidiffusion, dispatch]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({

View File

@ -12,12 +12,13 @@ import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listen
import { useAppDispatch } from 'app/store/storeHooks';
import { useIsAllowedToUpscale } from 'features/parameters/hooks/useIsAllowedToUpscale';
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
import { UpscaleWarning } from 'features/settingsAccordions/components/UpscaleSettingsAccordion/UpscaleWarning';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFrameCornersBold } from 'react-icons/pi';
import type { ImageDTO } from 'services/api/types';
import ParamESRGANModel from './ParamRealESRGANModel';
import ParamSpandrelModel from './ParamSpandrelModel';
type Props = { imageDTO?: ImageDTO };
@ -48,9 +49,10 @@ const ParamUpscalePopover = (props: Props) => {
/>
</PopoverTrigger>
<PopoverContent>
<PopoverBody minW={96}>
<PopoverBody w={96}>
<Flex flexDirection="column" gap={4}>
<ParamESRGANModel />
<ParamSpandrelModel isMultidiffusion={false} />
<UpscaleWarning usesTile={false} />
<Button
tooltip={detail}
size="sm"

View File

@ -1,6 +1,6 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectPostprocessingSlice } from 'features/parameters/store/postprocessingSlice';
import { selectUpscalelice } from 'features/parameters/store/upscaleSlice';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -55,13 +55,16 @@ const getDetailTKey = (isAllowedToUpscale?: ReturnType<typeof getIsAllowedToUpsc
};
export const createIsAllowedToUpscaleSelector = (imageDTO?: ImageDTO) =>
createMemoizedSelector(selectPostprocessingSlice, selectConfigSlice, (postprocessing, config) => {
const { esrganModelName } = postprocessing;
createMemoizedSelector(selectUpscalelice, selectConfigSlice, (upscale, config) => {
const { simpleUpscaleModel } = upscale;
const { maxUpscalePixels } = config;
if (!simpleUpscaleModel) {
return { isAllowedToUpscale: false, detailTKey: undefined };
}
const upscaledPixels = getUpscaledPixels(imageDTO, maxUpscalePixels);
const isAllowedToUpscale = getIsAllowedToUpscale(upscaledPixels, maxUpscalePixels);
const scaleFactor = esrganModelName.includes('x2') ? 2 : 4;
const scaleFactor = simpleUpscaleModel.name.includes('x2') ? 2 : 4;
const detailTKey = getDetailTKey(isAllowedToUpscale, scaleFactor);
return {
isAllowedToUpscale: scaleFactor === 2 ? isAllowedToUpscale.x2 : isAllowedToUpscale.x4,

View File

@ -1,53 +0,0 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { z } from 'zod';
const zParamESRGANModelName = z.enum([
'RealESRGAN_x4plus.pth',
'RealESRGAN_x4plus_anime_6B.pth',
'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth',
'RealESRGAN_x2plus.pth',
]);
type ParamESRGANModelName = z.infer<typeof zParamESRGANModelName>;
export const isParamESRGANModelName = (v: unknown): v is ParamESRGANModelName =>
zParamESRGANModelName.safeParse(v).success;
interface PostprocessingState {
_version: 1;
esrganModelName: ParamESRGANModelName;
}
const initialPostprocessingState: PostprocessingState = {
_version: 1,
esrganModelName: 'RealESRGAN_x4plus.pth',
};
export const postprocessingSlice = createSlice({
name: 'postprocessing',
initialState: initialPostprocessingState,
reducers: {
esrganModelNameChanged: (state, action: PayloadAction<ParamESRGANModelName>) => {
state.esrganModelName = action.payload;
},
},
});
export const { esrganModelNameChanged } = postprocessingSlice.actions;
export const selectPostprocessingSlice = (state: RootState) => state.postprocessing;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migratePostprocessingState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const postprocessingPersistConfig: PersistConfig<PostprocessingState> = {
name: postprocessingSlice.name,
initialState: initialPostprocessingState,
migrate: migratePostprocessingState,
persistDenylist: [],
};

View File

@ -12,6 +12,7 @@ interface UpscaleState {
creativity: number;
tileControlnetModel: ControlNetModelConfig | null;
scale: number;
simpleUpscaleModel: ParameterSpandrelImageToImageModel | null;
}
const initialUpscaleState: UpscaleState = {
@ -22,6 +23,7 @@ const initialUpscaleState: UpscaleState = {
creativity: 0,
tileControlnetModel: null,
scale: 4,
simpleUpscaleModel: null,
};
export const upscaleSlice = createSlice({
@ -46,6 +48,9 @@ export const upscaleSlice = createSlice({
scaleChanged: (state, action: PayloadAction<number>) => {
state.scale = action.payload;
},
simpleUpscaleModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => {
state.simpleUpscaleModel = action.payload;
},
},
});
@ -56,6 +61,7 @@ export const {
creativityChanged,
tileControlnetModelChanged,
scaleChanged,
simpleUpscaleModelChanged,
} = upscaleSlice.actions;
export const selectUpscalelice = (state: RootState) => state.upscale;

View File

@ -11,8 +11,8 @@ import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { MultidiffusionWarning } from './MultidiffusionWarning';
import { UpscaleInitialImage } from './UpscaleInitialImage';
import { UpscaleWarning } from './UpscaleWarning';
const selector = createMemoizedSelector([selectUpscalelice], (upscaleSlice) => {
const { upscaleModel, upscaleInitialImage, scale } = upscaleSlice;
@ -54,11 +54,11 @@ export const UpscaleSettingsAccordion = memo(() => {
<Flex gap={4}>
<UpscaleInitialImage />
<Flex direction="column" w="full" alignItems="center" gap={2}>
<ParamSpandrelModel />
<ParamSpandrelModel isMultidiffusion={true} />
<UpscaleScaleSlider />
</Flex>
</Flex>
<MultidiffusionWarning />
<UpscaleWarning usesTile={true} />
</Flex>
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
<Flex gap={4} pb={4} flexDir="column">

View File

@ -7,7 +7,11 @@ import { useCallback, useEffect, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { useControlNetModels } from 'services/api/hooks/modelsByType';
export const MultidiffusionWarning = () => {
interface Props {
usesTile: boolean;
}
export const UpscaleWarning = ({ usesTile }: Props) => {
const { t } = useTranslation();
const model = useAppSelector((s) => s.generation.model);
const { tileControlnetModel, upscaleModel } = useAppSelector((s) => s.upscale);
@ -28,14 +32,14 @@ export const MultidiffusionWarning = () => {
if (!model) {
_warnings.push(t('upscaling.mainModelDesc'));
}
if (!tileControlnetModel) {
if (!tileControlnetModel && usesTile) {
_warnings.push(t('upscaling.tileControlNetModelDesc'));
}
if (!upscaleModel) {
_warnings.push(t('upscaling.upscaleModelDesc'));
}
return _warnings;
}, [model, upscaleModel, tileControlnetModel, t]);
}, [model, upscaleModel, tileControlnetModel, usesTile, t]);
const handleGoToModelManager = useCallback(() => {
dispatch(setActiveTab('models'));