mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): improve scheduler selection logic
- remove UI-specific state (the enabled schedulers) from redux, instead derive it in a selector - simplify logic by putting schedulers in an object instead of an array - rename `activeSchedulers` to `enabledSchedulers` - remove need for `useEffect()` when `enabledSchedulers` changes by adding a listener for the `enabledSchedulersChanged` action/event to `generationSlice` - increase type safety by making `enabledSchedulers` an array of `SchedulerParam`, which is created by the zod schema for scheduler
This commit is contained in:
parent
150059f704
commit
94cfcdc411
@ -1,27 +1,8 @@
|
|||||||
import { SelectItem } from '@mantine/core';
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||||
export const SCHEDULERS: SelectItem[] = [
|
|
||||||
{ label: 'Euler', value: 'euler', group: 'Standard' },
|
|
||||||
{ label: 'DEIS', value: 'deis', group: 'Standard' },
|
|
||||||
{ label: 'DDIM', value: 'ddim', group: 'Standard' },
|
|
||||||
{ label: 'DDPM', value: 'ddpm', group: 'Standard' },
|
|
||||||
{ label: 'DPM++ 2S', value: 'dpmpp_2s', group: 'Standard' },
|
|
||||||
{ label: 'DPM++ 2M', value: 'dpmpp_2m', group: 'Standard' },
|
|
||||||
{ label: 'Heun', value: 'heun', group: 'Standard' },
|
|
||||||
{ label: 'KDPM 2', value: 'kdpm_2', group: 'Standard' },
|
|
||||||
{ label: 'LMS', value: 'lms', group: 'Standard' },
|
|
||||||
{ label: 'PNDM', value: 'pndm', group: 'Standard' },
|
|
||||||
{ label: 'UniPC', value: 'unipc', group: 'Standard' },
|
|
||||||
{ label: 'Euler Karras', value: 'euler_k', group: 'Karras' },
|
|
||||||
{ label: 'DPM++ 2S Karras', value: 'dpmpp_2s_k', group: 'Karras' },
|
|
||||||
{ label: 'DPM++ 2M Karras', value: 'dpmpp_2m_k', group: 'Karras' },
|
|
||||||
{ label: 'Heun Karras', value: 'heun_k', group: 'Karras' },
|
|
||||||
{ label: 'LMS Karras', value: 'lms_k', group: 'Karras' },
|
|
||||||
{ label: 'Euler Ancestral', value: 'euler_a', group: 'Ancestral' },
|
|
||||||
{ label: 'KDPM 2 Ancestral', value: 'kdpm_2_a', group: 'Ancestral' },
|
|
||||||
];
|
|
||||||
|
|
||||||
// zod needs the array to be `as const` to infer the type correctly
|
// zod needs the array to be `as const` to infer the type correctly
|
||||||
|
// this is the source of the `SchedulerParam` type, which is generated by zod
|
||||||
export const SCHEDULER_NAMES_AS_CONST = [
|
export const SCHEDULER_NAMES_AS_CONST = [
|
||||||
'euler',
|
'euler',
|
||||||
'deis',
|
'deis',
|
||||||
@ -43,7 +24,45 @@ export const SCHEDULER_NAMES_AS_CONST = [
|
|||||||
'kdpm_2_a',
|
'kdpm_2_a',
|
||||||
] as const;
|
] as const;
|
||||||
|
|
||||||
export const SCHEDULER_NAMES = [...SCHEDULER_NAMES_AS_CONST];
|
export const DEFAULT_SCHEDULER_NAME = 'euler';
|
||||||
|
|
||||||
|
export const SCHEDULER_NAMES: SchedulerParam[] = [...SCHEDULER_NAMES_AS_CONST];
|
||||||
|
|
||||||
|
export const SCHEDULER_SELECT_ITEMS: Record<
|
||||||
|
(typeof SCHEDULER_NAMES)[number],
|
||||||
|
SelectItem & { label: string; value: SchedulerParam; group: string }
|
||||||
|
> = {
|
||||||
|
euler: { label: 'Euler', value: 'euler', group: 'Standard' },
|
||||||
|
deis: { label: 'DEIS', value: 'deis', group: 'Standard' },
|
||||||
|
ddim: { label: 'DDIM', value: 'ddim', group: 'Standard' },
|
||||||
|
ddpm: { label: 'DDPM', value: 'ddpm', group: 'Standard' },
|
||||||
|
dpmpp_2s: { label: 'DPM++ 2S', value: 'dpmpp_2s', group: 'Standard' },
|
||||||
|
dpmpp_2m: { label: 'DPM++ 2M', value: 'dpmpp_2m', group: 'Standard' },
|
||||||
|
heun: { label: 'Heun', value: 'heun', group: 'Standard' },
|
||||||
|
kdpm_2: { label: 'KDPM 2', value: 'kdpm_2', group: 'Standard' },
|
||||||
|
lms: { label: 'LMS', value: 'lms', group: 'Standard' },
|
||||||
|
pndm: { label: 'PNDM', value: 'pndm', group: 'Standard' },
|
||||||
|
unipc: { label: 'UniPC', value: 'unipc', group: 'Standard' },
|
||||||
|
euler_k: { label: 'Euler Karras', value: 'euler_k', group: 'Karras' },
|
||||||
|
dpmpp_2s_k: {
|
||||||
|
label: 'DPM++ 2S Karras',
|
||||||
|
value: 'dpmpp_2s_k',
|
||||||
|
group: 'Karras',
|
||||||
|
},
|
||||||
|
dpmpp_2m_k: {
|
||||||
|
label: 'DPM++ 2M Karras',
|
||||||
|
value: 'dpmpp_2m_k',
|
||||||
|
group: 'Karras',
|
||||||
|
},
|
||||||
|
heun_k: { label: 'Heun Karras', value: 'heun_k', group: 'Karras' },
|
||||||
|
lms_k: { label: 'LMS Karras', value: 'lms_k', group: 'Karras' },
|
||||||
|
euler_a: { label: 'Euler Ancestral', value: 'euler_a', group: 'Ancestral' },
|
||||||
|
kdpm_2_a: {
|
||||||
|
label: 'KDPM 2 Ancestral',
|
||||||
|
value: 'kdpm_2_a',
|
||||||
|
group: 'Ancestral',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
export type Scheduler = (typeof SCHEDULER_NAMES)[number];
|
export type Scheduler = (typeof SCHEDULER_NAMES)[number];
|
||||||
|
|
||||||
|
@ -1,48 +1,47 @@
|
|||||||
import { SCHEDULER_NAMES, Scheduler } from 'app/constants';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { SCHEDULER_SELECT_ITEMS } from 'app/constants';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { setScheduler } from 'features/parameters/store/generationSlice';
|
import { setScheduler } from 'features/parameters/store/generationSlice';
|
||||||
import { setSelectedSchedulers } from 'features/ui/store/uiSlice';
|
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||||
import { memo, useCallback, useEffect } from 'react';
|
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[uiSelector, generationSelector],
|
||||||
|
(ui, generation) => {
|
||||||
|
const { scheduler } = generation;
|
||||||
|
const { enabledSchedulers } = ui;
|
||||||
|
|
||||||
|
const data = enabledSchedulers
|
||||||
|
.map(
|
||||||
|
(schedulerName) =>
|
||||||
|
SCHEDULER_SELECT_ITEMS[schedulerName as SchedulerParam]
|
||||||
|
)
|
||||||
|
.sort((a, b) => a.label.localeCompare(b.label));
|
||||||
|
|
||||||
|
return {
|
||||||
|
scheduler,
|
||||||
|
data,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
const ParamScheduler = () => {
|
const ParamScheduler = () => {
|
||||||
const scheduler = useAppSelector(
|
|
||||||
(state: RootState) => state.generation.scheduler
|
|
||||||
);
|
|
||||||
|
|
||||||
const selectedSchedulers = useAppSelector(
|
|
||||||
(state: RootState) => state.ui.selectedSchedulers
|
|
||||||
);
|
|
||||||
|
|
||||||
const activeSchedulers = useAppSelector(
|
|
||||||
(state: RootState) => state.ui.activeSchedulers
|
|
||||||
);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const { scheduler, data } = useAppSelector(selector);
|
||||||
useEffect(() => {
|
|
||||||
if (selectedSchedulers.length === 0) {
|
|
||||||
dispatch(setSelectedSchedulers(SCHEDULER_NAMES));
|
|
||||||
}
|
|
||||||
|
|
||||||
const schedulerFound = activeSchedulers.find(
|
|
||||||
(activeSchedulers) => activeSchedulers.value === scheduler
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!schedulerFound) {
|
|
||||||
dispatch(setScheduler(activeSchedulers[0].value as Scheduler));
|
|
||||||
}
|
|
||||||
}, [dispatch, selectedSchedulers, scheduler, activeSchedulers]);
|
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
if (!v) {
|
if (!v) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(setScheduler(v as Scheduler));
|
dispatch(setScheduler(v as SchedulerParam));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
@ -51,7 +50,7 @@ const ParamScheduler = () => {
|
|||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
label={t('parameters.scheduler')}
|
label={t('parameters.scheduler')}
|
||||||
value={scheduler}
|
value={scheduler}
|
||||||
data={activeSchedulers}
|
data={data}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import { Scheduler } from 'app/constants';
|
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { clamp, sortBy } from 'lodash-es';
|
import { clamp, sortBy } from 'lodash-es';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
@ -18,6 +17,8 @@ import {
|
|||||||
StrengthParam,
|
StrengthParam,
|
||||||
WidthParam,
|
WidthParam,
|
||||||
} from './parameterZodSchemas';
|
} from './parameterZodSchemas';
|
||||||
|
import { enabledSchedulersChanged } from 'features/ui/store/uiSlice';
|
||||||
|
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
|
||||||
|
|
||||||
export interface GenerationState {
|
export interface GenerationState {
|
||||||
cfgScale: CfgScaleParam;
|
cfgScale: CfgScaleParam;
|
||||||
@ -63,7 +64,7 @@ export const initialGenerationState: GenerationState = {
|
|||||||
perlin: 0,
|
perlin: 0,
|
||||||
positivePrompt: '',
|
positivePrompt: '',
|
||||||
negativePrompt: '',
|
negativePrompt: '',
|
||||||
scheduler: 'euler',
|
scheduler: DEFAULT_SCHEDULER_NAME,
|
||||||
seamBlur: 16,
|
seamBlur: 16,
|
||||||
seamSize: 96,
|
seamSize: 96,
|
||||||
seamSteps: 30,
|
seamSteps: 30,
|
||||||
@ -133,7 +134,7 @@ export const generationSlice = createSlice({
|
|||||||
setWidth: (state, action: PayloadAction<number>) => {
|
setWidth: (state, action: PayloadAction<number>) => {
|
||||||
state.width = action.payload;
|
state.width = action.payload;
|
||||||
},
|
},
|
||||||
setScheduler: (state, action: PayloadAction<Scheduler>) => {
|
setScheduler: (state, action: PayloadAction<SchedulerParam>) => {
|
||||||
state.scheduler = action.payload;
|
state.scheduler = action.payload;
|
||||||
},
|
},
|
||||||
setSeed: (state, action: PayloadAction<number>) => {
|
setSeed: (state, action: PayloadAction<number>) => {
|
||||||
@ -241,6 +242,21 @@ export const generationSlice = createSlice({
|
|||||||
state.initialImage.thumbnail_url = thumbnail_url;
|
state.initialImage.thumbnail_url = thumbnail_url;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
builder.addCase(enabledSchedulersChanged, (state, action) => {
|
||||||
|
const enabledSchedulers = action.payload;
|
||||||
|
|
||||||
|
if (!action.payload.length) {
|
||||||
|
// This means the user cleared the enabled schedulers multi-select. We need to set the scheduler to the default
|
||||||
|
state.scheduler = DEFAULT_SCHEDULER_NAME;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!enabledSchedulers.includes(state.scheduler)) {
|
||||||
|
// The current scheduler is now disabled, change it to the first enabled one
|
||||||
|
state.scheduler = action.payload[0];
|
||||||
|
}
|
||||||
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1,30 +1,39 @@
|
|||||||
import { SCHEDULERS } from 'app/constants';
|
import { SCHEDULER_SELECT_ITEMS } from 'app/constants';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||||
import { setSelectedSchedulers } from 'features/ui/store/uiSlice';
|
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||||
|
import { enabledSchedulersChanged } from 'features/ui/store/uiSlice';
|
||||||
|
import { map } from 'lodash-es';
|
||||||
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const data = map(SCHEDULER_SELECT_ITEMS).sort((a, b) =>
|
||||||
|
a.label.localeCompare(b.label)
|
||||||
|
);
|
||||||
|
|
||||||
export default function SettingsSchedulers() {
|
export default function SettingsSchedulers() {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const selectedSchedulers = useAppSelector(
|
|
||||||
(state: RootState) => state.ui.selectedSchedulers
|
|
||||||
);
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const schedulerSettingsHandler = (v: string[]) => {
|
const enabledSchedulers = useAppSelector(
|
||||||
dispatch(setSelectedSchedulers(v));
|
(state: RootState) => state.ui.enabledSchedulers
|
||||||
};
|
);
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: string[]) => {
|
||||||
|
dispatch(enabledSchedulersChanged(v as SchedulerParam[]));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineMultiSelect
|
<IAIMantineMultiSelect
|
||||||
label={t('settings.availableSchedulers')}
|
label={t('settings.availableSchedulers')}
|
||||||
value={selectedSchedulers}
|
value={enabledSchedulers}
|
||||||
data={SCHEDULERS}
|
data={data}
|
||||||
onChange={schedulerSettingsHandler}
|
onChange={handleChange}
|
||||||
clearable
|
clearable
|
||||||
searchable
|
searchable
|
||||||
maxSelectedValues={99}
|
maxSelectedValues={99}
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import { SelectItem } from '@mantine/core';
|
|
||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import { SCHEDULERS } from 'app/constants';
|
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { setActiveTabReducer } from './extraReducers';
|
import { setActiveTabReducer } from './extraReducers';
|
||||||
import { InvokeTabName } from './tabMap';
|
import { InvokeTabName } from './tabMap';
|
||||||
import { AddNewModelType, UIState } from './uiTypes';
|
import { AddNewModelType, UIState } from './uiTypes';
|
||||||
|
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||||
|
import { DEFAULT_SCHEDULER_NAME, SCHEDULER_NAMES } from 'app/constants';
|
||||||
|
|
||||||
export const initialUIState: UIState = {
|
export const initialUIState: UIState = {
|
||||||
activeTab: 0,
|
activeTab: 0,
|
||||||
@ -21,8 +21,7 @@ export const initialUIState: UIState = {
|
|||||||
shouldShowGallery: true,
|
shouldShowGallery: true,
|
||||||
shouldHidePreview: false,
|
shouldHidePreview: false,
|
||||||
shouldShowProgressInViewer: true,
|
shouldShowProgressInViewer: true,
|
||||||
activeSchedulers: [],
|
enabledSchedulers: [],
|
||||||
selectedSchedulers: [],
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export const uiSlice = createSlice({
|
export const uiSlice = createSlice({
|
||||||
@ -96,27 +95,16 @@ export const uiSlice = createSlice({
|
|||||||
setShouldShowProgressInViewer: (state, action: PayloadAction<boolean>) => {
|
setShouldShowProgressInViewer: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldShowProgressInViewer = action.payload;
|
state.shouldShowProgressInViewer = action.payload;
|
||||||
},
|
},
|
||||||
setSelectedSchedulers: (state, action: PayloadAction<string[]>) => {
|
enabledSchedulersChanged: (
|
||||||
const selectedSchedulerData: SelectItem[] = [];
|
state,
|
||||||
|
action: PayloadAction<SchedulerParam[]>
|
||||||
let selectedSchedulers = [...action.payload];
|
) => {
|
||||||
|
if (action.payload.length === 0) {
|
||||||
if (selectedSchedulers.length === 0) {
|
state.enabledSchedulers = [DEFAULT_SCHEDULER_NAME];
|
||||||
selectedSchedulers = [SCHEDULERS[0].value];
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
selectedSchedulers.forEach((item) => {
|
state.enabledSchedulers = action.payload;
|
||||||
const schedulerData = SCHEDULERS.find(
|
|
||||||
(scheduler) => scheduler.value === item
|
|
||||||
);
|
|
||||||
|
|
||||||
if (schedulerData) {
|
|
||||||
selectedSchedulerData.push(schedulerData);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
state.activeSchedulers = selectedSchedulerData;
|
|
||||||
state.selectedSchedulers = action.payload;
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
extraReducers(builder) {
|
extraReducers(builder) {
|
||||||
@ -144,7 +132,7 @@ export const {
|
|||||||
toggleParametersPanel,
|
toggleParametersPanel,
|
||||||
toggleGalleryPanel,
|
toggleGalleryPanel,
|
||||||
setShouldShowProgressInViewer,
|
setShouldShowProgressInViewer,
|
||||||
setSelectedSchedulers,
|
enabledSchedulersChanged,
|
||||||
} = uiSlice.actions;
|
} = uiSlice.actions;
|
||||||
|
|
||||||
export default uiSlice.reducer;
|
export default uiSlice.reducer;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { SelectItem } from '@mantine/core';
|
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||||
|
|
||||||
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
|
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
|
||||||
|
|
||||||
@ -28,6 +28,5 @@ export interface UIState {
|
|||||||
shouldPinGallery: boolean;
|
shouldPinGallery: boolean;
|
||||||
shouldShowGallery: boolean;
|
shouldShowGallery: boolean;
|
||||||
shouldShowProgressInViewer: boolean;
|
shouldShowProgressInViewer: boolean;
|
||||||
activeSchedulers: SelectItem[];
|
enabledSchedulers: SchedulerParam[];
|
||||||
selectedSchedulers: string[];
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user