diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index 5025ca081a..e498ecb749 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -18,7 +18,8 @@ const serializationDenylist: { gallery: galleryPersistDenylist, generation: generationPersistDenylist, lightbox: lightboxPersistDenylist, - models: modelsPersistDenylist, + sd1models: modelsPersistDenylist, + sd2models: modelsPersistDenylist, nodes: nodesPersistDenylist, postprocessing: postprocessingPersistDenylist, system: systemPersistDenylist, diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts index c6af5f3612..93cc19f832 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts @@ -7,7 +7,8 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice'; import { initialConfigState } from 'features/system/store/configSlice'; -import { initialModelsState } from 'features/system/store/modelSlice'; +import { sd1InitialModelsState } from 'features/system/store/models/sd1ModelSlice'; +import { sd2InitialModelsState } from 'features/system/store/models/sd2ModelSlice'; import { initialSystemState } from 'features/system/store/systemSlice'; import { initialHotkeysState } from 'features/ui/store/hotkeysSlice'; import { initialUIState } from 'features/ui/store/uiSlice'; @@ -21,7 +22,8 @@ const initialStates: { gallery: initialGalleryState, generation: initialGenerationState, lightbox: initialLightboxState, - models: initialModelsState, + sd1models: sd1InitialModelsState, + sd2models: sd2InitialModelsState, nodes: initialNodesState, postprocessing: initialPostprocessingState, system: initialSystemState, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index 3049d2c933..2ce1ba45e6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -1,9 +1,9 @@ -import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; import { appSocketConnected, socketConnected } from 'services/events/actions'; import { receivedPageOfImages } from 'services/thunks/image'; -import { receivedModels } from 'services/thunks/model'; +import { getModels } from 'services/thunks/model'; import { receivedOpenAPISchema } from 'services/thunks/schema'; +import { startAppListening } from '../..'; const moduleLog = log.child({ namespace: 'socketio' }); @@ -15,7 +15,7 @@ export const addSocketConnectedEventListener = () => { moduleLog.debug({ timestamp }, 'Connected'); - const { models, nodes, config, images } = getState(); + const { sd1models, sd2models, nodes, config, images } = getState(); const { disabledTabs } = config; @@ -23,8 +23,12 @@ export const addSocketConnectedEventListener = () => { dispatch(receivedPageOfImages()); } - if (!models.ids.length) { - dispatch(receivedModels()); + if (!sd1models.ids.length) { + dispatch(getModels({ baseModel: 'sd-1', modelType: 'pipeline' })); + } + + if (!sd2models.ids.length) { + dispatch(getModels({ baseModel: 'sd-2', modelType: 'pipeline' })); } if (!nodes.schema && !disabledTabs.includes('nodes')) { diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index f577b73895..7ff3fb8dc5 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -5,40 +5,44 @@ import { configureStore, } from '@reduxjs/toolkit'; -import { rememberReducer, rememberEnhancer } from 'redux-remember'; import dynamicMiddlewares from 'redux-dynamic-middlewares'; +import { rememberEnhancer, rememberReducer } from 'redux-remember'; import canvasReducer from 'features/canvas/store/canvasSlice'; +import controlNetReducer from 'features/controlNet/store/controlNetSlice'; import galleryReducer from 'features/gallery/store/gallerySlice'; import imagesReducer from 'features/gallery/store/imagesSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import generationReducer from 'features/parameters/store/generationSlice'; -import controlNetReducer from 'features/controlNet/store/controlNetSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import systemReducer from 'features/system/store/systemSlice'; // import sessionReducer from 'features/system/store/sessionSlice'; -import configReducer from 'features/system/store/configSlice'; -import uiReducer from 'features/ui/store/uiSlice'; -import hotkeysReducer from 'features/ui/store/hotkeysSlice'; -import modelsReducer from 'features/system/store/modelSlice'; import nodesReducer from 'features/nodes/store/nodesSlice'; +import configReducer from 'features/system/store/configSlice'; +import hotkeysReducer from 'features/ui/store/hotkeysSlice'; +import uiReducer from 'features/ui/store/uiSlice'; import { listenerMiddleware } from './middleware/listenerMiddleware'; import { actionSanitizer } from './middleware/devtools/actionSanitizer'; -import { stateSanitizer } from './middleware/devtools/stateSanitizer'; import { actionsDenylist } from './middleware/devtools/actionsDenylist'; +import { stateSanitizer } from './middleware/devtools/stateSanitizer'; +// Model Reducers +import sd1ModelReducer from 'features/system/store/models/sd1ModelSlice'; +import sd2ModelReducer from 'features/system/store/models/sd2ModelSlice'; + +import { LOCALSTORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; import { unserialize } from './enhancers/reduxRemember/unserialize'; -import { LOCALSTORAGE_PREFIX } from './constants'; const allReducers = { canvas: canvasReducer, gallery: galleryReducer, generation: generationReducer, lightbox: lightboxReducer, - models: modelsReducer, + sd1models: sd1ModelReducer, + sd2models: sd2ModelReducer, nodes: nodesReducer, postprocessing: postprocessingReducer, system: systemReducer, @@ -59,7 +63,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'gallery', 'generation', 'lightbox', - // 'models', 'nodes', 'postprocessing', 'system', diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index a1ef69de01..d3d37765f4 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -1,29 +1,14 @@ -import { Select } from '@chakra-ui/react'; -import { createSelector } from '@reduxjs/toolkit'; +import { NativeSelect } from '@mantine/core'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { ModelInputFieldTemplate, ModelInputFieldValue, } from 'features/nodes/types/types'; -import { selectModelsIds } from 'features/system/store/modelSlice'; -import { isEqual } from 'lodash-es'; +import { modelSelector } from 'features/system/components/ModelSelect'; import { ChangeEvent, memo } from 'react'; import { FieldComponentProps } from './types'; -const availableModelsSelector = createSelector( - [selectModelsIds], - (allModelNames) => { - return { allModelNames }; - // return map(modelList, (_, name) => name); - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); - const ModelInputFieldComponent = ( props: FieldComponentProps ) => { @@ -31,7 +16,7 @@ const ModelInputFieldComponent = ( const dispatch = useAppDispatch(); - const { allModelNames } = useAppSelector(availableModelsSelector); + const { sd1ModelData, sd2ModelData } = useAppSelector(modelSelector); const handleValueChanged = (e: ChangeEvent) => { dispatch( @@ -44,14 +29,11 @@ const ModelInputFieldComponent = ( }; return ( - + value={field.value || sd1ModelData[0].value} + data={sd1ModelData.concat(sd2ModelData)} + > ); }; diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index f516229efe..cdf26470d2 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,10 +1,11 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; -import { clamp, sortBy } from 'lodash-es'; -import { receivedModels } from 'services/thunks/model'; import { Scheduler } from 'app/constants'; -import { ImageDTO } from 'services/api'; import { configChanged } from 'features/system/store/configSlice'; +import { clamp, sortBy } from 'lodash-es'; +import { ImageDTO } from 'services/api'; +import { imageUrlsReceived } from 'services/thunks/image'; +import { getModels } from 'services/thunks/model'; import { CfgScaleParam, HeightParam, @@ -17,7 +18,6 @@ import { StrengthParam, WidthParam, } from './parameterZodSchemas'; -import { imageUrlsReceived } from 'services/thunks/image'; export interface GenerationState { cfgScale: CfgScaleParam; @@ -219,7 +219,7 @@ export const generationSlice = createSlice({ }, }, extraReducers: (builder) => { - builder.addCase(receivedModels.fulfilled, (state, action) => { + builder.addCase(getModels.fulfilled, (state, action) => { if (!state.model) { const firstModel = sortBy(action.payload, 'name')[0]; state.model = firstModel.name; diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index a38ab150dd..a65c8501dc 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -10,22 +10,42 @@ import IAIMantineSelect, { } from 'common/components/IAIMantineSelect'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { modelSelected } from 'features/parameters/store/generationSlice'; -import { selectModelsAll, selectModelsById } from '../store/modelSlice'; +import { + selectAllSD1Models, + selectByIdSD1Models, +} from '../store/models/sd1ModelSlice'; +import { + selectAllSD2Models, + selectByIdSD2Models, +} from '../store/models/sd2ModelSlice'; -const selector = createSelector( +export const modelSelector = createSelector( [(state: RootState) => state, generationSelector], (state, generation) => { - const selectedModel = selectModelsById(state, generation.model); + let selectedModel = selectByIdSD1Models(state, generation.model); + if (selectedModel === undefined) + selectedModel = selectByIdSD2Models(state, generation.model); - const modelData = selectModelsAll(state) + const sd1ModelData = selectAllSD1Models(state) .map((m) => ({ value: m.name, label: m.name, + group: '1.x Models', })) .sort((a, b) => a.label.localeCompare(b.label)); + + const sd2ModelData = selectAllSD2Models(state) + .map((m) => ({ + value: m.name, + label: m.name, + group: '2.x Models', + })) + .sort((a, b) => a.label.localeCompare(b.label)); + return { selectedModel, - modelData, + sd1ModelData, + sd2ModelData, }; }, { @@ -38,7 +58,9 @@ const selector = createSelector( const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { selectedModel, modelData } = useAppSelector(selector); + const { selectedModel, sd1ModelData, sd2ModelData } = + useAppSelector(modelSelector); + const handleChangeModel = useCallback( (v: string | null) => { if (!v) { @@ -55,7 +77,7 @@ const ModelSelect = () => { label={t('modelManager.model')} value={selectedModel?.name ?? ''} placeholder="Pick one" - data={modelData} + data={sd1ModelData.concat(sd2ModelData)} onChange={handleChangeModel} /> ); diff --git a/invokeai/frontend/web/src/features/system/store/modelSlice.ts b/invokeai/frontend/web/src/features/system/store/modelSlice.ts deleted file mode 100644 index ed38425872..0000000000 --- a/invokeai/frontend/web/src/features/system/store/modelSlice.ts +++ /dev/null @@ -1,47 +0,0 @@ -import { createEntityAdapter } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { CkptModelInfo, DiffusersModelInfo } from 'services/api'; -import { receivedModels } from 'services/thunks/model'; - -export type Model = (CkptModelInfo | DiffusersModelInfo) & { - name: string; -}; - -export const modelsAdapter = createEntityAdapter({ - selectId: (model) => model.name, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); - -export const initialModelsState = modelsAdapter.getInitialState(); - -export type ModelsState = typeof initialModelsState; - -export const modelsSlice = createSlice({ - name: 'models', - initialState: initialModelsState, - reducers: { - modelAdded: modelsAdapter.upsertOne, - }, - extraReducers(builder) { - /** - * Received Models - FULFILLED - */ - builder.addCase(receivedModels.fulfilled, (state, action) => { - const models = action.payload; - modelsAdapter.setAll(state, models); - }); - }, -}); - -export const { - selectAll: selectModelsAll, - selectById: selectModelsById, - selectEntities: selectModelsEntities, - selectIds: selectModelsIds, - selectTotal: selectModelsTotal, -} = modelsAdapter.getSelectors((state) => state.models); - -export const { modelAdded } = modelsSlice.actions; - -export default modelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/models/sd1ModelSlice.ts b/invokeai/frontend/web/src/features/system/store/models/sd1ModelSlice.ts new file mode 100644 index 0000000000..9f62fde264 --- /dev/null +++ b/invokeai/frontend/web/src/features/system/store/models/sd1ModelSlice.ts @@ -0,0 +1,53 @@ +import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { + StableDiffusion1ModelCheckpointConfig, + StableDiffusion1ModelDiffusersConfig, +} from 'services/api'; + +import { getModels } from 'services/thunks/model'; + +export type SD1ModelType = ( + | StableDiffusion1ModelCheckpointConfig + | StableDiffusion1ModelDiffusersConfig +) & { + name: string; +}; + +export const sd1ModelsAdapter = createEntityAdapter({ + selectId: (model) => model.name, + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); + +export const sd1InitialModelsState = sd1ModelsAdapter.getInitialState(); + +export type SD1ModelState = typeof sd1InitialModelsState; + +export const sd1ModelsSlice = createSlice({ + name: 'sd1models', + initialState: sd1InitialModelsState, + reducers: { + modelAdded: sd1ModelsAdapter.upsertOne, + }, + extraReducers(builder) { + /** + * Received Models - FULFILLED + */ + builder.addCase(getModels.fulfilled, (state, action) => { + if (action.meta.arg.baseModel !== 'sd-1') return; + sd1ModelsAdapter.setAll(state, action.payload); + }); + }, +}); + +export const { + selectAll: selectAllSD1Models, + selectById: selectByIdSD1Models, + selectEntities: selectEntitiesSD1Models, + selectIds: selectIdsSD1Models, + selectTotal: selectTotalSD1Models, +} = sd1ModelsAdapter.getSelectors((state) => state.sd1models); + +export const { modelAdded } = sd1ModelsSlice.actions; + +export default sd1ModelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/models/sd2ModelSlice.ts b/invokeai/frontend/web/src/features/system/store/models/sd2ModelSlice.ts new file mode 100644 index 0000000000..e8e1f5bedf --- /dev/null +++ b/invokeai/frontend/web/src/features/system/store/models/sd2ModelSlice.ts @@ -0,0 +1,53 @@ +import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { + StableDiffusion2ModelCheckpointConfig, + StableDiffusion2ModelDiffusersConfig, +} from 'services/api'; + +import { getModels } from 'services/thunks/model'; + +export type SD2ModelType = ( + | StableDiffusion2ModelCheckpointConfig + | StableDiffusion2ModelDiffusersConfig +) & { + name: string; +}; + +export const sd2ModelsAdapater = createEntityAdapter({ + selectId: (model) => model.name, + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); + +export const sd2InitialModelsState = sd2ModelsAdapater.getInitialState(); + +export type SD2ModelState = typeof sd2InitialModelsState; + +export const sd2ModelsSlice = createSlice({ + name: 'sd2models', + initialState: sd2InitialModelsState, + reducers: { + modelAdded: sd2ModelsAdapater.upsertOne, + }, + extraReducers(builder) { + /** + * Received Models - FULFILLED + */ + builder.addCase(getModels.fulfilled, (state, action) => { + if (action.meta.arg.baseModel !== 'sd-2') return; + sd2ModelsAdapater.setAll(state, action.payload); + }); + }, +}); + +export const { + selectAll: selectAllSD2Models, + selectById: selectByIdSD2Models, + selectEntities: selectEntitiesSD2Models, + selectIds: selectIdsSD2Models, + selectTotal: selectTotalSD2Models, +} = sd2ModelsAdapater.getSelectors((state) => state.sd2models); + +export const { modelAdded } = sd2ModelsSlice.actions; + +export default sd2ModelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts b/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts index aa9fb057e1..7b0d78d37e 100644 --- a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts @@ -1,6 +1,9 @@ -import { ModelsState } from './modelSlice'; +import { SD1ModelState } from './models/sd1ModelSlice'; +import { SD2ModelState } from './models/sd2ModelSlice'; /** * Models slice persist denylist */ -export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids']; +export const modelsPersistDenylist: + | (keyof SD1ModelState)[] + | (keyof SD2ModelState)[] = ['entities', 'ids']; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index b17f497f6c..7085099c31 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -1,20 +1,12 @@ import { UseToastOptions } from '@chakra-ui/react'; -import { PayloadAction } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import * as InvokeAI from 'app/types/invokeai'; -import { ProgressImage } from 'services/events/types'; -import { makeToast } from '../../../app/components/Toaster'; -import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session'; -import { receivedModels } from 'services/thunks/model'; -import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice'; -import { LogLevelName } from 'roarr'; import { InvokeLogLevel } from 'app/logging/useLogger'; -import { TFuncKey } from 'i18next'; -import { t } from 'i18next'; import { userInvoked } from 'app/store/actions'; -import { LANGUAGES } from '../components/LanguagePicker'; -import { imageUploaded } from 'services/thunks/image'; +import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice'; +import { TFuncKey, t } from 'i18next'; +import { LogLevelName } from 'roarr'; import { appSocketConnected, appSocketDisconnected, @@ -26,6 +18,12 @@ import { appSocketSubscribed, appSocketUnsubscribed, } from 'services/events/actions'; +import { ProgressImage } from 'services/events/types'; +import { imageUploaded } from 'services/thunks/image'; +import { getModels } from 'services/thunks/model'; +import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session'; +import { makeToast } from '../../../app/components/Toaster'; +import { LANGUAGES } from '../components/LanguagePicker'; export type CancelStrategy = 'immediate' | 'scheduled'; @@ -379,7 +377,7 @@ export const systemSlice = createSlice({ /** * Received available models from the backend */ - builder.addCase(receivedModels.fulfilled, (state) => { + builder.addCase(getModels.fulfilled, (state) => { state.wereModelsReceived = true; }); diff --git a/invokeai/frontend/web/src/services/thunks/model.ts b/invokeai/frontend/web/src/services/thunks/model.ts index 97f2bd8016..4d134439f7 100644 --- a/invokeai/frontend/web/src/services/thunks/model.ts +++ b/invokeai/frontend/web/src/services/thunks/model.ts @@ -1,31 +1,55 @@ import { log } from 'app/logging/useLogger'; import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { Model } from 'features/system/store/modelSlice'; +import { SD1ModelType } from 'features/system/store/models/sd1ModelSlice'; import { reduce, size } from 'lodash-es'; -import { ModelsService } from 'services/api'; +import { BaseModelType, ModelType, ModelsService } from 'services/api'; const models = log.child({ namespace: 'model' }); export const IMAGES_PER_PAGE = 20; -export const receivedModels = createAppAsyncThunk( - 'models/receivedModels', - async (_) => { - const response = await ModelsService.listModels(); +type getModelsArg = { + baseModel: BaseModelType | undefined; + modelType: ModelType | undefined; +}; - const deserializedModels = reduce( - response.models['sd-1']['pipeline'], - (modelsAccumulator, model, modelName) => { - modelsAccumulator[modelName] = { ...model, name: modelName }; +export const getModels = createAppAsyncThunk( + 'models/getModels', + async (arg: getModelsArg) => { + const response = await ModelsService.listModels(arg); - return modelsAccumulator; - }, - {} as Record - ); + let deserializedModels = {}; + + if (arg.baseModel === undefined) return response.models; + if (arg.modelType === undefined) return response.models; + + if (arg.baseModel === 'sd-1') { + deserializedModels = reduce( + response.models[arg.baseModel][arg.modelType], + (modelsAccumulator, model, modelName) => { + modelsAccumulator[modelName] = { ...model, name: modelName }; + return modelsAccumulator; + }, + {} as Record + ); + } + + if (arg.baseModel === 'sd-2') { + deserializedModels = reduce( + response.models[arg.baseModel][arg.modelType], + (modelsAccumulator, model, modelName) => { + modelsAccumulator[modelName] = { ...model, name: modelName }; + return modelsAccumulator; + }, + {} as Record + ); + } models.info( { response }, - `Received ${size(response.models['sd-1']['pipeline'])} models` + `Received ${size(response.models[arg.baseModel][arg.modelType])} ${[ + arg.baseModel, + ]} models` ); return deserializedModels;