diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 39f38e386c..96a6070ad2 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,6 +1,7 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import { Scheduler } from 'app/constants'; +import { ModelLoaderTypes } from 'features/system/components/ModelSelect'; import { configChanged } from 'features/system/store/configSlice'; import { clamp, sortBy } from 'lodash-es'; import { ImageDTO } from 'services/api'; @@ -49,6 +50,7 @@ export interface GenerationState { horizontalSymmetrySteps: number; verticalSymmetrySteps: number; model: ModelParam; + currentModelType: ModelLoaderTypes; shouldUseSeamless: boolean; seamlessXAxis: boolean; seamlessYAxis: boolean; @@ -83,6 +85,7 @@ export const initialGenerationState: GenerationState = { horizontalSymmetrySteps: 0, verticalSymmetrySteps: 0, model: '', + currentModelType: 'sd1_model_loader', shouldUseSeamless: false, seamlessXAxis: true, seamlessYAxis: true, @@ -218,6 +221,9 @@ export const generationSlice = createSlice({ modelSelected: (state, action: PayloadAction) => { state.model = action.payload; }, + setCurrentModelType: (state, action: PayloadAction) => { + state.currentModelType = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(getModels.fulfilled, (state, action) => { @@ -278,6 +284,7 @@ export const { setVerticalSymmetrySteps, initialImageChanged, modelSelected, + setCurrentModelType, setShouldUseNoiseSettings, setSeamless, setSeamlessXAxis, diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index a65c8501dc..bf0775d52e 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -1,6 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { isEqual } from 'lodash-es'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useEffect } from 'react'; import { useTranslation } from 'react-i18next'; import { RootState } from 'app/store/store'; @@ -9,7 +9,11 @@ import IAIMantineSelect, { IAISelectDataType, } from 'common/components/IAIMantineSelect'; import { generationSelector } from 'features/parameters/store/generationSelectors'; -import { modelSelected } from 'features/parameters/store/generationSlice'; +import { + modelSelected, + setCurrentModelType, +} from 'features/parameters/store/generationSlice'; + import { selectAllSD1Models, selectByIdSD1Models, @@ -55,12 +59,28 @@ export const modelSelector = createSelector( } ); +export type ModelLoaderTypes = 'sd1_model_loader' | 'sd2_model_loader'; + +const MODEL_LOADER_MAP = { + 'sd-1': 'sd1_model_loader', + 'sd-2': 'sd2_model_loader', +}; + const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); const { selectedModel, sd1ModelData, sd2ModelData } = useAppSelector(modelSelector); + useEffect(() => { + if (selectedModel) + dispatch( + setCurrentModelType( + MODEL_LOADER_MAP[selectedModel?.base_model] as ModelLoaderTypes + ) + ); + }, [dispatch, selectedModel]); + const handleChangeModel = useCallback( (v: string | null) => { if (!v) {