add modelSelected middleware to clear submodels on base_model change

This commit is contained in:
Mary Hipp 2023-07-06 15:07:47 -04:00 committed by psychedelicious
parent b9a1aa38e3
commit a9a4081f51
8 changed files with 59 additions and 8 deletions

View File

@ -86,6 +86,7 @@ import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesD
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
import { addImageDroppedListener } from './listeners/imageDropped';
import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected';
import { addModelSelectedListener } from './listeners/modelSelected';
export const listenerMiddleware = createListenerMiddleware();
@ -220,3 +221,6 @@ addSelectionAddedToBatchListener();
// DND
addImageDroppedListener();
// Models
addModelSelectedListener();

View File

@ -0,0 +1,35 @@
import {
modelChanged,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..';
import { modelSelected } from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster';
import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
export const addModelSelectedListener = () => {
startAppListening({
actionCreator: modelSelected,
effect: (action, { getState, dispatch }) => {
const state = getState();
const [base_model, type, name] = action.payload.split('/');
if (state.generation.model?.base_model !== base_model) {
dispatch(
addToast(
makeToast({
title: 'Base model changed, clearing submodels',
status: 'warning',
})
)
);
dispatch(vaeSelected('auto'));
dispatch(lorasCleared());
// TODO: controlnet cleared
}
dispatch(modelChanged({ id: action.payload, base_model, name, type }));
},
});
};

View File

@ -48,7 +48,7 @@ const ParamLoraSelect = () => {
data.push({
value: id,
label: lora.name,
description: 'This is a lora',
description: lora.description,
...(currentMainModel?.base_model !== lora.base_model
? { disabled: true, tooltip: 'Incompatible base model' }
: {}),

View File

@ -31,6 +31,9 @@ export const loraSlice = createSlice({
const id = action.payload;
delete state.loras[id];
},
lorasCleared: (state, action: PayloadAction<>) => {
state.loras = {};
},
loraWeightChanged: (
state,
action: PayloadAction<{ id: string; weight: number }>
@ -45,7 +48,12 @@ export const loraSlice = createSlice({
},
});
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } =
loraSlice.actions;
export const {
loraAdded,
loraRemoved,
loraWeightChanged,
loraWeightReset,
lorasCleared,
} = loraSlice.actions;
export default loraSlice.reducer;

View File

@ -2,7 +2,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
modelSelected,
setCfgScale,
setHeight,
setImg2imgStrength,
@ -14,7 +13,7 @@ import {
setWidth,
} from '../store/generationSlice';
import { isImageField } from 'services/api/guards';
import { initialImageSelected } from '../store/actions';
import { initialImageSelected, modelSelected } from '../store/actions';
import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api/types';
import {
@ -163,7 +162,7 @@ export const useRecallParameters = () => {
parameterNotSetToast();
return;
}
dispatch(modelSelected(model));
dispatch(modelSelected(model?.id || ''));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]

View File

@ -4,3 +4,5 @@ import { ImageDTO } from 'services/api/types';
export const initialImageSelected = createAction<ImageDTO | string | undefined>(
'generation/initialImageSelected'
);
export const modelSelected = createAction<string>('generation/modelSelected');

View File

@ -223,6 +223,9 @@ export const generationSlice = createSlice({
state.model = { id: action.payload, base_model, name, type };
},
modelChanged: (state, action: PayloadAction<ModelParam>) => {
state.model = action.payload;
},
vaeSelected: (state, action: PayloadAction<string>) => {
state.vae = action.payload;
},
@ -282,7 +285,7 @@ export const {
setHorizontalSymmetrySteps,
setVerticalSymmetrySteps,
initialImageChanged,
modelSelected,
modelChanged,
vaeSelected,
setShouldUseNoiseSettings,
setSeamlessXAxis,

View File

@ -3,12 +3,12 @@ import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { modelSelected } from 'features/parameters/store/generationSlice';
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { forEach, isString } from 'lodash-es';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { modelSelected } from '../../parameters/store/actions';
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',