mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: SDXL Refiner UI Data
This commit is contained in:
parent
b0ebd148fa
commit
3bdb059eb7
@ -0,0 +1,34 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import ParamRefinerModelSelect from './SDXLRefiner/ParamRefinerModelSelect';
|
||||||
|
import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
|
||||||
|
import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const { shouldUseSDXLRefiner } = state.sdxl;
|
||||||
|
return { activeLabel: shouldUseSDXLRefiner ? 'Enabled' : undefined };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamSDXLRefinerCollapse = () => {
|
||||||
|
const { activeLabel } = useAppSelector(selector);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAICollapse label="Refiner" activeLabel={activeLabel}>
|
||||||
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
|
<ParamUseSDXLRefiner />
|
||||||
|
<ParamRefinerModelSelect />
|
||||||
|
<ParamSDXLRefinerSteps />
|
||||||
|
</Flex>
|
||||||
|
</IAICollapse>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ParamSDXLRefinerCollapse;
|
@ -7,6 +7,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces
|
|||||||
import ImageToImageTabCoreParameters from 'features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters';
|
import ImageToImageTabCoreParameters from 'features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters';
|
||||||
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
||||||
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
||||||
|
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
||||||
|
|
||||||
const SDXLImageToImageTabParameters = () => {
|
const SDXLImageToImageTabParameters = () => {
|
||||||
return (
|
return (
|
||||||
@ -17,6 +18,7 @@ const SDXLImageToImageTabParameters = () => {
|
|||||||
<ParamSDXLNegativeStyleConditioning />
|
<ParamSDXLNegativeStyleConditioning />
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<ImageToImageTabCoreParameters />
|
<ImageToImageTabCoreParameters />
|
||||||
|
<ParamSDXLRefinerCollapse />
|
||||||
<ParamDynamicPromptsCollapse />
|
<ParamDynamicPromptsCollapse />
|
||||||
<ParamNoiseCollapse />
|
<ParamNoiseCollapse />
|
||||||
</>
|
</>
|
||||||
|
@ -0,0 +1,106 @@
|
|||||||
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||||
|
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => ({ model: state.sdxl.refinerModel }),
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamRefinerModelSelect = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const { model } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!mainModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(mainModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (['sdxl-refiner'].includes(model.base_model)) {
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.model_name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [mainModels]);
|
||||||
|
|
||||||
|
// grab the full model entity from the RTK Query cache
|
||||||
|
// TODO: maybe we should just store the full model entity in state?
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() =>
|
||||||
|
mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
|
||||||
|
null,
|
||||||
|
[mainModels?.entities, model]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChangeModel = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newModel = modelIdToMainModelParam(v);
|
||||||
|
|
||||||
|
if (!newModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(refinerModelChanged(newModel));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return isLoading ? (
|
||||||
|
<IAIMantineSearchableSelect
|
||||||
|
label="Refiner Model"
|
||||||
|
placeholder="Loading..."
|
||||||
|
disabled={true}
|
||||||
|
data={[]}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Flex w="100%" alignItems="center" gap={2}>
|
||||||
|
<IAIMantineSearchableSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label="Refiner Model"
|
||||||
|
value={selectedModel?.id}
|
||||||
|
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||||
|
data={data}
|
||||||
|
error={data.length === 0}
|
||||||
|
disabled={data.length === 0}
|
||||||
|
onChange={handleChangeModel}
|
||||||
|
w="100%"
|
||||||
|
/>
|
||||||
|
<Box mt={7}>
|
||||||
|
<SyncModelsButton iconMode />
|
||||||
|
</Box>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamRefinerModelSelect);
|
@ -0,0 +1,75 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
|
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
|
||||||
|
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[stateSelector],
|
||||||
|
({ sdxl, ui }) => {
|
||||||
|
const { refinerSteps } = sdxl;
|
||||||
|
const { shouldUseSliders } = ui;
|
||||||
|
|
||||||
|
return {
|
||||||
|
refinerSteps,
|
||||||
|
|
||||||
|
shouldUseSliders,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamSDXLRefinerSteps = () => {
|
||||||
|
const { refinerSteps, shouldUseSliders } = useAppSelector(selector);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
dispatch(setRefinerSteps(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
const handleReset = useCallback(() => {
|
||||||
|
dispatch(setRefinerSteps(20));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
const handleBlur = useCallback(() => {
|
||||||
|
dispatch(clampSymmetrySteps());
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return shouldUseSliders ? (
|
||||||
|
<IAISlider
|
||||||
|
label={t('parameters.steps')}
|
||||||
|
min={1}
|
||||||
|
max={200}
|
||||||
|
step={1}
|
||||||
|
onChange={handleChange}
|
||||||
|
handleReset={handleReset}
|
||||||
|
value={refinerSteps}
|
||||||
|
withInput
|
||||||
|
withReset
|
||||||
|
withSliderMarks
|
||||||
|
sliderNumberInputProps={{ max: 999 }}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<IAINumberInput
|
||||||
|
label={t('parameters.steps')}
|
||||||
|
min={1}
|
||||||
|
max={999}
|
||||||
|
step={1}
|
||||||
|
onChange={handleChange}
|
||||||
|
value={refinerSteps}
|
||||||
|
numberInputFieldProps={{ textAlign: 'center' }}
|
||||||
|
onBlur={handleBlur}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamSDXLRefinerSteps);
|
@ -0,0 +1,30 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import { ChangeEvent } from 'react';
|
||||||
|
|
||||||
|
export default function ParamUseSDXLRefiner() {
|
||||||
|
const shouldUseSDXLRefiner = useAppSelector(
|
||||||
|
(state: RootState) => state.sdxl.shouldUseSDXLRefiner
|
||||||
|
);
|
||||||
|
|
||||||
|
const isRefinerAvailable = useAppSelector(
|
||||||
|
(state: RootState) => state.sdxl.isRefinerAvailable
|
||||||
|
);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleUseSDXLRefinerChange = (e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(setShouldUseSDXLRefiner(e.target.checked));
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISwitch
|
||||||
|
label="Use Refiner"
|
||||||
|
isChecked={shouldUseSDXLRefiner}
|
||||||
|
onChange={handleUseSDXLRefinerChange}
|
||||||
|
isDisabled={isRefinerAvailable}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
@ -6,6 +6,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces
|
|||||||
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
|
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
|
||||||
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
||||||
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
||||||
|
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
||||||
|
|
||||||
const SDXLTextToImageTabParameters = () => {
|
const SDXLTextToImageTabParameters = () => {
|
||||||
return (
|
return (
|
||||||
@ -16,6 +17,7 @@ const SDXLTextToImageTabParameters = () => {
|
|||||||
<ParamSDXLNegativeStyleConditioning />
|
<ParamSDXLNegativeStyleConditioning />
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<TextToImageTabCoreParameters />
|
<TextToImageTabCoreParameters />
|
||||||
|
<ParamSDXLRefinerCollapse />
|
||||||
<ParamDynamicPromptsCollapse />
|
<ParamDynamicPromptsCollapse />
|
||||||
<ParamNoiseCollapse />
|
<ParamNoiseCollapse />
|
||||||
</>
|
</>
|
||||||
|
@ -1,19 +1,36 @@
|
|||||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
import {
|
import {
|
||||||
|
MainModelParam,
|
||||||
NegativeStylePromptSDXLParam,
|
NegativeStylePromptSDXLParam,
|
||||||
PositiveStylePromptSDXLParam,
|
PositiveStylePromptSDXLParam,
|
||||||
|
SchedulerParam,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { MainModelField } from 'services/api/types';
|
||||||
|
|
||||||
type SDXLInitialState = {
|
type SDXLInitialState = {
|
||||||
shouldUseSDXLRefiner: boolean;
|
|
||||||
positiveStylePrompt: PositiveStylePromptSDXLParam;
|
positiveStylePrompt: PositiveStylePromptSDXLParam;
|
||||||
negativeStylePrompt: NegativeStylePromptSDXLParam;
|
negativeStylePrompt: NegativeStylePromptSDXLParam;
|
||||||
|
isRefinerAvailable: boolean;
|
||||||
|
shouldUseSDXLRefiner: boolean;
|
||||||
|
refinerModel: MainModelField | null;
|
||||||
|
refinerSteps: number;
|
||||||
|
refinerCFGScale: number;
|
||||||
|
refinerScheduler: SchedulerParam;
|
||||||
|
refinerAestheticScore: number;
|
||||||
|
refinerStart: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
const sdxlInitialState: SDXLInitialState = {
|
const sdxlInitialState: SDXLInitialState = {
|
||||||
shouldUseSDXLRefiner: false,
|
|
||||||
positiveStylePrompt: '',
|
positiveStylePrompt: '',
|
||||||
negativeStylePrompt: '',
|
negativeStylePrompt: '',
|
||||||
|
isRefinerAvailable: false,
|
||||||
|
shouldUseSDXLRefiner: false,
|
||||||
|
refinerModel: null,
|
||||||
|
refinerSteps: 20,
|
||||||
|
refinerCFGScale: 7.5,
|
||||||
|
refinerScheduler: 'euler',
|
||||||
|
refinerAestheticScore: 6,
|
||||||
|
refinerStart: 0.7,
|
||||||
};
|
};
|
||||||
|
|
||||||
const sdxlSlice = createSlice({
|
const sdxlSlice = createSlice({
|
||||||
@ -26,16 +43,47 @@ const sdxlSlice = createSlice({
|
|||||||
setNegativeStylePromptSDXL: (state, action: PayloadAction<string>) => {
|
setNegativeStylePromptSDXL: (state, action: PayloadAction<string>) => {
|
||||||
state.negativeStylePrompt = action.payload;
|
state.negativeStylePrompt = action.payload;
|
||||||
},
|
},
|
||||||
|
setIsRefinerAvailable: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.isRefinerAvailable = action.payload;
|
||||||
|
},
|
||||||
setShouldUseSDXLRefiner: (state, action: PayloadAction<boolean>) => {
|
setShouldUseSDXLRefiner: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldUseSDXLRefiner = action.payload;
|
state.shouldUseSDXLRefiner = action.payload;
|
||||||
},
|
},
|
||||||
|
refinerModelChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<MainModelParam | null>
|
||||||
|
) => {
|
||||||
|
state.refinerModel = action.payload;
|
||||||
|
},
|
||||||
|
setRefinerSteps: (state, action: PayloadAction<number>) => {
|
||||||
|
state.refinerSteps = action.payload;
|
||||||
|
},
|
||||||
|
setRefinerCFGScale: (state, action: PayloadAction<number>) => {
|
||||||
|
state.refinerCFGScale = action.payload;
|
||||||
|
},
|
||||||
|
setRefinerScheduler: (state, action: PayloadAction<SchedulerParam>) => {
|
||||||
|
state.refinerScheduler = action.payload;
|
||||||
|
},
|
||||||
|
setRefinerAestheticScore: (state, action: PayloadAction<number>) => {
|
||||||
|
state.refinerAestheticScore = action.payload;
|
||||||
|
},
|
||||||
|
setRefinerStart: (state, action: PayloadAction<number>) => {
|
||||||
|
state.refinerStart = action.payload;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export const {
|
export const {
|
||||||
setPositiveStylePromptSDXL,
|
setPositiveStylePromptSDXL,
|
||||||
setNegativeStylePromptSDXL,
|
setNegativeStylePromptSDXL,
|
||||||
|
setIsRefinerAvailable,
|
||||||
setShouldUseSDXLRefiner,
|
setShouldUseSDXLRefiner,
|
||||||
|
refinerModelChanged,
|
||||||
|
setRefinerSteps,
|
||||||
|
setRefinerCFGScale,
|
||||||
|
setRefinerScheduler,
|
||||||
|
setRefinerAestheticScore,
|
||||||
|
setRefinerStart,
|
||||||
} = sdxlSlice.actions;
|
} = sdxlSlice.actions;
|
||||||
|
|
||||||
export default sdxlSlice.reducer;
|
export default sdxlSlice.reducer;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user