diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts
index 3506bafac2..c757c8959f 100644
--- a/invokeai/frontend/web/src/app/constants.ts
+++ b/invokeai/frontend/web/src/app/constants.ts
@@ -1,6 +1,28 @@
-// TODO: use Enums?
+import { SelectItem } from '@mantine/core';
-export const SCHEDULERS = [
+// TODO: use Enums?
+export const SCHEDULERS: SelectItem[] = [
+ { label: 'euler', value: 'euler', group: 'Standard' },
+ { label: 'deis', value: 'deis', group: 'Standard' },
+ { label: 'ddim', value: 'ddim', group: 'Standard' },
+ { label: 'ddpm', value: 'ddpm', group: 'Standard' },
+ { label: 'dpmpp_2s', value: 'dpmpp_2s', group: 'Standard' },
+ { label: 'dpmpp_2m', value: 'dpmpp_2m', group: 'Standard' },
+ { label: 'heun', value: 'heun', group: 'Standard' },
+ { label: 'kdpm_2', value: 'kdpm_2', group: 'Standard' },
+ { label: 'lms', value: 'lms', group: 'Standard' },
+ { label: 'pndm', value: 'pndm', group: 'Standard' },
+ { label: 'unipc', value: 'unipc', group: 'Standard' },
+ { label: 'euler_k', value: 'euler_k', group: 'Karras' },
+ { label: 'dpmpp_2s_k', value: 'dpmpp_2s_k', group: 'Karras' },
+ { label: 'dpmpp_2m_k', value: 'dpmpp_2m_k', group: 'Karras' },
+ { label: 'heun_k', value: 'heun_k', group: 'Karras' },
+ { label: 'lms_k', value: 'lms_k', group: 'Karras' },
+ { label: 'euler_a', value: 'euler_a', group: 'Ancestral' },
+ { label: 'kdpm_2_a', value: 'kdpm_2_a', group: 'Ancestral' },
+];
+
+export const SCHEDULER_ITEMS = [
'ddim',
'lms',
'lms_k',
@@ -19,9 +41,9 @@ export const SCHEDULERS = [
'heun',
'heun_k',
'unipc',
-] as const;
+];
-export type Scheduler = (typeof SCHEDULERS)[number];
+export type Scheduler = typeof SCHEDULERS;
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [
diff --git a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx
new file mode 100644
index 0000000000..2efb467869
--- /dev/null
+++ b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx
@@ -0,0 +1,79 @@
+import { Tooltip } from '@chakra-ui/react';
+import { MultiSelect, MultiSelectProps } from '@mantine/core';
+import { memo } from 'react';
+
+type IAIMultiSelectProps = MultiSelectProps & {
+ tooltip?: string;
+};
+
+const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
+ const { searchable = true, tooltip, ...rest } = props;
+ return (
+
+ ({
+ label: {
+ color: 'var(--invokeai-colors-base-300)',
+ fontWeight: 'normal',
+ },
+ input: {
+ backgroundColor: 'var(--invokeai-colors-base-900)',
+ borderWidth: '2px',
+ borderColor: 'var(--invokeai-colors-base-800)',
+ color: 'var(--invokeai-colors-base-100)',
+ padding: 10,
+ paddingRight: 24,
+ fontWeight: 600,
+ '&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },
+ '&:focus': {
+ borderColor: 'var(--invokeai-colors-accent-600)',
+ },
+ },
+ value: {
+ backgroundColor: 'var(--invokeai-colors-base-800)',
+ color: 'var(--invokeai-colors-base-100)',
+ '&:hover': {
+ backgroundColor: 'var(--invokeai-colors-base-700)',
+ cursor: 'pointer',
+ },
+ },
+ dropdown: {
+ backgroundColor: 'var(--invokeai-colors-base-800)',
+ borderColor: 'var(--invokeai-colors-base-700)',
+ },
+ item: {
+ backgroundColor: 'var(--invokeai-colors-base-800)',
+ color: 'var(--invokeai-colors-base-200)',
+ padding: 6,
+ '&[data-hovered]': {
+ color: 'var(--invokeai-colors-base-100)',
+ backgroundColor: 'var(--invokeai-colors-base-750)',
+ },
+ '&[data-active]': {
+ backgroundColor: 'var(--invokeai-colors-base-750)',
+ '&:hover': {
+ color: 'var(--invokeai-colors-base-100)',
+ backgroundColor: 'var(--invokeai-colors-base-750)',
+ },
+ },
+ '&[data-selected]': {
+ color: 'var(--invokeai-colors-base-50)',
+ backgroundColor: 'var(--invokeai-colors-accent-650)',
+ fontWeight: 600,
+ '&:hover': {
+ backgroundColor: 'var(--invokeai-colors-accent-600)',
+ },
+ },
+ },
+ rightSection: {
+ width: 24,
+ },
+ })}
+ {...rest}
+ />
+
+ );
+};
+
+export default memo(IAIMantineMultiSelect);
diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx
index cf29636ea3..d749b33eea 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx
@@ -1,43 +1,44 @@
-import { createSelector } from '@reduxjs/toolkit';
-import { Scheduler } from 'app/constants';
+import { SCHEDULER_ITEMS } from 'app/constants';
+import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
-import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
-import IAIMantineSelect, {
- IAISelectDataType,
-} from 'common/components/IAIMantineSelect';
-import { generationSelector } from 'features/parameters/store/generationSelectors';
+import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { setScheduler } from 'features/parameters/store/generationSlice';
-import { uiSelector } from 'features/ui/store/uiSelectors';
-import { memo, useCallback } from 'react';
+import { setSelectedSchedulers } from 'features/ui/store/uiSlice';
+import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
-const selector = createSelector(
- [uiSelector, generationSelector],
- (ui, generation) => {
- const allSchedulers: string[] = ui.schedulers
- .slice()
- .sort((a, b) => a.localeCompare(b));
-
- return {
- scheduler: generation.scheduler,
- allSchedulers,
- };
- },
- defaultSelectorOptions
-);
-
const ParamScheduler = () => {
- const { allSchedulers, scheduler } = useAppSelector(selector);
+ const scheduler = useAppSelector(
+ (state: RootState) => state.generation.scheduler
+ );
+
+ const selectedSchedulers = useAppSelector(
+ (state: RootState) => state.ui.selectedSchedulers
+ );
+
+ const activeSchedulers = useAppSelector(
+ (state: RootState) => state.ui.activeSchedulers
+ );
const dispatch = useAppDispatch();
const { t } = useTranslation();
+ useEffect(() => {
+ if (selectedSchedulers.length === 0)
+ dispatch(setSelectedSchedulers(SCHEDULER_ITEMS));
+
+ const schedulerFound = activeSchedulers.find(
+ (activeSchedulers) => activeSchedulers.label === scheduler
+ );
+ if (!schedulerFound) dispatch(setScheduler(activeSchedulers[0].value));
+ }, [dispatch, selectedSchedulers, scheduler, activeSchedulers]);
+
const handleChange = useCallback(
(v: string | null) => {
if (!v) {
return;
}
- dispatch(setScheduler(v as Scheduler));
+ dispatch(setScheduler(v));
},
[dispatch]
);
@@ -46,7 +47,7 @@ const ParamScheduler = () => {
);
diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
index 4244b512fb..0916fc84c6 100644
--- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
+++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
@@ -1,6 +1,5 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
-import { Scheduler } from 'app/constants';
import { ModelLoaderTypes } from 'features/system/components/ModelSelect';
import { configChanged } from 'features/system/store/configSlice';
import { clamp, sortBy } from 'lodash-es';
@@ -136,7 +135,7 @@ export const generationSlice = createSlice({
setWidth: (state, action: PayloadAction) => {
state.width = action.payload;
},
- setScheduler: (state, action: PayloadAction) => {
+ setScheduler: (state, action: PayloadAction) => {
state.scheduler = action.payload;
},
setSeed: (state, action: PayloadAction) => {
diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts
index b99e57bfbb..d17dcbce10 100644
--- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts
+++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts
@@ -1,4 +1,4 @@
-import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants';
+import { NUMPY_RAND_MAX } from 'app/constants';
import { z } from 'zod';
/**
@@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam =>
/**
* Zod schema for scheduler parameter
*/
-export const zScheduler = z.enum(SCHEDULERS);
+export const zScheduler = z.string();
/**
* Type alias for scheduler parameter, inferred from its zod schema
*/
diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx
index e5f4a4cbf7..e8def6a7eb 100644
--- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx
+++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx
@@ -1,47 +1,33 @@
-import {
- Menu,
- MenuButton,
- MenuItemOption,
- MenuList,
- MenuOptionGroup,
-} from '@chakra-ui/react';
import { SCHEDULERS } from 'app/constants';
-
import { RootState } from 'app/store/store';
+
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
-import IAIButton from 'common/components/IAIButton';
-import { setSchedulers } from 'features/ui/store/uiSlice';
-import { isArray } from 'lodash-es';
+import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
+import { setSelectedSchedulers } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next';
export default function SettingsSchedulers() {
- const schedulers = useAppSelector((state: RootState) => state.ui.schedulers);
-
const dispatch = useAppDispatch();
+
+ const selectedSchedulers = useAppSelector(
+ (state: RootState) => state.ui.selectedSchedulers
+ );
+
const { t } = useTranslation();
- const schedulerSettingsHandler = (v: string | string[]) => {
- if (isArray(v)) dispatch(setSchedulers(v.sort()));
+ const schedulerSettingsHandler = (v: string[]) => {
+ dispatch(setSelectedSchedulers(v));
};
return (
-
+
);
}
diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts
index 65a48bc92c..907d5a5295 100644
--- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts
+++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts
@@ -1,10 +1,11 @@
+import { SelectItem } from '@mantine/core';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
+import { SCHEDULERS } from 'app/constants';
+import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName } from './tabMap';
import { AddNewModelType, UIState } from './uiTypes';
-import { initialImageChanged } from 'features/parameters/store/generationSlice';
-import { SCHEDULERS } from 'app/constants';
export const initialUIState: UIState = {
activeTab: 0,
@@ -20,7 +21,8 @@ export const initialUIState: UIState = {
shouldShowGallery: true,
shouldHidePreview: false,
shouldShowProgressInViewer: true,
- schedulers: SCHEDULERS,
+ activeSchedulers: [],
+ selectedSchedulers: [],
};
export const uiSlice = createSlice({
@@ -94,9 +96,20 @@ export const uiSlice = createSlice({
setShouldShowProgressInViewer: (state, action: PayloadAction) => {
state.shouldShowProgressInViewer = action.payload;
},
- setSchedulers: (state, action: PayloadAction) => {
- state.schedulers = [];
- state.schedulers = action.payload;
+ setSelectedSchedulers: (state, action: PayloadAction) => {
+ const selectedSchedulerData: SelectItem[] = [];
+
+ if (action.payload.length === 0) action.payload = [SCHEDULERS[0].value];
+
+ action.payload.forEach((item) => {
+ const schedulerData = SCHEDULERS.find(
+ (scheduler) => scheduler.value === item
+ );
+ if (schedulerData) selectedSchedulerData.push(schedulerData);
+ });
+
+ state.activeSchedulers = selectedSchedulerData;
+ state.selectedSchedulers = action.payload;
},
},
extraReducers(builder) {
@@ -124,7 +137,7 @@ export const {
toggleParametersPanel,
toggleGalleryPanel,
setShouldShowProgressInViewer,
- setSchedulers,
+ setSelectedSchedulers,
} = uiSlice.actions;
export default uiSlice.reducer;
diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts
index 18a758cdd6..88ccb864e8 100644
--- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts
+++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts
@@ -1,3 +1,5 @@
+import { SelectItem } from '@mantine/core';
+
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
export type Coordinates = {
@@ -26,5 +28,6 @@ export interface UIState {
shouldPinGallery: boolean;
shouldShowGallery: boolean;
shouldShowProgressInViewer: boolean;
- schedulers: string[];
+ activeSchedulers: SelectItem[];
+ selectedSchedulers: string[];
}