add modelSelected middleware to clear submodels on base_model change

This commit is contained in:
Mary Hipp 2023-07-06 15:07:47 -04:00 committed by psychedelicious
parent b9a1aa38e3
commit a9a4081f51
8 changed files with 59 additions and 8 deletions

View File

@ -86,6 +86,7 @@ import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesD
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch'; import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
import { addImageDroppedListener } from './listeners/imageDropped'; import { addImageDroppedListener } from './listeners/imageDropped';
import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected'; import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected';
import { addModelSelectedListener } from './listeners/modelSelected';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -220,3 +221,6 @@ addSelectionAddedToBatchListener();
// DND // DND
addImageDroppedListener(); addImageDroppedListener();
// Models
addModelSelectedListener();

View File

@ -0,0 +1,35 @@
import {
modelChanged,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..';
import { modelSelected } from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster';
import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
export const addModelSelectedListener = () => {
startAppListening({
actionCreator: modelSelected,
effect: (action, { getState, dispatch }) => {
const state = getState();
const [base_model, type, name] = action.payload.split('/');
if (state.generation.model?.base_model !== base_model) {
dispatch(
addToast(
makeToast({
title: 'Base model changed, clearing submodels',
status: 'warning',
})
)
);
dispatch(vaeSelected('auto'));
dispatch(lorasCleared());
// TODO: controlnet cleared
}
dispatch(modelChanged({ id: action.payload, base_model, name, type }));
},
});
};

View File

@ -48,7 +48,7 @@ const ParamLoraSelect = () => {
data.push({ data.push({
value: id, value: id,
label: lora.name, label: lora.name,
description: 'This is a lora', description: lora.description,
...(currentMainModel?.base_model !== lora.base_model ...(currentMainModel?.base_model !== lora.base_model
? { disabled: true, tooltip: 'Incompatible base model' } ? { disabled: true, tooltip: 'Incompatible base model' }
: {}), : {}),

View File

@ -31,6 +31,9 @@ export const loraSlice = createSlice({
const id = action.payload; const id = action.payload;
delete state.loras[id]; delete state.loras[id];
}, },
lorasCleared: (state, action: PayloadAction<>) => {
state.loras = {};
},
loraWeightChanged: ( loraWeightChanged: (
state, state,
action: PayloadAction<{ id: string; weight: number }> action: PayloadAction<{ id: string; weight: number }>
@ -45,7 +48,12 @@ export const loraSlice = createSlice({
}, },
}); });
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } = export const {
loraSlice.actions; loraAdded,
loraRemoved,
loraWeightChanged,
loraWeightReset,
lorasCleared,
} = loraSlice.actions;
export default loraSlice.reducer; export default loraSlice.reducer;

View File

@ -2,7 +2,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { import {
modelSelected,
setCfgScale, setCfgScale,
setHeight, setHeight,
setImg2imgStrength, setImg2imgStrength,
@ -14,7 +13,7 @@ import {
setWidth, setWidth,
} from '../store/generationSlice'; } from '../store/generationSlice';
import { isImageField } from 'services/api/guards'; import { isImageField } from 'services/api/guards';
import { initialImageSelected } from '../store/actions'; import { initialImageSelected, modelSelected } from '../store/actions';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { import {
@ -163,7 +162,7 @@ export const useRecallParameters = () => {
parameterNotSetToast(); parameterNotSetToast();
return; return;
} }
dispatch(modelSelected(model)); dispatch(modelSelected(model?.id || ''));
parameterSetToast(); parameterSetToast();
}, },
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]

View File

@ -4,3 +4,5 @@ import { ImageDTO } from 'services/api/types';
export const initialImageSelected = createAction<ImageDTO | string | undefined>( export const initialImageSelected = createAction<ImageDTO | string | undefined>(
'generation/initialImageSelected' 'generation/initialImageSelected'
); );
export const modelSelected = createAction<string>('generation/modelSelected');

View File

@ -223,6 +223,9 @@ export const generationSlice = createSlice({
state.model = { id: action.payload, base_model, name, type }; state.model = { id: action.payload, base_model, name, type };
}, },
modelChanged: (state, action: PayloadAction<ModelParam>) => {
state.model = action.payload;
},
vaeSelected: (state, action: PayloadAction<string>) => { vaeSelected: (state, action: PayloadAction<string>) => {
state.vae = action.payload; state.vae = action.payload;
}, },
@ -282,7 +285,7 @@ export const {
setHorizontalSymmetrySteps, setHorizontalSymmetrySteps,
setVerticalSymmetrySteps, setVerticalSymmetrySteps,
initialImageChanged, initialImageChanged,
modelSelected, modelChanged,
vaeSelected, vaeSelected,
setShouldUseNoiseSettings, setShouldUseNoiseSettings,
setSeamlessXAxis, setSeamlessXAxis,

View File

@ -3,12 +3,12 @@ import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { modelSelected } from 'features/parameters/store/generationSlice';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { forEach, isString } from 'lodash-es'; import { forEach, isString } from 'lodash-es';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { modelSelected } from '../../parameters/store/actions';
export const MODEL_TYPE_MAP = { export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x', 'sd-1': 'Stable Diffusion 1.x',