feat(ui): restore ad-hoc upscaling

- remove face restoration entirely
- add dropdown for ESRGAN model select
- add ad-hoc upscaling graph and workflow
This commit is contained in:
psychedelicious 2023-07-17 21:08:59 +10:00
parent be659364c2
commit afa84a149c
29 changed files with 229 additions and 728 deletions

View File

@ -59,15 +59,8 @@ export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {
export type Scheduler = (typeof SCHEDULER_NAMES)[number];
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [
{ label: '2x', value: '2' },
{ label: '4x', value: '4' },
];
export const NUMPY_RAND_MIN = 0;
export const NUMPY_RAND_MAX = 2147483647;
export const FACETOOL_TYPES = ['gfpgan', 'codeformer'] as const;
export const NODE_MIN_WIDTH = 250;

View File

@ -90,6 +90,7 @@ import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
export const listenerMiddleware = createListenerMiddleware();
@ -228,3 +229,5 @@ addModelSelectedListener();
addAppStartedListener();
addModelsLoadedListener();
addAppConfigReceivedListener();
addUpscaleRequestedListener();

View File

@ -1,7 +1,7 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { sessionCreated } from 'services/api/thunks/session';
import { serializeError } from 'serialize-error';
import { sessionCreated } from 'services/api/thunks/session';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'session' });

View File

@ -0,0 +1,37 @@
import { createAction } from '@reduxjs/toolkit';
import { log } from 'app/logging/useLogger';
import { buildAdHocUpscaleGraph } from 'features/nodes/util/graphBuilders/buildAdHocUpscaleGraph';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { sessionCreated } from 'services/api/thunks/session';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'upscale' });
export const upscaleRequested = createAction<{ image_name: string }>(
`upscale/upscaleRequested`
);
export const addUpscaleRequestedListener = () => {
startAppListening({
actionCreator: upscaleRequested,
effect: async (
action,
{ dispatch, getState, take, unsubscribe, subscribe }
) => {
const { image_name } = action.payload;
const { esrganModelName } = getState().postprocessing;
const graph = buildAdHocUpscaleGraph({
image_name,
esrganModelName,
});
// Create a session to run the graph & wait til it's ready to invoke
dispatch(sessionCreated({ graph }));
await take(sessionCreated.fulfilled.match);
dispatch(sessionReadyToInvoke());
},
});
};

View File

@ -11,7 +11,7 @@ interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
({ label, tooltip, description, disabled, ...others }: ItemProps, ref) => (
<Tooltip label={tooltip} placement="top" hasArrow>
<Tooltip label={tooltip} placement="top" hasArrow openDelay={500}>
<Box ref={ref} {...others}>
<Box>
<Text>{label}</Text>

View File

@ -5,34 +5,26 @@ import {
ButtonGroup,
Flex,
FlexProps,
Link,
Menu,
MenuButton,
MenuItem,
MenuList,
} from '@chakra-ui/react';
// import { runESRGAN, runFacetool } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppToaster } from 'app/components/Toaster';
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { stateSelector } from 'app/store/store';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { DeleteImageButton } from 'features/imageDeletion/components/DeleteImageButton';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
setActiveTab,
setShouldShowImageDetails,
setShouldShowProgressInViewer,
} from 'features/ui/store/uiSlice';
@ -42,38 +34,25 @@ import { useTranslation } from 'react-i18next';
import {
FaAsterisk,
FaCode,
FaCopy,
FaDownload,
FaExpandArrowsAlt,
FaGrinStars,
FaHourglassHalf,
FaQuoteRight,
FaSeedling,
FaShare,
FaShareAlt,
} from 'react-icons/fa';
import {
useGetImageDTOQuery,
useGetImageMetadataQuery,
} from 'services/api/endpoints/images';
import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
import { menuListMotionProps } from 'theme/components/menu';
import { useDebounce } from 'use-debounce';
import { sentImageToImg2Img } from '../../store/actions';
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
const currentImageButtonsSelector = createSelector(
[stateSelector, activeTabNameSelector],
({ gallery, system, postprocessing, ui }, activeTabName) => {
const {
isProcessing,
isConnected,
isGFPGANAvailable,
isESRGANAvailable,
shouldConfirmOnDelete,
progressImage,
} = system;
const { upscalingLevel, facetoolStrength } = postprocessing;
({ gallery, system, ui }, activeTabName) => {
const { isProcessing, isConnected, shouldConfirmOnDelete, progressImage } =
system;
const {
shouldShowImageDetails,
@ -88,10 +67,6 @@ const currentImageButtonsSelector = createSelector(
shouldConfirmOnDelete,
isProcessing,
isConnected,
isGFPGANAvailable,
isESRGANAvailable,
upscalingLevel,
facetoolStrength,
shouldDisableToolbarButtons: Boolean(progressImage) || !lastSelectedImage,
shouldShowImageDetails,
activeTabName,
@ -114,27 +89,17 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const {
isProcessing,
isConnected,
isGFPGANAvailable,
isESRGANAvailable,
upscalingLevel,
facetoolStrength,
shouldDisableToolbarButtons,
shouldShowImageDetails,
activeTabName,
lastSelectedImage,
shouldShowProgressInViewer,
} = useAppSelector(currentImageButtonsSelector);
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
const isFaceRestoreEnabled = useFeatureStatus('faceRestore').isFeatureEnabled;
const toaster = useAppToaster();
const { t } = useTranslation();
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
@ -155,42 +120,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const metadata = metadataData?.metadata;
const handleCopyImageLink = useCallback(() => {
const getImageUrl = () => {
if (!imageDTO) {
return;
}
if (imageDTO.image_url.startsWith('http')) {
return imageDTO.image_url;
}
return window.location.toString() + imageDTO.image_url;
};
const url = getImageUrl();
if (!url) {
toaster({
title: t('toast.problemCopyingImageLink'),
status: 'error',
duration: 2500,
isClosable: true,
});
return;
}
navigator.clipboard.writeText(url).then(() => {
toaster({
title: t('toast.imageLinkCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
});
}, [toaster, t, imageDTO]);
const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(metadata);
}, [metadata, recallAllParameters]);
@ -223,8 +152,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('shift+i', handleSendToImageToImage, [imageDTO]);
const handleClickUpscale = useCallback(() => {
// selectedImage && dispatch(runESRGAN(selectedImage));
}, []);
if (!imageDTO) {
return;
}
dispatch(upscaleRequested({ image_name: imageDTO.image_name }));
}, [dispatch, imageDTO]);
const handleDelete = useCallback(() => {
if (!imageDTO) {
@ -242,53 +174,17 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
enabled: () =>
Boolean(
isUpscalingEnabled &&
isESRGANAvailable &&
!shouldDisableToolbarButtons &&
isConnected &&
!isProcessing &&
upscalingLevel
!isProcessing
),
},
[
isUpscalingEnabled,
imageDTO,
isESRGANAvailable,
shouldDisableToolbarButtons,
isConnected,
isProcessing,
upscalingLevel,
]
);
const handleClickFixFaces = useCallback(() => {
// selectedImage && dispatch(runFacetool(selectedImage));
}, []);
useHotkeys(
'Shift+R',
() => {
handleClickFixFaces();
},
{
enabled: () =>
Boolean(
isFaceRestoreEnabled &&
isGFPGANAvailable &&
!shouldDisableToolbarButtons &&
isConnected &&
!isProcessing &&
facetoolStrength
),
},
[
isFaceRestoreEnabled,
imageDTO,
isGFPGANAvailable,
shouldDisableToolbarButtons,
isConnected,
isProcessing,
facetoolStrength,
]
);
@ -297,25 +193,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
[dispatch, shouldShowImageDetails]
);
const handleSendToCanvas = useCallback(() => {
if (!imageDTO) return;
dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(imageDTO));
dispatch(requestCanvasRescale());
if (activeTabName !== 'unifiedCanvas') {
dispatch(setActiveTab('unifiedCanvas'));
}
toaster({
title: t('toast.sentToUnifiedCanvas'),
status: 'success',
duration: 2500,
isClosable: true,
});
}, [imageDTO, dispatch, activeTabName, toaster, t]);
useHotkeys(
'i',
() => {
@ -337,13 +214,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
}, [dispatch, shouldShowProgressInViewer]);
const handleCopyImage = useCallback(() => {
if (!imageDTO) {
return;
}
copyImageToClipboard(imageDTO.image_url);
}, [copyImageToClipboard, imageDTO]);
return (
<>
<Flex
@ -396,72 +266,12 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/>
</ButtonGroup>
{(isUpscalingEnabled || isFaceRestoreEnabled) && (
{isUpscalingEnabled && (
<ButtonGroup
isAttached={true}
isDisabled={shouldDisableToolbarButtons}
>
{isFaceRestoreEnabled && (
<IAIPopover
triggerComponent={
<IAIIconButton
icon={<FaGrinStars />}
aria-label={t('parameters.restoreFaces')}
/>
}
>
<Flex
sx={{
flexDirection: 'column',
rowGap: 4,
}}
>
<FaceRestoreSettings />
<IAIButton
isDisabled={
!isGFPGANAvailable ||
!imageDTO ||
!(isConnected && !isProcessing) ||
!facetoolStrength
}
onClick={handleClickFixFaces}
>
{t('parameters.restoreFaces')}
</IAIButton>
</Flex>
</IAIPopover>
)}
{isUpscalingEnabled && (
<IAIPopover
triggerComponent={
<IAIIconButton
icon={<FaExpandArrowsAlt />}
aria-label={t('parameters.upscale')}
/>
}
>
<Flex
sx={{
flexDirection: 'column',
gap: 4,
}}
>
<UpscaleSettings />
<IAIButton
isDisabled={
!isESRGANAvailable ||
!imageDTO ||
!(isConnected && !isProcessing) ||
!upscalingLevel
}
onClick={handleClickUpscale}
>
{t('parameters.upscaleImage')}
</IAIButton>
</Flex>
</IAIPopover>
)}
{isUpscalingEnabled && <ParamUpscalePopover imageDTO={imageDTO} />}
</ButtonGroup>
)}

View File

@ -0,0 +1,32 @@
import { NonNullableGraph } from 'features/nodes/types/types';
import { ESRGANModelName } from 'features/parameters/store/postprocessingSlice';
import { Graph, ESRGANInvocation } from 'services/api/types';
import { REALESRGAN as ESRGAN } from './constants';
type Arg = {
image_name: string;
esrganModelName: ESRGANModelName;
};
export const buildAdHocUpscaleGraph = ({
image_name,
esrganModelName,
}: Arg): Graph => {
const realesrganNode: ESRGANInvocation = {
id: ESRGAN,
type: 'esrgan',
image: { image_name },
model_name: esrganModelName,
is_intermediate: false,
};
const graph: NonNullableGraph = {
id: `adhoc-esrgan-graph`,
nodes: {
[ESRGAN]: realesrganNode,
},
edges: [],
};
return graph;
};

View File

@ -20,6 +20,9 @@ export const DYNAMIC_PROMPT = 'dynamic_prompt';
export const IMAGE_COLLECTION = 'image_collection';
export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate';
export const METADATA_ACCUMULATOR = 'metadata_accumulator';
export const REALESRGAN = 'esrgan';
export const DIVIDE = 'divide';
export const SCALE = 'scale_image';
// friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -1,34 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setCodeformerFidelity } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function CodeformerFidelity() {
const isGFPGANAvailable = useAppSelector(
(state: RootState) => state.system.isGFPGANAvailable
);
const codeformerFidelity = useAppSelector(
(state: RootState) => state.postprocessing.codeformerFidelity
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
isDisabled={!isGFPGANAvailable}
label={t('parameters.codeformerFidelity')}
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setCodeformerFidelity(v))}
handleReset={() => dispatch(setCodeformerFidelity(1))}
value={codeformerFidelity}
withReset
withSliderMarks
withInput
/>
);
}

View File

@ -1,25 +0,0 @@
import { VStack } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import type { RootState } from 'app/store/store';
import FaceRestoreType from './FaceRestoreType';
import FaceRestoreStrength from './FaceRestoreStrength';
import CodeformerFidelity from './CodeformerFidelity';
/**
* Displays face-fixing/GFPGAN options (strength).
*/
const FaceRestoreSettings = () => {
const facetoolType = useAppSelector(
(state: RootState) => state.postprocessing.facetoolType
);
return (
<VStack gap={2} alignItems="stretch">
<FaceRestoreType />
<FaceRestoreStrength />
{facetoolType === 'codeformer' && <CodeformerFidelity />}
</VStack>
);
};
export default FaceRestoreSettings;

View File

@ -1,34 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setFacetoolStrength } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function FaceRestoreStrength() {
const isGFPGANAvailable = useAppSelector(
(state: RootState) => state.system.isGFPGANAvailable
);
const facetoolStrength = useAppSelector(
(state: RootState) => state.postprocessing.facetoolStrength
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
isDisabled={!isGFPGANAvailable}
label={t('parameters.strength')}
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setFacetoolStrength(v))}
handleReset={() => dispatch(setFacetoolStrength(0.75))}
value={facetoolStrength}
withReset
withSliderMarks
withInput
/>
);
}

View File

@ -1,28 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldRunFacetool } from 'features/parameters/store/postprocessingSlice';
import { ChangeEvent } from 'react';
export default function FaceRestoreToggle() {
const isGFPGANAvailable = useAppSelector(
(state: RootState) => state.system.isGFPGANAvailable
);
const shouldRunFacetool = useAppSelector(
(state: RootState) => state.postprocessing.shouldRunFacetool
);
const dispatch = useAppDispatch();
const handleChangeShouldRunFacetool = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRunFacetool(e.target.checked));
return (
<IAISwitch
isDisabled={!isGFPGANAvailable}
isChecked={shouldRunFacetool}
onChange={handleChangeShouldRunFacetool}
/>
);
}

View File

@ -1,30 +0,0 @@
import { FACETOOL_TYPES } from 'app/constants';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
FacetoolType,
setFacetoolType,
} from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function FaceRestoreType() {
const facetoolType = useAppSelector(
(state: RootState) => state.postprocessing.facetoolType
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChangeFacetoolType = (v: string) =>
dispatch(setFacetoolType(v as FacetoolType));
return (
<IAIMantineSearchableSelect
label={t('parameters.type')}
data={FACETOOL_TYPES.concat()}
value={facetoolType}
onChange={handleChangeFacetoolType}
/>
);
}

View File

@ -1,43 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { ParamHiresStrength } from './ParamHiresStrength';
import { ParamHiresToggle } from './ParamHiresToggle';
const selector = createSelector(
stateSelector,
(state) => {
const activeLabel = state.postprocessing.hiresFix ? 'Enabled' : undefined;
return { activeLabel };
},
defaultSelectorOptions
);
const ParamHiresCollapse = () => {
const { t } = useTranslation();
const { activeLabel } = useAppSelector(selector);
const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled;
if (!isHiresEnabled) {
return null;
}
return (
<IAICollapse label={t('parameters.hiresOptim')} activeLabel={activeLabel}>
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamHiresToggle />
<ParamHiresStrength />
</Flex>
</IAICollapse>
);
};
export default memo(ParamHiresCollapse);

View File

@ -1,51 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
import { setHiresStrength } from 'features/parameters/store/postprocessingSlice';
import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
const hiresStrengthSelector = createSelector(
[postprocessingSelector],
({ hiresFix, hiresStrength }) => ({ hiresFix, hiresStrength }),
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const ParamHiresStrength = () => {
const { hiresFix, hiresStrength } = useAppSelector(hiresStrengthSelector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleHiresStrength = (v: number) => {
dispatch(setHiresStrength(v));
};
const handleHiResStrengthReset = () => {
dispatch(setHiresStrength(0.75));
};
return (
<IAISlider
label={t('parameters.hiresStrength')}
step={0.01}
min={0.01}
max={0.99}
onChange={handleHiresStrength}
value={hiresStrength}
isInteger={false}
withInput
withSliderMarks
// inputWidth={22}
withReset
handleReset={handleHiResStrengthReset}
isDisabled={!hiresFix}
/>
);
};

View File

@ -1,30 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
/**
* Hires Fix Toggle
*/
export const ParamHiresToggle = () => {
const dispatch = useAppDispatch();
const hiresFix = useAppSelector(
(state: RootState) => state.postprocessing.hiresFix
);
const { t } = useTranslation();
const handleChangeHiresFix = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setHiresFix(e.target.checked));
return (
<IAISwitch
label={t('parameters.hiresOptim')}
isChecked={hiresFix}
onChange={handleChangeHiresFix}
/>
);
};

View File

@ -0,0 +1,58 @@
import { SelectItem } from '@mantine/core';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import {
ESRGANModelName,
esrganModelNameChanged,
} from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export const ESRGAN_MODEL_NAMES: SelectItem[] = [
{
label: 'RealESRGAN x2 Plus',
value: 'RealESRGAN_x2plus.pth',
tooltip: 'Attempts to retain sharpness, low smoothing',
group: 'x2 Upscalers',
},
{
label: 'RealESRGAN x4 Plus',
value: 'RealESRGAN_x4plus.pth',
tooltip: 'Best for photos and highly detailed images, medium smoothing',
group: 'x4 Upscalers',
},
{
label: 'RealESRGAN x4 Plus (anime 6B)',
value: 'RealESRGAN_x4plus_anime_6B.pth',
tooltip: 'Best for anime/manga, high smoothing',
group: 'x4 Upscalers',
},
{
label: 'ESRGAN SRx4',
value: 'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth',
tooltip: 'Retains sharpness, low smoothing',
group: 'x4 Upscalers',
},
];
export default function ParamESRGANModel() {
const esrganModelName = useAppSelector(
(state: RootState) => state.postprocessing.esrganModelName
);
const dispatch = useAppDispatch();
const handleChange = (v: string) =>
dispatch(esrganModelNameChanged(v as ESRGANModelName));
return (
<IAIMantineSelect
label="ESRGAN Model"
value={esrganModelName}
itemComponent={IAIMantineSelectItemWithTooltip}
onChange={handleChange}
data={ESRGAN_MODEL_NAMES}
/>
);
}

View File

@ -0,0 +1,62 @@
import { Flex, useDisclosure } from '@chakra-ui/react';
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaExpandArrowsAlt } from 'react-icons/fa';
import { ImageDTO } from 'services/api/types';
import ParamESRGANModel from './ParamRealESRGANModel';
type Props = { imageDTO?: ImageDTO };
const ParamUpscalePopover = (props: Props) => {
const { imageDTO } = props;
const dispatch = useAppDispatch();
const isBusy = useAppSelector(selectIsBusy);
const { t } = useTranslation();
const { isOpen, onOpen, onClose } = useDisclosure();
const handleClickUpscale = useCallback(() => {
onClose();
if (!imageDTO) {
return;
}
dispatch(upscaleRequested({ image_name: imageDTO.image_name }));
}, [dispatch, imageDTO, onClose]);
return (
<IAIPopover
isOpen={isOpen}
onClose={onClose}
triggerComponent={
<IAIIconButton
onClick={onOpen}
icon={<FaExpandArrowsAlt />}
aria-label={t('parameters.upscale')}
/>
}
>
<Flex
sx={{
flexDirection: 'column',
gap: 4,
}}
>
<ParamESRGANModel />
<IAIButton
size="sm"
isDisabled={!imageDTO || isBusy}
onClick={handleClickUpscale}
>
{t('parameters.upscaleImage')}
</IAIButton>
</Flex>
</IAIPopover>
);
};
export default ParamUpscalePopover;

View File

@ -1,36 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setUpscalingDenoising } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function UpscaleDenoisingStrength() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const upscalingDenoising = useAppSelector(
(state: RootState) => state.postprocessing.upscalingDenoising
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
label={t('parameters.denoisingStrength')}
value={upscalingDenoising}
min={0}
max={1}
step={0.01}
onChange={(v) => {
dispatch(setUpscalingDenoising(v));
}}
handleReset={() => dispatch(setUpscalingDenoising(0.75))}
withSliderMarks
withInput
withReset
isDisabled={!isESRGANAvailable}
/>
);
}

View File

@ -1,35 +0,0 @@
import { UPSCALING_LEVELS } from 'app/constants';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
UpscalingLevel,
setUpscalingLevel,
} from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function UpscaleScale() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const upscalingLevel = useAppSelector(
(state: RootState) => state.postprocessing.upscalingLevel
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
const handleChangeLevel = (v: string) =>
dispatch(setUpscalingLevel(Number(v) as UpscalingLevel));
return (
<IAIMantineSearchableSelect
disabled={!isESRGANAvailable}
label={t('parameters.scale')}
value={String(upscalingLevel)}
onChange={handleChangeLevel}
data={UPSCALING_LEVELS}
/>
);
}

View File

@ -1,19 +0,0 @@
import { VStack } from '@chakra-ui/react';
import UpscaleDenoisingStrength from './UpscaleDenoisingStrength';
import UpscaleStrength from './UpscaleStrength';
import UpscaleScale from './UpscaleScale';
/**
* Displays upscaling/ESRGAN options (level and strength).
*/
const UpscaleSettings = () => {
return (
<VStack gap={2} alignItems="stretch">
<UpscaleScale />
<UpscaleDenoisingStrength />
<UpscaleStrength />
</VStack>
);
};
export default UpscaleSettings;

View File

@ -1,33 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setUpscalingStrength } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function UpscaleStrength() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const upscalingStrength = useAppSelector(
(state: RootState) => state.postprocessing.upscalingStrength
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
label={`${t('parameters.upscale')} ${t('parameters.strength')}`}
value={upscalingStrength}
min={0}
max={1}
step={0.05}
onChange={(v) => dispatch(setUpscalingStrength(v))}
handleReset={() => dispatch(setUpscalingStrength(0.75))}
withSliderMarks
withInput
withReset
isDisabled={!isESRGANAvailable}
/>
);
}

View File

@ -1,26 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldRunESRGAN } from 'features/parameters/store/postprocessingSlice';
import { ChangeEvent } from 'react';
export default function UpscaleToggle() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const shouldRunESRGAN = useAppSelector(
(state: RootState) => state.postprocessing.shouldRunESRGAN
);
const dispatch = useAppDispatch();
const handleChangeShouldRunESRGAN = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRunESRGAN(e.target.checked));
return (
<IAISwitch
isDisabled={!isESRGANAvailable}
isChecked={shouldRunESRGAN}
onChange={handleChangeShouldRunESRGAN}
/>
);
}

View File

@ -1,98 +1,27 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { FACETOOL_TYPES } from 'app/constants';
import { ESRGANInvocation } from 'services/api/types';
export type UpscalingLevel = 2 | 4;
export type FacetoolType = (typeof FACETOOL_TYPES)[number];
export type ESRGANModelName = NonNullable<ESRGANInvocation['model_name']>;
export interface PostprocessingState {
codeformerFidelity: number;
facetoolStrength: number;
facetoolType: FacetoolType;
hiresFix: boolean;
hiresStrength: number;
shouldLoopback: boolean;
shouldRunESRGAN: boolean;
shouldRunFacetool: boolean;
upscalingLevel: UpscalingLevel;
upscalingDenoising: number;
upscalingStrength: number;
esrganModelName: ESRGANModelName;
}
export const initialPostprocessingState: PostprocessingState = {
codeformerFidelity: 0.75,
facetoolStrength: 0.75,
facetoolType: 'gfpgan',
hiresFix: false,
hiresStrength: 0.75,
shouldLoopback: false,
shouldRunESRGAN: false,
shouldRunFacetool: false,
upscalingLevel: 4,
upscalingDenoising: 0.75,
upscalingStrength: 0.75,
esrganModelName: 'RealESRGAN_x4plus.pth',
};
export const postprocessingSlice = createSlice({
name: 'postprocessing',
initialState: initialPostprocessingState,
reducers: {
setFacetoolStrength: (state, action: PayloadAction<number>) => {
state.facetoolStrength = action.payload;
},
setCodeformerFidelity: (state, action: PayloadAction<number>) => {
state.codeformerFidelity = action.payload;
},
setUpscalingLevel: (state, action: PayloadAction<UpscalingLevel>) => {
state.upscalingLevel = action.payload;
},
setUpscalingDenoising: (state, action: PayloadAction<number>) => {
state.upscalingDenoising = action.payload;
},
setUpscalingStrength: (state, action: PayloadAction<number>) => {
state.upscalingStrength = action.payload;
},
setHiresFix: (state, action: PayloadAction<boolean>) => {
state.hiresFix = action.payload;
},
setHiresStrength: (state, action: PayloadAction<number>) => {
state.hiresStrength = action.payload;
},
resetPostprocessingState: (state) => {
return {
...state,
...initialPostprocessingState,
};
},
setShouldRunFacetool: (state, action: PayloadAction<boolean>) => {
state.shouldRunFacetool = action.payload;
},
setFacetoolType: (state, action: PayloadAction<FacetoolType>) => {
state.facetoolType = action.payload;
},
setShouldRunESRGAN: (state, action: PayloadAction<boolean>) => {
state.shouldRunESRGAN = action.payload;
},
setShouldLoopback: (state, action: PayloadAction<boolean>) => {
state.shouldLoopback = action.payload;
esrganModelNameChanged: (state, action: PayloadAction<ESRGANModelName>) => {
state.esrganModelName = action.payload;
},
},
});
export const {
resetPostprocessingState,
setCodeformerFidelity,
setFacetoolStrength,
setFacetoolType,
setHiresFix,
setHiresStrength,
setShouldLoopback,
setShouldRunESRGAN,
setShouldRunFacetool,
setUpscalingLevel,
setUpscalingDenoising,
setUpscalingStrength,
} = postprocessingSlice.actions;
export const { esrganModelNameChanged } = postprocessingSlice.actions;
export default postprocessingSlice.reducer;

View File

@ -4,7 +4,6 @@ import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Adv
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
@ -26,7 +25,6 @@ const TextToImageTabParameters = () => {
<ParamVariationCollapse />
<ParamNoiseCollapse />
<ParamSymmetryCollapse />
<ParamHiresCollapse />
<ParamSeamlessCollapse />
<ParamAdvancedCollapse />
</>

View File

@ -90,6 +90,9 @@ export type InpaintInvocation = TypeReq<
export type ImageResizeInvocation = TypeReq<
components['schemas']['ImageResizeInvocation']
>;
export type ImageScaleInvocation = TypeReq<
components['schemas']['ImageScaleInvocation']
>;
export type RandomIntInvocation = TypeReq<
components['schemas']['RandomIntInvocation']
>;
@ -124,6 +127,12 @@ export type LoraLoaderInvocation = TypeReq<
export type MetadataAccumulatorInvocation = TypeReq<
components['schemas']['MetadataAccumulatorInvocation']
>;
export type ESRGANInvocation = TypeReq<
components['schemas']['ESRGANInvocation']
>;
export type DivideInvocation = TypeReq<
components['schemas']['DivideInvocation']
>;
// ControlNet Nodes
export type ControlNetInvocation = TypeReq<