mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'scheduler-select' of https://github.com/blessedcoolant/InvokeAI into scheduler-select
This commit is contained in:
commit
f4ca9d0e09
@ -547,7 +547,8 @@
|
||||
"general": "General",
|
||||
"generation": "Generation",
|
||||
"ui": "User Interface",
|
||||
"availableSchedulers": "Available Schedulers"
|
||||
"favoriteSchedulers": "Favorite Schedulers",
|
||||
"favoriteSchedulersPlaceholder": "No schedulers favorited"
|
||||
},
|
||||
"toast": {
|
||||
"serverError": "Server Error",
|
||||
|
@ -1,4 +1,3 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
|
||||
// zod needs the array to be `as const` to infer the type correctly
|
||||
@ -28,40 +27,25 @@ export const DEFAULT_SCHEDULER_NAME = 'euler';
|
||||
|
||||
export const SCHEDULER_NAMES: SchedulerParam[] = [...SCHEDULER_NAMES_AS_CONST];
|
||||
|
||||
export const SCHEDULER_SELECT_ITEMS: Record<
|
||||
(typeof SCHEDULER_NAMES)[number],
|
||||
SelectItem & { label: string; value: SchedulerParam; group: string }
|
||||
> = {
|
||||
euler: { label: 'Euler', value: 'euler', group: 'Standard' },
|
||||
deis: { label: 'DEIS', value: 'deis', group: 'Standard' },
|
||||
ddim: { label: 'DDIM', value: 'ddim', group: 'Standard' },
|
||||
ddpm: { label: 'DDPM', value: 'ddpm', group: 'Standard' },
|
||||
dpmpp_2s: { label: 'DPM++ 2S', value: 'dpmpp_2s', group: 'Standard' },
|
||||
dpmpp_2m: { label: 'DPM++ 2M', value: 'dpmpp_2m', group: 'Standard' },
|
||||
heun: { label: 'Heun', value: 'heun', group: 'Standard' },
|
||||
kdpm_2: { label: 'KDPM 2', value: 'kdpm_2', group: 'Standard' },
|
||||
lms: { label: 'LMS', value: 'lms', group: 'Standard' },
|
||||
pndm: { label: 'PNDM', value: 'pndm', group: 'Standard' },
|
||||
unipc: { label: 'UniPC', value: 'unipc', group: 'Standard' },
|
||||
euler_k: { label: 'Euler Karras', value: 'euler_k', group: 'Karras' },
|
||||
dpmpp_2s_k: {
|
||||
label: 'DPM++ 2S Karras',
|
||||
value: 'dpmpp_2s_k',
|
||||
group: 'Karras',
|
||||
},
|
||||
dpmpp_2m_k: {
|
||||
label: 'DPM++ 2M Karras',
|
||||
value: 'dpmpp_2m_k',
|
||||
group: 'Karras',
|
||||
},
|
||||
heun_k: { label: 'Heun Karras', value: 'heun_k', group: 'Karras' },
|
||||
lms_k: { label: 'LMS Karras', value: 'lms_k', group: 'Karras' },
|
||||
euler_a: { label: 'Euler Ancestral', value: 'euler_a', group: 'Ancestral' },
|
||||
kdpm_2_a: {
|
||||
label: 'KDPM 2 Ancestral',
|
||||
value: 'kdpm_2_a',
|
||||
group: 'Ancestral',
|
||||
},
|
||||
export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {
|
||||
euler: 'Euler',
|
||||
deis: 'DEIS',
|
||||
ddim: 'DDIM',
|
||||
ddpm: 'DDPM',
|
||||
dpmpp_2s: 'DPM++ 2S',
|
||||
dpmpp_2m: 'DPM++ 2M',
|
||||
heun: 'Heun',
|
||||
kdpm_2: 'KDPM 2',
|
||||
lms: 'LMS',
|
||||
pndm: 'PNDM',
|
||||
unipc: 'UniPC',
|
||||
euler_k: 'Euler Karras',
|
||||
dpmpp_2s_k: 'DPM++ 2S Karras',
|
||||
dpmpp_2m_k: 'DPM++ 2M Karras',
|
||||
heun_k: 'Heun Karras',
|
||||
lms_k: 'LMS Karras',
|
||||
euler_a: 'Euler Ancestral',
|
||||
kdpm_2_a: 'KDPM 2 Ancestral',
|
||||
};
|
||||
|
||||
export type Scheduler = (typeof SCHEDULER_NAMES)[number];
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { SCHEDULER_SELECT_ITEMS } from 'app/constants';
|
||||
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
@ -14,14 +14,15 @@ const selector = createSelector(
|
||||
[uiSelector, generationSelector],
|
||||
(ui, generation) => {
|
||||
const { scheduler } = generation;
|
||||
const { enabledSchedulers } = ui;
|
||||
const { favoriteSchedulers: enabledSchedulers } = ui;
|
||||
|
||||
const data = enabledSchedulers
|
||||
.map(
|
||||
(schedulerName) =>
|
||||
SCHEDULER_SELECT_ITEMS[schedulerName as SchedulerParam]
|
||||
)
|
||||
.sort((a, b) => a.label.localeCompare(b.label));
|
||||
const data = SCHEDULER_NAMES.map((schedulerName) => ({
|
||||
value: schedulerName,
|
||||
label: SCHEDULER_LABEL_MAP[schedulerName as SchedulerParam],
|
||||
group: enabledSchedulers.includes(schedulerName)
|
||||
? 'Favorites'
|
||||
: undefined,
|
||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return {
|
||||
scheduler,
|
||||
|
@ -17,7 +17,6 @@ import {
|
||||
StrengthParam,
|
||||
WidthParam,
|
||||
} from './parameterZodSchemas';
|
||||
import { enabledSchedulersChanged } from 'features/ui/store/uiSlice';
|
||||
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
|
||||
|
||||
export interface GenerationState {
|
||||
@ -242,21 +241,6 @@ export const generationSlice = createSlice({
|
||||
state.initialImage.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
});
|
||||
|
||||
builder.addCase(enabledSchedulersChanged, (state, action) => {
|
||||
const enabledSchedulers = action.payload;
|
||||
|
||||
if (!action.payload.length) {
|
||||
// This means the user cleared the enabled schedulers multi-select. We need to set the scheduler to the default
|
||||
state.scheduler = DEFAULT_SCHEDULER_NAME;
|
||||
return;
|
||||
}
|
||||
|
||||
if (!enabledSchedulers.includes(state.scheduler)) {
|
||||
// The current scheduler is now disabled, change it to the first enabled one
|
||||
state.scheduler = action.payload[0];
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -1,42 +1,44 @@
|
||||
import { SCHEDULER_SELECT_ITEMS } from 'app/constants';
|
||||
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
|
||||
import { RootState } from 'app/store/store';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { enabledSchedulersChanged } from 'features/ui/store/uiSlice';
|
||||
import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const data = map(SCHEDULER_SELECT_ITEMS).sort((a, b) =>
|
||||
a.label.localeCompare(b.label)
|
||||
);
|
||||
const data = map(SCHEDULER_NAMES, (s) => ({
|
||||
value: s,
|
||||
label: SCHEDULER_LABEL_MAP[s],
|
||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
export default function SettingsSchedulers() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const enabledSchedulers = useAppSelector(
|
||||
(state: RootState) => state.ui.enabledSchedulers
|
||||
(state: RootState) => state.ui.favoriteSchedulers
|
||||
);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string[]) => {
|
||||
dispatch(enabledSchedulersChanged(v as SchedulerParam[]));
|
||||
dispatch(favoriteSchedulersChanged(v as SchedulerParam[]));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineMultiSelect
|
||||
label={t('settings.availableSchedulers')}
|
||||
label={t('settings.favoriteSchedulers')}
|
||||
value={enabledSchedulers}
|
||||
data={data}
|
||||
onChange={handleChange}
|
||||
clearable
|
||||
searchable
|
||||
maxSelectedValues={99}
|
||||
placeholder={t('settings.favoriteSchedulersPlaceholder')}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ import { setActiveTabReducer } from './extraReducers';
|
||||
import { InvokeTabName } from './tabMap';
|
||||
import { AddNewModelType, UIState } from './uiTypes';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { DEFAULT_SCHEDULER_NAME, SCHEDULER_NAMES } from 'app/constants';
|
||||
|
||||
export const initialUIState: UIState = {
|
||||
activeTab: 0,
|
||||
@ -21,7 +20,7 @@ export const initialUIState: UIState = {
|
||||
shouldShowGallery: true,
|
||||
shouldHidePreview: false,
|
||||
shouldShowProgressInViewer: true,
|
||||
enabledSchedulers: SCHEDULER_NAMES,
|
||||
favoriteSchedulers: [],
|
||||
};
|
||||
|
||||
export const uiSlice = createSlice({
|
||||
@ -95,16 +94,11 @@ export const uiSlice = createSlice({
|
||||
setShouldShowProgressInViewer: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowProgressInViewer = action.payload;
|
||||
},
|
||||
enabledSchedulersChanged: (
|
||||
favoriteSchedulersChanged: (
|
||||
state,
|
||||
action: PayloadAction<SchedulerParam[]>
|
||||
) => {
|
||||
if (action.payload.length === 0) {
|
||||
state.enabledSchedulers = [DEFAULT_SCHEDULER_NAME];
|
||||
return;
|
||||
}
|
||||
|
||||
state.enabledSchedulers = action.payload;
|
||||
state.favoriteSchedulers = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
@ -132,7 +126,7 @@ export const {
|
||||
toggleParametersPanel,
|
||||
toggleGalleryPanel,
|
||||
setShouldShowProgressInViewer,
|
||||
enabledSchedulersChanged,
|
||||
favoriteSchedulersChanged,
|
||||
} = uiSlice.actions;
|
||||
|
||||
export default uiSlice.reducer;
|
||||
|
@ -28,5 +28,5 @@ export interface UIState {
|
||||
shouldPinGallery: boolean;
|
||||
shouldShowGallery: boolean;
|
||||
shouldShowProgressInViewer: boolean;
|
||||
enabledSchedulers: SchedulerParam[];
|
||||
favoriteSchedulers: SchedulerParam[];
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user