Revert "feat: Port Schedulers to Mantine"

This reverts commit e0c105f413.
This commit is contained in:
blessedcoolant 2023-06-18 22:22:56 +12:00
parent 809ec7163e
commit 9fda21cf40
8 changed files with 75 additions and 178 deletions

View File

@ -1,28 +1,6 @@
import { SelectItem } from '@mantine/core';
// TODO: use Enums?
export const SCHEDULERS: SelectItem[] = [
{ label: 'euler', value: 'euler', group: 'Standard' },
{ label: 'deis', value: 'deis', group: 'Standard' },
{ label: 'ddim', value: 'ddim', group: 'Standard' },
{ label: 'ddpm', value: 'ddpm', group: 'Standard' },
{ label: 'dpmpp_2s', value: 'dpmpp_2s', group: 'Standard' },
{ label: 'dpmpp_2m', value: 'dpmpp_2m', group: 'Standard' },
{ label: 'heun', value: 'heun', group: 'Standard' },
{ label: 'kdpm_2', value: 'kdpm_2', group: 'Standard' },
{ label: 'lms', value: 'lms', group: 'Standard' },
{ label: 'pndm', value: 'pndm', group: 'Standard' },
{ label: 'unipc', value: 'unipc', group: 'Standard' },
{ label: 'euler_k', value: 'euler_k', group: 'Karras' },
{ label: 'dpmpp_2s_k', value: 'dpmpp_2s_k', group: 'Karras' },
{ label: 'dpmpp_2m_k', value: 'dpmpp_2m_k', group: 'Karras' },
{ label: 'heun_k', value: 'heun_k', group: 'Karras' },
{ label: 'lms_k', value: 'lms_k', group: 'Karras' },
{ label: 'euler_a', value: 'euler_a', group: 'Ancestral' },
{ label: 'kdpm_2_a', value: 'kdpm_2_a', group: 'Ancestral' },
];
export const SCHEDULER_ITEMS = [
export const SCHEDULERS = [
'ddim',
'lms',
'lms_k',
@ -41,9 +19,9 @@ export const SCHEDULER_ITEMS = [
'heun',
'heun_k',
'unipc',
];
] as const;
export type Scheduler = typeof SCHEDULERS;
export type Scheduler = (typeof SCHEDULERS)[number];
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [

View File

@ -1,79 +0,0 @@
import { Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { memo } from 'react';
type IAIMultiSelectProps = MultiSelectProps & {
tooltip?: string;
};
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, ...rest } = props;
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<MultiSelect
searchable={searchable}
styles={() => ({
label: {
color: 'var(--invokeai-colors-base-300)',
fontWeight: 'normal',
},
input: {
backgroundColor: 'var(--invokeai-colors-base-900)',
borderWidth: '2px',
borderColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)',
padding: 10,
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },
'&:focus': {
borderColor: 'var(--invokeai-colors-accent-600)',
},
},
value: {
backgroundColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)',
'&:hover': {
backgroundColor: 'var(--invokeai-colors-base-700)',
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: 'var(--invokeai-colors-base-800)',
borderColor: 'var(--invokeai-colors-base-700)',
},
item: {
backgroundColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-200)',
padding: 6,
'&[data-hovered]': {
color: 'var(--invokeai-colors-base-100)',
backgroundColor: 'var(--invokeai-colors-base-750)',
},
'&[data-active]': {
backgroundColor: 'var(--invokeai-colors-base-750)',
'&:hover': {
color: 'var(--invokeai-colors-base-100)',
backgroundColor: 'var(--invokeai-colors-base-750)',
},
},
'&[data-selected]': {
color: 'var(--invokeai-colors-base-50)',
backgroundColor: 'var(--invokeai-colors-accent-650)',
fontWeight: 600,
'&:hover': {
backgroundColor: 'var(--invokeai-colors-accent-600)',
},
},
},
rightSection: {
width: 24,
},
})}
{...rest}
/>
</Tooltip>
);
};
export default memo(IAIMantineMultiSelect);

View File

@ -1,44 +1,43 @@
import { SCHEDULER_ITEMS } from 'app/constants';
import { RootState } from 'app/store/store';
import { createSelector } from '@reduxjs/toolkit';
import { Scheduler } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice';
import { setSelectedSchedulers } from 'features/ui/store/uiSlice';
import { memo, useCallback, useEffect } from 'react';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
[uiSelector, generationSelector],
(ui, generation) => {
const allSchedulers: string[] = ui.schedulers
.slice()
.sort((a, b) => a.localeCompare(b));
return {
scheduler: generation.scheduler,
allSchedulers,
};
},
defaultSelectorOptions
);
const ParamScheduler = () => {
const scheduler = useAppSelector(
(state: RootState) => state.generation.scheduler
);
const selectedSchedulers = useAppSelector(
(state: RootState) => state.ui.selectedSchedulers
);
const activeSchedulers = useAppSelector(
(state: RootState) => state.ui.activeSchedulers
);
const { allSchedulers, scheduler } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
useEffect(() => {
if (selectedSchedulers.length === 0)
dispatch(setSelectedSchedulers(SCHEDULER_ITEMS));
const schedulerFound = activeSchedulers.find(
(activeSchedulers) => activeSchedulers.label === scheduler
);
if (!schedulerFound) dispatch(setScheduler(activeSchedulers[0].value));
}, [dispatch, selectedSchedulers, scheduler, activeSchedulers]);
const handleChange = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(setScheduler(v));
dispatch(setScheduler(v as Scheduler));
},
[dispatch]
);
@ -47,7 +46,7 @@ const ParamScheduler = () => {
<IAIMantineSelect
label={t('parameters.scheduler')}
value={scheduler}
data={activeSchedulers}
data={allSchedulers}
onChange={handleChange}
/>
);

View File

@ -1,5 +1,6 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { Scheduler } from 'app/constants';
import { ModelLoaderTypes } from 'features/system/components/ModelSelect';
import { configChanged } from 'features/system/store/configSlice';
import { clamp, sortBy } from 'lodash-es';
@ -135,7 +136,7 @@ export const generationSlice = createSlice({
setWidth: (state, action: PayloadAction<number>) => {
state.width = action.payload;
},
setScheduler: (state, action: PayloadAction<string>) => {
setScheduler: (state, action: PayloadAction<Scheduler>) => {
state.scheduler = action.payload;
},
setSeed: (state, action: PayloadAction<number>) => {

View File

@ -1,4 +1,4 @@
import { NUMPY_RAND_MAX } from 'app/constants';
import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants';
import { z } from 'zod';
/**
@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam =>
/**
* Zod schema for scheduler parameter
*/
export const zScheduler = z.string();
export const zScheduler = z.enum(SCHEDULERS);
/**
* Type alias for scheduler parameter, inferred from its zod schema
*/

View File

@ -1,33 +1,47 @@
import {
Menu,
MenuButton,
MenuItemOption,
MenuList,
MenuOptionGroup,
} from '@chakra-ui/react';
import { SCHEDULERS } from 'app/constants';
import { RootState } from 'app/store/store';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { setSelectedSchedulers } from 'features/ui/store/uiSlice';
import IAIButton from 'common/components/IAIButton';
import { setSchedulers } from 'features/ui/store/uiSlice';
import { isArray } from 'lodash-es';
import { useTranslation } from 'react-i18next';
export default function SettingsSchedulers() {
const schedulers = useAppSelector((state: RootState) => state.ui.schedulers);
const dispatch = useAppDispatch();
const selectedSchedulers = useAppSelector(
(state: RootState) => state.ui.selectedSchedulers
);
const { t } = useTranslation();
const schedulerSettingsHandler = (v: string[]) => {
dispatch(setSelectedSchedulers(v));
const schedulerSettingsHandler = (v: string | string[]) => {
if (isArray(v)) dispatch(setSchedulers(v.sort()));
};
return (
<IAIMantineMultiSelect
label={t('settings.availableSchedulers')}
value={selectedSchedulers}
data={SCHEDULERS}
<Menu closeOnSelect={false}>
<MenuButton as={IAIButton}>
{t('settings.availableSchedulers')}
</MenuButton>
<MenuList maxHeight={64} overflowY="scroll">
<MenuOptionGroup
value={schedulers}
type="checkbox"
onChange={schedulerSettingsHandler}
clearable
searchable
maxSelectedValues={99}
/>
>
{SCHEDULERS.map((scheduler) => (
<MenuItemOption key={scheduler} value={scheduler}>
{scheduler}
</MenuItemOption>
))}
</MenuOptionGroup>
</MenuList>
</Menu>
);
}

View File

@ -1,11 +1,10 @@
import { SelectItem } from '@mantine/core';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { SCHEDULERS } from 'app/constants';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName } from './tabMap';
import { AddNewModelType, UIState } from './uiTypes';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { SCHEDULERS } from 'app/constants';
export const initialUIState: UIState = {
activeTab: 0,
@ -21,8 +20,7 @@ export const initialUIState: UIState = {
shouldShowGallery: true,
shouldHidePreview: false,
shouldShowProgressInViewer: true,
activeSchedulers: [],
selectedSchedulers: [],
schedulers: SCHEDULERS,
};
export const uiSlice = createSlice({
@ -96,20 +94,9 @@ export const uiSlice = createSlice({
setShouldShowProgressInViewer: (state, action: PayloadAction<boolean>) => {
state.shouldShowProgressInViewer = action.payload;
},
setSelectedSchedulers: (state, action: PayloadAction<string[]>) => {
const selectedSchedulerData: SelectItem[] = [];
if (action.payload.length === 0) action.payload = [SCHEDULERS[0].value];
action.payload.forEach((item) => {
const schedulerData = SCHEDULERS.find(
(scheduler) => scheduler.value === item
);
if (schedulerData) selectedSchedulerData.push(schedulerData);
});
state.activeSchedulers = selectedSchedulerData;
state.selectedSchedulers = action.payload;
setSchedulers: (state, action: PayloadAction<string[]>) => {
state.schedulers = [];
state.schedulers = action.payload;
},
},
extraReducers(builder) {
@ -137,7 +124,7 @@ export const {
toggleParametersPanel,
toggleGalleryPanel,
setShouldShowProgressInViewer,
setSelectedSchedulers,
setSchedulers,
} = uiSlice.actions;
export default uiSlice.reducer;

View File

@ -1,5 +1,3 @@
import { SelectItem } from '@mantine/core';
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
export type Coordinates = {
@ -28,6 +26,5 @@ export interface UIState {
shouldPinGallery: boolean;
shouldShowGallery: boolean;
shouldShowProgressInViewer: boolean;
activeSchedulers: SelectItem[];
selectedSchedulers: string[];
schedulers: string[];
}