feat: Port Schedulers to Mantine (#3552)

- Ports Schedulers to use IAIMantineSelect.
- Adds ability to favorite schedulers in Settings. Favorited schedulers
show up at the top of the list.
- Adds IAIMantineMultiSelect component.
- Change SettingsSchedulers component to use IAIMantineMultiSelect
instead of Chakra Menus.
This commit is contained in:
blessedcoolant 2023-06-18 22:22:03 +12:00 committed by GitHub
commit a11946f0ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 212 additions and 84 deletions

View File

@ -547,7 +547,8 @@
"general": "General", "general": "General",
"generation": "Generation", "generation": "Generation",
"ui": "User Interface", "ui": "User Interface",
"availableSchedulers": "Available Schedulers" "favoriteSchedulers": "Favorite Schedulers",
"favoriteSchedulersPlaceholder": "No schedulers favorited"
}, },
"toast": { "toast": {
"serverError": "Server Error", "serverError": "Server Error",

View File

@ -1,27 +1,54 @@
// TODO: use Enums? import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
export const SCHEDULERS = [ // zod needs the array to be `as const` to infer the type correctly
'ddim', // this is the source of the `SchedulerParam` type, which is generated by zod
'lms', export const SCHEDULER_NAMES_AS_CONST = [
'lms_k',
'euler', 'euler',
'euler_k',
'euler_a',
'dpmpp_2s',
'dpmpp_2s_k',
'dpmpp_2m',
'dpmpp_2m_k',
'kdpm_2',
'kdpm_2_a',
'deis', 'deis',
'ddim',
'ddpm', 'ddpm',
'pndm', 'dpmpp_2s',
'dpmpp_2m',
'heun', 'heun',
'heun_k', 'kdpm_2',
'lms',
'pndm',
'unipc', 'unipc',
'euler_k',
'dpmpp_2s_k',
'dpmpp_2m_k',
'heun_k',
'lms_k',
'euler_a',
'kdpm_2_a',
] as const; ] as const;
export type Scheduler = (typeof SCHEDULERS)[number]; export const DEFAULT_SCHEDULER_NAME = 'euler';
export const SCHEDULER_NAMES: SchedulerParam[] = [...SCHEDULER_NAMES_AS_CONST];
export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {
euler: 'Euler',
deis: 'DEIS',
ddim: 'DDIM',
ddpm: 'DDPM',
dpmpp_2s: 'DPM++ 2S',
dpmpp_2m: 'DPM++ 2M',
heun: 'Heun',
kdpm_2: 'KDPM 2',
lms: 'LMS',
pndm: 'PNDM',
unipc: 'UniPC',
euler_k: 'Euler Karras',
dpmpp_2s_k: 'DPM++ 2S Karras',
dpmpp_2m_k: 'DPM++ 2M Karras',
heun_k: 'Heun Karras',
lms_k: 'LMS Karras',
euler_a: 'Euler Ancestral',
kdpm_2_a: 'KDPM 2 Ancestral',
};
export type Scheduler = (typeof SCHEDULER_NAMES)[number];
// Valid upscaling levels // Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [ export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [

View File

@ -0,0 +1,94 @@
import { Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { memo } from 'react';
type IAIMultiSelectProps = MultiSelectProps & {
tooltip?: string;
};
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, ...rest } = props;
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<MultiSelect
searchable={searchable}
styles={() => ({
label: {
color: 'var(--invokeai-colors-base-300)',
fontWeight: 'normal',
},
searchInput: {
'::placeholder': {
color: 'var(--invokeai-colors-base-700)',
},
},
input: {
backgroundColor: 'var(--invokeai-colors-base-900)',
borderWidth: '2px',
borderColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)',
padding: 10,
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },
'&:focus': {
borderColor: 'var(--invokeai-colors-accent-600)',
},
'&:focus-within': {
borderColor: 'var(--invokeai-colors-accent-600)',
},
},
value: {
backgroundColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)',
button: {
color: 'var(--invokeai-colors-base-100)',
},
'&:hover': {
backgroundColor: 'var(--invokeai-colors-base-700)',
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: 'var(--invokeai-colors-base-800)',
borderColor: 'var(--invokeai-colors-base-700)',
},
item: {
backgroundColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-200)',
padding: 6,
'&[data-hovered]': {
color: 'var(--invokeai-colors-base-100)',
backgroundColor: 'var(--invokeai-colors-base-750)',
},
'&[data-active]': {
backgroundColor: 'var(--invokeai-colors-base-750)',
'&:hover': {
color: 'var(--invokeai-colors-base-100)',
backgroundColor: 'var(--invokeai-colors-base-750)',
},
},
'&[data-selected]': {
color: 'var(--invokeai-colors-base-50)',
backgroundColor: 'var(--invokeai-colors-accent-650)',
fontWeight: 600,
'&:hover': {
backgroundColor: 'var(--invokeai-colors-accent-600)',
},
},
},
rightSection: {
width: 24,
padding: 20,
button: {
color: 'var(--invokeai-colors-base-100)',
},
},
})}
{...rest}
/>
</Tooltip>
);
};
export default memo(IAIMantineMultiSelect);

View File

@ -1,12 +1,11 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { Scheduler } from 'app/constants'; import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect, { import IAIMantineSelect from 'common/components/IAIMantineSelect';
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice'; import { setScheduler } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -14,30 +13,36 @@ import { useTranslation } from 'react-i18next';
const selector = createSelector( const selector = createSelector(
[uiSelector, generationSelector], [uiSelector, generationSelector],
(ui, generation) => { (ui, generation) => {
const allSchedulers: string[] = ui.schedulers const { scheduler } = generation;
.slice() const { favoriteSchedulers: enabledSchedulers } = ui;
.sort((a, b) => a.localeCompare(b));
const data = SCHEDULER_NAMES.map((schedulerName) => ({
value: schedulerName,
label: SCHEDULER_LABEL_MAP[schedulerName as SchedulerParam],
group: enabledSchedulers.includes(schedulerName)
? 'Favorites'
: undefined,
})).sort((a, b) => a.label.localeCompare(b.label));
return { return {
scheduler: generation.scheduler, scheduler,
allSchedulers, data,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const ParamScheduler = () => { const ParamScheduler = () => {
const { allSchedulers, scheduler } = useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { scheduler, data } = useAppSelector(selector);
const handleChange = useCallback( const handleChange = useCallback(
(v: string | null) => { (v: string | null) => {
if (!v) { if (!v) {
return; return;
} }
dispatch(setScheduler(v as Scheduler)); dispatch(setScheduler(v as SchedulerParam));
}, },
[dispatch] [dispatch]
); );
@ -46,7 +51,7 @@ const ParamScheduler = () => {
<IAIMantineSelect <IAIMantineSelect
label={t('parameters.scheduler')} label={t('parameters.scheduler')}
value={scheduler} value={scheduler}
data={allSchedulers} data={data}
onChange={handleChange} onChange={handleChange}
/> />
); );

View File

@ -1,12 +1,12 @@
import { Box, Flex } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { memo } from 'react';
import ModelSelect from 'features/system/components/ModelSelect'; import ModelSelect from 'features/system/components/ModelSelect';
import { memo } from 'react';
import ParamScheduler from './ParamScheduler'; import ParamScheduler from './ParamScheduler';
const ParamSchedulerAndModel = () => { const ParamSchedulerAndModel = () => {
return ( return (
<Flex gap={3} w="full"> <Flex gap={3} w="full">
<Box w="16rem"> <Box w="20rem">
<ParamScheduler /> <ParamScheduler />
</Box> </Box>
<Box w="full"> <Box w="full">

View File

@ -1,10 +1,10 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { clamp, sortBy } from 'lodash-es';
import { receivedModels } from 'services/thunks/model';
import { Scheduler } from 'app/constants';
import { ImageDTO } from 'services/api';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { clamp, sortBy } from 'lodash-es';
import { ImageDTO } from 'services/api';
import { imageUrlsReceived } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import { import {
CfgScaleParam, CfgScaleParam,
HeightParam, HeightParam,
@ -17,7 +17,7 @@ import {
StrengthParam, StrengthParam,
WidthParam, WidthParam,
} from './parameterZodSchemas'; } from './parameterZodSchemas';
import { imageUrlsReceived } from 'services/thunks/image'; import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
export interface GenerationState { export interface GenerationState {
cfgScale: CfgScaleParam; cfgScale: CfgScaleParam;
@ -63,7 +63,7 @@ export const initialGenerationState: GenerationState = {
perlin: 0, perlin: 0,
positivePrompt: '', positivePrompt: '',
negativePrompt: '', negativePrompt: '',
scheduler: 'euler', scheduler: DEFAULT_SCHEDULER_NAME,
seamBlur: 16, seamBlur: 16,
seamSize: 96, seamSize: 96,
seamSteps: 30, seamSteps: 30,
@ -133,7 +133,7 @@ export const generationSlice = createSlice({
setWidth: (state, action: PayloadAction<number>) => { setWidth: (state, action: PayloadAction<number>) => {
state.width = action.payload; state.width = action.payload;
}, },
setScheduler: (state, action: PayloadAction<Scheduler>) => { setScheduler: (state, action: PayloadAction<SchedulerParam>) => {
state.scheduler = action.payload; state.scheduler = action.payload;
}, },
setSeed: (state, action: PayloadAction<number>) => { setSeed: (state, action: PayloadAction<number>) => {

View File

@ -1,4 +1,4 @@
import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants'; import { NUMPY_RAND_MAX, SCHEDULER_NAMES_AS_CONST } from 'app/constants';
import { z } from 'zod'; import { z } from 'zod';
/** /**
@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam =>
/** /**
* Zod schema for scheduler parameter * Zod schema for scheduler parameter
*/ */
export const zScheduler = z.enum(SCHEDULERS); export const zScheduler = z.enum(SCHEDULER_NAMES_AS_CONST);
/** /**
* Type alias for scheduler parameter, inferred from its zod schema * Type alias for scheduler parameter, inferred from its zod schema
*/ */

View File

@ -1,47 +1,44 @@
import { import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
Menu,
MenuButton,
MenuItemOption,
MenuList,
MenuOptionGroup,
} from '@chakra-ui/react';
import { SCHEDULERS } from 'app/constants';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { setSchedulers } from 'features/ui/store/uiSlice'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
import { isArray } from 'lodash-es'; import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice';
import { map } from 'lodash-es';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
export default function SettingsSchedulers() { const data = map(SCHEDULER_NAMES, (s) => ({
const schedulers = useAppSelector((state: RootState) => state.ui.schedulers); value: s,
label: SCHEDULER_LABEL_MAP[s],
})).sort((a, b) => a.label.localeCompare(b.label));
export default function SettingsSchedulers() {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const schedulerSettingsHandler = (v: string | string[]) => { const enabledSchedulers = useAppSelector(
if (isArray(v)) dispatch(setSchedulers(v.sort())); (state: RootState) => state.ui.favoriteSchedulers
}; );
const handleChange = useCallback(
(v: string[]) => {
dispatch(favoriteSchedulersChanged(v as SchedulerParam[]));
},
[dispatch]
);
return ( return (
<Menu closeOnSelect={false}> <IAIMantineMultiSelect
<MenuButton as={IAIButton}> label={t('settings.favoriteSchedulers')}
{t('settings.availableSchedulers')} value={enabledSchedulers}
</MenuButton> data={data}
<MenuList maxHeight={64} overflowY="scroll"> onChange={handleChange}
<MenuOptionGroup clearable
value={schedulers} searchable
type="checkbox" maxSelectedValues={99}
onChange={schedulerSettingsHandler} placeholder={t('settings.favoriteSchedulersPlaceholder')}
> />
{SCHEDULERS.map((scheduler) => (
<MenuItemOption key={scheduler} value={scheduler}>
{scheduler}
</MenuItemOption>
))}
</MenuOptionGroup>
</MenuList>
</Menu>
); );
} }

View File

@ -1,10 +1,10 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { setActiveTabReducer } from './extraReducers'; import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName } from './tabMap'; import { InvokeTabName } from './tabMap';
import { AddNewModelType, UIState } from './uiTypes'; import { AddNewModelType, UIState } from './uiTypes';
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
import { SCHEDULERS } from 'app/constants';
export const initialUIState: UIState = { export const initialUIState: UIState = {
activeTab: 0, activeTab: 0,
@ -20,7 +20,7 @@ export const initialUIState: UIState = {
shouldShowGallery: true, shouldShowGallery: true,
shouldHidePreview: false, shouldHidePreview: false,
shouldShowProgressInViewer: true, shouldShowProgressInViewer: true,
schedulers: SCHEDULERS, favoriteSchedulers: [],
}; };
export const uiSlice = createSlice({ export const uiSlice = createSlice({
@ -94,9 +94,11 @@ export const uiSlice = createSlice({
setShouldShowProgressInViewer: (state, action: PayloadAction<boolean>) => { setShouldShowProgressInViewer: (state, action: PayloadAction<boolean>) => {
state.shouldShowProgressInViewer = action.payload; state.shouldShowProgressInViewer = action.payload;
}, },
setSchedulers: (state, action: PayloadAction<string[]>) => { favoriteSchedulersChanged: (
state.schedulers = []; state,
state.schedulers = action.payload; action: PayloadAction<SchedulerParam[]>
) => {
state.favoriteSchedulers = action.payload;
}, },
}, },
extraReducers(builder) { extraReducers(builder) {
@ -124,7 +126,7 @@ export const {
toggleParametersPanel, toggleParametersPanel,
toggleGalleryPanel, toggleGalleryPanel,
setShouldShowProgressInViewer, setShouldShowProgressInViewer,
setSchedulers, favoriteSchedulersChanged,
} = uiSlice.actions; } = uiSlice.actions;
export default uiSlice.reducer; export default uiSlice.reducer;

View File

@ -1,3 +1,5 @@
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
export type AddNewModelType = 'ckpt' | 'diffusers' | null; export type AddNewModelType = 'ckpt' | 'diffusers' | null;
export type Coordinates = { export type Coordinates = {
@ -26,5 +28,5 @@ export interface UIState {
shouldPinGallery: boolean; shouldPinGallery: boolean;
shouldShowGallery: boolean; shouldShowGallery: boolean;
shouldShowProgressInViewer: boolean; shouldShowProgressInViewer: boolean;
schedulers: string[]; favoriteSchedulers: SchedulerParam[];
} }