updated simple upscale to use spandrel node and list of available spandrel models

This commit is contained in:
chainchompa 2024-07-23 10:15:31 -04:00
parent bc1d9748ce
commit c098edc6b2
7 changed files with 82 additions and 99 deletions

View File

@ -18,7 +18,6 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening
const log = logger('session');
const { imageDTO } = action.payload;
const { image_name } = imageDTO;
const state = getState();
const { isAllowedToUpscale, detailTKey } = createIsAllowedToUpscaleSelector(imageDTO)(state);
@ -41,7 +40,7 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening
prepend: true,
batch: {
graph: buildAdHocUpscaleGraph({
image_name,
image: imageDTO,
state,
}),
runs: 1,

View File

@ -1,39 +1,40 @@
import type { RootState } from 'app/store/store';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Graph, Invocation, NonNullableGraph } from 'services/api/types';
import type { Graph, ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { assert } from 'tsafe';
import { addCoreMetadataNode, upsertMetadata } from './canvas/metadata';
import { ESRGAN } from './constants';
import { SPANDREL } from './constants';
type Arg = {
image_name: string;
image: ImageDTO;
state: RootState;
};
export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => {
const { esrganModelName } = state.postprocessing;
export const buildAdHocUpscaleGraph = ({ image, state }: Arg): Graph => {
const { simpleUpscaleModel } = state.upscale;
const realesrganNode: Invocation<'esrgan'> = {
id: ESRGAN,
type: 'esrgan',
image: { image_name },
model_name: esrganModelName,
is_intermediate: false,
board: getBoardField(state),
assert(simpleUpscaleModel, 'No upscale model found in state');
const upscaleNode: Invocation<'spandrel_image_to_image'> = {
id: SPANDREL,
type: 'spandrel_image_to_image',
image_to_image_model: simpleUpscaleModel,
tile_size: 500,
image,
};
const graph: NonNullableGraph = {
id: `adhoc-esrgan-graph`,
id: `adhoc-upscale-graph`,
nodes: {
[ESRGAN]: realesrganNode,
[SPANDREL]: upscaleNode,
},
edges: [],
};
addCoreMetadataNode(graph, {}, ESRGAN);
addCoreMetadataNode(graph, {}, SPANDREL);
upsertMetadata(graph, {
esrgan_model: esrganModelName,
spandrel_model: simpleUpscaleModel,
});
return graph;
};
};

View File

@ -1,72 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import { esrganModelNameChanged, isParamESRGANModelName } from 'features/parameters/store/postprocessingSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
const options: GroupBase<ComboboxOption>[] = [
{
label: 'x2 Upscalers',
options: [
{
label: 'RealESRGAN x2 Plus',
value: 'RealESRGAN_x2plus.pth',
description: 'Attempts to retain sharpness, low smoothing',
},
],
},
{
label: 'x4 Upscalers',
options: [
{
label: 'RealESRGAN x4 Plus',
value: 'RealESRGAN_x4plus.pth',
description: 'Best for photos and highly detailed images, medium smoothing',
},
{
label: 'RealESRGAN x4 Plus (anime 6B)',
value: 'RealESRGAN_x4plus_anime_6B.pth',
description: 'Best for anime/manga, high smoothing',
},
{
label: 'ESRGAN SRx4',
value: 'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth',
description: 'Retains sharpness, low smoothing',
},
],
},
];
const ParamESRGANModel = () => {
const { t } = useTranslation();
const esrganModelName = useAppSelector((s) => s.postprocessing.esrganModelName);
const dispatch = useAppDispatch();
const onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isParamESRGANModelName(v?.value)) {
return;
}
dispatch(esrganModelNameChanged(v.value));
},
[dispatch]
);
const value = useMemo(
() => options.flatMap((o) => o.options).find((m) => m.value === esrganModelName),
[esrganModelName]
);
return (
<FormControl orientation="vertical">
<FormLabel>{t('models.esrganModel')}</FormLabel>
<Combobox value={value} onChange={onChange} options={options} />
</FormControl>
);
};
export default memo(ParamESRGANModel);

View File

@ -0,0 +1,46 @@
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useModelCombobox } from 'common/hooks/useModelCombobox';
import { simpleUpscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
const ParamSimpleUpscale = () => {
const { t } = useTranslation();
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
const model = useAppSelector((s) => s.upscale.simpleUpscaleModel);
const dispatch = useAppDispatch();
const _onChange = useCallback(
(v: SpandrelImageToImageModelConfig | null) => {
dispatch(simpleUpscaleModelChanged(v));
},
[dispatch]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: model,
isLoading,
});
return (
<FormControl orientation="vertical">
<FormLabel>{t('upscaling.upscaleModel')}</FormLabel>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
);
};
export default memo(ParamSimpleUpscale);

View File

@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
import { PiFrameCornersBold } from 'react-icons/pi';
import type { ImageDTO } from 'services/api/types';
import ParamESRGANModel from './ParamRealESRGANModel';
import ParamSimpleUpscale from './ParamSimpleUpscale';
type Props = { imageDTO?: ImageDTO };
@ -49,9 +49,9 @@ const ParamUpscalePopover = (props: Props) => {
/>
</PopoverTrigger>
<PopoverContent>
<PopoverBody minW={96}>
<PopoverBody w={96}>
<Flex flexDirection="column" gap={4}>
<ParamESRGANModel />
<ParamSimpleUpscale />
<UpscaleWarning usesTile={false} />
<Button
tooltip={detail}

View File

@ -1,6 +1,6 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectPostprocessingSlice } from 'features/parameters/store/postprocessingSlice';
import { selectUpscalelice } from 'features/parameters/store/upscaleSlice';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -55,13 +55,16 @@ const getDetailTKey = (isAllowedToUpscale?: ReturnType<typeof getIsAllowedToUpsc
};
export const createIsAllowedToUpscaleSelector = (imageDTO?: ImageDTO) =>
createMemoizedSelector(selectPostprocessingSlice, selectConfigSlice, (postprocessing, config) => {
const { esrganModelName } = postprocessing;
createMemoizedSelector(selectUpscalelice, selectConfigSlice, (upscale, config) => {
const { simpleUpscaleModel } = upscale;
const { maxUpscalePixels } = config;
if (!simpleUpscaleModel) {
return { isAllowedToUpscale: false, detailTKey: undefined };
}
const upscaledPixels = getUpscaledPixels(imageDTO, maxUpscalePixels);
const isAllowedToUpscale = getIsAllowedToUpscale(upscaledPixels, maxUpscalePixels);
const scaleFactor = esrganModelName.includes('x2') ? 2 : 4;
const scaleFactor = simpleUpscaleModel.name.includes('x2') ? 2 : 4;
const detailTKey = getDetailTKey(isAllowedToUpscale, scaleFactor);
return {
isAllowedToUpscale: scaleFactor === 2 ? isAllowedToUpscale.x2 : isAllowedToUpscale.x4,

View File

@ -12,6 +12,7 @@ interface UpscaleState {
creativity: number;
tileControlnetModel: ControlNetModelConfig | null;
scale: number;
simpleUpscaleModel: ParameterSpandrelImageToImageModel | null;
}
const initialUpscaleState: UpscaleState = {
@ -22,6 +23,7 @@ const initialUpscaleState: UpscaleState = {
creativity: 0,
tileControlnetModel: null,
scale: 4,
simpleUpscaleModel: null,
};
export const upscaleSlice = createSlice({
@ -46,6 +48,9 @@ export const upscaleSlice = createSlice({
scaleChanged: (state, action: PayloadAction<number>) => {
state.scale = action.payload;
},
simpleUpscaleModelChanged: (state, action: PayloadAction<ParameterSpandrelImageToImageModel | null>) => {
state.simpleUpscaleModel = action.payload;
},
},
});
@ -56,6 +61,7 @@ export const {
creativityChanged,
tileControlnetModelChanged,
scaleChanged,
simpleUpscaleModelChanged,
} = upscaleSlice.actions;
export const selectUpscalelice = (state: RootState) => state.upscale;