mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Port Schedulers to Mantine
This commit is contained in:
@ -1,43 +1,44 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { Scheduler } from 'app/constants';
|
||||
import { SCHEDULER_ITEMS } from 'app/constants';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect, {
|
||||
IAISelectDataType,
|
||||
} from 'common/components/IAIMantineSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { setScheduler } from 'features/parameters/store/generationSlice';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { setSelectedSchedulers } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
[uiSelector, generationSelector],
|
||||
(ui, generation) => {
|
||||
const allSchedulers: string[] = ui.schedulers
|
||||
.slice()
|
||||
.sort((a, b) => a.localeCompare(b));
|
||||
|
||||
return {
|
||||
scheduler: generation.scheduler,
|
||||
allSchedulers,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamScheduler = () => {
|
||||
const { allSchedulers, scheduler } = useAppSelector(selector);
|
||||
const scheduler = useAppSelector(
|
||||
(state: RootState) => state.generation.scheduler
|
||||
);
|
||||
|
||||
const selectedSchedulers = useAppSelector(
|
||||
(state: RootState) => state.ui.selectedSchedulers
|
||||
);
|
||||
|
||||
const activeSchedulers = useAppSelector(
|
||||
(state: RootState) => state.ui.activeSchedulers
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedSchedulers.length === 0)
|
||||
dispatch(setSelectedSchedulers(SCHEDULER_ITEMS));
|
||||
|
||||
const schedulerFound = activeSchedulers.find(
|
||||
(activeSchedulers) => activeSchedulers.label === scheduler
|
||||
);
|
||||
if (!schedulerFound) dispatch(setScheduler(activeSchedulers[0].value));
|
||||
}, [dispatch, selectedSchedulers, scheduler, activeSchedulers]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
dispatch(setScheduler(v as Scheduler));
|
||||
dispatch(setScheduler(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@ -46,7 +47,7 @@ const ParamScheduler = () => {
|
||||
<IAIMantineSelect
|
||||
label={t('parameters.scheduler')}
|
||||
value={scheduler}
|
||||
data={allSchedulers}
|
||||
data={activeSchedulers}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
|
@ -1,10 +1,10 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { clamp, sortBy } from 'lodash-es';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
import { Scheduler } from 'app/constants';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { clamp, sortBy } from 'lodash-es';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { imageUrlsReceived } from 'services/thunks/image';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
import {
|
||||
CfgScaleParam,
|
||||
HeightParam,
|
||||
@ -17,7 +17,6 @@ import {
|
||||
StrengthParam,
|
||||
WidthParam,
|
||||
} from './parameterZodSchemas';
|
||||
import { imageUrlsReceived } from 'services/thunks/image';
|
||||
|
||||
export interface GenerationState {
|
||||
cfgScale: CfgScaleParam;
|
||||
@ -133,7 +132,7 @@ export const generationSlice = createSlice({
|
||||
setWidth: (state, action: PayloadAction<number>) => {
|
||||
state.width = action.payload;
|
||||
},
|
||||
setScheduler: (state, action: PayloadAction<Scheduler>) => {
|
||||
setScheduler: (state, action: PayloadAction<string>) => {
|
||||
state.scheduler = action.payload;
|
||||
},
|
||||
setSeed: (state, action: PayloadAction<number>) => {
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants';
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import { z } from 'zod';
|
||||
|
||||
/**
|
||||
@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam =>
|
||||
/**
|
||||
* Zod schema for scheduler parameter
|
||||
*/
|
||||
export const zScheduler = z.enum(SCHEDULERS);
|
||||
export const zScheduler = z.string();
|
||||
/**
|
||||
* Type alias for scheduler parameter, inferred from its zod schema
|
||||
*/
|
||||
|
Reference in New Issue
Block a user