diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index 3506bafac2..c757c8959f 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -1,6 +1,28 @@ -// TODO: use Enums? +import { SelectItem } from '@mantine/core'; -export const SCHEDULERS = [ +// TODO: use Enums? +export const SCHEDULERS: SelectItem[] = [ + { label: 'euler', value: 'euler', group: 'Standard' }, + { label: 'deis', value: 'deis', group: 'Standard' }, + { label: 'ddim', value: 'ddim', group: 'Standard' }, + { label: 'ddpm', value: 'ddpm', group: 'Standard' }, + { label: 'dpmpp_2s', value: 'dpmpp_2s', group: 'Standard' }, + { label: 'dpmpp_2m', value: 'dpmpp_2m', group: 'Standard' }, + { label: 'heun', value: 'heun', group: 'Standard' }, + { label: 'kdpm_2', value: 'kdpm_2', group: 'Standard' }, + { label: 'lms', value: 'lms', group: 'Standard' }, + { label: 'pndm', value: 'pndm', group: 'Standard' }, + { label: 'unipc', value: 'unipc', group: 'Standard' }, + { label: 'euler_k', value: 'euler_k', group: 'Karras' }, + { label: 'dpmpp_2s_k', value: 'dpmpp_2s_k', group: 'Karras' }, + { label: 'dpmpp_2m_k', value: 'dpmpp_2m_k', group: 'Karras' }, + { label: 'heun_k', value: 'heun_k', group: 'Karras' }, + { label: 'lms_k', value: 'lms_k', group: 'Karras' }, + { label: 'euler_a', value: 'euler_a', group: 'Ancestral' }, + { label: 'kdpm_2_a', value: 'kdpm_2_a', group: 'Ancestral' }, +]; + +export const SCHEDULER_ITEMS = [ 'ddim', 'lms', 'lms_k', @@ -19,9 +41,9 @@ export const SCHEDULERS = [ 'heun', 'heun_k', 'unipc', -] as const; +]; -export type Scheduler = (typeof SCHEDULERS)[number]; +export type Scheduler = typeof SCHEDULERS; // Valid upscaling levels export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [ diff --git a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx new file mode 100644 index 0000000000..2efb467869 --- /dev/null +++ b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx @@ -0,0 +1,79 @@ +import { Tooltip } from '@chakra-ui/react'; +import { MultiSelect, MultiSelectProps } from '@mantine/core'; +import { memo } from 'react'; + +type IAIMultiSelectProps = MultiSelectProps & { + tooltip?: string; +}; + +const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { + const { searchable = true, tooltip, ...rest } = props; + return ( + + ({ + label: { + color: 'var(--invokeai-colors-base-300)', + fontWeight: 'normal', + }, + input: { + backgroundColor: 'var(--invokeai-colors-base-900)', + borderWidth: '2px', + borderColor: 'var(--invokeai-colors-base-800)', + color: 'var(--invokeai-colors-base-100)', + padding: 10, + paddingRight: 24, + fontWeight: 600, + '&:hover': { borderColor: 'var(--invokeai-colors-base-700)' }, + '&:focus': { + borderColor: 'var(--invokeai-colors-accent-600)', + }, + }, + value: { + backgroundColor: 'var(--invokeai-colors-base-800)', + color: 'var(--invokeai-colors-base-100)', + '&:hover': { + backgroundColor: 'var(--invokeai-colors-base-700)', + cursor: 'pointer', + }, + }, + dropdown: { + backgroundColor: 'var(--invokeai-colors-base-800)', + borderColor: 'var(--invokeai-colors-base-700)', + }, + item: { + backgroundColor: 'var(--invokeai-colors-base-800)', + color: 'var(--invokeai-colors-base-200)', + padding: 6, + '&[data-hovered]': { + color: 'var(--invokeai-colors-base-100)', + backgroundColor: 'var(--invokeai-colors-base-750)', + }, + '&[data-active]': { + backgroundColor: 'var(--invokeai-colors-base-750)', + '&:hover': { + color: 'var(--invokeai-colors-base-100)', + backgroundColor: 'var(--invokeai-colors-base-750)', + }, + }, + '&[data-selected]': { + color: 'var(--invokeai-colors-base-50)', + backgroundColor: 'var(--invokeai-colors-accent-650)', + fontWeight: 600, + '&:hover': { + backgroundColor: 'var(--invokeai-colors-accent-600)', + }, + }, + }, + rightSection: { + width: 24, + }, + })} + {...rest} + /> + + ); +}; + +export default memo(IAIMantineMultiSelect); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index cf29636ea3..d749b33eea 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -1,43 +1,44 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { Scheduler } from 'app/constants'; +import { SCHEDULER_ITEMS } from 'app/constants'; +import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import IAIMantineSelect, { - IAISelectDataType, -} from 'common/components/IAIMantineSelect'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { setScheduler } from 'features/parameters/store/generationSlice'; -import { uiSelector } from 'features/ui/store/uiSelectors'; -import { memo, useCallback } from 'react'; +import { setSelectedSchedulers } from 'features/ui/store/uiSlice'; +import { memo, useCallback, useEffect } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createSelector( - [uiSelector, generationSelector], - (ui, generation) => { - const allSchedulers: string[] = ui.schedulers - .slice() - .sort((a, b) => a.localeCompare(b)); - - return { - scheduler: generation.scheduler, - allSchedulers, - }; - }, - defaultSelectorOptions -); - const ParamScheduler = () => { - const { allSchedulers, scheduler } = useAppSelector(selector); + const scheduler = useAppSelector( + (state: RootState) => state.generation.scheduler + ); + + const selectedSchedulers = useAppSelector( + (state: RootState) => state.ui.selectedSchedulers + ); + + const activeSchedulers = useAppSelector( + (state: RootState) => state.ui.activeSchedulers + ); const dispatch = useAppDispatch(); const { t } = useTranslation(); + useEffect(() => { + if (selectedSchedulers.length === 0) + dispatch(setSelectedSchedulers(SCHEDULER_ITEMS)); + + const schedulerFound = activeSchedulers.find( + (activeSchedulers) => activeSchedulers.label === scheduler + ); + if (!schedulerFound) dispatch(setScheduler(activeSchedulers[0].value)); + }, [dispatch, selectedSchedulers, scheduler, activeSchedulers]); + const handleChange = useCallback( (v: string | null) => { if (!v) { return; } - dispatch(setScheduler(v as Scheduler)); + dispatch(setScheduler(v)); }, [dispatch] ); @@ -46,7 +47,7 @@ const ParamScheduler = () => { ); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 4244b512fb..0916fc84c6 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,6 +1,5 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; -import { Scheduler } from 'app/constants'; import { ModelLoaderTypes } from 'features/system/components/ModelSelect'; import { configChanged } from 'features/system/store/configSlice'; import { clamp, sortBy } from 'lodash-es'; @@ -136,7 +135,7 @@ export const generationSlice = createSlice({ setWidth: (state, action: PayloadAction) => { state.width = action.payload; }, - setScheduler: (state, action: PayloadAction) => { + setScheduler: (state, action: PayloadAction) => { state.scheduler = action.payload; }, setSeed: (state, action: PayloadAction) => { diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts index b99e57bfbb..d17dcbce10 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -1,4 +1,4 @@ -import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants'; +import { NUMPY_RAND_MAX } from 'app/constants'; import { z } from 'zod'; /** @@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam => /** * Zod schema for scheduler parameter */ -export const zScheduler = z.enum(SCHEDULERS); +export const zScheduler = z.string(); /** * Type alias for scheduler parameter, inferred from its zod schema */ diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx index e5f4a4cbf7..e8def6a7eb 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx @@ -1,47 +1,33 @@ -import { - Menu, - MenuButton, - MenuItemOption, - MenuList, - MenuOptionGroup, -} from '@chakra-ui/react'; import { SCHEDULERS } from 'app/constants'; - import { RootState } from 'app/store/store'; + import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIButton from 'common/components/IAIButton'; -import { setSchedulers } from 'features/ui/store/uiSlice'; -import { isArray } from 'lodash-es'; +import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; +import { setSelectedSchedulers } from 'features/ui/store/uiSlice'; import { useTranslation } from 'react-i18next'; export default function SettingsSchedulers() { - const schedulers = useAppSelector((state: RootState) => state.ui.schedulers); - const dispatch = useAppDispatch(); + + const selectedSchedulers = useAppSelector( + (state: RootState) => state.ui.selectedSchedulers + ); + const { t } = useTranslation(); - const schedulerSettingsHandler = (v: string | string[]) => { - if (isArray(v)) dispatch(setSchedulers(v.sort())); + const schedulerSettingsHandler = (v: string[]) => { + dispatch(setSelectedSchedulers(v)); }; return ( - - - {t('settings.availableSchedulers')} - - - - {SCHEDULERS.map((scheduler) => ( - - {scheduler} - - ))} - - - + ); } diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 65a48bc92c..907d5a5295 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -1,10 +1,11 @@ +import { SelectItem } from '@mantine/core'; import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; +import { SCHEDULERS } from 'app/constants'; +import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { setActiveTabReducer } from './extraReducers'; import { InvokeTabName } from './tabMap'; import { AddNewModelType, UIState } from './uiTypes'; -import { initialImageChanged } from 'features/parameters/store/generationSlice'; -import { SCHEDULERS } from 'app/constants'; export const initialUIState: UIState = { activeTab: 0, @@ -20,7 +21,8 @@ export const initialUIState: UIState = { shouldShowGallery: true, shouldHidePreview: false, shouldShowProgressInViewer: true, - schedulers: SCHEDULERS, + activeSchedulers: [], + selectedSchedulers: [], }; export const uiSlice = createSlice({ @@ -94,9 +96,20 @@ export const uiSlice = createSlice({ setShouldShowProgressInViewer: (state, action: PayloadAction) => { state.shouldShowProgressInViewer = action.payload; }, - setSchedulers: (state, action: PayloadAction) => { - state.schedulers = []; - state.schedulers = action.payload; + setSelectedSchedulers: (state, action: PayloadAction) => { + const selectedSchedulerData: SelectItem[] = []; + + if (action.payload.length === 0) action.payload = [SCHEDULERS[0].value]; + + action.payload.forEach((item) => { + const schedulerData = SCHEDULERS.find( + (scheduler) => scheduler.value === item + ); + if (schedulerData) selectedSchedulerData.push(schedulerData); + }); + + state.activeSchedulers = selectedSchedulerData; + state.selectedSchedulers = action.payload; }, }, extraReducers(builder) { @@ -124,7 +137,7 @@ export const { toggleParametersPanel, toggleGalleryPanel, setShouldShowProgressInViewer, - setSchedulers, + setSelectedSchedulers, } = uiSlice.actions; export default uiSlice.reducer; diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts index 18a758cdd6..88ccb864e8 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts @@ -1,3 +1,5 @@ +import { SelectItem } from '@mantine/core'; + export type AddNewModelType = 'ckpt' | 'diffusers' | null; export type Coordinates = { @@ -26,5 +28,6 @@ export interface UIState { shouldPinGallery: boolean; shouldShowGallery: boolean; shouldShowProgressInViewer: boolean; - schedulers: string[]; + activeSchedulers: SelectItem[]; + selectedSchedulers: string[]; }