diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts index 93cc19f832..dc1c25c015 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts @@ -7,8 +7,8 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice'; import { initialConfigState } from 'features/system/store/configSlice'; -import { sd1InitialModelsState } from 'features/system/store/models/sd1ModelSlice'; -import { sd2InitialModelsState } from 'features/system/store/models/sd2ModelSlice'; +import { sd1InitialPipelineModelsState } from 'features/system/store/models/sd1PipelineModelSlice'; +import { sd2InitialPipelineModelsState } from 'features/system/store/models/sd2PipelineModelSlice'; import { initialSystemState } from 'features/system/store/systemSlice'; import { initialHotkeysState } from 'features/ui/store/hotkeysSlice'; import { initialUIState } from 'features/ui/store/uiSlice'; @@ -22,8 +22,8 @@ const initialStates: { gallery: initialGalleryState, generation: initialGenerationState, lightbox: initialLightboxState, - sd1models: sd1InitialModelsState, - sd2models: sd2InitialModelsState, + sd1pipelinemodels: sd1InitialPipelineModelsState, + sd2pipelinemodels: sd2InitialPipelineModelsState, nodes: initialNodesState, postprocessing: initialPostprocessingState, system: initialSystemState, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index 2ce1ba45e6..14263b643b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -15,7 +15,8 @@ export const addSocketConnectedEventListener = () => { moduleLog.debug({ timestamp }, 'Connected'); - const { sd1models, sd2models, nodes, config, images } = getState(); + const { sd1pipelinemodels, sd2pipelinemodels, nodes, config, images } = + getState(); const { disabledTabs } = config; @@ -23,11 +24,11 @@ export const addSocketConnectedEventListener = () => { dispatch(receivedPageOfImages()); } - if (!sd1models.ids.length) { + if (!sd1pipelinemodels.ids.length) { dispatch(getModels({ baseModel: 'sd-1', modelType: 'pipeline' })); } - if (!sd2models.ids.length) { + if (!sd2pipelinemodels.ids.length) { dispatch(getModels({ baseModel: 'sd-2', modelType: 'pipeline' })); } diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 7ff3fb8dc5..06aa6d3535 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -29,8 +29,8 @@ import { actionsDenylist } from './middleware/devtools/actionsDenylist'; import { stateSanitizer } from './middleware/devtools/stateSanitizer'; // Model Reducers -import sd1ModelReducer from 'features/system/store/models/sd1ModelSlice'; -import sd2ModelReducer from 'features/system/store/models/sd2ModelSlice'; +import sd1PipelineModelReducer from 'features/system/store/models/sd1PipelineModelSlice'; +import sd2PipelineModelReducer from 'features/system/store/models/sd2PipelineModelSlice'; import { LOCALSTORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; @@ -41,8 +41,8 @@ const allReducers = { gallery: galleryReducer, generation: generationReducer, lightbox: lightboxReducer, - sd1models: sd1ModelReducer, - sd2models: sd2ModelReducer, + sd1pipelinemodels: sd1PipelineModelReducer, + sd2pipelinemodels: sd2PipelineModelReducer, nodes: nodesReducer, postprocessing: postprocessingReducer, system: systemReducer, diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index 3842e8da3a..480c8591bb 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -17,7 +17,7 @@ const ModelInputFieldComponent = ( const dispatch = useAppDispatch(); - const { sd1ModelDropDownData, sd2ModelDropdownData } = + const { sd1PipelineModelDropDownData, sd2PipelineModelDropdownData } = useAppSelector(modelSelector); const handleValueChanged = (e: ChangeEvent) => { @@ -33,8 +33,8 @@ const ModelInputFieldComponent = ( return ( ); }; diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index 43de144991..813bd9fb70 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -20,8 +20,11 @@ const MODEL_LOADER_MAP = { const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { selectedModel, sd1ModelDropDownData, sd2ModelDropdownData } = - useAppSelector(modelSelector); + const { + selectedModel, + sd1PipelineModelDropDownData, + sd2PipelineModelDropdownData, + } = useAppSelector(modelSelector); useEffect(() => { if (selectedModel) @@ -48,7 +51,7 @@ const ModelSelect = () => { label={t('modelManager.model')} value={selectedModel?.name ?? ''} placeholder="Pick one" - data={sd1ModelDropDownData.concat(sd2ModelDropdownData)} + data={sd1PipelineModelDropDownData.concat(sd2PipelineModelDropdownData)} onChange={handleChangeModel} /> ); diff --git a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts b/invokeai/frontend/web/src/features/system/store/modelSelectors.ts index 6e101da5f5..b63c6d256c 100644 --- a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts +++ b/invokeai/frontend/web/src/features/system/store/modelSelectors.ts @@ -3,26 +3,30 @@ import { RootState } from 'app/store/store'; import { IAISelectDataType } from 'common/components/IAIMantineSelect'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { isEqual } from 'lodash-es'; + import { - selectAllSD1Models, - selectByIdSD1Models, -} from './models/sd1ModelSlice'; + selectAllSD1PipelineModels, + selectByIdSD1PipelineModels, +} from './models/sd1PipelineModelSlice'; + import { - selectAllSD2Models, - selectByIdSD2Models, -} from './models/sd2ModelSlice'; + selectAllSD2PipelineModels, + selectByIdSD2PipelineModels, +} from './models/sd2PipelineModelSlice'; export const modelSelector = createSelector( [(state: RootState) => state, generationSelector], (state, generation) => { - let selectedModel = selectByIdSD1Models(state, generation.model); + let selectedModel = selectByIdSD1PipelineModels(state, generation.model); if (selectedModel === undefined) - selectedModel = selectByIdSD2Models(state, generation.model); + selectedModel = selectByIdSD2PipelineModels(state, generation.model); - const sd1Models = selectAllSD1Models(state); - const sd2Models = selectAllSD2Models(state); + const sd1PipelineModels = selectAllSD1PipelineModels(state); + const sd2PipelineModels = selectAllSD2PipelineModels(state); - const sd1ModelDropDownData = selectAllSD1Models(state) + const allPipelineModels = sd1PipelineModels.concat(sd2PipelineModels); + + const sd1PipelineModelDropDownData = selectAllSD1PipelineModels(state) .map((m) => ({ value: m.name, label: m.name, @@ -30,7 +34,7 @@ export const modelSelector = createSelector( })) .sort((a, b) => a.label.localeCompare(b.label)); - const sd2ModelDropdownData = selectAllSD2Models(state) + const sd2PipelineModelDropdownData = selectAllSD2PipelineModels(state) .map((m) => ({ value: m.name, label: m.name, @@ -40,10 +44,11 @@ export const modelSelector = createSelector( return { selectedModel, - sd1Models, - sd2Models, - sd1ModelDropDownData, - sd2ModelDropdownData, + allPipelineModels, + sd1PipelineModels, + sd2PipelineModels, + sd1PipelineModelDropDownData, + sd2PipelineModelDropdownData, }; }, { diff --git a/invokeai/frontend/web/src/features/system/store/models/sd1ModelSlice.ts b/invokeai/frontend/web/src/features/system/store/models/sd1ModelSlice.ts deleted file mode 100644 index 9f62fde264..0000000000 --- a/invokeai/frontend/web/src/features/system/store/models/sd1ModelSlice.ts +++ /dev/null @@ -1,53 +0,0 @@ -import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { - StableDiffusion1ModelCheckpointConfig, - StableDiffusion1ModelDiffusersConfig, -} from 'services/api'; - -import { getModels } from 'services/thunks/model'; - -export type SD1ModelType = ( - | StableDiffusion1ModelCheckpointConfig - | StableDiffusion1ModelDiffusersConfig -) & { - name: string; -}; - -export const sd1ModelsAdapter = createEntityAdapter({ - selectId: (model) => model.name, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); - -export const sd1InitialModelsState = sd1ModelsAdapter.getInitialState(); - -export type SD1ModelState = typeof sd1InitialModelsState; - -export const sd1ModelsSlice = createSlice({ - name: 'sd1models', - initialState: sd1InitialModelsState, - reducers: { - modelAdded: sd1ModelsAdapter.upsertOne, - }, - extraReducers(builder) { - /** - * Received Models - FULFILLED - */ - builder.addCase(getModels.fulfilled, (state, action) => { - if (action.meta.arg.baseModel !== 'sd-1') return; - sd1ModelsAdapter.setAll(state, action.payload); - }); - }, -}); - -export const { - selectAll: selectAllSD1Models, - selectById: selectByIdSD1Models, - selectEntities: selectEntitiesSD1Models, - selectIds: selectIdsSD1Models, - selectTotal: selectTotalSD1Models, -} = sd1ModelsAdapter.getSelectors((state) => state.sd1models); - -export const { modelAdded } = sd1ModelsSlice.actions; - -export default sd1ModelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/models/sd1PipelineModelSlice.ts b/invokeai/frontend/web/src/features/system/store/models/sd1PipelineModelSlice.ts new file mode 100644 index 0000000000..5755b14886 --- /dev/null +++ b/invokeai/frontend/web/src/features/system/store/models/sd1PipelineModelSlice.ts @@ -0,0 +1,57 @@ +import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { + StableDiffusion1ModelCheckpointConfig, + StableDiffusion1ModelDiffusersConfig, +} from 'services/api'; + +import { getModels } from 'services/thunks/model'; + +export type SD1PipelineModelType = ( + | StableDiffusion1ModelCheckpointConfig + | StableDiffusion1ModelDiffusersConfig +) & { + name: string; +}; + +export const sd1PipelineModelsAdapter = + createEntityAdapter({ + selectId: (model) => model.name, + sortComparer: (a, b) => a.name.localeCompare(b.name), + }); + +export const sd1InitialPipelineModelsState = + sd1PipelineModelsAdapter.getInitialState(); + +export type SD1PipelineModelState = typeof sd1InitialPipelineModelsState; + +export const sd1PipelineModelsSlice = createSlice({ + name: 'sd1models', + initialState: sd1InitialPipelineModelsState, + reducers: { + modelAdded: sd1PipelineModelsAdapter.upsertOne, + }, + extraReducers(builder) { + /** + * Received Models - FULFILLED + */ + builder.addCase(getModels.fulfilled, (state, action) => { + if (action.meta.arg.baseModel !== 'sd-1') return; + sd1PipelineModelsAdapter.setAll(state, action.payload); + }); + }, +}); + +export const { + selectAll: selectAllSD1PipelineModels, + selectById: selectByIdSD1PipelineModels, + selectEntities: selectEntitiesSD1PipelineModels, + selectIds: selectIdsSD1PipelineModels, + selectTotal: selectTotalSD1PipelineModels, +} = sd1PipelineModelsAdapter.getSelectors( + (state) => state.sd1pipelinemodels +); + +export const { modelAdded } = sd1PipelineModelsSlice.actions; + +export default sd1PipelineModelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/models/sd2ModelSlice.ts b/invokeai/frontend/web/src/features/system/store/models/sd2ModelSlice.ts deleted file mode 100644 index e8e1f5bedf..0000000000 --- a/invokeai/frontend/web/src/features/system/store/models/sd2ModelSlice.ts +++ /dev/null @@ -1,53 +0,0 @@ -import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { - StableDiffusion2ModelCheckpointConfig, - StableDiffusion2ModelDiffusersConfig, -} from 'services/api'; - -import { getModels } from 'services/thunks/model'; - -export type SD2ModelType = ( - | StableDiffusion2ModelCheckpointConfig - | StableDiffusion2ModelDiffusersConfig -) & { - name: string; -}; - -export const sd2ModelsAdapater = createEntityAdapter({ - selectId: (model) => model.name, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); - -export const sd2InitialModelsState = sd2ModelsAdapater.getInitialState(); - -export type SD2ModelState = typeof sd2InitialModelsState; - -export const sd2ModelsSlice = createSlice({ - name: 'sd2models', - initialState: sd2InitialModelsState, - reducers: { - modelAdded: sd2ModelsAdapater.upsertOne, - }, - extraReducers(builder) { - /** - * Received Models - FULFILLED - */ - builder.addCase(getModels.fulfilled, (state, action) => { - if (action.meta.arg.baseModel !== 'sd-2') return; - sd2ModelsAdapater.setAll(state, action.payload); - }); - }, -}); - -export const { - selectAll: selectAllSD2Models, - selectById: selectByIdSD2Models, - selectEntities: selectEntitiesSD2Models, - selectIds: selectIdsSD2Models, - selectTotal: selectTotalSD2Models, -} = sd2ModelsAdapater.getSelectors((state) => state.sd2models); - -export const { modelAdded } = sd2ModelsSlice.actions; - -export default sd2ModelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/models/sd2PipelineModelSlice.ts b/invokeai/frontend/web/src/features/system/store/models/sd2PipelineModelSlice.ts new file mode 100644 index 0000000000..0c307e23cc --- /dev/null +++ b/invokeai/frontend/web/src/features/system/store/models/sd2PipelineModelSlice.ts @@ -0,0 +1,57 @@ +import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { + StableDiffusion2ModelCheckpointConfig, + StableDiffusion2ModelDiffusersConfig, +} from 'services/api'; + +import { getModels } from 'services/thunks/model'; + +export type SD2PipelineModelType = ( + | StableDiffusion2ModelCheckpointConfig + | StableDiffusion2ModelDiffusersConfig +) & { + name: string; +}; + +export const sd2PipelineModelsAdapater = + createEntityAdapter({ + selectId: (model) => model.name, + sortComparer: (a, b) => a.name.localeCompare(b.name), + }); + +export const sd2InitialPipelineModelsState = + sd2PipelineModelsAdapater.getInitialState(); + +export type SD2PipelineModelState = typeof sd2InitialPipelineModelsState; + +export const sd2PipelineModelsSlice = createSlice({ + name: 'sd2models', + initialState: sd2InitialPipelineModelsState, + reducers: { + modelAdded: sd2PipelineModelsAdapater.upsertOne, + }, + extraReducers(builder) { + /** + * Received Models - FULFILLED + */ + builder.addCase(getModels.fulfilled, (state, action) => { + if (action.meta.arg.baseModel !== 'sd-2') return; + sd2PipelineModelsAdapater.setAll(state, action.payload); + }); + }, +}); + +export const { + selectAll: selectAllSD2PipelineModels, + selectById: selectByIdSD2PipelineModels, + selectEntities: selectEntitiesSD2PipelineModels, + selectIds: selectIdsSD2PipelineModels, + selectTotal: selectTotalSD2PipelineModels, +} = sd2PipelineModelsAdapater.getSelectors( + (state) => state.sd2pipelinemodels +); + +export const { modelAdded } = sd2PipelineModelsSlice.actions; + +export default sd2PipelineModelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts b/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts index 7b0d78d37e..417a399cf2 100644 --- a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts @@ -1,9 +1,9 @@ -import { SD1ModelState } from './models/sd1ModelSlice'; -import { SD2ModelState } from './models/sd2ModelSlice'; +import { SD1PipelineModelState } from './models/sd1PipelineModelSlice'; +import { SD2PipelineModelState } from './models/sd2PipelineModelSlice'; /** * Models slice persist denylist */ export const modelsPersistDenylist: - | (keyof SD1ModelState)[] - | (keyof SD2ModelState)[] = ['entities', 'ids']; + | (keyof SD1PipelineModelState)[] + | (keyof SD2PipelineModelState)[] = ['entities', 'ids']; diff --git a/invokeai/frontend/web/src/services/thunks/model.ts b/invokeai/frontend/web/src/services/thunks/model.ts index 4d134439f7..039748fa3f 100644 --- a/invokeai/frontend/web/src/services/thunks/model.ts +++ b/invokeai/frontend/web/src/services/thunks/model.ts @@ -1,6 +1,7 @@ import { log } from 'app/logging/useLogger'; import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { SD1ModelType } from 'features/system/store/models/sd1ModelSlice'; +import { SD1PipelineModelType } from 'features/system/store/models/sd1PipelineModelSlice'; +import { SD2PipelineModelType } from 'features/system/store/models/sd2PipelineModelSlice'; import { reduce, size } from 'lodash-es'; import { BaseModelType, ModelType, ModelsService } from 'services/api'; @@ -30,7 +31,7 @@ export const getModels = createAppAsyncThunk( modelsAccumulator[modelName] = { ...model, name: modelName }; return modelsAccumulator; }, - {} as Record + {} as Record ); } @@ -41,7 +42,7 @@ export const getModels = createAppAsyncThunk( modelsAccumulator[modelName] = { ...model, name: modelName }; return modelsAccumulator; }, - {} as Record + {} as Record ); }