feat(ui): enabledSchedulers -> favoriteSchedulers

This commit is contained in:
psychedelicious 2023-06-18 20:01:05 +10:00
parent 450641c414
commit b96b95bc95
7 changed files with 45 additions and 79 deletions

View File

@ -547,7 +547,8 @@
"general": "General", "general": "General",
"generation": "Generation", "generation": "Generation",
"ui": "User Interface", "ui": "User Interface",
"availableSchedulers": "Available Schedulers" "favoriteSchedulers": "Favorite Schedulers",
"favoriteSchedulersPlaceholder": "No schedulers favorited"
}, },
"toast": { "toast": {
"serverError": "Server Error", "serverError": "Server Error",

View File

@ -1,4 +1,3 @@
import { SelectItem } from '@mantine/core';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
// zod needs the array to be `as const` to infer the type correctly // zod needs the array to be `as const` to infer the type correctly
@ -28,40 +27,25 @@ export const DEFAULT_SCHEDULER_NAME = 'euler';
export const SCHEDULER_NAMES: SchedulerParam[] = [...SCHEDULER_NAMES_AS_CONST]; export const SCHEDULER_NAMES: SchedulerParam[] = [...SCHEDULER_NAMES_AS_CONST];
export const SCHEDULER_SELECT_ITEMS: Record< export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {
(typeof SCHEDULER_NAMES)[number], euler: 'Euler',
SelectItem & { label: string; value: SchedulerParam; group: string } deis: 'DEIS',
> = { ddim: 'DDIM',
euler: { label: 'Euler', value: 'euler', group: 'Standard' }, ddpm: 'DDPM',
deis: { label: 'DEIS', value: 'deis', group: 'Standard' }, dpmpp_2s: 'DPM++ 2S',
ddim: { label: 'DDIM', value: 'ddim', group: 'Standard' }, dpmpp_2m: 'DPM++ 2M',
ddpm: { label: 'DDPM', value: 'ddpm', group: 'Standard' }, heun: 'Heun',
dpmpp_2s: { label: 'DPM++ 2S', value: 'dpmpp_2s', group: 'Standard' }, kdpm_2: 'KDPM 2',
dpmpp_2m: { label: 'DPM++ 2M', value: 'dpmpp_2m', group: 'Standard' }, lms: 'LMS',
heun: { label: 'Heun', value: 'heun', group: 'Standard' }, pndm: 'PNDM',
kdpm_2: { label: 'KDPM 2', value: 'kdpm_2', group: 'Standard' }, unipc: 'UniPC',
lms: { label: 'LMS', value: 'lms', group: 'Standard' }, euler_k: 'Euler Karras',
pndm: { label: 'PNDM', value: 'pndm', group: 'Standard' }, dpmpp_2s_k: 'DPM++ 2S Karras',
unipc: { label: 'UniPC', value: 'unipc', group: 'Standard' }, dpmpp_2m_k: 'DPM++ 2M Karras',
euler_k: { label: 'Euler Karras', value: 'euler_k', group: 'Karras' }, heun_k: 'Heun Karras',
dpmpp_2s_k: { lms_k: 'LMS Karras',
label: 'DPM++ 2S Karras', euler_a: 'Euler Ancestral',
value: 'dpmpp_2s_k', kdpm_2_a: 'KDPM 2 Ancestral',
group: 'Karras',
},
dpmpp_2m_k: {
label: 'DPM++ 2M Karras',
value: 'dpmpp_2m_k',
group: 'Karras',
},
heun_k: { label: 'Heun Karras', value: 'heun_k', group: 'Karras' },
lms_k: { label: 'LMS Karras', value: 'lms_k', group: 'Karras' },
euler_a: { label: 'Euler Ancestral', value: 'euler_a', group: 'Ancestral' },
kdpm_2_a: {
label: 'KDPM 2 Ancestral',
value: 'kdpm_2_a',
group: 'Ancestral',
},
}; };
export type Scheduler = (typeof SCHEDULER_NAMES)[number]; export type Scheduler = (typeof SCHEDULER_NAMES)[number];

View File

@ -1,5 +1,5 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { SCHEDULER_SELECT_ITEMS } from 'app/constants'; import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
@ -14,14 +14,15 @@ const selector = createSelector(
[uiSelector, generationSelector], [uiSelector, generationSelector],
(ui, generation) => { (ui, generation) => {
const { scheduler } = generation; const { scheduler } = generation;
const { enabledSchedulers } = ui; const { favoriteSchedulers: enabledSchedulers } = ui;
const data = enabledSchedulers const data = SCHEDULER_NAMES.map((schedulerName) => ({
.map( value: schedulerName,
(schedulerName) => label: SCHEDULER_LABEL_MAP[schedulerName as SchedulerParam],
SCHEDULER_SELECT_ITEMS[schedulerName as SchedulerParam] group: enabledSchedulers.includes(schedulerName)
) ? 'Favorites'
.sort((a, b) => a.label.localeCompare(b.label)); : undefined,
})).sort((a, b) => a.label.localeCompare(b.label));
return { return {
scheduler, scheduler,

View File

@ -17,7 +17,6 @@ import {
StrengthParam, StrengthParam,
WidthParam, WidthParam,
} from './parameterZodSchemas'; } from './parameterZodSchemas';
import { enabledSchedulersChanged } from 'features/ui/store/uiSlice';
import { DEFAULT_SCHEDULER_NAME } from 'app/constants'; import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
export interface GenerationState { export interface GenerationState {
@ -242,21 +241,6 @@ export const generationSlice = createSlice({
state.initialImage.thumbnail_url = thumbnail_url; state.initialImage.thumbnail_url = thumbnail_url;
} }
}); });
builder.addCase(enabledSchedulersChanged, (state, action) => {
const enabledSchedulers = action.payload;
if (!action.payload.length) {
// This means the user cleared the enabled schedulers multi-select. We need to set the scheduler to the default
state.scheduler = DEFAULT_SCHEDULER_NAME;
return;
}
if (!enabledSchedulers.includes(state.scheduler)) {
// The current scheduler is now disabled, change it to the first enabled one
state.scheduler = action.payload[0];
}
});
}, },
}); });

View File

@ -1,42 +1,44 @@
import { SCHEDULER_SELECT_ITEMS } from 'app/constants'; import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
import { enabledSchedulersChanged } from 'features/ui/store/uiSlice'; import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const data = map(SCHEDULER_SELECT_ITEMS).sort((a, b) => const data = map(SCHEDULER_NAMES, (s) => ({
a.label.localeCompare(b.label) value: s,
); label: SCHEDULER_LABEL_MAP[s],
})).sort((a, b) => a.label.localeCompare(b.label));
export default function SettingsSchedulers() { export default function SettingsSchedulers() {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const enabledSchedulers = useAppSelector( const enabledSchedulers = useAppSelector(
(state: RootState) => state.ui.enabledSchedulers (state: RootState) => state.ui.favoriteSchedulers
); );
const handleChange = useCallback( const handleChange = useCallback(
(v: string[]) => { (v: string[]) => {
dispatch(enabledSchedulersChanged(v as SchedulerParam[])); dispatch(favoriteSchedulersChanged(v as SchedulerParam[]));
}, },
[dispatch] [dispatch]
); );
return ( return (
<IAIMantineMultiSelect <IAIMantineMultiSelect
label={t('settings.availableSchedulers')} label={t('settings.favoriteSchedulers')}
value={enabledSchedulers} value={enabledSchedulers}
data={data} data={data}
onChange={handleChange} onChange={handleChange}
clearable clearable
searchable searchable
maxSelectedValues={99} maxSelectedValues={99}
placeholder={t('settings.favoriteSchedulersPlaceholder')}
/> />
); );
} }

View File

@ -5,7 +5,6 @@ import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName } from './tabMap'; import { InvokeTabName } from './tabMap';
import { AddNewModelType, UIState } from './uiTypes'; import { AddNewModelType, UIState } from './uiTypes';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
import { DEFAULT_SCHEDULER_NAME, SCHEDULER_NAMES } from 'app/constants';
export const initialUIState: UIState = { export const initialUIState: UIState = {
activeTab: 0, activeTab: 0,
@ -21,7 +20,7 @@ export const initialUIState: UIState = {
shouldShowGallery: true, shouldShowGallery: true,
shouldHidePreview: false, shouldHidePreview: false,
shouldShowProgressInViewer: true, shouldShowProgressInViewer: true,
enabledSchedulers: SCHEDULER_NAMES, favoriteSchedulers: [],
}; };
export const uiSlice = createSlice({ export const uiSlice = createSlice({
@ -95,16 +94,11 @@ export const uiSlice = createSlice({
setShouldShowProgressInViewer: (state, action: PayloadAction<boolean>) => { setShouldShowProgressInViewer: (state, action: PayloadAction<boolean>) => {
state.shouldShowProgressInViewer = action.payload; state.shouldShowProgressInViewer = action.payload;
}, },
enabledSchedulersChanged: ( favoriteSchedulersChanged: (
state, state,
action: PayloadAction<SchedulerParam[]> action: PayloadAction<SchedulerParam[]>
) => { ) => {
if (action.payload.length === 0) { state.favoriteSchedulers = action.payload;
state.enabledSchedulers = [DEFAULT_SCHEDULER_NAME];
return;
}
state.enabledSchedulers = action.payload;
}, },
}, },
extraReducers(builder) { extraReducers(builder) {
@ -132,7 +126,7 @@ export const {
toggleParametersPanel, toggleParametersPanel,
toggleGalleryPanel, toggleGalleryPanel,
setShouldShowProgressInViewer, setShouldShowProgressInViewer,
enabledSchedulersChanged, favoriteSchedulersChanged,
} = uiSlice.actions; } = uiSlice.actions;
export default uiSlice.reducer; export default uiSlice.reducer;

View File

@ -28,5 +28,5 @@ export interface UIState {
shouldPinGallery: boolean; shouldPinGallery: boolean;
shouldShowGallery: boolean; shouldShowGallery: boolean;
shouldShowProgressInViewer: boolean; shouldShowProgressInViewer: boolean;
enabledSchedulers: SchedulerParam[]; favoriteSchedulers: SchedulerParam[];
} }