diff --git a/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLRefinerCollapse.tsx b/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLRefinerCollapse.tsx new file mode 100644 index 0000000000..dcf266c368 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLRefinerCollapse.tsx @@ -0,0 +1,34 @@ +import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAICollapse from 'common/components/IAICollapse'; +import ParamRefinerModelSelect from './SDXLRefiner/ParamRefinerModelSelect'; +import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps'; +import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner'; + +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseSDXLRefiner } = state.sdxl; + return { activeLabel: shouldUseSDXLRefiner ? 'Enabled' : undefined }; + }, + defaultSelectorOptions +); + +const ParamSDXLRefinerCollapse = () => { + const { activeLabel } = useAppSelector(selector); + + return ( + + + + + + + + ); +}; + +export default ParamSDXLRefinerCollapse; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx index 57e76fc120..a92e7a2fd0 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx @@ -7,6 +7,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces import ImageToImageTabCoreParameters from 'features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters'; import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning'; import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning'; +import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse'; const SDXLImageToImageTabParameters = () => { return ( @@ -17,6 +18,7 @@ const SDXLImageToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamRefinerModelSelect.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamRefinerModelSelect.tsx new file mode 100644 index 0000000000..4c7c32a069 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamRefinerModelSelect.tsx @@ -0,0 +1,106 @@ +import { Box, Flex } from '@chakra-ui/react'; +import { SelectItem } from '@mantine/core'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; +import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice'; +import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; +import { forEach } from 'lodash-es'; +import { memo, useCallback, useMemo } from 'react'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; + +const selector = createSelector( + stateSelector, + (state) => ({ model: state.sdxl.refinerModel }), + defaultSelectorOptions +); + +const ParamRefinerModelSelect = () => { + const dispatch = useAppDispatch(); + + const { model } = useAppSelector(selector); + + const { data: mainModels, isLoading } = useGetMainModelsQuery(); + + const data = useMemo(() => { + if (!mainModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(mainModels.entities, (model, id) => { + if (!model) { + return; + } + + if (['sdxl-refiner'].includes(model.base_model)) { + data.push({ + value: id, + label: model.model_name, + group: MODEL_TYPE_MAP[model.base_model], + }); + } + }); + + return data; + }, [mainModels]); + + // grab the full model entity from the RTK Query cache + // TODO: maybe we should just store the full model entity in state? + const selectedModel = useMemo( + () => + mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ?? + null, + [mainModels?.entities, model] + ); + + const handleChangeModel = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + const newModel = modelIdToMainModelParam(v); + + if (!newModel) { + return; + } + + dispatch(refinerModelChanged(newModel)); + }, + [dispatch] + ); + + return isLoading ? ( + + ) : ( + + 0 ? 'Select a model' : 'No models available'} + data={data} + error={data.length === 0} + disabled={data.length === 0} + onChange={handleChangeModel} + w="100%" + /> + + + + + ); +}; + +export default memo(ParamRefinerModelSelect); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx new file mode 100644 index 0000000000..2861f2f847 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx @@ -0,0 +1,75 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAINumberInput from 'common/components/IAINumberInput'; + +import IAISlider from 'common/components/IAISlider'; +import { clampSymmetrySteps } from 'features/parameters/store/generationSlice'; +import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; + +const selector = createSelector( + [stateSelector], + ({ sdxl, ui }) => { + const { refinerSteps } = sdxl; + const { shouldUseSliders } = ui; + + return { + refinerSteps, + + shouldUseSliders, + }; + }, + defaultSelectorOptions +); + +const ParamSDXLRefinerSteps = () => { + const { refinerSteps, shouldUseSliders } = useAppSelector(selector); + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const handleChange = useCallback( + (v: number) => { + dispatch(setRefinerSteps(v)); + }, + [dispatch] + ); + const handleReset = useCallback(() => { + dispatch(setRefinerSteps(20)); + }, [dispatch]); + + const handleBlur = useCallback(() => { + dispatch(clampSymmetrySteps()); + }, [dispatch]); + + return shouldUseSliders ? ( + + ) : ( + + ); +}; + +export default memo(ParamSDXLRefinerSteps); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx new file mode 100644 index 0000000000..9da8286910 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx @@ -0,0 +1,30 @@ +import { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAISwitch from 'common/components/IAISwitch'; +import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice'; +import { ChangeEvent } from 'react'; + +export default function ParamUseSDXLRefiner() { + const shouldUseSDXLRefiner = useAppSelector( + (state: RootState) => state.sdxl.shouldUseSDXLRefiner + ); + + const isRefinerAvailable = useAppSelector( + (state: RootState) => state.sdxl.isRefinerAvailable + ); + + const dispatch = useAppDispatch(); + + const handleUseSDXLRefinerChange = (e: ChangeEvent) => { + dispatch(setShouldUseSDXLRefiner(e.target.checked)); + }; + + return ( + + ); +} diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx index abebccc3f0..2175fcc9e3 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx @@ -6,6 +6,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters'; import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning'; import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning'; +import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse'; const SDXLTextToImageTabParameters = () => { return ( @@ -16,6 +17,7 @@ const SDXLTextToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts b/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts index 95a0ac0de8..7250425839 100644 --- a/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts +++ b/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts @@ -1,19 +1,36 @@ import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { + MainModelParam, NegativeStylePromptSDXLParam, PositiveStylePromptSDXLParam, + SchedulerParam, } from 'features/parameters/types/parameterSchemas'; +import { MainModelField } from 'services/api/types'; type SDXLInitialState = { - shouldUseSDXLRefiner: boolean; positiveStylePrompt: PositiveStylePromptSDXLParam; negativeStylePrompt: NegativeStylePromptSDXLParam; + isRefinerAvailable: boolean; + shouldUseSDXLRefiner: boolean; + refinerModel: MainModelField | null; + refinerSteps: number; + refinerCFGScale: number; + refinerScheduler: SchedulerParam; + refinerAestheticScore: number; + refinerStart: number; }; const sdxlInitialState: SDXLInitialState = { - shouldUseSDXLRefiner: false, positiveStylePrompt: '', negativeStylePrompt: '', + isRefinerAvailable: false, + shouldUseSDXLRefiner: false, + refinerModel: null, + refinerSteps: 20, + refinerCFGScale: 7.5, + refinerScheduler: 'euler', + refinerAestheticScore: 6, + refinerStart: 0.7, }; const sdxlSlice = createSlice({ @@ -26,16 +43,47 @@ const sdxlSlice = createSlice({ setNegativeStylePromptSDXL: (state, action: PayloadAction) => { state.negativeStylePrompt = action.payload; }, + setIsRefinerAvailable: (state, action: PayloadAction) => { + state.isRefinerAvailable = action.payload; + }, setShouldUseSDXLRefiner: (state, action: PayloadAction) => { state.shouldUseSDXLRefiner = action.payload; }, + refinerModelChanged: ( + state, + action: PayloadAction + ) => { + state.refinerModel = action.payload; + }, + setRefinerSteps: (state, action: PayloadAction) => { + state.refinerSteps = action.payload; + }, + setRefinerCFGScale: (state, action: PayloadAction) => { + state.refinerCFGScale = action.payload; + }, + setRefinerScheduler: (state, action: PayloadAction) => { + state.refinerScheduler = action.payload; + }, + setRefinerAestheticScore: (state, action: PayloadAction) => { + state.refinerAestheticScore = action.payload; + }, + setRefinerStart: (state, action: PayloadAction) => { + state.refinerStart = action.payload; + }, }, }); export const { setPositiveStylePromptSDXL, setNegativeStylePromptSDXL, + setIsRefinerAvailable, setShouldUseSDXLRefiner, + refinerModelChanged, + setRefinerSteps, + setRefinerCFGScale, + setRefinerScheduler, + setRefinerAestheticScore, + setRefinerStart, } = sdxlSlice.actions; export default sdxlSlice.reducer;