diff --git a/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLRefinerCollapse.tsx b/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLRefinerCollapse.tsx
new file mode 100644
index 0000000000..dcf266c368
--- /dev/null
+++ b/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLRefinerCollapse.tsx
@@ -0,0 +1,34 @@
+import { Flex } from '@chakra-ui/react';
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppSelector } from 'app/store/storeHooks';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
+import IAICollapse from 'common/components/IAICollapse';
+import ParamRefinerModelSelect from './SDXLRefiner/ParamRefinerModelSelect';
+import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
+import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
+
+const selector = createSelector(
+ stateSelector,
+ (state) => {
+ const { shouldUseSDXLRefiner } = state.sdxl;
+ return { activeLabel: shouldUseSDXLRefiner ? 'Enabled' : undefined };
+ },
+ defaultSelectorOptions
+);
+
+const ParamSDXLRefinerCollapse = () => {
+ const { activeLabel } = useAppSelector(selector);
+
+ return (
+
+
+
+
+
+
+
+ );
+};
+
+export default ParamSDXLRefinerCollapse;
diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx
index 57e76fc120..a92e7a2fd0 100644
--- a/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx
+++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx
@@ -7,6 +7,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces
import ImageToImageTabCoreParameters from 'features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
+import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
const SDXLImageToImageTabParameters = () => {
return (
@@ -17,6 +18,7 @@ const SDXLImageToImageTabParameters = () => {
+
>
diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamRefinerModelSelect.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamRefinerModelSelect.tsx
new file mode 100644
index 0000000000..4c7c32a069
--- /dev/null
+++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamRefinerModelSelect.tsx
@@ -0,0 +1,106 @@
+import { Box, Flex } from '@chakra-ui/react';
+import { SelectItem } from '@mantine/core';
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
+import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
+import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
+import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
+import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
+import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
+import { forEach } from 'lodash-es';
+import { memo, useCallback, useMemo } from 'react';
+import { useGetMainModelsQuery } from 'services/api/endpoints/models';
+
+const selector = createSelector(
+ stateSelector,
+ (state) => ({ model: state.sdxl.refinerModel }),
+ defaultSelectorOptions
+);
+
+const ParamRefinerModelSelect = () => {
+ const dispatch = useAppDispatch();
+
+ const { model } = useAppSelector(selector);
+
+ const { data: mainModels, isLoading } = useGetMainModelsQuery();
+
+ const data = useMemo(() => {
+ if (!mainModels) {
+ return [];
+ }
+
+ const data: SelectItem[] = [];
+
+ forEach(mainModels.entities, (model, id) => {
+ if (!model) {
+ return;
+ }
+
+ if (['sdxl-refiner'].includes(model.base_model)) {
+ data.push({
+ value: id,
+ label: model.model_name,
+ group: MODEL_TYPE_MAP[model.base_model],
+ });
+ }
+ });
+
+ return data;
+ }, [mainModels]);
+
+ // grab the full model entity from the RTK Query cache
+ // TODO: maybe we should just store the full model entity in state?
+ const selectedModel = useMemo(
+ () =>
+ mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
+ null,
+ [mainModels?.entities, model]
+ );
+
+ const handleChangeModel = useCallback(
+ (v: string | null) => {
+ if (!v) {
+ return;
+ }
+
+ const newModel = modelIdToMainModelParam(v);
+
+ if (!newModel) {
+ return;
+ }
+
+ dispatch(refinerModelChanged(newModel));
+ },
+ [dispatch]
+ );
+
+ return isLoading ? (
+
+ ) : (
+
+ 0 ? 'Select a model' : 'No models available'}
+ data={data}
+ error={data.length === 0}
+ disabled={data.length === 0}
+ onChange={handleChangeModel}
+ w="100%"
+ />
+
+
+
+
+ );
+};
+
+export default memo(ParamRefinerModelSelect);
diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx
new file mode 100644
index 0000000000..2861f2f847
--- /dev/null
+++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx
@@ -0,0 +1,75 @@
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
+import IAINumberInput from 'common/components/IAINumberInput';
+
+import IAISlider from 'common/components/IAISlider';
+import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
+import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
+import { memo, useCallback } from 'react';
+import { useTranslation } from 'react-i18next';
+
+const selector = createSelector(
+ [stateSelector],
+ ({ sdxl, ui }) => {
+ const { refinerSteps } = sdxl;
+ const { shouldUseSliders } = ui;
+
+ return {
+ refinerSteps,
+
+ shouldUseSliders,
+ };
+ },
+ defaultSelectorOptions
+);
+
+const ParamSDXLRefinerSteps = () => {
+ const { refinerSteps, shouldUseSliders } = useAppSelector(selector);
+ const dispatch = useAppDispatch();
+ const { t } = useTranslation();
+
+ const handleChange = useCallback(
+ (v: number) => {
+ dispatch(setRefinerSteps(v));
+ },
+ [dispatch]
+ );
+ const handleReset = useCallback(() => {
+ dispatch(setRefinerSteps(20));
+ }, [dispatch]);
+
+ const handleBlur = useCallback(() => {
+ dispatch(clampSymmetrySteps());
+ }, [dispatch]);
+
+ return shouldUseSliders ? (
+
+ ) : (
+
+ );
+};
+
+export default memo(ParamSDXLRefinerSteps);
diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx
new file mode 100644
index 0000000000..9da8286910
--- /dev/null
+++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx
@@ -0,0 +1,30 @@
+import { RootState } from 'app/store/store';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import IAISwitch from 'common/components/IAISwitch';
+import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
+import { ChangeEvent } from 'react';
+
+export default function ParamUseSDXLRefiner() {
+ const shouldUseSDXLRefiner = useAppSelector(
+ (state: RootState) => state.sdxl.shouldUseSDXLRefiner
+ );
+
+ const isRefinerAvailable = useAppSelector(
+ (state: RootState) => state.sdxl.isRefinerAvailable
+ );
+
+ const dispatch = useAppDispatch();
+
+ const handleUseSDXLRefinerChange = (e: ChangeEvent) => {
+ dispatch(setShouldUseSDXLRefiner(e.target.checked));
+ };
+
+ return (
+
+ );
+}
diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx
index abebccc3f0..2175fcc9e3 100644
--- a/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx
+++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx
@@ -6,6 +6,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
+import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
const SDXLTextToImageTabParameters = () => {
return (
@@ -16,6 +17,7 @@ const SDXLTextToImageTabParameters = () => {
+
>
diff --git a/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts b/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts
index 95a0ac0de8..7250425839 100644
--- a/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts
+++ b/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts
@@ -1,19 +1,36 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import {
+ MainModelParam,
NegativeStylePromptSDXLParam,
PositiveStylePromptSDXLParam,
+ SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
+import { MainModelField } from 'services/api/types';
type SDXLInitialState = {
- shouldUseSDXLRefiner: boolean;
positiveStylePrompt: PositiveStylePromptSDXLParam;
negativeStylePrompt: NegativeStylePromptSDXLParam;
+ isRefinerAvailable: boolean;
+ shouldUseSDXLRefiner: boolean;
+ refinerModel: MainModelField | null;
+ refinerSteps: number;
+ refinerCFGScale: number;
+ refinerScheduler: SchedulerParam;
+ refinerAestheticScore: number;
+ refinerStart: number;
};
const sdxlInitialState: SDXLInitialState = {
- shouldUseSDXLRefiner: false,
positiveStylePrompt: '',
negativeStylePrompt: '',
+ isRefinerAvailable: false,
+ shouldUseSDXLRefiner: false,
+ refinerModel: null,
+ refinerSteps: 20,
+ refinerCFGScale: 7.5,
+ refinerScheduler: 'euler',
+ refinerAestheticScore: 6,
+ refinerStart: 0.7,
};
const sdxlSlice = createSlice({
@@ -26,16 +43,47 @@ const sdxlSlice = createSlice({
setNegativeStylePromptSDXL: (state, action: PayloadAction) => {
state.negativeStylePrompt = action.payload;
},
+ setIsRefinerAvailable: (state, action: PayloadAction) => {
+ state.isRefinerAvailable = action.payload;
+ },
setShouldUseSDXLRefiner: (state, action: PayloadAction) => {
state.shouldUseSDXLRefiner = action.payload;
},
+ refinerModelChanged: (
+ state,
+ action: PayloadAction
+ ) => {
+ state.refinerModel = action.payload;
+ },
+ setRefinerSteps: (state, action: PayloadAction) => {
+ state.refinerSteps = action.payload;
+ },
+ setRefinerCFGScale: (state, action: PayloadAction) => {
+ state.refinerCFGScale = action.payload;
+ },
+ setRefinerScheduler: (state, action: PayloadAction) => {
+ state.refinerScheduler = action.payload;
+ },
+ setRefinerAestheticScore: (state, action: PayloadAction) => {
+ state.refinerAestheticScore = action.payload;
+ },
+ setRefinerStart: (state, action: PayloadAction) => {
+ state.refinerStart = action.payload;
+ },
},
});
export const {
setPositiveStylePromptSDXL,
setNegativeStylePromptSDXL,
+ setIsRefinerAvailable,
setShouldUseSDXLRefiner,
+ refinerModelChanged,
+ setRefinerSteps,
+ setRefinerCFGScale,
+ setRefinerScheduler,
+ setRefinerAestheticScore,
+ setRefinerStart,
} = sdxlSlice.actions;
export default sdxlSlice.reducer;