mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): improve scheduler selection logic
- remove UI-specific state (the enabled schedulers) from redux, instead derive it in a selector - simplify logic by putting schedulers in an object instead of an array - rename `activeSchedulers` to `enabledSchedulers` - remove need for `useEffect()` when `enabledSchedulers` changes by adding a listener for the `enabledSchedulersChanged` action/event to `generationSlice` - increase type safety by making `enabledSchedulers` an array of `SchedulerParam`, which is created by the zod schema for scheduler
This commit is contained in:
@ -1,48 +1,47 @@
|
||||
import { SCHEDULER_NAMES, Scheduler } from 'app/constants';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { SCHEDULER_SELECT_ITEMS } from 'app/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { setScheduler } from 'features/parameters/store/generationSlice';
|
||||
import { setSelectedSchedulers } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useEffect } from 'react';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
[uiSelector, generationSelector],
|
||||
(ui, generation) => {
|
||||
const { scheduler } = generation;
|
||||
const { enabledSchedulers } = ui;
|
||||
|
||||
const data = enabledSchedulers
|
||||
.map(
|
||||
(schedulerName) =>
|
||||
SCHEDULER_SELECT_ITEMS[schedulerName as SchedulerParam]
|
||||
)
|
||||
.sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return {
|
||||
scheduler,
|
||||
data,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamScheduler = () => {
|
||||
const scheduler = useAppSelector(
|
||||
(state: RootState) => state.generation.scheduler
|
||||
);
|
||||
|
||||
const selectedSchedulers = useAppSelector(
|
||||
(state: RootState) => state.ui.selectedSchedulers
|
||||
);
|
||||
|
||||
const activeSchedulers = useAppSelector(
|
||||
(state: RootState) => state.ui.activeSchedulers
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedSchedulers.length === 0) {
|
||||
dispatch(setSelectedSchedulers(SCHEDULER_NAMES));
|
||||
}
|
||||
|
||||
const schedulerFound = activeSchedulers.find(
|
||||
(activeSchedulers) => activeSchedulers.value === scheduler
|
||||
);
|
||||
|
||||
if (!schedulerFound) {
|
||||
dispatch(setScheduler(activeSchedulers[0].value as Scheduler));
|
||||
}
|
||||
}, [dispatch, selectedSchedulers, scheduler, activeSchedulers]);
|
||||
const { scheduler, data } = useAppSelector(selector);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
dispatch(setScheduler(v as Scheduler));
|
||||
dispatch(setScheduler(v as SchedulerParam));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@ -51,7 +50,7 @@ const ParamScheduler = () => {
|
||||
<IAIMantineSelect
|
||||
label={t('parameters.scheduler')}
|
||||
value={scheduler}
|
||||
data={activeSchedulers}
|
||||
data={data}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
|
@ -1,6 +1,5 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { Scheduler } from 'app/constants';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { clamp, sortBy } from 'lodash-es';
|
||||
import { ImageDTO } from 'services/api';
|
||||
@ -18,6 +17,8 @@ import {
|
||||
StrengthParam,
|
||||
WidthParam,
|
||||
} from './parameterZodSchemas';
|
||||
import { enabledSchedulersChanged } from 'features/ui/store/uiSlice';
|
||||
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
|
||||
|
||||
export interface GenerationState {
|
||||
cfgScale: CfgScaleParam;
|
||||
@ -63,7 +64,7 @@ export const initialGenerationState: GenerationState = {
|
||||
perlin: 0,
|
||||
positivePrompt: '',
|
||||
negativePrompt: '',
|
||||
scheduler: 'euler',
|
||||
scheduler: DEFAULT_SCHEDULER_NAME,
|
||||
seamBlur: 16,
|
||||
seamSize: 96,
|
||||
seamSteps: 30,
|
||||
@ -133,7 +134,7 @@ export const generationSlice = createSlice({
|
||||
setWidth: (state, action: PayloadAction<number>) => {
|
||||
state.width = action.payload;
|
||||
},
|
||||
setScheduler: (state, action: PayloadAction<Scheduler>) => {
|
||||
setScheduler: (state, action: PayloadAction<SchedulerParam>) => {
|
||||
state.scheduler = action.payload;
|
||||
},
|
||||
setSeed: (state, action: PayloadAction<number>) => {
|
||||
@ -241,6 +242,21 @@ export const generationSlice = createSlice({
|
||||
state.initialImage.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
});
|
||||
|
||||
builder.addCase(enabledSchedulersChanged, (state, action) => {
|
||||
const enabledSchedulers = action.payload;
|
||||
|
||||
if (!action.payload.length) {
|
||||
// This means the user cleared the enabled schedulers multi-select. We need to set the scheduler to the default
|
||||
state.scheduler = DEFAULT_SCHEDULER_NAME;
|
||||
return;
|
||||
}
|
||||
|
||||
if (!enabledSchedulers.includes(state.scheduler)) {
|
||||
// The current scheduler is now disabled, change it to the first enabled one
|
||||
state.scheduler = action.payload[0];
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
|
Reference in New Issue
Block a user