Revert "feat: Port Schedulers to Mantine"

This reverts commit e0c105f413.
This commit is contained in:
blessedcoolant 2023-06-18 22:22:56 +12:00
parent 809ec7163e
commit 9fda21cf40
8 changed files with 75 additions and 178 deletions

View File

@ -1,28 +1,6 @@
import { SelectItem } from '@mantine/core';
// TODO: use Enums? // TODO: use Enums?
export const SCHEDULERS: SelectItem[] = [
{ label: 'euler', value: 'euler', group: 'Standard' },
{ label: 'deis', value: 'deis', group: 'Standard' },
{ label: 'ddim', value: 'ddim', group: 'Standard' },
{ label: 'ddpm', value: 'ddpm', group: 'Standard' },
{ label: 'dpmpp_2s', value: 'dpmpp_2s', group: 'Standard' },
{ label: 'dpmpp_2m', value: 'dpmpp_2m', group: 'Standard' },
{ label: 'heun', value: 'heun', group: 'Standard' },
{ label: 'kdpm_2', value: 'kdpm_2', group: 'Standard' },
{ label: 'lms', value: 'lms', group: 'Standard' },
{ label: 'pndm', value: 'pndm', group: 'Standard' },
{ label: 'unipc', value: 'unipc', group: 'Standard' },
{ label: 'euler_k', value: 'euler_k', group: 'Karras' },
{ label: 'dpmpp_2s_k', value: 'dpmpp_2s_k', group: 'Karras' },
{ label: 'dpmpp_2m_k', value: 'dpmpp_2m_k', group: 'Karras' },
{ label: 'heun_k', value: 'heun_k', group: 'Karras' },
{ label: 'lms_k', value: 'lms_k', group: 'Karras' },
{ label: 'euler_a', value: 'euler_a', group: 'Ancestral' },
{ label: 'kdpm_2_a', value: 'kdpm_2_a', group: 'Ancestral' },
];
export const SCHEDULER_ITEMS = [ export const SCHEDULERS = [
'ddim', 'ddim',
'lms', 'lms',
'lms_k', 'lms_k',
@ -41,9 +19,9 @@ export const SCHEDULER_ITEMS = [
'heun', 'heun',
'heun_k', 'heun_k',
'unipc', 'unipc',
]; ] as const;
export type Scheduler = typeof SCHEDULERS; export type Scheduler = (typeof SCHEDULERS)[number];
// Valid upscaling levels // Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [ export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [

View File

@ -1,79 +0,0 @@
import { Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { memo } from 'react';
type IAIMultiSelectProps = MultiSelectProps & {
tooltip?: string;
};
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, ...rest } = props;
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<MultiSelect
searchable={searchable}
styles={() => ({
label: {
color: 'var(--invokeai-colors-base-300)',
fontWeight: 'normal',
},
input: {
backgroundColor: 'var(--invokeai-colors-base-900)',
borderWidth: '2px',
borderColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)',
padding: 10,
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },
'&:focus': {
borderColor: 'var(--invokeai-colors-accent-600)',
},
},
value: {
backgroundColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)',
'&:hover': {
backgroundColor: 'var(--invokeai-colors-base-700)',
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: 'var(--invokeai-colors-base-800)',
borderColor: 'var(--invokeai-colors-base-700)',
},
item: {
backgroundColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-200)',
padding: 6,
'&[data-hovered]': {
color: 'var(--invokeai-colors-base-100)',
backgroundColor: 'var(--invokeai-colors-base-750)',
},
'&[data-active]': {
backgroundColor: 'var(--invokeai-colors-base-750)',
'&:hover': {
color: 'var(--invokeai-colors-base-100)',
backgroundColor: 'var(--invokeai-colors-base-750)',
},
},
'&[data-selected]': {
color: 'var(--invokeai-colors-base-50)',
backgroundColor: 'var(--invokeai-colors-accent-650)',
fontWeight: 600,
'&:hover': {
backgroundColor: 'var(--invokeai-colors-accent-600)',
},
},
},
rightSection: {
width: 24,
},
})}
{...rest}
/>
</Tooltip>
);
};
export default memo(IAIMantineMultiSelect);

View File

@ -1,44 +1,43 @@
import { SCHEDULER_ITEMS } from 'app/constants'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { Scheduler } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice'; import { setScheduler } from 'features/parameters/store/generationSlice';
import { setSelectedSchedulers } from 'features/ui/store/uiSlice'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selector = createSelector(
[uiSelector, generationSelector],
(ui, generation) => {
const allSchedulers: string[] = ui.schedulers
.slice()
.sort((a, b) => a.localeCompare(b));
return {
scheduler: generation.scheduler,
allSchedulers,
};
},
defaultSelectorOptions
);
const ParamScheduler = () => { const ParamScheduler = () => {
const scheduler = useAppSelector( const { allSchedulers, scheduler } = useAppSelector(selector);
(state: RootState) => state.generation.scheduler
);
const selectedSchedulers = useAppSelector(
(state: RootState) => state.ui.selectedSchedulers
);
const activeSchedulers = useAppSelector(
(state: RootState) => state.ui.activeSchedulers
);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
useEffect(() => {
if (selectedSchedulers.length === 0)
dispatch(setSelectedSchedulers(SCHEDULER_ITEMS));
const schedulerFound = activeSchedulers.find(
(activeSchedulers) => activeSchedulers.label === scheduler
);
if (!schedulerFound) dispatch(setScheduler(activeSchedulers[0].value));
}, [dispatch, selectedSchedulers, scheduler, activeSchedulers]);
const handleChange = useCallback( const handleChange = useCallback(
(v: string | null) => { (v: string | null) => {
if (!v) { if (!v) {
return; return;
} }
dispatch(setScheduler(v)); dispatch(setScheduler(v as Scheduler));
}, },
[dispatch] [dispatch]
); );
@ -47,7 +46,7 @@ const ParamScheduler = () => {
<IAIMantineSelect <IAIMantineSelect
label={t('parameters.scheduler')} label={t('parameters.scheduler')}
value={scheduler} value={scheduler}
data={activeSchedulers} data={allSchedulers}
onChange={handleChange} onChange={handleChange}
/> />
); );

View File

@ -1,5 +1,6 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { Scheduler } from 'app/constants';
import { ModelLoaderTypes } from 'features/system/components/ModelSelect'; import { ModelLoaderTypes } from 'features/system/components/ModelSelect';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { clamp, sortBy } from 'lodash-es'; import { clamp, sortBy } from 'lodash-es';
@ -135,7 +136,7 @@ export const generationSlice = createSlice({
setWidth: (state, action: PayloadAction<number>) => { setWidth: (state, action: PayloadAction<number>) => {
state.width = action.payload; state.width = action.payload;
}, },
setScheduler: (state, action: PayloadAction<string>) => { setScheduler: (state, action: PayloadAction<Scheduler>) => {
state.scheduler = action.payload; state.scheduler = action.payload;
}, },
setSeed: (state, action: PayloadAction<number>) => { setSeed: (state, action: PayloadAction<number>) => {

View File

@ -1,4 +1,4 @@
import { NUMPY_RAND_MAX } from 'app/constants'; import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants';
import { z } from 'zod'; import { z } from 'zod';
/** /**
@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam =>
/** /**
* Zod schema for scheduler parameter * Zod schema for scheduler parameter
*/ */
export const zScheduler = z.string(); export const zScheduler = z.enum(SCHEDULERS);
/** /**
* Type alias for scheduler parameter, inferred from its zod schema * Type alias for scheduler parameter, inferred from its zod schema
*/ */

View File

@ -1,33 +1,47 @@
import {
Menu,
MenuButton,
MenuItemOption,
MenuList,
MenuOptionGroup,
} from '@chakra-ui/react';
import { SCHEDULERS } from 'app/constants'; import { SCHEDULERS } from 'app/constants';
import { RootState } from 'app/store/store';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import IAIButton from 'common/components/IAIButton';
import { setSelectedSchedulers } from 'features/ui/store/uiSlice'; import { setSchedulers } from 'features/ui/store/uiSlice';
import { isArray } from 'lodash-es';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
export default function SettingsSchedulers() { export default function SettingsSchedulers() {
const schedulers = useAppSelector((state: RootState) => state.ui.schedulers);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selectedSchedulers = useAppSelector(
(state: RootState) => state.ui.selectedSchedulers
);
const { t } = useTranslation(); const { t } = useTranslation();
const schedulerSettingsHandler = (v: string[]) => { const schedulerSettingsHandler = (v: string | string[]) => {
dispatch(setSelectedSchedulers(v)); if (isArray(v)) dispatch(setSchedulers(v.sort()));
}; };
return ( return (
<IAIMantineMultiSelect <Menu closeOnSelect={false}>
label={t('settings.availableSchedulers')} <MenuButton as={IAIButton}>
value={selectedSchedulers} {t('settings.availableSchedulers')}
data={SCHEDULERS} </MenuButton>
onChange={schedulerSettingsHandler} <MenuList maxHeight={64} overflowY="scroll">
clearable <MenuOptionGroup
searchable value={schedulers}
maxSelectedValues={99} type="checkbox"
/> onChange={schedulerSettingsHandler}
>
{SCHEDULERS.map((scheduler) => (
<MenuItemOption key={scheduler} value={scheduler}>
{scheduler}
</MenuItemOption>
))}
</MenuOptionGroup>
</MenuList>
</Menu>
); );
} }

View File

@ -1,11 +1,10 @@
import { SelectItem } from '@mantine/core';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { SCHEDULERS } from 'app/constants';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { setActiveTabReducer } from './extraReducers'; import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName } from './tabMap'; import { InvokeTabName } from './tabMap';
import { AddNewModelType, UIState } from './uiTypes'; import { AddNewModelType, UIState } from './uiTypes';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { SCHEDULERS } from 'app/constants';
export const initialUIState: UIState = { export const initialUIState: UIState = {
activeTab: 0, activeTab: 0,
@ -21,8 +20,7 @@ export const initialUIState: UIState = {
shouldShowGallery: true, shouldShowGallery: true,
shouldHidePreview: false, shouldHidePreview: false,
shouldShowProgressInViewer: true, shouldShowProgressInViewer: true,
activeSchedulers: [], schedulers: SCHEDULERS,
selectedSchedulers: [],
}; };
export const uiSlice = createSlice({ export const uiSlice = createSlice({
@ -96,20 +94,9 @@ export const uiSlice = createSlice({
setShouldShowProgressInViewer: (state, action: PayloadAction<boolean>) => { setShouldShowProgressInViewer: (state, action: PayloadAction<boolean>) => {
state.shouldShowProgressInViewer = action.payload; state.shouldShowProgressInViewer = action.payload;
}, },
setSelectedSchedulers: (state, action: PayloadAction<string[]>) => { setSchedulers: (state, action: PayloadAction<string[]>) => {
const selectedSchedulerData: SelectItem[] = []; state.schedulers = [];
state.schedulers = action.payload;
if (action.payload.length === 0) action.payload = [SCHEDULERS[0].value];
action.payload.forEach((item) => {
const schedulerData = SCHEDULERS.find(
(scheduler) => scheduler.value === item
);
if (schedulerData) selectedSchedulerData.push(schedulerData);
});
state.activeSchedulers = selectedSchedulerData;
state.selectedSchedulers = action.payload;
}, },
}, },
extraReducers(builder) { extraReducers(builder) {
@ -137,7 +124,7 @@ export const {
toggleParametersPanel, toggleParametersPanel,
toggleGalleryPanel, toggleGalleryPanel,
setShouldShowProgressInViewer, setShouldShowProgressInViewer,
setSelectedSchedulers, setSchedulers,
} = uiSlice.actions; } = uiSlice.actions;
export default uiSlice.reducer; export default uiSlice.reducer;

View File

@ -1,5 +1,3 @@
import { SelectItem } from '@mantine/core';
export type AddNewModelType = 'ckpt' | 'diffusers' | null; export type AddNewModelType = 'ckpt' | 'diffusers' | null;
export type Coordinates = { export type Coordinates = {
@ -28,6 +26,5 @@ export interface UIState {
shouldPinGallery: boolean; shouldPinGallery: boolean;
shouldShowGallery: boolean; shouldShowGallery: boolean;
shouldShowProgressInViewer: boolean; shouldShowProgressInViewer: boolean;
activeSchedulers: SelectItem[]; schedulers: string[];
selectedSchedulers: string[];
} }