From a9a4081f51badb0cdef7913698d0e540cb9a3d1f Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Thu, 6 Jul 2023 15:07:47 -0400 Subject: [PATCH] add modelSelected middleware to clear submodels on base_model change --- .../middleware/listenerMiddleware/index.ts | 4 +++ .../listeners/modelSelected.ts | 35 +++++++++++++++++++ .../lora/components/ParamLoraSelect.tsx | 2 +- .../web/src/features/lora/store/loraSlice.ts | 12 +++++-- .../parameters/hooks/useRecallParameters.ts | 5 ++- .../src/features/parameters/store/actions.ts | 2 ++ .../parameters/store/generationSlice.ts | 5 ++- .../system/components/ModelSelect.tsx | 2 +- 8 files changed, 59 insertions(+), 8 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 900fabfee9..59fa48a9b7 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -86,6 +86,7 @@ import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesD import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch'; import { addImageDroppedListener } from './listeners/imageDropped'; import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected'; +import { addModelSelectedListener } from './listeners/modelSelected'; export const listenerMiddleware = createListenerMiddleware(); @@ -220,3 +221,6 @@ addSelectionAddedToBatchListener(); // DND addImageDroppedListener(); + +// Models +addModelSelectedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts new file mode 100644 index 0000000000..d10a4e25e2 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -0,0 +1,35 @@ +import { + modelChanged, + vaeSelected, +} from 'features/parameters/store/generationSlice'; +import { addToast } from 'features/system/store/systemSlice'; +import { startAppListening } from '..'; +import { modelSelected } from 'features/parameters/store/actions'; +import { makeToast } from 'app/components/Toaster'; +import { lorasCleared } from '../../../../../features/lora/store/loraSlice'; + +export const addModelSelectedListener = () => { + startAppListening({ + actionCreator: modelSelected, + effect: (action, { getState, dispatch }) => { + const state = getState(); + const [base_model, type, name] = action.payload.split('/'); + + if (state.generation.model?.base_model !== base_model) { + dispatch( + addToast( + makeToast({ + title: 'Base model changed, clearing submodels', + status: 'warning', + }) + ) + ); + dispatch(vaeSelected('auto')); + dispatch(lorasCleared()); + // TODO: controlnet cleared + } + + dispatch(modelChanged({ id: action.payload, base_model, name, type })); + }, + }); +}; diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx index 4d5aa81738..b2455ed706 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx @@ -48,7 +48,7 @@ const ParamLoraSelect = () => { data.push({ value: id, label: lora.name, - description: 'This is a lora', + description: lora.description, ...(currentMainModel?.base_model !== lora.base_model ? { disabled: true, tooltip: 'Incompatible base model' } : {}), diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index 7da6018e58..bab6f2f7e1 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -31,6 +31,9 @@ export const loraSlice = createSlice({ const id = action.payload; delete state.loras[id]; }, + lorasCleared: (state, action: PayloadAction<>) => { + state.loras = {}; + }, loraWeightChanged: ( state, action: PayloadAction<{ id: string; weight: number }> @@ -45,7 +48,12 @@ export const loraSlice = createSlice({ }, }); -export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } = - loraSlice.actions; +export const { + loraAdded, + loraRemoved, + loraWeightChanged, + loraWeightReset, + lorasCleared, +} = loraSlice.actions; export default loraSlice.reducer; diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index f504c62ed6..71c054c40d 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -2,7 +2,6 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { - modelSelected, setCfgScale, setHeight, setImg2imgStrength, @@ -14,7 +13,7 @@ import { setWidth, } from '../store/generationSlice'; import { isImageField } from 'services/api/guards'; -import { initialImageSelected } from '../store/actions'; +import { initialImageSelected, modelSelected } from '../store/actions'; import { useAppToaster } from 'app/components/Toaster'; import { ImageDTO } from 'services/api/types'; import { @@ -163,7 +162,7 @@ export const useRecallParameters = () => { parameterNotSetToast(); return; } - dispatch(modelSelected(model)); + dispatch(modelSelected(model?.id || '')); parameterSetToast(); }, [dispatch, parameterSetToast, parameterNotSetToast] diff --git a/invokeai/frontend/web/src/features/parameters/store/actions.ts b/invokeai/frontend/web/src/features/parameters/store/actions.ts index 2fb56c0883..a74a2f633d 100644 --- a/invokeai/frontend/web/src/features/parameters/store/actions.ts +++ b/invokeai/frontend/web/src/features/parameters/store/actions.ts @@ -4,3 +4,5 @@ import { ImageDTO } from 'services/api/types'; export const initialImageSelected = createAction( 'generation/initialImageSelected' ); + +export const modelSelected = createAction('generation/modelSelected'); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 4f93cf43e5..83262d3aa8 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -223,6 +223,9 @@ export const generationSlice = createSlice({ state.model = { id: action.payload, base_model, name, type }; }, + modelChanged: (state, action: PayloadAction) => { + state.model = action.payload; + }, vaeSelected: (state, action: PayloadAction) => { state.vae = action.payload; }, @@ -282,7 +285,7 @@ export const { setHorizontalSymmetrySteps, setVerticalSymmetrySteps, initialImageChanged, - modelSelected, + modelChanged, vaeSelected, setShouldUseNoiseSettings, setSeamlessXAxis, diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index 4336792858..0e49cae1df 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -3,12 +3,12 @@ import { useTranslation } from 'react-i18next'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; -import { modelSelected } from 'features/parameters/store/generationSlice'; import { SelectItem } from '@mantine/core'; import { RootState } from 'app/store/store'; import { forEach, isString } from 'lodash-es'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { modelSelected } from '../../parameters/store/actions'; export const MODEL_TYPE_MAP = { 'sd-1': 'Stable Diffusion 1.x',