Merge branch 'main' into onnx-testing

This commit is contained in:
Brandon Rising
2023-07-18 22:56:41 -04:00
361 changed files with 13813 additions and 10110 deletions

View File

@ -14,6 +14,14 @@ export const clipSkipMap = {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
sdxl: {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
'sdxl-refiner': {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
};
export default function ParamClipSkip() {

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { setBoundingBoxScaleMethod } from 'features/canvas/store/canvasSlice';
import {
@ -35,7 +35,7 @@ const ParamScaleBeforeProcessing = () => {
};
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
label={t('parameters.scaleBeforeProcessing')}
data={BOUNDING_BOX_SCALES_DICT}
value={boundingBoxScale}

View File

@ -2,12 +2,13 @@ import { Divider, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAICollapse from 'common/components/IAICollapse';
import IAIIconButton from 'common/components/IAIIconButton';
import ControlNet from 'features/controlNet/components/ControlNet';
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
import {
controlNetAdded,
controlNetModelChanged,
controlNetSelector,
} from 'features/controlNet/store/controlNetSlice';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
@ -15,6 +16,8 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { map } from 'lodash-es';
import { Fragment, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaPlus } from 'react-icons/fa';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { v4 as uuidv4 } from 'uuid';
const selector = createSelector(
@ -39,10 +42,23 @@ const ParamControlNetCollapse = () => {
const { controlNetsArray, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const dispatch = useAppDispatch();
const { firstModel } = useGetControlNetModelsQuery(undefined, {
selectFromResult: (result) => {
const firstModel = result.data?.entities[result.data?.ids[0]];
return {
firstModel,
};
},
});
const handleClickedAddControlNet = useCallback(() => {
dispatch(controlNetAdded({ controlNetId: uuidv4() }));
}, [dispatch]);
if (!firstModel) {
return;
}
const controlNetId = uuidv4();
dispatch(controlNetAdded({ controlNetId }));
dispatch(controlNetModelChanged({ controlNetId, model: firstModel }));
}, [dispatch, firstModel]);
if (isControlNetDisabled) {
return null;
@ -51,16 +67,39 @@ const ParamControlNetCollapse = () => {
return (
<IAICollapse label="ControlNet" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 3 }}>
<ParamControlNetFeatureToggle />
<Flex gap={2} alignItems="center">
<Flex
sx={{
flexDirection: 'column',
w: '100%',
gap: 2,
px: 4,
py: 2,
borderRadius: 4,
bg: 'base.200',
_dark: {
bg: 'base.850',
},
}}
>
<ParamControlNetFeatureToggle />
</Flex>
<IAIIconButton
tooltip="Add ControlNet"
aria-label="Add ControlNet"
icon={<FaPlus />}
isDisabled={!firstModel}
flexGrow={1}
size="md"
onClick={handleClickedAddControlNet}
/>
</Flex>
{controlNetsArray.map((c, i) => (
<Fragment key={c.controlNetId}>
{i > 0 && <Divider />}
<ControlNet controlNet={c} />
<ControlNet controlNetId={c.controlNetId} />
</Fragment>
))}
<IAIButton flexGrow={1} onClick={handleClickedAddControlNet}>
Add ControlNet
</IAIButton>
</Flex>
</IAICollapse>
);

View File

@ -1,8 +1,8 @@
import { Box, Flex } from '@chakra-ui/react';
import ModelSelect from 'features/system/components/ModelSelect';
import VAESelect from 'features/system/components/VAESelect';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
import ParamScheduler from './ParamScheduler';
const ParamModelandVAEandScheduler = () => {
@ -11,12 +11,12 @@ const ParamModelandVAEandScheduler = () => {
return (
<Flex gap={3} w="full" flexWrap={isVaeEnabled ? 'wrap' : 'nowrap'}>
<Box w="full">
<ModelSelect />
<ParamMainModelSelect />
</Box>
<Flex gap={3} w="full">
{isVaeEnabled && (
<Box w="full">
<VAESelect />
<ParamVAEModelSelect />
</Box>
)}
<Box w="full">

View File

@ -2,10 +2,10 @@ import { createSelector } from '@reduxjs/toolkit';
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -48,7 +48,7 @@ const ParamScheduler = () => {
);
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
label={t('parameters.scheduler')}
value={scheduler}
data={data}

View File

@ -1,34 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setCodeformerFidelity } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function CodeformerFidelity() {
const isGFPGANAvailable = useAppSelector(
(state: RootState) => state.system.isGFPGANAvailable
);
const codeformerFidelity = useAppSelector(
(state: RootState) => state.postprocessing.codeformerFidelity
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
isDisabled={!isGFPGANAvailable}
label={t('parameters.codeformerFidelity')}
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setCodeformerFidelity(v))}
handleReset={() => dispatch(setCodeformerFidelity(1))}
value={codeformerFidelity}
withReset
withSliderMarks
withInput
/>
);
}

View File

@ -1,25 +0,0 @@
import { VStack } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import type { RootState } from 'app/store/store';
import FaceRestoreType from './FaceRestoreType';
import FaceRestoreStrength from './FaceRestoreStrength';
import CodeformerFidelity from './CodeformerFidelity';
/**
* Displays face-fixing/GFPGAN options (strength).
*/
const FaceRestoreSettings = () => {
const facetoolType = useAppSelector(
(state: RootState) => state.postprocessing.facetoolType
);
return (
<VStack gap={2} alignItems="stretch">
<FaceRestoreType />
<FaceRestoreStrength />
{facetoolType === 'codeformer' && <CodeformerFidelity />}
</VStack>
);
};
export default FaceRestoreSettings;

View File

@ -1,34 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setFacetoolStrength } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function FaceRestoreStrength() {
const isGFPGANAvailable = useAppSelector(
(state: RootState) => state.system.isGFPGANAvailable
);
const facetoolStrength = useAppSelector(
(state: RootState) => state.postprocessing.facetoolStrength
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
isDisabled={!isGFPGANAvailable}
label={t('parameters.strength')}
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setFacetoolStrength(v))}
handleReset={() => dispatch(setFacetoolStrength(0.75))}
value={facetoolStrength}
withReset
withSliderMarks
withInput
/>
);
}

View File

@ -1,28 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldRunFacetool } from 'features/parameters/store/postprocessingSlice';
import { ChangeEvent } from 'react';
export default function FaceRestoreToggle() {
const isGFPGANAvailable = useAppSelector(
(state: RootState) => state.system.isGFPGANAvailable
);
const shouldRunFacetool = useAppSelector(
(state: RootState) => state.postprocessing.shouldRunFacetool
);
const dispatch = useAppDispatch();
const handleChangeShouldRunFacetool = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRunFacetool(e.target.checked));
return (
<IAISwitch
isDisabled={!isGFPGANAvailable}
isChecked={shouldRunFacetool}
onChange={handleChangeShouldRunFacetool}
/>
);
}

View File

@ -1,30 +0,0 @@
import { FACETOOL_TYPES } from 'app/constants';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
FacetoolType,
setFacetoolType,
} from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function FaceRestoreType() {
const facetoolType = useAppSelector(
(state: RootState) => state.postprocessing.facetoolType
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChangeFacetoolType = (v: string) =>
dispatch(setFacetoolType(v as FacetoolType));
return (
<IAIMantineSelect
label={t('parameters.type')}
data={FACETOOL_TYPES.concat()}
value={facetoolType}
onChange={handleChangeFacetoolType}
/>
);
}

View File

@ -1,43 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { ParamHiresStrength } from './ParamHiresStrength';
import { ParamHiresToggle } from './ParamHiresToggle';
const selector = createSelector(
stateSelector,
(state) => {
const activeLabel = state.postprocessing.hiresFix ? 'Enabled' : undefined;
return { activeLabel };
},
defaultSelectorOptions
);
const ParamHiresCollapse = () => {
const { t } = useTranslation();
const { activeLabel } = useAppSelector(selector);
const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled;
if (!isHiresEnabled) {
return null;
}
return (
<IAICollapse label={t('parameters.hiresOptim')} activeLabel={activeLabel}>
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamHiresToggle />
<ParamHiresStrength />
</Flex>
</IAICollapse>
);
};
export default memo(ParamHiresCollapse);

View File

@ -1,51 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
import { setHiresStrength } from 'features/parameters/store/postprocessingSlice';
import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
const hiresStrengthSelector = createSelector(
[postprocessingSelector],
({ hiresFix, hiresStrength }) => ({ hiresFix, hiresStrength }),
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const ParamHiresStrength = () => {
const { hiresFix, hiresStrength } = useAppSelector(hiresStrengthSelector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleHiresStrength = (v: number) => {
dispatch(setHiresStrength(v));
};
const handleHiResStrengthReset = () => {
dispatch(setHiresStrength(0.75));
};
return (
<IAISlider
label={t('parameters.hiresStrength')}
step={0.01}
min={0.01}
max={0.99}
onChange={handleHiresStrength}
value={hiresStrength}
isInteger={false}
withInput
withSliderMarks
// inputWidth={22}
withReset
handleReset={handleHiResStrengthReset}
isDisabled={!hiresFix}
/>
);
};

View File

@ -1,30 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
/**
* Hires Fix Toggle
*/
export const ParamHiresToggle = () => {
const dispatch = useAppDispatch();
const hiresFix = useAppSelector(
(state: RootState) => state.postprocessing.hiresFix
);
const { t } = useTranslation();
const handleChangeHiresFix = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setHiresFix(e.target.checked));
return (
<IAISwitch
label={t('parameters.hiresOptim')}
isChecked={hiresFix}
onChange={handleChangeHiresFix}
/>
);
};

View File

@ -1,30 +1,24 @@
import { Flex, Icon, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { useMemo } from 'react';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { FaImage } from 'react-icons/fa';
import { stateSelector } from 'app/store/store';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { useMemo } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
const selector = createSelector(
[stateSelector],
(state) => {
const { initialImage } = state.generation;
const { asInitialImage: useBatchAsInitialImage, imageNames } = state.batch;
return {
initialImage,
useBatchAsInitialImage,
isResetButtonDisabled: useBatchAsInitialImage
? imageNames.length === 0
: !initialImage,
isResetButtonDisabled: !initialImage,
};
},
defaultSelectorOptions

View File

@ -1,22 +1,14 @@
import { Flex, Spacer, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { useCallback, useMemo } from 'react';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaLayerGroup, FaUndo, FaUpload } from 'react-icons/fa';
import useImageUploader from 'common/hooks/useImageUploader';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import IAIButton from 'common/components/IAIButton';
import { stateSelector } from 'app/store/store';
import {
asInitialImageToggled,
batchReset,
} from 'features/batch/store/batchSlice';
import BatchImageContainer from 'features/batch/components/BatchImageContainer';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import useImageUploader from 'common/hooks/useImageUploader';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { useCallback } from 'react';
import { FaUndo, FaUpload } from 'react-icons/fa';
import { PostUploadAction } from 'services/api/thunks/image';
import InitialImage from './InitialImage';
@ -24,59 +16,34 @@ const selector = createSelector(
[stateSelector],
(state) => {
const { initialImage } = state.generation;
const { asInitialImage: useBatchAsInitialImage, imageNames } = state.batch;
return {
initialImage,
useBatchAsInitialImage,
isResetButtonDisabled: useBatchAsInitialImage
? imageNames.length === 0
: !initialImage,
isResetButtonDisabled: !initialImage,
};
},
defaultSelectorOptions
);
const postUploadAction: PostUploadAction = {
type: 'SET_INITIAL_IMAGE',
};
const InitialImageDisplay = () => {
const { initialImage, useBatchAsInitialImage, isResetButtonDisabled } =
useAppSelector(selector);
const { isResetButtonDisabled } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { openUploader } = useImageUploader();
const {
currentData: imageDTO,
isLoading,
isError,
isSuccess,
} = useGetImageDTOQuery(initialImage?.imageName ?? skipToken);
const postUploadAction = useMemo<PostUploadAction>(
() =>
useBatchAsInitialImage
? { type: 'ADD_TO_BATCH' }
: { type: 'SET_INITIAL_IMAGE' },
[useBatchAsInitialImage]
);
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
postUploadAction,
});
const handleReset = useCallback(() => {
if (useBatchAsInitialImage) {
dispatch(batchReset());
} else {
dispatch(clearInitialImage());
}
}, [dispatch, useBatchAsInitialImage]);
dispatch(clearInitialImage());
}, [dispatch]);
const handleUpload = useCallback(() => {
openUploader();
}, [openUploader]);
const handleClickUseBatch = useCallback(() => {
dispatch(asInitialImageToggled());
}, [dispatch]);
return (
<Flex
layerStyle={'first'}
@ -114,40 +81,22 @@ const InitialImageDisplay = () => {
Initial Image
</Text>
<Spacer />
{/* <IAIButton
tooltip={useBatchAsInitialImage ? 'Disable Batch' : 'Enable Batch'}
aria-label={useBatchAsInitialImage ? 'Disable Batch' : 'Enable Batch'}
leftIcon={<FaLayerGroup />}
isChecked={useBatchAsInitialImage}
onClick={handleClickUseBatch}
>
{useBatchAsInitialImage ? 'Batch' : 'Single'}
</IAIButton> */}
<IAIIconButton
tooltip={
useBatchAsInitialImage ? 'Upload to Batch' : 'Upload Initial Image'
}
aria-label={
useBatchAsInitialImage ? 'Upload to Batch' : 'Upload Initial Image'
}
tooltip={'Upload Initial Image'}
aria-label={'Upload Initial Image'}
icon={<FaUpload />}
onClick={handleUpload}
{...getUploadButtonProps()}
/>
<IAIIconButton
tooltip={
useBatchAsInitialImage ? 'Reset Batch' : 'Reset Initial Image'
}
aria-label={
useBatchAsInitialImage ? 'Reset Batch' : 'Reset Initial Image'
}
tooltip={'Reset Initial Image'}
aria-label={'Reset Initial Image'}
icon={<FaUndo />}
onClick={handleReset}
isDisabled={isResetButtonDisabled}
/>
</Flex>
<InitialImage />
{/* {useBatchAsInitialImage ? <BatchImageContainer /> : <InitialImage />} */}
<input {...getUploadInputProps()} />
</Flex>
);

View File

@ -0,0 +1,123 @@
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { modelSelected } from 'features/parameters/store/actions';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { forEach } from 'lodash-es';
import {
useGetMainModelsQuery,
useGetOnnxModelsQuery,
} from 'services/api/endpoints/models';
import { modelIdToOnnxModelField } from 'features/nodes/util/modelIdToOnnxModelField';
const selector = createSelector(
stateSelector,
(state) => ({ model: state.generation.model }),
defaultSelectorOptions
);
const ParamMainModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { model } = useAppSelector(selector);
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const { data: onnxModels, isLoading: onnxLoading } = useGetOnnxModelsQuery();
const data = useMemo(() => {
if (!mainModels) {
return [];
}
const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => {
if (!model || ['sdxl', 'sdxl-refiner'].includes(model.base_model)) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
forEach(onnxModels?.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [mainModels, onnxModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
(mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ||
onnxModels?.entities[
`${model?.base_model}/onnx/${model?.model_name}`
]) ??
null,
[mainModels?.entities, model, onnxModels?.entities]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
let newModel = modelIdToMainModelParam(v);
if (v.includes('onnx')) {
newModel = modelIdToOnnxModelField(v);
}
if (!newModel) {
return;
}
dispatch(modelSelected(newModel));
},
[dispatch]
);
return isLoading || onnxLoading ? (
<IAIMantineSearchableSelect
label={t('modelManager.model')}
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label={t('modelManager.model')}
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
disabled={data.length === 0}
onChange={handleChangeModel}
/>
);
};
export default memo(ParamMainModelSelect);

View File

@ -0,0 +1,58 @@
import { SelectItem } from '@mantine/core';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import {
ESRGANModelName,
esrganModelNameChanged,
} from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export const ESRGAN_MODEL_NAMES: SelectItem[] = [
{
label: 'RealESRGAN x2 Plus',
value: 'RealESRGAN_x2plus.pth',
tooltip: 'Attempts to retain sharpness, low smoothing',
group: 'x2 Upscalers',
},
{
label: 'RealESRGAN x4 Plus',
value: 'RealESRGAN_x4plus.pth',
tooltip: 'Best for photos and highly detailed images, medium smoothing',
group: 'x4 Upscalers',
},
{
label: 'RealESRGAN x4 Plus (anime 6B)',
value: 'RealESRGAN_x4plus_anime_6B.pth',
tooltip: 'Best for anime/manga, high smoothing',
group: 'x4 Upscalers',
},
{
label: 'ESRGAN SRx4',
value: 'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth',
tooltip: 'Retains sharpness, low smoothing',
group: 'x4 Upscalers',
},
];
export default function ParamESRGANModel() {
const esrganModelName = useAppSelector(
(state: RootState) => state.postprocessing.esrganModelName
);
const dispatch = useAppDispatch();
const handleChange = (v: string) =>
dispatch(esrganModelNameChanged(v as ESRGANModelName));
return (
<IAIMantineSelect
label="ESRGAN Model"
value={esrganModelName}
itemComponent={IAIMantineSelectItemWithTooltip}
onChange={handleChange}
data={ESRGAN_MODEL_NAMES}
/>
);
}

View File

@ -0,0 +1,62 @@
import { Flex, useDisclosure } from '@chakra-ui/react';
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaExpandArrowsAlt } from 'react-icons/fa';
import { ImageDTO } from 'services/api/types';
import ParamESRGANModel from './ParamRealESRGANModel';
type Props = { imageDTO?: ImageDTO };
const ParamUpscalePopover = (props: Props) => {
const { imageDTO } = props;
const dispatch = useAppDispatch();
const isBusy = useAppSelector(selectIsBusy);
const { t } = useTranslation();
const { isOpen, onOpen, onClose } = useDisclosure();
const handleClickUpscale = useCallback(() => {
onClose();
if (!imageDTO) {
return;
}
dispatch(upscaleRequested({ image_name: imageDTO.image_name }));
}, [dispatch, imageDTO, onClose]);
return (
<IAIPopover
isOpen={isOpen}
onClose={onClose}
triggerComponent={
<IAIIconButton
onClick={onOpen}
icon={<FaExpandArrowsAlt />}
aria-label={t('parameters.upscale')}
/>
}
>
<Flex
sx={{
flexDirection: 'column',
gap: 4,
}}
>
<ParamESRGANModel />
<IAIButton
size="sm"
isDisabled={!imageDTO || isBusy}
onClick={handleClickUpscale}
>
{t('parameters.upscaleImage')}
</IAIButton>
</Flex>
</IAIPopover>
);
};
export default ParamUpscalePopover;

View File

@ -1,36 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setUpscalingDenoising } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function UpscaleDenoisingStrength() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const upscalingDenoising = useAppSelector(
(state: RootState) => state.postprocessing.upscalingDenoising
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
label={t('parameters.denoisingStrength')}
value={upscalingDenoising}
min={0}
max={1}
step={0.01}
onChange={(v) => {
dispatch(setUpscalingDenoising(v));
}}
handleReset={() => dispatch(setUpscalingDenoising(0.75))}
withSliderMarks
withInput
withReset
isDisabled={!isESRGANAvailable}
/>
);
}

View File

@ -1,35 +0,0 @@
import { UPSCALING_LEVELS } from 'app/constants';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
UpscalingLevel,
setUpscalingLevel,
} from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function UpscaleScale() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const upscalingLevel = useAppSelector(
(state: RootState) => state.postprocessing.upscalingLevel
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
const handleChangeLevel = (v: string) =>
dispatch(setUpscalingLevel(Number(v) as UpscalingLevel));
return (
<IAIMantineSelect
disabled={!isESRGANAvailable}
label={t('parameters.scale')}
value={String(upscalingLevel)}
onChange={handleChangeLevel}
data={UPSCALING_LEVELS}
/>
);
}

View File

@ -1,19 +0,0 @@
import { VStack } from '@chakra-ui/react';
import UpscaleDenoisingStrength from './UpscaleDenoisingStrength';
import UpscaleStrength from './UpscaleStrength';
import UpscaleScale from './UpscaleScale';
/**
* Displays upscaling/ESRGAN options (level and strength).
*/
const UpscaleSettings = () => {
return (
<VStack gap={2} alignItems="stretch">
<UpscaleScale />
<UpscaleDenoisingStrength />
<UpscaleStrength />
</VStack>
);
};
export default UpscaleSettings;

View File

@ -1,33 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setUpscalingStrength } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function UpscaleStrength() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const upscalingStrength = useAppSelector(
(state: RootState) => state.postprocessing.upscalingStrength
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
label={`${t('parameters.upscale')} ${t('parameters.strength')}`}
value={upscalingStrength}
min={0}
max={1}
step={0.05}
onChange={(v) => dispatch(setUpscalingStrength(v))}
handleReset={() => dispatch(setUpscalingStrength(0.75))}
withSliderMarks
withInput
withReset
isDisabled={!isESRGANAvailable}
/>
);
}

View File

@ -1,26 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldRunESRGAN } from 'features/parameters/store/postprocessingSlice';
import { ChangeEvent } from 'react';
export default function UpscaleToggle() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const shouldRunESRGAN = useAppSelector(
(state: RootState) => state.postprocessing.shouldRunESRGAN
);
const dispatch = useAppDispatch();
const handleChangeShouldRunESRGAN = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRunESRGAN(e.target.checked));
return (
<IAISwitch
isDisabled={!isESRGANAvailable}
isChecked={shouldRunESRGAN}
onChange={handleChangeShouldRunESRGAN}
/>
);
}

View File

@ -0,0 +1,110 @@
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { SelectItem } from '@mantine/core';
import { forEach } from 'lodash-es';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { vaeSelected } from 'features/parameters/store/generationSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
const selector = createSelector(
stateSelector,
({ generation }) => {
const { model, vae } = generation;
return { model, vae };
},
defaultSelectorOptions
);
const ParamVAEModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { model, vae } = useAppSelector(selector);
const { data: vaeModels } = useGetVaeModelsQuery();
const data = useMemo(() => {
if (!vaeModels) {
return [];
}
// add a "default" option, this means use the main model's included VAE
const data: SelectItem[] = [
{
value: 'default',
label: 'Default',
group: 'Default',
},
];
forEach(vaeModels.entities, (vae, id) => {
if (!vae) {
return;
}
const disabled = model?.base_model !== vae.base_model;
data.push({
value: id,
label: vae.model_name,
group: MODEL_TYPE_MAP[vae.base_model],
disabled,
tooltip: disabled
? `Incompatible base model: ${vae.base_model}`
: undefined,
});
});
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [vaeModels, model?.base_model]);
// grab the full model entity from the RTK Query cache
const selectedVaeModel = useMemo(
() =>
vaeModels?.entities[`${vae?.base_model}/vae/${vae?.model_name}`] ?? null,
[vaeModels?.entities, vae]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v || v === 'default') {
dispatch(vaeSelected(null));
return;
}
const newVaeModel = modelIdToVAEModelParam(v);
if (!newVaeModel) {
return;
}
dispatch(vaeSelected(newVaeModel));
},
[dispatch]
);
return (
<IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description}
label={t('modelManager.vae')}
value={selectedVaeModel?.id ?? 'default'}
placeholder="Default"
data={data}
onChange={handleChangeModel}
disabled={data.length === 0}
clearable
/>
);
};
export default memo(ParamVAEModelSelect);

View File

@ -17,8 +17,10 @@ import { FaPlay } from 'react-icons/fa';
const IN_PROGRESS_STYLES: ChakraProps['sx'] = {
_disabled: {
bg: 'none',
color: 'base.600',
cursor: 'not-allowed',
_hover: {
color: 'base.600',
bg: 'none',
},
},

View File

@ -28,7 +28,7 @@ import {
isValidSteps,
isValidStrength,
isValidWidth,
} from '../store/parameterZodSchemas';
} from '../types/parameterSchemas';
export const useRecallParameters = () => {
const dispatch = useAppDispatch();

View File

@ -13,6 +13,7 @@ import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
import {
CfgScaleParam,
HeightParam,
MainModelParam,
NegativePromptParam,
PositivePromptParam,
SchedulerParam,
@ -22,7 +23,7 @@ import {
VaeModelParam,
WidthParam,
zMainModel,
} from './parameterZodSchemas';
} from '../types/parameterSchemas';
export interface GenerationState {
cfgScale: CfgScaleParam;
@ -228,19 +229,20 @@ export const generationSlice = createSlice({
},
modelChanged: (
state,
action: PayloadAction<MainModelField | OnnxModelField | null>
action: PayloadAction<MainModelParam | OnnxModelField | null>
) => {
if (!action.payload) {
state.model = null;
}
state.model = action.payload;
state.model = zMainModel.parse(action.payload);
if (state.model === null) {
return;
}
// Clamp ClipSkip Based On Selected Model
const { maxClip } = clipSkipMap[state.model.base_model];
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
},
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
// null is a valid VAE!
state.vae = action.payload;
},
setClipSkip: (state, action: PayloadAction<number>) => {
@ -256,12 +258,16 @@ export const generationSlice = createSlice({
if (defaultModel && !state.model) {
const [base_model, model_type, model_name] = defaultModel.split('/');
state.model = zMainModel.parse({
id: defaultModel,
name: model_name,
const result = zMainModel.safeParse({
model_name,
base_model,
model_type,
});
if (result.success) {
state.model = result.data;
}
}
});
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {

View File

@ -1,98 +1,27 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { FACETOOL_TYPES } from 'app/constants';
import { ESRGANInvocation } from 'services/api/types';
export type UpscalingLevel = 2 | 4;
export type FacetoolType = (typeof FACETOOL_TYPES)[number];
export type ESRGANModelName = NonNullable<ESRGANInvocation['model_name']>;
export interface PostprocessingState {
codeformerFidelity: number;
facetoolStrength: number;
facetoolType: FacetoolType;
hiresFix: boolean;
hiresStrength: number;
shouldLoopback: boolean;
shouldRunESRGAN: boolean;
shouldRunFacetool: boolean;
upscalingLevel: UpscalingLevel;
upscalingDenoising: number;
upscalingStrength: number;
esrganModelName: ESRGANModelName;
}
export const initialPostprocessingState: PostprocessingState = {
codeformerFidelity: 0.75,
facetoolStrength: 0.75,
facetoolType: 'gfpgan',
hiresFix: false,
hiresStrength: 0.75,
shouldLoopback: false,
shouldRunESRGAN: false,
shouldRunFacetool: false,
upscalingLevel: 4,
upscalingDenoising: 0.75,
upscalingStrength: 0.75,
esrganModelName: 'RealESRGAN_x4plus.pth',
};
export const postprocessingSlice = createSlice({
name: 'postprocessing',
initialState: initialPostprocessingState,
reducers: {
setFacetoolStrength: (state, action: PayloadAction<number>) => {
state.facetoolStrength = action.payload;
},
setCodeformerFidelity: (state, action: PayloadAction<number>) => {
state.codeformerFidelity = action.payload;
},
setUpscalingLevel: (state, action: PayloadAction<UpscalingLevel>) => {
state.upscalingLevel = action.payload;
},
setUpscalingDenoising: (state, action: PayloadAction<number>) => {
state.upscalingDenoising = action.payload;
},
setUpscalingStrength: (state, action: PayloadAction<number>) => {
state.upscalingStrength = action.payload;
},
setHiresFix: (state, action: PayloadAction<boolean>) => {
state.hiresFix = action.payload;
},
setHiresStrength: (state, action: PayloadAction<number>) => {
state.hiresStrength = action.payload;
},
resetPostprocessingState: (state) => {
return {
...state,
...initialPostprocessingState,
};
},
setShouldRunFacetool: (state, action: PayloadAction<boolean>) => {
state.shouldRunFacetool = action.payload;
},
setFacetoolType: (state, action: PayloadAction<FacetoolType>) => {
state.facetoolType = action.payload;
},
setShouldRunESRGAN: (state, action: PayloadAction<boolean>) => {
state.shouldRunESRGAN = action.payload;
},
setShouldLoopback: (state, action: PayloadAction<boolean>) => {
state.shouldLoopback = action.payload;
esrganModelNameChanged: (state, action: PayloadAction<ESRGANModelName>) => {
state.esrganModelName = action.payload;
},
},
});
export const {
resetPostprocessingState,
setCodeformerFidelity,
setFacetoolStrength,
setFacetoolType,
setHiresFix,
setHiresStrength,
setShouldLoopback,
setShouldRunESRGAN,
setShouldRunFacetool,
setUpscalingLevel,
setUpscalingDenoising,
setUpscalingStrength,
} = postprocessingSlice.actions;
export const { esrganModelNameChanged } = postprocessingSlice.actions;
export default postprocessingSlice.reducer;

View File

@ -0,0 +1,6 @@
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
sdxl: 'Stable Diffusion XL',
'sdxl-refiner': 'Stable Diffusion XL Refiner',
};

View File

@ -126,7 +126,6 @@ export type HeightParam = z.infer<typeof zHeight>;
export const isValidHeight = (val: unknown): val is HeightParam =>
zHeight.safeParse(val).success;
const zBaseModel = z.enum(['sd-1', 'sd-2']);
const zModelType = z.enum([
'vae',
'lora',
@ -135,6 +134,7 @@ const zModelType = z.enum([
'controlnet',
'embedding',
]);
const zBaseModel = z.enum(['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
export type BaseModelParam = z.infer<typeof zBaseModel>;
@ -143,7 +143,7 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
* TODO: Make this a dynamically generated enum?
*/
export const zMainModel = z.object({
model_name: z.string(),
model_name: z.string().min(1),
base_model: zBaseModel,
model_type: zModelType,
});
@ -161,8 +161,7 @@ export const isValidMainModel = (val: unknown): val is MainModelParam =>
* Zod schema for VAE parameter
*/
export const zVaeModel = z.object({
id: z.string(),
name: z.string(),
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
@ -178,8 +177,7 @@ export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
* Zod schema for LoRA
*/
export const zLoRAModel = z.object({
id: z.string(),
model_name: z.string(),
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
@ -191,6 +189,23 @@ export type LoRAModelParam = z.infer<typeof zLoRAModel>;
*/
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
zLoRAModel.safeParse(val).success;
/**
* Zod schema for ControlNet models
*/
export const zControlNetModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidControlNetModel = (
val: unknown
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
/**
* Zod schema for l2l strength parameter

View File

@ -0,0 +1,30 @@
import { log } from 'app/logging/useLogger';
import { zControlNetModel } from 'features/parameters/types/parameterSchemas';
import { ControlNetModelField } from 'services/api/types';
const moduleLog = log.child({ module: 'models' });
export const modelIdToControlNetModelParam = (
controlNetModelId: string
): ControlNetModelField | undefined => {
const [base_model, model_type, model_name] = controlNetModelId.split('/');
const result = zControlNetModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
moduleLog.error(
{
controlNetModelId,
errors: result.error.format(),
},
'Failed to parse ControlNet model id'
);
return;
}
return result.data;
};

View File

@ -0,0 +1,28 @@
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToLoRAModelParam = (
loraModelId: string
): LoRAModelParam | undefined => {
const [base_model, model_type, model_name] = loraModelId.split('/');
const result = zLoRAModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
moduleLog.error(
{
loraModelId,
errors: result.error.format(),
},
'Failed to parse LoRA model id'
);
return;
}
return result.data;
};

View File

@ -0,0 +1,31 @@
import {
MainModelParam,
zMainModel,
} from 'features/parameters/types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToMainModelParam = (
mainModelId: string
): MainModelParam | undefined => {
const [base_model, model_type, model_name] = mainModelId.split('/');
const result = zMainModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
moduleLog.error(
{
mainModelId,
errors: result.error.format(),
},
'Failed to parse main model id'
);
return;
}
return result.data;
};

View File

@ -0,0 +1,28 @@
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToVAEModelParam = (
vaeModelId: string
): VaeModelParam | undefined => {
const [base_model, model_type, model_name] = vaeModelId.split('/');
const result = zVaeModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
moduleLog.error(
{
vaeModelId,
errors: result.error.format(),
},
'Failed to parse VAE model id'
);
return;
}
return result.data;
};