mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Revert "feat: Port Schedulers to Mantine"
This reverts commit e0c105f413
.
This commit is contained in:
parent
809ec7163e
commit
9fda21cf40
@ -1,28 +1,6 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
|
||||
// TODO: use Enums?
|
||||
export const SCHEDULERS: SelectItem[] = [
|
||||
{ label: 'euler', value: 'euler', group: 'Standard' },
|
||||
{ label: 'deis', value: 'deis', group: 'Standard' },
|
||||
{ label: 'ddim', value: 'ddim', group: 'Standard' },
|
||||
{ label: 'ddpm', value: 'ddpm', group: 'Standard' },
|
||||
{ label: 'dpmpp_2s', value: 'dpmpp_2s', group: 'Standard' },
|
||||
{ label: 'dpmpp_2m', value: 'dpmpp_2m', group: 'Standard' },
|
||||
{ label: 'heun', value: 'heun', group: 'Standard' },
|
||||
{ label: 'kdpm_2', value: 'kdpm_2', group: 'Standard' },
|
||||
{ label: 'lms', value: 'lms', group: 'Standard' },
|
||||
{ label: 'pndm', value: 'pndm', group: 'Standard' },
|
||||
{ label: 'unipc', value: 'unipc', group: 'Standard' },
|
||||
{ label: 'euler_k', value: 'euler_k', group: 'Karras' },
|
||||
{ label: 'dpmpp_2s_k', value: 'dpmpp_2s_k', group: 'Karras' },
|
||||
{ label: 'dpmpp_2m_k', value: 'dpmpp_2m_k', group: 'Karras' },
|
||||
{ label: 'heun_k', value: 'heun_k', group: 'Karras' },
|
||||
{ label: 'lms_k', value: 'lms_k', group: 'Karras' },
|
||||
{ label: 'euler_a', value: 'euler_a', group: 'Ancestral' },
|
||||
{ label: 'kdpm_2_a', value: 'kdpm_2_a', group: 'Ancestral' },
|
||||
];
|
||||
|
||||
export const SCHEDULER_ITEMS = [
|
||||
export const SCHEDULERS = [
|
||||
'ddim',
|
||||
'lms',
|
||||
'lms_k',
|
||||
@ -41,9 +19,9 @@ export const SCHEDULER_ITEMS = [
|
||||
'heun',
|
||||
'heun_k',
|
||||
'unipc',
|
||||
];
|
||||
] as const;
|
||||
|
||||
export type Scheduler = typeof SCHEDULERS;
|
||||
export type Scheduler = (typeof SCHEDULERS)[number];
|
||||
|
||||
// Valid upscaling levels
|
||||
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [
|
||||
|
@ -1,79 +0,0 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
||||
import { memo } from 'react';
|
||||
|
||||
type IAIMultiSelectProps = MultiSelectProps & {
|
||||
tooltip?: string;
|
||||
};
|
||||
|
||||
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
const { searchable = true, tooltip, ...rest } = props;
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<MultiSelect
|
||||
searchable={searchable}
|
||||
styles={() => ({
|
||||
label: {
|
||||
color: 'var(--invokeai-colors-base-300)',
|
||||
fontWeight: 'normal',
|
||||
},
|
||||
input: {
|
||||
backgroundColor: 'var(--invokeai-colors-base-900)',
|
||||
borderWidth: '2px',
|
||||
borderColor: 'var(--invokeai-colors-base-800)',
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
padding: 10,
|
||||
paddingRight: 24,
|
||||
fontWeight: 600,
|
||||
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },
|
||||
'&:focus': {
|
||||
borderColor: 'var(--invokeai-colors-accent-600)',
|
||||
},
|
||||
},
|
||||
value: {
|
||||
backgroundColor: 'var(--invokeai-colors-base-800)',
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
'&:hover': {
|
||||
backgroundColor: 'var(--invokeai-colors-base-700)',
|
||||
cursor: 'pointer',
|
||||
},
|
||||
},
|
||||
dropdown: {
|
||||
backgroundColor: 'var(--invokeai-colors-base-800)',
|
||||
borderColor: 'var(--invokeai-colors-base-700)',
|
||||
},
|
||||
item: {
|
||||
backgroundColor: 'var(--invokeai-colors-base-800)',
|
||||
color: 'var(--invokeai-colors-base-200)',
|
||||
padding: 6,
|
||||
'&[data-hovered]': {
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
backgroundColor: 'var(--invokeai-colors-base-750)',
|
||||
},
|
||||
'&[data-active]': {
|
||||
backgroundColor: 'var(--invokeai-colors-base-750)',
|
||||
'&:hover': {
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
backgroundColor: 'var(--invokeai-colors-base-750)',
|
||||
},
|
||||
},
|
||||
'&[data-selected]': {
|
||||
color: 'var(--invokeai-colors-base-50)',
|
||||
backgroundColor: 'var(--invokeai-colors-accent-650)',
|
||||
fontWeight: 600,
|
||||
'&:hover': {
|
||||
backgroundColor: 'var(--invokeai-colors-accent-600)',
|
||||
},
|
||||
},
|
||||
},
|
||||
rightSection: {
|
||||
width: 24,
|
||||
},
|
||||
})}
|
||||
{...rest}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIMantineMultiSelect);
|
@ -1,44 +1,43 @@
|
||||
import { SCHEDULER_ITEMS } from 'app/constants';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { Scheduler } from 'app/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect, {
|
||||
IAISelectDataType,
|
||||
} from 'common/components/IAIMantineSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { setScheduler } from 'features/parameters/store/generationSlice';
|
||||
import { setSelectedSchedulers } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useEffect } from 'react';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
[uiSelector, generationSelector],
|
||||
(ui, generation) => {
|
||||
const allSchedulers: string[] = ui.schedulers
|
||||
.slice()
|
||||
.sort((a, b) => a.localeCompare(b));
|
||||
|
||||
return {
|
||||
scheduler: generation.scheduler,
|
||||
allSchedulers,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamScheduler = () => {
|
||||
const scheduler = useAppSelector(
|
||||
(state: RootState) => state.generation.scheduler
|
||||
);
|
||||
|
||||
const selectedSchedulers = useAppSelector(
|
||||
(state: RootState) => state.ui.selectedSchedulers
|
||||
);
|
||||
|
||||
const activeSchedulers = useAppSelector(
|
||||
(state: RootState) => state.ui.activeSchedulers
|
||||
);
|
||||
const { allSchedulers, scheduler } = useAppSelector(selector);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedSchedulers.length === 0)
|
||||
dispatch(setSelectedSchedulers(SCHEDULER_ITEMS));
|
||||
|
||||
const schedulerFound = activeSchedulers.find(
|
||||
(activeSchedulers) => activeSchedulers.label === scheduler
|
||||
);
|
||||
if (!schedulerFound) dispatch(setScheduler(activeSchedulers[0].value));
|
||||
}, [dispatch, selectedSchedulers, scheduler, activeSchedulers]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
dispatch(setScheduler(v));
|
||||
dispatch(setScheduler(v as Scheduler));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@ -47,7 +46,7 @@ const ParamScheduler = () => {
|
||||
<IAIMantineSelect
|
||||
label={t('parameters.scheduler')}
|
||||
value={scheduler}
|
||||
data={activeSchedulers}
|
||||
data={allSchedulers}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
|
@ -1,5 +1,6 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { Scheduler } from 'app/constants';
|
||||
import { ModelLoaderTypes } from 'features/system/components/ModelSelect';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { clamp, sortBy } from 'lodash-es';
|
||||
@ -135,7 +136,7 @@ export const generationSlice = createSlice({
|
||||
setWidth: (state, action: PayloadAction<number>) => {
|
||||
state.width = action.payload;
|
||||
},
|
||||
setScheduler: (state, action: PayloadAction<string>) => {
|
||||
setScheduler: (state, action: PayloadAction<Scheduler>) => {
|
||||
state.scheduler = action.payload;
|
||||
},
|
||||
setSeed: (state, action: PayloadAction<number>) => {
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants';
|
||||
import { z } from 'zod';
|
||||
|
||||
/**
|
||||
@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam =>
|
||||
/**
|
||||
* Zod schema for scheduler parameter
|
||||
*/
|
||||
export const zScheduler = z.string();
|
||||
export const zScheduler = z.enum(SCHEDULERS);
|
||||
/**
|
||||
* Type alias for scheduler parameter, inferred from its zod schema
|
||||
*/
|
||||
|
@ -1,33 +1,47 @@
|
||||
import {
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItemOption,
|
||||
MenuList,
|
||||
MenuOptionGroup,
|
||||
} from '@chakra-ui/react';
|
||||
import { SCHEDULERS } from 'app/constants';
|
||||
import { RootState } from 'app/store/store';
|
||||
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
import { setSelectedSchedulers } from 'features/ui/store/uiSlice';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { setSchedulers } from 'features/ui/store/uiSlice';
|
||||
import { isArray } from 'lodash-es';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function SettingsSchedulers() {
|
||||
const schedulers = useAppSelector((state: RootState) => state.ui.schedulers);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selectedSchedulers = useAppSelector(
|
||||
(state: RootState) => state.ui.selectedSchedulers
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const schedulerSettingsHandler = (v: string[]) => {
|
||||
dispatch(setSelectedSchedulers(v));
|
||||
const schedulerSettingsHandler = (v: string | string[]) => {
|
||||
if (isArray(v)) dispatch(setSchedulers(v.sort()));
|
||||
};
|
||||
|
||||
return (
|
||||
<IAIMantineMultiSelect
|
||||
label={t('settings.availableSchedulers')}
|
||||
value={selectedSchedulers}
|
||||
data={SCHEDULERS}
|
||||
<Menu closeOnSelect={false}>
|
||||
<MenuButton as={IAIButton}>
|
||||
{t('settings.availableSchedulers')}
|
||||
</MenuButton>
|
||||
<MenuList maxHeight={64} overflowY="scroll">
|
||||
<MenuOptionGroup
|
||||
value={schedulers}
|
||||
type="checkbox"
|
||||
onChange={schedulerSettingsHandler}
|
||||
clearable
|
||||
searchable
|
||||
maxSelectedValues={99}
|
||||
/>
|
||||
>
|
||||
{SCHEDULERS.map((scheduler) => (
|
||||
<MenuItemOption key={scheduler} value={scheduler}>
|
||||
{scheduler}
|
||||
</MenuItemOption>
|
||||
))}
|
||||
</MenuOptionGroup>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
}
|
||||
|
@ -1,11 +1,10 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { SCHEDULERS } from 'app/constants';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { setActiveTabReducer } from './extraReducers';
|
||||
import { InvokeTabName } from './tabMap';
|
||||
import { AddNewModelType, UIState } from './uiTypes';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { SCHEDULERS } from 'app/constants';
|
||||
|
||||
export const initialUIState: UIState = {
|
||||
activeTab: 0,
|
||||
@ -21,8 +20,7 @@ export const initialUIState: UIState = {
|
||||
shouldShowGallery: true,
|
||||
shouldHidePreview: false,
|
||||
shouldShowProgressInViewer: true,
|
||||
activeSchedulers: [],
|
||||
selectedSchedulers: [],
|
||||
schedulers: SCHEDULERS,
|
||||
};
|
||||
|
||||
export const uiSlice = createSlice({
|
||||
@ -96,20 +94,9 @@ export const uiSlice = createSlice({
|
||||
setShouldShowProgressInViewer: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowProgressInViewer = action.payload;
|
||||
},
|
||||
setSelectedSchedulers: (state, action: PayloadAction<string[]>) => {
|
||||
const selectedSchedulerData: SelectItem[] = [];
|
||||
|
||||
if (action.payload.length === 0) action.payload = [SCHEDULERS[0].value];
|
||||
|
||||
action.payload.forEach((item) => {
|
||||
const schedulerData = SCHEDULERS.find(
|
||||
(scheduler) => scheduler.value === item
|
||||
);
|
||||
if (schedulerData) selectedSchedulerData.push(schedulerData);
|
||||
});
|
||||
|
||||
state.activeSchedulers = selectedSchedulerData;
|
||||
state.selectedSchedulers = action.payload;
|
||||
setSchedulers: (state, action: PayloadAction<string[]>) => {
|
||||
state.schedulers = [];
|
||||
state.schedulers = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
@ -137,7 +124,7 @@ export const {
|
||||
toggleParametersPanel,
|
||||
toggleGalleryPanel,
|
||||
setShouldShowProgressInViewer,
|
||||
setSelectedSchedulers,
|
||||
setSchedulers,
|
||||
} = uiSlice.actions;
|
||||
|
||||
export default uiSlice.reducer;
|
||||
|
@ -1,5 +1,3 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
|
||||
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
|
||||
|
||||
export type Coordinates = {
|
||||
@ -28,6 +26,5 @@ export interface UIState {
|
||||
shouldPinGallery: boolean;
|
||||
shouldShowGallery: boolean;
|
||||
shouldShowProgressInViewer: boolean;
|
||||
activeSchedulers: SelectItem[];
|
||||
selectedSchedulers: string[];
|
||||
schedulers: string[];
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user