mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add modelSelected middleware to clear submodels on base_model change
This commit is contained in:
parent
b9a1aa38e3
commit
a9a4081f51
@ -86,6 +86,7 @@ import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesD
|
||||
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
|
||||
import { addImageDroppedListener } from './listeners/imageDropped';
|
||||
import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected';
|
||||
import { addModelSelectedListener } from './listeners/modelSelected';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -220,3 +221,6 @@ addSelectionAddedToBatchListener();
|
||||
|
||||
// DND
|
||||
addImageDroppedListener();
|
||||
|
||||
// Models
|
||||
addModelSelectedListener();
|
||||
|
@ -0,0 +1,35 @@
|
||||
import {
|
||||
modelChanged,
|
||||
vaeSelected,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { startAppListening } from '..';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
|
||||
|
||||
export const addModelSelectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: modelSelected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const [base_model, type, name] = action.payload.split('/');
|
||||
|
||||
if (state.generation.model?.base_model !== base_model) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Base model changed, clearing submodels',
|
||||
status: 'warning',
|
||||
})
|
||||
)
|
||||
);
|
||||
dispatch(vaeSelected('auto'));
|
||||
dispatch(lorasCleared());
|
||||
// TODO: controlnet cleared
|
||||
}
|
||||
|
||||
dispatch(modelChanged({ id: action.payload, base_model, name, type }));
|
||||
},
|
||||
});
|
||||
};
|
@ -48,7 +48,7 @@ const ParamLoraSelect = () => {
|
||||
data.push({
|
||||
value: id,
|
||||
label: lora.name,
|
||||
description: 'This is a lora',
|
||||
description: lora.description,
|
||||
...(currentMainModel?.base_model !== lora.base_model
|
||||
? { disabled: true, tooltip: 'Incompatible base model' }
|
||||
: {}),
|
||||
|
@ -31,6 +31,9 @@ export const loraSlice = createSlice({
|
||||
const id = action.payload;
|
||||
delete state.loras[id];
|
||||
},
|
||||
lorasCleared: (state, action: PayloadAction<>) => {
|
||||
state.loras = {};
|
||||
},
|
||||
loraWeightChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ id: string; weight: number }>
|
||||
@ -45,7 +48,12 @@ export const loraSlice = createSlice({
|
||||
},
|
||||
});
|
||||
|
||||
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } =
|
||||
loraSlice.actions;
|
||||
export const {
|
||||
loraAdded,
|
||||
loraRemoved,
|
||||
loraWeightChanged,
|
||||
loraWeightReset,
|
||||
lorasCleared,
|
||||
} = loraSlice.actions;
|
||||
|
||||
export default loraSlice.reducer;
|
||||
|
@ -2,7 +2,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
modelSelected,
|
||||
setCfgScale,
|
||||
setHeight,
|
||||
setImg2imgStrength,
|
||||
@ -14,7 +13,7 @@ import {
|
||||
setWidth,
|
||||
} from '../store/generationSlice';
|
||||
import { isImageField } from 'services/api/guards';
|
||||
import { initialImageSelected } from '../store/actions';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
@ -163,7 +162,7 @@ export const useRecallParameters = () => {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
dispatch(modelSelected(model));
|
||||
dispatch(modelSelected(model?.id || ''));
|
||||
parameterSetToast();
|
||||
},
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
|
@ -4,3 +4,5 @@ import { ImageDTO } from 'services/api/types';
|
||||
export const initialImageSelected = createAction<ImageDTO | string | undefined>(
|
||||
'generation/initialImageSelected'
|
||||
);
|
||||
|
||||
export const modelSelected = createAction<string>('generation/modelSelected');
|
||||
|
@ -223,6 +223,9 @@ export const generationSlice = createSlice({
|
||||
|
||||
state.model = { id: action.payload, base_model, name, type };
|
||||
},
|
||||
modelChanged: (state, action: PayloadAction<ModelParam>) => {
|
||||
state.model = action.payload;
|
||||
},
|
||||
vaeSelected: (state, action: PayloadAction<string>) => {
|
||||
state.vae = action.payload;
|
||||
},
|
||||
@ -282,7 +285,7 @@ export const {
|
||||
setHorizontalSymmetrySteps,
|
||||
setVerticalSymmetrySteps,
|
||||
initialImageChanged,
|
||||
modelSelected,
|
||||
modelChanged,
|
||||
vaeSelected,
|
||||
setShouldUseNoiseSettings,
|
||||
setSeamlessXAxis,
|
||||
|
@ -3,12 +3,12 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { modelSelected } from '../../parameters/store/actions';
|
||||
|
||||
export const MODEL_TYPE_MAP = {
|
||||
'sd-1': 'Stable Diffusion 1.x',
|
||||
|
Loading…
Reference in New Issue
Block a user