From 59b5dfc3e01366ddaa3a99a63a6bd683c31fbaf1 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 18 Jun 2023 19:31:53 +1200 Subject: [PATCH 01/11] feat: Port Schedulers to Mantine --- invokeai/frontend/web/src/app/constants.ts | 30 ++++++- .../components/IAIMantineMultiSelect.tsx | 79 +++++++++++++++++++ .../Parameters/Core/ParamScheduler.tsx | 55 ++++++------- .../parameters/store/generationSlice.ts | 11 ++- .../parameters/store/parameterZodSchemas.ts | 4 +- .../SettingsModal/SettingsSchedulers.tsx | 52 +++++------- .../web/src/features/ui/store/uiSlice.ts | 27 +++++-- .../web/src/features/ui/store/uiTypes.ts | 5 +- 8 files changed, 183 insertions(+), 80 deletions(-) create mode 100644 invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index 3506bafac2..c757c8959f 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -1,6 +1,28 @@ -// TODO: use Enums? +import { SelectItem } from '@mantine/core'; -export const SCHEDULERS = [ +// TODO: use Enums? +export const SCHEDULERS: SelectItem[] = [ + { label: 'euler', value: 'euler', group: 'Standard' }, + { label: 'deis', value: 'deis', group: 'Standard' }, + { label: 'ddim', value: 'ddim', group: 'Standard' }, + { label: 'ddpm', value: 'ddpm', group: 'Standard' }, + { label: 'dpmpp_2s', value: 'dpmpp_2s', group: 'Standard' }, + { label: 'dpmpp_2m', value: 'dpmpp_2m', group: 'Standard' }, + { label: 'heun', value: 'heun', group: 'Standard' }, + { label: 'kdpm_2', value: 'kdpm_2', group: 'Standard' }, + { label: 'lms', value: 'lms', group: 'Standard' }, + { label: 'pndm', value: 'pndm', group: 'Standard' }, + { label: 'unipc', value: 'unipc', group: 'Standard' }, + { label: 'euler_k', value: 'euler_k', group: 'Karras' }, + { label: 'dpmpp_2s_k', value: 'dpmpp_2s_k', group: 'Karras' }, + { label: 'dpmpp_2m_k', value: 'dpmpp_2m_k', group: 'Karras' }, + { label: 'heun_k', value: 'heun_k', group: 'Karras' }, + { label: 'lms_k', value: 'lms_k', group: 'Karras' }, + { label: 'euler_a', value: 'euler_a', group: 'Ancestral' }, + { label: 'kdpm_2_a', value: 'kdpm_2_a', group: 'Ancestral' }, +]; + +export const SCHEDULER_ITEMS = [ 'ddim', 'lms', 'lms_k', @@ -19,9 +41,9 @@ export const SCHEDULERS = [ 'heun', 'heun_k', 'unipc', -] as const; +]; -export type Scheduler = (typeof SCHEDULERS)[number]; +export type Scheduler = typeof SCHEDULERS; // Valid upscaling levels export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [ diff --git a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx new file mode 100644 index 0000000000..2efb467869 --- /dev/null +++ b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx @@ -0,0 +1,79 @@ +import { Tooltip } from '@chakra-ui/react'; +import { MultiSelect, MultiSelectProps } from '@mantine/core'; +import { memo } from 'react'; + +type IAIMultiSelectProps = MultiSelectProps & { + tooltip?: string; +}; + +const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { + const { searchable = true, tooltip, ...rest } = props; + return ( + + ({ + label: { + color: 'var(--invokeai-colors-base-300)', + fontWeight: 'normal', + }, + input: { + backgroundColor: 'var(--invokeai-colors-base-900)', + borderWidth: '2px', + borderColor: 'var(--invokeai-colors-base-800)', + color: 'var(--invokeai-colors-base-100)', + padding: 10, + paddingRight: 24, + fontWeight: 600, + '&:hover': { borderColor: 'var(--invokeai-colors-base-700)' }, + '&:focus': { + borderColor: 'var(--invokeai-colors-accent-600)', + }, + }, + value: { + backgroundColor: 'var(--invokeai-colors-base-800)', + color: 'var(--invokeai-colors-base-100)', + '&:hover': { + backgroundColor: 'var(--invokeai-colors-base-700)', + cursor: 'pointer', + }, + }, + dropdown: { + backgroundColor: 'var(--invokeai-colors-base-800)', + borderColor: 'var(--invokeai-colors-base-700)', + }, + item: { + backgroundColor: 'var(--invokeai-colors-base-800)', + color: 'var(--invokeai-colors-base-200)', + padding: 6, + '&[data-hovered]': { + color: 'var(--invokeai-colors-base-100)', + backgroundColor: 'var(--invokeai-colors-base-750)', + }, + '&[data-active]': { + backgroundColor: 'var(--invokeai-colors-base-750)', + '&:hover': { + color: 'var(--invokeai-colors-base-100)', + backgroundColor: 'var(--invokeai-colors-base-750)', + }, + }, + '&[data-selected]': { + color: 'var(--invokeai-colors-base-50)', + backgroundColor: 'var(--invokeai-colors-accent-650)', + fontWeight: 600, + '&:hover': { + backgroundColor: 'var(--invokeai-colors-accent-600)', + }, + }, + }, + rightSection: { + width: 24, + }, + })} + {...rest} + /> + + ); +}; + +export default memo(IAIMantineMultiSelect); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index cf29636ea3..d749b33eea 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -1,43 +1,44 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { Scheduler } from 'app/constants'; +import { SCHEDULER_ITEMS } from 'app/constants'; +import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import IAIMantineSelect, { - IAISelectDataType, -} from 'common/components/IAIMantineSelect'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { setScheduler } from 'features/parameters/store/generationSlice'; -import { uiSelector } from 'features/ui/store/uiSelectors'; -import { memo, useCallback } from 'react'; +import { setSelectedSchedulers } from 'features/ui/store/uiSlice'; +import { memo, useCallback, useEffect } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createSelector( - [uiSelector, generationSelector], - (ui, generation) => { - const allSchedulers: string[] = ui.schedulers - .slice() - .sort((a, b) => a.localeCompare(b)); - - return { - scheduler: generation.scheduler, - allSchedulers, - }; - }, - defaultSelectorOptions -); - const ParamScheduler = () => { - const { allSchedulers, scheduler } = useAppSelector(selector); + const scheduler = useAppSelector( + (state: RootState) => state.generation.scheduler + ); + + const selectedSchedulers = useAppSelector( + (state: RootState) => state.ui.selectedSchedulers + ); + + const activeSchedulers = useAppSelector( + (state: RootState) => state.ui.activeSchedulers + ); const dispatch = useAppDispatch(); const { t } = useTranslation(); + useEffect(() => { + if (selectedSchedulers.length === 0) + dispatch(setSelectedSchedulers(SCHEDULER_ITEMS)); + + const schedulerFound = activeSchedulers.find( + (activeSchedulers) => activeSchedulers.label === scheduler + ); + if (!schedulerFound) dispatch(setScheduler(activeSchedulers[0].value)); + }, [dispatch, selectedSchedulers, scheduler, activeSchedulers]); + const handleChange = useCallback( (v: string | null) => { if (!v) { return; } - dispatch(setScheduler(v as Scheduler)); + dispatch(setScheduler(v)); }, [dispatch] ); @@ -46,7 +47,7 @@ const ParamScheduler = () => { ); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index f516229efe..5f13bf59ea 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,10 +1,10 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; -import { clamp, sortBy } from 'lodash-es'; -import { receivedModels } from 'services/thunks/model'; -import { Scheduler } from 'app/constants'; -import { ImageDTO } from 'services/api'; import { configChanged } from 'features/system/store/configSlice'; +import { clamp, sortBy } from 'lodash-es'; +import { ImageDTO } from 'services/api'; +import { imageUrlsReceived } from 'services/thunks/image'; +import { receivedModels } from 'services/thunks/model'; import { CfgScaleParam, HeightParam, @@ -17,7 +17,6 @@ import { StrengthParam, WidthParam, } from './parameterZodSchemas'; -import { imageUrlsReceived } from 'services/thunks/image'; export interface GenerationState { cfgScale: CfgScaleParam; @@ -133,7 +132,7 @@ export const generationSlice = createSlice({ setWidth: (state, action: PayloadAction) => { state.width = action.payload; }, - setScheduler: (state, action: PayloadAction) => { + setScheduler: (state, action: PayloadAction) => { state.scheduler = action.payload; }, setSeed: (state, action: PayloadAction) => { diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts index b99e57bfbb..d17dcbce10 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -1,4 +1,4 @@ -import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants'; +import { NUMPY_RAND_MAX } from 'app/constants'; import { z } from 'zod'; /** @@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam => /** * Zod schema for scheduler parameter */ -export const zScheduler = z.enum(SCHEDULERS); +export const zScheduler = z.string(); /** * Type alias for scheduler parameter, inferred from its zod schema */ diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx index e5f4a4cbf7..e8def6a7eb 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx @@ -1,47 +1,33 @@ -import { - Menu, - MenuButton, - MenuItemOption, - MenuList, - MenuOptionGroup, -} from '@chakra-ui/react'; import { SCHEDULERS } from 'app/constants'; - import { RootState } from 'app/store/store'; + import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIButton from 'common/components/IAIButton'; -import { setSchedulers } from 'features/ui/store/uiSlice'; -import { isArray } from 'lodash-es'; +import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; +import { setSelectedSchedulers } from 'features/ui/store/uiSlice'; import { useTranslation } from 'react-i18next'; export default function SettingsSchedulers() { - const schedulers = useAppSelector((state: RootState) => state.ui.schedulers); - const dispatch = useAppDispatch(); + + const selectedSchedulers = useAppSelector( + (state: RootState) => state.ui.selectedSchedulers + ); + const { t } = useTranslation(); - const schedulerSettingsHandler = (v: string | string[]) => { - if (isArray(v)) dispatch(setSchedulers(v.sort())); + const schedulerSettingsHandler = (v: string[]) => { + dispatch(setSelectedSchedulers(v)); }; return ( - - - {t('settings.availableSchedulers')} - - - - {SCHEDULERS.map((scheduler) => ( - - {scheduler} - - ))} - - - + ); } diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 65a48bc92c..907d5a5295 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -1,10 +1,11 @@ +import { SelectItem } from '@mantine/core'; import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; +import { SCHEDULERS } from 'app/constants'; +import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { setActiveTabReducer } from './extraReducers'; import { InvokeTabName } from './tabMap'; import { AddNewModelType, UIState } from './uiTypes'; -import { initialImageChanged } from 'features/parameters/store/generationSlice'; -import { SCHEDULERS } from 'app/constants'; export const initialUIState: UIState = { activeTab: 0, @@ -20,7 +21,8 @@ export const initialUIState: UIState = { shouldShowGallery: true, shouldHidePreview: false, shouldShowProgressInViewer: true, - schedulers: SCHEDULERS, + activeSchedulers: [], + selectedSchedulers: [], }; export const uiSlice = createSlice({ @@ -94,9 +96,20 @@ export const uiSlice = createSlice({ setShouldShowProgressInViewer: (state, action: PayloadAction) => { state.shouldShowProgressInViewer = action.payload; }, - setSchedulers: (state, action: PayloadAction) => { - state.schedulers = []; - state.schedulers = action.payload; + setSelectedSchedulers: (state, action: PayloadAction) => { + const selectedSchedulerData: SelectItem[] = []; + + if (action.payload.length === 0) action.payload = [SCHEDULERS[0].value]; + + action.payload.forEach((item) => { + const schedulerData = SCHEDULERS.find( + (scheduler) => scheduler.value === item + ); + if (schedulerData) selectedSchedulerData.push(schedulerData); + }); + + state.activeSchedulers = selectedSchedulerData; + state.selectedSchedulers = action.payload; }, }, extraReducers(builder) { @@ -124,7 +137,7 @@ export const { toggleParametersPanel, toggleGalleryPanel, setShouldShowProgressInViewer, - setSchedulers, + setSelectedSchedulers, } = uiSlice.actions; export default uiSlice.reducer; diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts index 18a758cdd6..88ccb864e8 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts @@ -1,3 +1,5 @@ +import { SelectItem } from '@mantine/core'; + export type AddNewModelType = 'ckpt' | 'diffusers' | null; export type Coordinates = { @@ -26,5 +28,6 @@ export interface UIState { shouldPinGallery: boolean; shouldShowGallery: boolean; shouldShowProgressInViewer: boolean; - schedulers: string[]; + activeSchedulers: SelectItem[]; + selectedSchedulers: string[]; } From 06428fac672d43a064584c56830aeb7ad2baecd1 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 18 Jun 2023 20:02:36 +1200 Subject: [PATCH 02/11] fix: Revert scheduler back to zod validation --- invokeai/frontend/web/src/app/constants.ts | 4 ++-- .../components/Parameters/Core/ParamScheduler.tsx | 9 +++++---- .../web/src/features/parameters/store/generationSlice.ts | 3 ++- .../src/features/parameters/store/parameterZodSchemas.ts | 4 ++-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index c757c8959f..ae362ec85b 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -41,9 +41,9 @@ export const SCHEDULER_ITEMS = [ 'heun', 'heun_k', 'unipc', -]; +] as const; -export type Scheduler = typeof SCHEDULERS; +export type Scheduler = (typeof SCHEDULER_ITEMS)[number]; // Valid upscaling levels export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [ diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index d749b33eea..50295b206e 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -1,4 +1,4 @@ -import { SCHEDULER_ITEMS } from 'app/constants'; +import { SCHEDULER_ITEMS, Scheduler } from 'app/constants'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; @@ -25,12 +25,13 @@ const ParamScheduler = () => { useEffect(() => { if (selectedSchedulers.length === 0) - dispatch(setSelectedSchedulers(SCHEDULER_ITEMS)); + dispatch(setSelectedSchedulers([...SCHEDULER_ITEMS])); const schedulerFound = activeSchedulers.find( (activeSchedulers) => activeSchedulers.label === scheduler ); - if (!schedulerFound) dispatch(setScheduler(activeSchedulers[0].value)); + if (!schedulerFound) + dispatch(setScheduler(activeSchedulers[0].value as Scheduler)); }, [dispatch, selectedSchedulers, scheduler, activeSchedulers]); const handleChange = useCallback( @@ -38,7 +39,7 @@ const ParamScheduler = () => { if (!v) { return; } - dispatch(setScheduler(v)); + dispatch(setScheduler(v as Scheduler)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 5f13bf59ea..59801b8c46 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,5 +1,6 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; +import { Scheduler } from 'app/constants'; import { configChanged } from 'features/system/store/configSlice'; import { clamp, sortBy } from 'lodash-es'; import { ImageDTO } from 'services/api'; @@ -132,7 +133,7 @@ export const generationSlice = createSlice({ setWidth: (state, action: PayloadAction) => { state.width = action.payload; }, - setScheduler: (state, action: PayloadAction) => { + setScheduler: (state, action: PayloadAction) => { state.scheduler = action.payload; }, setSeed: (state, action: PayloadAction) => { diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts index d17dcbce10..b865faf121 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -1,4 +1,4 @@ -import { NUMPY_RAND_MAX } from 'app/constants'; +import { NUMPY_RAND_MAX, SCHEDULER_ITEMS } from 'app/constants'; import { z } from 'zod'; /** @@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam => /** * Zod schema for scheduler parameter */ -export const zScheduler = z.string(); +export const zScheduler = z.enum(SCHEDULER_ITEMS); /** * Type alias for scheduler parameter, inferred from its zod schema */ From dae5b9b259d507645ac3243d57ed662aa3e6dc98 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 18 Jun 2023 20:06:56 +1200 Subject: [PATCH 03/11] fix: Minor styling fix to the IAIMantineMultiSelect component --- .../frontend/web/src/common/components/IAIMantineMultiSelect.tsx | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx index 2efb467869..5bf37c1aa9 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx @@ -68,6 +68,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { }, rightSection: { width: 24, + paddingRight: 20, }, })} {...rest} From be8c0bb95241db44fe0071131eac1f88bcc5e7a7 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 18 Jun 2023 20:17:51 +1200 Subject: [PATCH 04/11] feat: Use Labels for Schedulers --- invokeai/frontend/web/src/app/constants.ts | 36 +++++++++---------- .../Parameters/Core/ParamScheduler.tsx | 2 +- .../Core/ParamSchedulerAndModel.tsx | 4 +-- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index ae362ec85b..af125c7157 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -2,24 +2,24 @@ import { SelectItem } from '@mantine/core'; // TODO: use Enums? export const SCHEDULERS: SelectItem[] = [ - { label: 'euler', value: 'euler', group: 'Standard' }, - { label: 'deis', value: 'deis', group: 'Standard' }, - { label: 'ddim', value: 'ddim', group: 'Standard' }, - { label: 'ddpm', value: 'ddpm', group: 'Standard' }, - { label: 'dpmpp_2s', value: 'dpmpp_2s', group: 'Standard' }, - { label: 'dpmpp_2m', value: 'dpmpp_2m', group: 'Standard' }, - { label: 'heun', value: 'heun', group: 'Standard' }, - { label: 'kdpm_2', value: 'kdpm_2', group: 'Standard' }, - { label: 'lms', value: 'lms', group: 'Standard' }, - { label: 'pndm', value: 'pndm', group: 'Standard' }, - { label: 'unipc', value: 'unipc', group: 'Standard' }, - { label: 'euler_k', value: 'euler_k', group: 'Karras' }, - { label: 'dpmpp_2s_k', value: 'dpmpp_2s_k', group: 'Karras' }, - { label: 'dpmpp_2m_k', value: 'dpmpp_2m_k', group: 'Karras' }, - { label: 'heun_k', value: 'heun_k', group: 'Karras' }, - { label: 'lms_k', value: 'lms_k', group: 'Karras' }, - { label: 'euler_a', value: 'euler_a', group: 'Ancestral' }, - { label: 'kdpm_2_a', value: 'kdpm_2_a', group: 'Ancestral' }, + { label: 'Euler', value: 'euler', group: 'Standard' }, + { label: 'DEIS', value: 'deis', group: 'Standard' }, + { label: 'DDIM', value: 'ddim', group: 'Standard' }, + { label: 'DDPM', value: 'ddpm', group: 'Standard' }, + { label: 'DPM++ 2S', value: 'dpmpp_2s', group: 'Standard' }, + { label: 'DPM++ 2M', value: 'dpmpp_2m', group: 'Standard' }, + { label: 'Heun', value: 'heun', group: 'Standard' }, + { label: 'KDPM 2', value: 'kdpm_2', group: 'Standard' }, + { label: 'LMS', value: 'lms', group: 'Standard' }, + { label: 'PNDM', value: 'pndm', group: 'Standard' }, + { label: 'UniPC', value: 'unipc', group: 'Standard' }, + { label: 'Euler Karras', value: 'euler_k', group: 'Karras' }, + { label: 'DPM++ 2S Karras', value: 'dpmpp_2s_k', group: 'Karras' }, + { label: 'DPM++ 2M Karras', value: 'dpmpp_2m_k', group: 'Karras' }, + { label: 'Heun Karras', value: 'heun_k', group: 'Karras' }, + { label: 'LMS Karras', value: 'lms_k', group: 'Karras' }, + { label: 'Euler Ancestral', value: 'euler_a', group: 'Ancestral' }, + { label: 'KDPM 2 Ancestral', value: 'kdpm_2_a', group: 'Ancestral' }, ]; export const SCHEDULER_ITEMS = [ diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index 50295b206e..0a343d0742 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -28,7 +28,7 @@ const ParamScheduler = () => { dispatch(setSelectedSchedulers([...SCHEDULER_ITEMS])); const schedulerFound = activeSchedulers.find( - (activeSchedulers) => activeSchedulers.label === scheduler + (activeSchedulers) => activeSchedulers.value === scheduler ); if (!schedulerFound) dispatch(setScheduler(activeSchedulers[0].value as Scheduler)); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx index 3b53f5005c..65da89b94d 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx @@ -1,12 +1,12 @@ import { Box, Flex } from '@chakra-ui/react'; -import { memo } from 'react'; import ModelSelect from 'features/system/components/ModelSelect'; +import { memo } from 'react'; import ParamScheduler from './ParamScheduler'; const ParamSchedulerAndModel = () => { return ( - + From f1a8b9daee81657ee75d5caae84d1c82faf75220 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Jun 2023 18:47:59 +1000 Subject: [PATCH 05/11] fix(ui): clarify scheduler logic - use full conditional syntax with `{}` - do not mutate `action.payload` in a reducer --- .../components/Parameters/Core/ParamScheduler.tsx | 7 +++++-- .../frontend/web/src/features/ui/store/uiSlice.ts | 13 ++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index 0a343d0742..1fed60fa74 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -24,14 +24,17 @@ const ParamScheduler = () => { const { t } = useTranslation(); useEffect(() => { - if (selectedSchedulers.length === 0) + if (selectedSchedulers.length === 0) { dispatch(setSelectedSchedulers([...SCHEDULER_ITEMS])); + } const schedulerFound = activeSchedulers.find( (activeSchedulers) => activeSchedulers.value === scheduler ); - if (!schedulerFound) + + if (!schedulerFound) { dispatch(setScheduler(activeSchedulers[0].value as Scheduler)); + } }, [dispatch, selectedSchedulers, scheduler, activeSchedulers]); const handleChange = useCallback( diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 907d5a5295..4de6109b20 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -99,13 +99,20 @@ export const uiSlice = createSlice({ setSelectedSchedulers: (state, action: PayloadAction) => { const selectedSchedulerData: SelectItem[] = []; - if (action.payload.length === 0) action.payload = [SCHEDULERS[0].value]; + let selectedSchedulers = [...action.payload]; - action.payload.forEach((item) => { + if (selectedSchedulers.length === 0) { + selectedSchedulers = [SCHEDULERS[0].value]; + } + + selectedSchedulers.forEach((item) => { const schedulerData = SCHEDULERS.find( (scheduler) => scheduler.value === item ); - if (schedulerData) selectedSchedulerData.push(schedulerData); + + if (schedulerData) { + selectedSchedulerData.push(schedulerData); + } }); state.activeSchedulers = selectedSchedulerData; From 150059f7049dfc85fe42944a7a3eea5c67ef6b64 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Jun 2023 18:49:10 +1000 Subject: [PATCH 06/11] fix(ui): create all scheduler constants up-front --- invokeai/frontend/web/src/app/constants.ts | 34 ++++++++++--------- .../Parameters/Core/ParamScheduler.tsx | 4 +-- .../parameters/store/parameterZodSchemas.ts | 4 +-- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index af125c7157..3f737b669b 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -1,6 +1,5 @@ import { SelectItem } from '@mantine/core'; -// TODO: use Enums? export const SCHEDULERS: SelectItem[] = [ { label: 'Euler', value: 'euler', group: 'Standard' }, { label: 'DEIS', value: 'deis', group: 'Standard' }, @@ -22,28 +21,31 @@ export const SCHEDULERS: SelectItem[] = [ { label: 'KDPM 2 Ancestral', value: 'kdpm_2_a', group: 'Ancestral' }, ]; -export const SCHEDULER_ITEMS = [ - 'ddim', - 'lms', - 'lms_k', +// zod needs the array to be `as const` to infer the type correctly +export const SCHEDULER_NAMES_AS_CONST = [ 'euler', - 'euler_k', - 'euler_a', - 'dpmpp_2s', - 'dpmpp_2s_k', - 'dpmpp_2m', - 'dpmpp_2m_k', - 'kdpm_2', - 'kdpm_2_a', 'deis', + 'ddim', 'ddpm', - 'pndm', + 'dpmpp_2s', + 'dpmpp_2m', 'heun', - 'heun_k', + 'kdpm_2', + 'lms', + 'pndm', 'unipc', + 'euler_k', + 'dpmpp_2s_k', + 'dpmpp_2m_k', + 'heun_k', + 'lms_k', + 'euler_a', + 'kdpm_2_a', ] as const; -export type Scheduler = (typeof SCHEDULER_ITEMS)[number]; +export const SCHEDULER_NAMES = [...SCHEDULER_NAMES_AS_CONST]; + +export type Scheduler = (typeof SCHEDULER_NAMES)[number]; // Valid upscaling levels export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [ diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index 1fed60fa74..321f57ca6f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -1,4 +1,4 @@ -import { SCHEDULER_ITEMS, Scheduler } from 'app/constants'; +import { SCHEDULER_NAMES, Scheduler } from 'app/constants'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; @@ -25,7 +25,7 @@ const ParamScheduler = () => { useEffect(() => { if (selectedSchedulers.length === 0) { - dispatch(setSelectedSchedulers([...SCHEDULER_ITEMS])); + dispatch(setSelectedSchedulers(SCHEDULER_NAMES)); } const schedulerFound = activeSchedulers.find( diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts index b865faf121..61567d3fb8 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -1,4 +1,4 @@ -import { NUMPY_RAND_MAX, SCHEDULER_ITEMS } from 'app/constants'; +import { NUMPY_RAND_MAX, SCHEDULER_NAMES_AS_CONST } from 'app/constants'; import { z } from 'zod'; /** @@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam => /** * Zod schema for scheduler parameter */ -export const zScheduler = z.enum(SCHEDULER_ITEMS); +export const zScheduler = z.enum(SCHEDULER_NAMES_AS_CONST); /** * Type alias for scheduler parameter, inferred from its zod schema */ From 94cfcdc411604ce25c0c393c5715e7a5fc63c810 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Jun 2023 19:34:37 +1000 Subject: [PATCH 07/11] feat(ui): improve scheduler selection logic - remove UI-specific state (the enabled schedulers) from redux, instead derive it in a selector - simplify logic by putting schedulers in an object instead of an array - rename `activeSchedulers` to `enabledSchedulers` - remove need for `useEffect()` when `enabledSchedulers` changes by adding a listener for the `enabledSchedulersChanged` action/event to `generationSlice` - increase type safety by making `enabledSchedulers` an array of `SchedulerParam`, which is created by the zod schema for scheduler --- invokeai/frontend/web/src/app/constants.ts | 63 ++++++++++++------- .../Parameters/Core/ParamScheduler.tsx | 63 +++++++++---------- .../parameters/store/generationSlice.ts | 22 ++++++- .../SettingsModal/SettingsSchedulers.tsx | 35 +++++++---- .../web/src/features/ui/store/uiSlice.ts | 36 ++++------- .../web/src/features/ui/store/uiTypes.ts | 5 +- 6 files changed, 127 insertions(+), 97 deletions(-) diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index 3f737b669b..d62526f4b3 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -1,27 +1,8 @@ import { SelectItem } from '@mantine/core'; - -export const SCHEDULERS: SelectItem[] = [ - { label: 'Euler', value: 'euler', group: 'Standard' }, - { label: 'DEIS', value: 'deis', group: 'Standard' }, - { label: 'DDIM', value: 'ddim', group: 'Standard' }, - { label: 'DDPM', value: 'ddpm', group: 'Standard' }, - { label: 'DPM++ 2S', value: 'dpmpp_2s', group: 'Standard' }, - { label: 'DPM++ 2M', value: 'dpmpp_2m', group: 'Standard' }, - { label: 'Heun', value: 'heun', group: 'Standard' }, - { label: 'KDPM 2', value: 'kdpm_2', group: 'Standard' }, - { label: 'LMS', value: 'lms', group: 'Standard' }, - { label: 'PNDM', value: 'pndm', group: 'Standard' }, - { label: 'UniPC', value: 'unipc', group: 'Standard' }, - { label: 'Euler Karras', value: 'euler_k', group: 'Karras' }, - { label: 'DPM++ 2S Karras', value: 'dpmpp_2s_k', group: 'Karras' }, - { label: 'DPM++ 2M Karras', value: 'dpmpp_2m_k', group: 'Karras' }, - { label: 'Heun Karras', value: 'heun_k', group: 'Karras' }, - { label: 'LMS Karras', value: 'lms_k', group: 'Karras' }, - { label: 'Euler Ancestral', value: 'euler_a', group: 'Ancestral' }, - { label: 'KDPM 2 Ancestral', value: 'kdpm_2_a', group: 'Ancestral' }, -]; +import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; // zod needs the array to be `as const` to infer the type correctly +// this is the source of the `SchedulerParam` type, which is generated by zod export const SCHEDULER_NAMES_AS_CONST = [ 'euler', 'deis', @@ -43,7 +24,45 @@ export const SCHEDULER_NAMES_AS_CONST = [ 'kdpm_2_a', ] as const; -export const SCHEDULER_NAMES = [...SCHEDULER_NAMES_AS_CONST]; +export const DEFAULT_SCHEDULER_NAME = 'euler'; + +export const SCHEDULER_NAMES: SchedulerParam[] = [...SCHEDULER_NAMES_AS_CONST]; + +export const SCHEDULER_SELECT_ITEMS: Record< + (typeof SCHEDULER_NAMES)[number], + SelectItem & { label: string; value: SchedulerParam; group: string } +> = { + euler: { label: 'Euler', value: 'euler', group: 'Standard' }, + deis: { label: 'DEIS', value: 'deis', group: 'Standard' }, + ddim: { label: 'DDIM', value: 'ddim', group: 'Standard' }, + ddpm: { label: 'DDPM', value: 'ddpm', group: 'Standard' }, + dpmpp_2s: { label: 'DPM++ 2S', value: 'dpmpp_2s', group: 'Standard' }, + dpmpp_2m: { label: 'DPM++ 2M', value: 'dpmpp_2m', group: 'Standard' }, + heun: { label: 'Heun', value: 'heun', group: 'Standard' }, + kdpm_2: { label: 'KDPM 2', value: 'kdpm_2', group: 'Standard' }, + lms: { label: 'LMS', value: 'lms', group: 'Standard' }, + pndm: { label: 'PNDM', value: 'pndm', group: 'Standard' }, + unipc: { label: 'UniPC', value: 'unipc', group: 'Standard' }, + euler_k: { label: 'Euler Karras', value: 'euler_k', group: 'Karras' }, + dpmpp_2s_k: { + label: 'DPM++ 2S Karras', + value: 'dpmpp_2s_k', + group: 'Karras', + }, + dpmpp_2m_k: { + label: 'DPM++ 2M Karras', + value: 'dpmpp_2m_k', + group: 'Karras', + }, + heun_k: { label: 'Heun Karras', value: 'heun_k', group: 'Karras' }, + lms_k: { label: 'LMS Karras', value: 'lms_k', group: 'Karras' }, + euler_a: { label: 'Euler Ancestral', value: 'euler_a', group: 'Ancestral' }, + kdpm_2_a: { + label: 'KDPM 2 Ancestral', + value: 'kdpm_2_a', + group: 'Ancestral', + }, +}; export type Scheduler = (typeof SCHEDULER_NAMES)[number]; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index 321f57ca6f..e0f6f55a33 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -1,48 +1,47 @@ -import { SCHEDULER_NAMES, Scheduler } from 'app/constants'; -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { SCHEDULER_SELECT_ITEMS } from 'app/constants'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setScheduler } from 'features/parameters/store/generationSlice'; -import { setSelectedSchedulers } from 'features/ui/store/uiSlice'; -import { memo, useCallback, useEffect } from 'react'; +import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; +import { uiSelector } from 'features/ui/store/uiSelectors'; +import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; +const selector = createSelector( + [uiSelector, generationSelector], + (ui, generation) => { + const { scheduler } = generation; + const { enabledSchedulers } = ui; + + const data = enabledSchedulers + .map( + (schedulerName) => + SCHEDULER_SELECT_ITEMS[schedulerName as SchedulerParam] + ) + .sort((a, b) => a.label.localeCompare(b.label)); + + return { + scheduler, + data, + }; + }, + defaultSelectorOptions +); + const ParamScheduler = () => { - const scheduler = useAppSelector( - (state: RootState) => state.generation.scheduler - ); - - const selectedSchedulers = useAppSelector( - (state: RootState) => state.ui.selectedSchedulers - ); - - const activeSchedulers = useAppSelector( - (state: RootState) => state.ui.activeSchedulers - ); - const dispatch = useAppDispatch(); const { t } = useTranslation(); - - useEffect(() => { - if (selectedSchedulers.length === 0) { - dispatch(setSelectedSchedulers(SCHEDULER_NAMES)); - } - - const schedulerFound = activeSchedulers.find( - (activeSchedulers) => activeSchedulers.value === scheduler - ); - - if (!schedulerFound) { - dispatch(setScheduler(activeSchedulers[0].value as Scheduler)); - } - }, [dispatch, selectedSchedulers, scheduler, activeSchedulers]); + const { scheduler, data } = useAppSelector(selector); const handleChange = useCallback( (v: string | null) => { if (!v) { return; } - dispatch(setScheduler(v as Scheduler)); + dispatch(setScheduler(v as SchedulerParam)); }, [dispatch] ); @@ -51,7 +50,7 @@ const ParamScheduler = () => { ); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 59801b8c46..63132defc0 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,6 +1,5 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; -import { Scheduler } from 'app/constants'; import { configChanged } from 'features/system/store/configSlice'; import { clamp, sortBy } from 'lodash-es'; import { ImageDTO } from 'services/api'; @@ -18,6 +17,8 @@ import { StrengthParam, WidthParam, } from './parameterZodSchemas'; +import { enabledSchedulersChanged } from 'features/ui/store/uiSlice'; +import { DEFAULT_SCHEDULER_NAME } from 'app/constants'; export interface GenerationState { cfgScale: CfgScaleParam; @@ -63,7 +64,7 @@ export const initialGenerationState: GenerationState = { perlin: 0, positivePrompt: '', negativePrompt: '', - scheduler: 'euler', + scheduler: DEFAULT_SCHEDULER_NAME, seamBlur: 16, seamSize: 96, seamSteps: 30, @@ -133,7 +134,7 @@ export const generationSlice = createSlice({ setWidth: (state, action: PayloadAction) => { state.width = action.payload; }, - setScheduler: (state, action: PayloadAction) => { + setScheduler: (state, action: PayloadAction) => { state.scheduler = action.payload; }, setSeed: (state, action: PayloadAction) => { @@ -241,6 +242,21 @@ export const generationSlice = createSlice({ state.initialImage.thumbnail_url = thumbnail_url; } }); + + builder.addCase(enabledSchedulersChanged, (state, action) => { + const enabledSchedulers = action.payload; + + if (!action.payload.length) { + // This means the user cleared the enabled schedulers multi-select. We need to set the scheduler to the default + state.scheduler = DEFAULT_SCHEDULER_NAME; + return; + } + + if (!enabledSchedulers.includes(state.scheduler)) { + // The current scheduler is now disabled, change it to the first enabled one + state.scheduler = action.payload[0]; + } + }); }, }); diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx index e8def6a7eb..68b75842cb 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx @@ -1,30 +1,39 @@ -import { SCHEDULERS } from 'app/constants'; +import { SCHEDULER_SELECT_ITEMS } from 'app/constants'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; -import { setSelectedSchedulers } from 'features/ui/store/uiSlice'; +import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; +import { enabledSchedulersChanged } from 'features/ui/store/uiSlice'; +import { map } from 'lodash-es'; +import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; +const data = map(SCHEDULER_SELECT_ITEMS).sort((a, b) => + a.label.localeCompare(b.label) +); + export default function SettingsSchedulers() { const dispatch = useAppDispatch(); - - const selectedSchedulers = useAppSelector( - (state: RootState) => state.ui.selectedSchedulers - ); - const { t } = useTranslation(); - const schedulerSettingsHandler = (v: string[]) => { - dispatch(setSelectedSchedulers(v)); - }; + const enabledSchedulers = useAppSelector( + (state: RootState) => state.ui.enabledSchedulers + ); + + const handleChange = useCallback( + (v: string[]) => { + dispatch(enabledSchedulersChanged(v as SchedulerParam[])); + }, + [dispatch] + ); return ( ) => { state.shouldShowProgressInViewer = action.payload; }, - setSelectedSchedulers: (state, action: PayloadAction) => { - const selectedSchedulerData: SelectItem[] = []; - - let selectedSchedulers = [...action.payload]; - - if (selectedSchedulers.length === 0) { - selectedSchedulers = [SCHEDULERS[0].value]; + enabledSchedulersChanged: ( + state, + action: PayloadAction + ) => { + if (action.payload.length === 0) { + state.enabledSchedulers = [DEFAULT_SCHEDULER_NAME]; + return; } - selectedSchedulers.forEach((item) => { - const schedulerData = SCHEDULERS.find( - (scheduler) => scheduler.value === item - ); - - if (schedulerData) { - selectedSchedulerData.push(schedulerData); - } - }); - - state.activeSchedulers = selectedSchedulerData; - state.selectedSchedulers = action.payload; + state.enabledSchedulers = action.payload; }, }, extraReducers(builder) { @@ -144,7 +132,7 @@ export const { toggleParametersPanel, toggleGalleryPanel, setShouldShowProgressInViewer, - setSelectedSchedulers, + enabledSchedulersChanged, } = uiSlice.actions; export default uiSlice.reducer; diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts index 88ccb864e8..c73964c042 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts @@ -1,4 +1,4 @@ -import { SelectItem } from '@mantine/core'; +import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; export type AddNewModelType = 'ckpt' | 'diffusers' | null; @@ -28,6 +28,5 @@ export interface UIState { shouldPinGallery: boolean; shouldShowGallery: boolean; shouldShowProgressInViewer: boolean; - activeSchedulers: SelectItem[]; - selectedSchedulers: string[]; + enabledSchedulers: SchedulerParam[]; } From 450641c414b8db7df16946e1c9bceeae1ab86d4b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Jun 2023 19:39:31 +1000 Subject: [PATCH 08/11] fix(ui): enable all schedulers by default --- invokeai/frontend/web/src/features/ui/store/uiSlice.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 6f535900fe..d334fb2341 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -21,7 +21,7 @@ export const initialUIState: UIState = { shouldShowGallery: true, shouldHidePreview: false, shouldShowProgressInViewer: true, - enabledSchedulers: [], + enabledSchedulers: SCHEDULER_NAMES, }; export const uiSlice = createSlice({ From b96b95bc957513a6aca9b6188e502072d9a2b446 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Jun 2023 20:01:05 +1000 Subject: [PATCH 09/11] feat(ui): `enabledSchedulers` -> `favoriteSchedulers` --- invokeai/frontend/web/public/locales/en.json | 3 +- invokeai/frontend/web/src/app/constants.ts | 54 +++++++------------ .../Parameters/Core/ParamScheduler.tsx | 17 +++--- .../parameters/store/generationSlice.ts | 16 ------ .../SettingsModal/SettingsSchedulers.tsx | 18 ++++--- .../web/src/features/ui/store/uiSlice.ts | 14 ++--- .../web/src/features/ui/store/uiTypes.ts | 2 +- 7 files changed, 45 insertions(+), 79 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 7a73bae411..eae0c07eff 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -547,7 +547,8 @@ "general": "General", "generation": "Generation", "ui": "User Interface", - "availableSchedulers": "Available Schedulers" + "favoriteSchedulers": "Favorite Schedulers", + "favoriteSchedulersPlaceholder": "No schedulers favorited" }, "toast": { "serverError": "Server Error", diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index d62526f4b3..db5fea4a66 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -1,4 +1,3 @@ -import { SelectItem } from '@mantine/core'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; // zod needs the array to be `as const` to infer the type correctly @@ -28,40 +27,25 @@ export const DEFAULT_SCHEDULER_NAME = 'euler'; export const SCHEDULER_NAMES: SchedulerParam[] = [...SCHEDULER_NAMES_AS_CONST]; -export const SCHEDULER_SELECT_ITEMS: Record< - (typeof SCHEDULER_NAMES)[number], - SelectItem & { label: string; value: SchedulerParam; group: string } -> = { - euler: { label: 'Euler', value: 'euler', group: 'Standard' }, - deis: { label: 'DEIS', value: 'deis', group: 'Standard' }, - ddim: { label: 'DDIM', value: 'ddim', group: 'Standard' }, - ddpm: { label: 'DDPM', value: 'ddpm', group: 'Standard' }, - dpmpp_2s: { label: 'DPM++ 2S', value: 'dpmpp_2s', group: 'Standard' }, - dpmpp_2m: { label: 'DPM++ 2M', value: 'dpmpp_2m', group: 'Standard' }, - heun: { label: 'Heun', value: 'heun', group: 'Standard' }, - kdpm_2: { label: 'KDPM 2', value: 'kdpm_2', group: 'Standard' }, - lms: { label: 'LMS', value: 'lms', group: 'Standard' }, - pndm: { label: 'PNDM', value: 'pndm', group: 'Standard' }, - unipc: { label: 'UniPC', value: 'unipc', group: 'Standard' }, - euler_k: { label: 'Euler Karras', value: 'euler_k', group: 'Karras' }, - dpmpp_2s_k: { - label: 'DPM++ 2S Karras', - value: 'dpmpp_2s_k', - group: 'Karras', - }, - dpmpp_2m_k: { - label: 'DPM++ 2M Karras', - value: 'dpmpp_2m_k', - group: 'Karras', - }, - heun_k: { label: 'Heun Karras', value: 'heun_k', group: 'Karras' }, - lms_k: { label: 'LMS Karras', value: 'lms_k', group: 'Karras' }, - euler_a: { label: 'Euler Ancestral', value: 'euler_a', group: 'Ancestral' }, - kdpm_2_a: { - label: 'KDPM 2 Ancestral', - value: 'kdpm_2_a', - group: 'Ancestral', - }, +export const SCHEDULER_LABEL_MAP: Record = { + euler: 'Euler', + deis: 'DEIS', + ddim: 'DDIM', + ddpm: 'DDPM', + dpmpp_2s: 'DPM++ 2S', + dpmpp_2m: 'DPM++ 2M', + heun: 'Heun', + kdpm_2: 'KDPM 2', + lms: 'LMS', + pndm: 'PNDM', + unipc: 'UniPC', + euler_k: 'Euler Karras', + dpmpp_2s_k: 'DPM++ 2S Karras', + dpmpp_2m_k: 'DPM++ 2M Karras', + heun_k: 'Heun Karras', + lms_k: 'LMS Karras', + euler_a: 'Euler Ancestral', + kdpm_2_a: 'KDPM 2 Ancestral', }; export type Scheduler = (typeof SCHEDULER_NAMES)[number]; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index e0f6f55a33..8818dcba9b 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -1,5 +1,5 @@ import { createSelector } from '@reduxjs/toolkit'; -import { SCHEDULER_SELECT_ITEMS } from 'app/constants'; +import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; @@ -14,14 +14,15 @@ const selector = createSelector( [uiSelector, generationSelector], (ui, generation) => { const { scheduler } = generation; - const { enabledSchedulers } = ui; + const { favoriteSchedulers: enabledSchedulers } = ui; - const data = enabledSchedulers - .map( - (schedulerName) => - SCHEDULER_SELECT_ITEMS[schedulerName as SchedulerParam] - ) - .sort((a, b) => a.label.localeCompare(b.label)); + const data = SCHEDULER_NAMES.map((schedulerName) => ({ + value: schedulerName, + label: SCHEDULER_LABEL_MAP[schedulerName as SchedulerParam], + group: enabledSchedulers.includes(schedulerName) + ? 'Favorites' + : undefined, + })).sort((a, b) => a.label.localeCompare(b.label)); return { scheduler, diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 63132defc0..961ea1b8af 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -17,7 +17,6 @@ import { StrengthParam, WidthParam, } from './parameterZodSchemas'; -import { enabledSchedulersChanged } from 'features/ui/store/uiSlice'; import { DEFAULT_SCHEDULER_NAME } from 'app/constants'; export interface GenerationState { @@ -242,21 +241,6 @@ export const generationSlice = createSlice({ state.initialImage.thumbnail_url = thumbnail_url; } }); - - builder.addCase(enabledSchedulersChanged, (state, action) => { - const enabledSchedulers = action.payload; - - if (!action.payload.length) { - // This means the user cleared the enabled schedulers multi-select. We need to set the scheduler to the default - state.scheduler = DEFAULT_SCHEDULER_NAME; - return; - } - - if (!enabledSchedulers.includes(state.scheduler)) { - // The current scheduler is now disabled, change it to the first enabled one - state.scheduler = action.payload[0]; - } - }); }, }); diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx index 68b75842cb..2e0b3234c7 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx @@ -1,42 +1,44 @@ -import { SCHEDULER_SELECT_ITEMS } from 'app/constants'; +import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; -import { enabledSchedulersChanged } from 'features/ui/store/uiSlice'; +import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice'; import { map } from 'lodash-es'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -const data = map(SCHEDULER_SELECT_ITEMS).sort((a, b) => - a.label.localeCompare(b.label) -); +const data = map(SCHEDULER_NAMES, (s) => ({ + value: s, + label: SCHEDULER_LABEL_MAP[s], +})).sort((a, b) => a.label.localeCompare(b.label)); export default function SettingsSchedulers() { const dispatch = useAppDispatch(); const { t } = useTranslation(); const enabledSchedulers = useAppSelector( - (state: RootState) => state.ui.enabledSchedulers + (state: RootState) => state.ui.favoriteSchedulers ); const handleChange = useCallback( (v: string[]) => { - dispatch(enabledSchedulersChanged(v as SchedulerParam[])); + dispatch(favoriteSchedulersChanged(v as SchedulerParam[])); }, [dispatch] ); return ( ); } diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index d334fb2341..36c514e995 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -5,7 +5,6 @@ import { setActiveTabReducer } from './extraReducers'; import { InvokeTabName } from './tabMap'; import { AddNewModelType, UIState } from './uiTypes'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; -import { DEFAULT_SCHEDULER_NAME, SCHEDULER_NAMES } from 'app/constants'; export const initialUIState: UIState = { activeTab: 0, @@ -21,7 +20,7 @@ export const initialUIState: UIState = { shouldShowGallery: true, shouldHidePreview: false, shouldShowProgressInViewer: true, - enabledSchedulers: SCHEDULER_NAMES, + favoriteSchedulers: [], }; export const uiSlice = createSlice({ @@ -95,16 +94,11 @@ export const uiSlice = createSlice({ setShouldShowProgressInViewer: (state, action: PayloadAction) => { state.shouldShowProgressInViewer = action.payload; }, - enabledSchedulersChanged: ( + favoriteSchedulersChanged: ( state, action: PayloadAction ) => { - if (action.payload.length === 0) { - state.enabledSchedulers = [DEFAULT_SCHEDULER_NAME]; - return; - } - - state.enabledSchedulers = action.payload; + state.favoriteSchedulers = action.payload; }, }, extraReducers(builder) { @@ -132,7 +126,7 @@ export const { toggleParametersPanel, toggleGalleryPanel, setShouldShowProgressInViewer, - enabledSchedulersChanged, + favoriteSchedulersChanged, } = uiSlice.actions; export default uiSlice.reducer; diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts index c73964c042..2a9a82fbe8 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts @@ -28,5 +28,5 @@ export interface UIState { shouldPinGallery: boolean; shouldShowGallery: boolean; shouldShowProgressInViewer: boolean; - enabledSchedulers: SchedulerParam[]; + favoriteSchedulers: SchedulerParam[]; } From a960fa009d63083de5d10558d4f12eb4a29e97e0 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 18 Jun 2023 22:04:12 +1200 Subject: [PATCH 10/11] fix: Fix some styling issues with IAIMantineMultiSelect --- .../src/common/components/IAIMantineMultiSelect.tsx | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx index 5bf37c1aa9..646c0288fe 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx @@ -29,10 +29,16 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { '&:focus': { borderColor: 'var(--invokeai-colors-accent-600)', }, + '&:focus-within': { + borderColor: 'var(--invokeai-colors-accent-600)', + }, }, value: { backgroundColor: 'var(--invokeai-colors-base-800)', color: 'var(--invokeai-colors-base-100)', + button: { + color: 'var(--invokeai-colors-base-100)', + }, '&:hover': { backgroundColor: 'var(--invokeai-colors-base-700)', cursor: 'pointer', @@ -68,7 +74,10 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { }, rightSection: { width: 24, - paddingRight: 20, + padding: 20, + button: { + color: 'var(--invokeai-colors-base-100)', + }, }, })} {...rest} From 80a8d3ef28cb29f32304ab6e3afbb310244ea9e5 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 18 Jun 2023 22:17:09 +1200 Subject: [PATCH 11/11] style: Theme placeholder style for IAIMantineMultiSelect --- .../web/src/common/components/IAIMantineMultiSelect.tsx | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx index 646c0288fe..c7ce1de4c1 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx @@ -17,6 +17,11 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { color: 'var(--invokeai-colors-base-300)', fontWeight: 'normal', }, + searchInput: { + '::placeholder': { + color: 'var(--invokeai-colors-base-700)', + }, + }, input: { backgroundColor: 'var(--invokeai-colors-base-900)', borderWidth: '2px',