mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add modelSelected middleware to clear submodels on base_model change
This commit is contained in:
parent
b9a1aa38e3
commit
a9a4081f51
@ -86,6 +86,7 @@ import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesD
|
|||||||
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
|
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
|
||||||
import { addImageDroppedListener } from './listeners/imageDropped';
|
import { addImageDroppedListener } from './listeners/imageDropped';
|
||||||
import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected';
|
import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected';
|
||||||
|
import { addModelSelectedListener } from './listeners/modelSelected';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -220,3 +221,6 @@ addSelectionAddedToBatchListener();
|
|||||||
|
|
||||||
// DND
|
// DND
|
||||||
addImageDroppedListener();
|
addImageDroppedListener();
|
||||||
|
|
||||||
|
// Models
|
||||||
|
addModelSelectedListener();
|
||||||
|
@ -0,0 +1,35 @@
|
|||||||
|
import {
|
||||||
|
modelChanged,
|
||||||
|
vaeSelected,
|
||||||
|
} from 'features/parameters/store/generationSlice';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { modelSelected } from 'features/parameters/store/actions';
|
||||||
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
|
||||||
|
|
||||||
|
export const addModelSelectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: modelSelected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const state = getState();
|
||||||
|
const [base_model, type, name] = action.payload.split('/');
|
||||||
|
|
||||||
|
if (state.generation.model?.base_model !== base_model) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Base model changed, clearing submodels',
|
||||||
|
status: 'warning',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
dispatch(vaeSelected('auto'));
|
||||||
|
dispatch(lorasCleared());
|
||||||
|
// TODO: controlnet cleared
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(modelChanged({ id: action.payload, base_model, name, type }));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -48,7 +48,7 @@ const ParamLoraSelect = () => {
|
|||||||
data.push({
|
data.push({
|
||||||
value: id,
|
value: id,
|
||||||
label: lora.name,
|
label: lora.name,
|
||||||
description: 'This is a lora',
|
description: lora.description,
|
||||||
...(currentMainModel?.base_model !== lora.base_model
|
...(currentMainModel?.base_model !== lora.base_model
|
||||||
? { disabled: true, tooltip: 'Incompatible base model' }
|
? { disabled: true, tooltip: 'Incompatible base model' }
|
||||||
: {}),
|
: {}),
|
||||||
|
@ -31,6 +31,9 @@ export const loraSlice = createSlice({
|
|||||||
const id = action.payload;
|
const id = action.payload;
|
||||||
delete state.loras[id];
|
delete state.loras[id];
|
||||||
},
|
},
|
||||||
|
lorasCleared: (state, action: PayloadAction<>) => {
|
||||||
|
state.loras = {};
|
||||||
|
},
|
||||||
loraWeightChanged: (
|
loraWeightChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ id: string; weight: number }>
|
action: PayloadAction<{ id: string; weight: number }>
|
||||||
@ -45,7 +48,12 @@ export const loraSlice = createSlice({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } =
|
export const {
|
||||||
loraSlice.actions;
|
loraAdded,
|
||||||
|
loraRemoved,
|
||||||
|
loraWeightChanged,
|
||||||
|
loraWeightReset,
|
||||||
|
lorasCleared,
|
||||||
|
} = loraSlice.actions;
|
||||||
|
|
||||||
export default loraSlice.reducer;
|
export default loraSlice.reducer;
|
||||||
|
@ -2,7 +2,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
|||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import {
|
||||||
modelSelected,
|
|
||||||
setCfgScale,
|
setCfgScale,
|
||||||
setHeight,
|
setHeight,
|
||||||
setImg2imgStrength,
|
setImg2imgStrength,
|
||||||
@ -14,7 +13,7 @@ import {
|
|||||||
setWidth,
|
setWidth,
|
||||||
} from '../store/generationSlice';
|
} from '../store/generationSlice';
|
||||||
import { isImageField } from 'services/api/guards';
|
import { isImageField } from 'services/api/guards';
|
||||||
import { initialImageSelected } from '../store/actions';
|
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
@ -163,7 +162,7 @@ export const useRecallParameters = () => {
|
|||||||
parameterNotSetToast();
|
parameterNotSetToast();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(modelSelected(model));
|
dispatch(modelSelected(model?.id || ''));
|
||||||
parameterSetToast();
|
parameterSetToast();
|
||||||
},
|
},
|
||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
|
@ -4,3 +4,5 @@ import { ImageDTO } from 'services/api/types';
|
|||||||
export const initialImageSelected = createAction<ImageDTO | string | undefined>(
|
export const initialImageSelected = createAction<ImageDTO | string | undefined>(
|
||||||
'generation/initialImageSelected'
|
'generation/initialImageSelected'
|
||||||
);
|
);
|
||||||
|
|
||||||
|
export const modelSelected = createAction<string>('generation/modelSelected');
|
||||||
|
@ -223,6 +223,9 @@ export const generationSlice = createSlice({
|
|||||||
|
|
||||||
state.model = { id: action.payload, base_model, name, type };
|
state.model = { id: action.payload, base_model, name, type };
|
||||||
},
|
},
|
||||||
|
modelChanged: (state, action: PayloadAction<ModelParam>) => {
|
||||||
|
state.model = action.payload;
|
||||||
|
},
|
||||||
vaeSelected: (state, action: PayloadAction<string>) => {
|
vaeSelected: (state, action: PayloadAction<string>) => {
|
||||||
state.vae = action.payload;
|
state.vae = action.payload;
|
||||||
},
|
},
|
||||||
@ -282,7 +285,7 @@ export const {
|
|||||||
setHorizontalSymmetrySteps,
|
setHorizontalSymmetrySteps,
|
||||||
setVerticalSymmetrySteps,
|
setVerticalSymmetrySteps,
|
||||||
initialImageChanged,
|
initialImageChanged,
|
||||||
modelSelected,
|
modelChanged,
|
||||||
vaeSelected,
|
vaeSelected,
|
||||||
setShouldUseNoiseSettings,
|
setShouldUseNoiseSettings,
|
||||||
setSeamlessXAxis,
|
setSeamlessXAxis,
|
||||||
|
@ -3,12 +3,12 @@ import { useTranslation } from 'react-i18next';
|
|||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
|
||||||
|
|
||||||
import { SelectItem } from '@mantine/core';
|
import { SelectItem } from '@mantine/core';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { forEach, isString } from 'lodash-es';
|
import { forEach, isString } from 'lodash-es';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { modelSelected } from '../../parameters/store/actions';
|
||||||
|
|
||||||
export const MODEL_TYPE_MAP = {
|
export const MODEL_TYPE_MAP = {
|
||||||
'sd-1': 'Stable Diffusion 1.x',
|
'sd-1': 'Stable Diffusion 1.x',
|
||||||
|
Loading…
Reference in New Issue
Block a user