wip: Add 2.x Models to the Model List

This commit is contained in:
blessedcoolant 2023-06-18 07:01:44 +12:00 committed by psychedelicious
parent e374211313
commit f8d7477c7a
13 changed files with 228 additions and 130 deletions

View File

@ -18,7 +18,8 @@ const serializationDenylist: {
gallery: galleryPersistDenylist, gallery: galleryPersistDenylist,
generation: generationPersistDenylist, generation: generationPersistDenylist,
lightbox: lightboxPersistDenylist, lightbox: lightboxPersistDenylist,
models: modelsPersistDenylist, sd1models: modelsPersistDenylist,
sd2models: modelsPersistDenylist,
nodes: nodesPersistDenylist, nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist, postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist, system: systemPersistDenylist,

View File

@ -7,7 +7,8 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice'; import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialConfigState } from 'features/system/store/configSlice'; import { initialConfigState } from 'features/system/store/configSlice';
import { initialModelsState } from 'features/system/store/modelSlice'; import { sd1InitialModelsState } from 'features/system/store/models/sd1ModelSlice';
import { sd2InitialModelsState } from 'features/system/store/models/sd2ModelSlice';
import { initialSystemState } from 'features/system/store/systemSlice'; import { initialSystemState } from 'features/system/store/systemSlice';
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice'; import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
import { initialUIState } from 'features/ui/store/uiSlice'; import { initialUIState } from 'features/ui/store/uiSlice';
@ -21,7 +22,8 @@ const initialStates: {
gallery: initialGalleryState, gallery: initialGalleryState,
generation: initialGenerationState, generation: initialGenerationState,
lightbox: initialLightboxState, lightbox: initialLightboxState,
models: initialModelsState, sd1models: sd1InitialModelsState,
sd2models: sd2InitialModelsState,
nodes: initialNodesState, nodes: initialNodesState,
postprocessing: initialPostprocessingState, postprocessing: initialPostprocessingState,
system: initialSystemState, system: initialSystemState,

View File

@ -1,9 +1,9 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions'; import { appSocketConnected, socketConnected } from 'services/events/actions';
import { receivedPageOfImages } from 'services/thunks/image'; import { receivedPageOfImages } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model'; import { getModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema'; import { receivedOpenAPISchema } from 'services/thunks/schema';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
@ -15,7 +15,7 @@ export const addSocketConnectedEventListener = () => {
moduleLog.debug({ timestamp }, 'Connected'); moduleLog.debug({ timestamp }, 'Connected');
const { models, nodes, config, images } = getState(); const { sd1models, sd2models, nodes, config, images } = getState();
const { disabledTabs } = config; const { disabledTabs } = config;
@ -28,8 +28,12 @@ export const addSocketConnectedEventListener = () => {
); );
} }
if (!models.ids.length) { if (!sd1models.ids.length) {
dispatch(receivedModels()); dispatch(getModels({ baseModel: 'sd-1', modelType: 'pipeline' }));
}
if (!sd2models.ids.length) {
dispatch(getModels({ baseModel: 'sd-2', modelType: 'pipeline' }));
} }
if (!nodes.schema && !disabledTabs.includes('nodes')) { if (!nodes.schema && !disabledTabs.includes('nodes')) {

View File

@ -5,34 +5,37 @@ import {
configureStore, configureStore,
} from '@reduxjs/toolkit'; } from '@reduxjs/toolkit';
import { rememberReducer, rememberEnhancer } from 'redux-remember';
import dynamicMiddlewares from 'redux-dynamic-middlewares'; import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import canvasReducer from 'features/canvas/store/canvasSlice'; import canvasReducer from 'features/canvas/store/canvasSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import galleryReducer from 'features/gallery/store/gallerySlice'; import galleryReducer from 'features/gallery/store/gallerySlice';
import imagesReducer from 'features/gallery/store/imagesSlice'; import imagesReducer from 'features/gallery/store/imagesSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
// import sessionReducer from 'features/system/store/sessionSlice'; // import sessionReducer from 'features/system/store/sessionSlice';
import configReducer from 'features/system/store/configSlice';
import uiReducer from 'features/ui/store/uiSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer from 'features/nodes/store/nodesSlice'; import nodesReducer from 'features/nodes/store/nodesSlice';
import boardsReducer from 'features/gallery/store/boardSlice'; import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import { listenerMiddleware } from './middleware/listenerMiddleware'; import { listenerMiddleware } from './middleware/listenerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer'; import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist'; import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
// Model Reducers
import sd1ModelReducer from 'features/system/store/models/sd1ModelSlice';
import sd2ModelReducer from 'features/system/store/models/sd2ModelSlice';
import { LOCALSTORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize'; import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize'; import { unserialize } from './enhancers/reduxRemember/unserialize';
import { LOCALSTORAGE_PREFIX } from './constants';
import { api } from 'services/apiSlice'; import { api } from 'services/apiSlice';
const allReducers = { const allReducers = {
@ -40,7 +43,8 @@ const allReducers = {
gallery: galleryReducer, gallery: galleryReducer,
generation: generationReducer, generation: generationReducer,
lightbox: lightboxReducer, lightbox: lightboxReducer,
models: modelsReducer, sd1models: sd1ModelReducer,
sd2models: sd2ModelReducer,
nodes: nodesReducer, nodes: nodesReducer,
postprocessing: postprocessingReducer, postprocessing: postprocessingReducer,
system: systemReducer, system: systemReducer,
@ -63,7 +67,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'gallery', 'gallery',
'generation', 'generation',
'lightbox', 'lightbox',
// 'models',
'nodes', 'nodes',
'postprocessing', 'postprocessing',
'system', 'system',

View File

@ -1,29 +1,14 @@
import { Select } from '@chakra-ui/react'; import { NativeSelect } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
ModelInputFieldTemplate, ModelInputFieldTemplate,
ModelInputFieldValue, ModelInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { selectModelsIds } from 'features/system/store/modelSlice'; import { modelSelector } from 'features/system/components/ModelSelect';
import { isEqual } from 'lodash-es';
import { ChangeEvent, memo } from 'react'; import { ChangeEvent, memo } from 'react';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
const availableModelsSelector = createSelector(
[selectModelsIds],
(allModelNames) => {
return { allModelNames };
// return map(modelList, (_, name) => name);
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ModelInputFieldComponent = ( const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate> props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
) => { ) => {
@ -31,7 +16,7 @@ const ModelInputFieldComponent = (
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { allModelNames } = useAppSelector(availableModelsSelector); const { sd1ModelData, sd2ModelData } = useAppSelector(modelSelector);
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => { const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch( dispatch(
@ -44,14 +29,11 @@ const ModelInputFieldComponent = (
}; };
return ( return (
<Select <NativeSelect
onChange={handleValueChanged} onChange={handleValueChanged}
value={field.value || allModelNames[0]} value={field.value || sd1ModelData[0].value}
> data={sd1ModelData.concat(sd2ModelData)}
{allModelNames.map((option) => ( ></NativeSelect>
<option key={option}>{option}</option>
))}
</Select>
); );
}; };

View File

@ -1,10 +1,11 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { Scheduler } from 'app/constants';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { clamp, sortBy } from 'lodash-es'; import { clamp, sortBy } from 'lodash-es';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { imageUrlsReceived } from 'services/thunks/image'; import { imageUrlsReceived } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model'; import { getModels } from 'services/thunks/model';
import { import {
CfgScaleParam, CfgScaleParam,
HeightParam, HeightParam,
@ -17,7 +18,6 @@ import {
StrengthParam, StrengthParam,
WidthParam, WidthParam,
} from './parameterZodSchemas'; } from './parameterZodSchemas';
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
export interface GenerationState { export interface GenerationState {
cfgScale: CfgScaleParam; cfgScale: CfgScaleParam;
@ -220,7 +220,7 @@ export const generationSlice = createSlice({
}, },
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(receivedModels.fulfilled, (state, action) => { builder.addCase(getModels.fulfilled, (state, action) => {
if (!state.model) { if (!state.model) {
const firstModel = sortBy(action.payload, 'name')[0]; const firstModel = sortBy(action.payload, 'name')[0];
state.model = firstModel.name; state.model = firstModel.name;

View File

@ -10,22 +10,42 @@ import IAIMantineSelect, {
} from 'common/components/IAIMantineSelect'; } from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { modelSelected } from 'features/parameters/store/generationSlice'; import { modelSelected } from 'features/parameters/store/generationSlice';
import { selectModelsAll, selectModelsById } from '../store/modelSlice'; import {
selectAllSD1Models,
selectByIdSD1Models,
} from '../store/models/sd1ModelSlice';
import {
selectAllSD2Models,
selectByIdSD2Models,
} from '../store/models/sd2ModelSlice';
const selector = createSelector( export const modelSelector = createSelector(
[(state: RootState) => state, generationSelector], [(state: RootState) => state, generationSelector],
(state, generation) => { (state, generation) => {
const selectedModel = selectModelsById(state, generation.model); let selectedModel = selectByIdSD1Models(state, generation.model);
if (selectedModel === undefined)
selectedModel = selectByIdSD2Models(state, generation.model);
const modelData = selectModelsAll(state) const sd1ModelData = selectAllSD1Models(state)
.map<IAISelectDataType>((m) => ({ .map<IAISelectDataType>((m) => ({
value: m.name, value: m.name,
label: m.name, label: m.name,
group: '1.x Models',
})) }))
.sort((a, b) => a.label.localeCompare(b.label)); .sort((a, b) => a.label.localeCompare(b.label));
const sd2ModelData = selectAllSD2Models(state)
.map<IAISelectDataType>((m) => ({
value: m.name,
label: m.name,
group: '2.x Models',
}))
.sort((a, b) => a.label.localeCompare(b.label));
return { return {
selectedModel, selectedModel,
modelData, sd1ModelData,
sd2ModelData,
}; };
}, },
{ {
@ -38,7 +58,9 @@ const selector = createSelector(
const ModelSelect = () => { const ModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { selectedModel, modelData } = useAppSelector(selector); const { selectedModel, sd1ModelData, sd2ModelData } =
useAppSelector(modelSelector);
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
(v: string | null) => { (v: string | null) => {
if (!v) { if (!v) {
@ -55,7 +77,7 @@ const ModelSelect = () => {
label={t('modelManager.model')} label={t('modelManager.model')}
value={selectedModel?.name ?? ''} value={selectedModel?.name ?? ''}
placeholder="Pick one" placeholder="Pick one"
data={modelData} data={sd1ModelData.concat(sd2ModelData)}
onChange={handleChangeModel} onChange={handleChangeModel}
/> />
); );

View File

@ -1,47 +0,0 @@
import { createEntityAdapter } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { CkptModelInfo, DiffusersModelInfo } from 'services/api';
import { receivedModels } from 'services/thunks/model';
export type Model = (CkptModelInfo | DiffusersModelInfo) & {
name: string;
};
export const modelsAdapter = createEntityAdapter<Model>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const initialModelsState = modelsAdapter.getInitialState();
export type ModelsState = typeof initialModelsState;
export const modelsSlice = createSlice({
name: 'models',
initialState: initialModelsState,
reducers: {
modelAdded: modelsAdapter.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(receivedModels.fulfilled, (state, action) => {
const models = action.payload;
modelsAdapter.setAll(state, models);
});
},
});
export const {
selectAll: selectModelsAll,
selectById: selectModelsById,
selectEntities: selectModelsEntities,
selectIds: selectModelsIds,
selectTotal: selectModelsTotal,
} = modelsAdapter.getSelectors<RootState>((state) => state.models);
export const { modelAdded } = modelsSlice.actions;
export default modelsSlice.reducer;

View File

@ -0,0 +1,53 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
StableDiffusion1ModelCheckpointConfig,
StableDiffusion1ModelDiffusersConfig,
} from 'services/api';
import { getModels } from 'services/thunks/model';
export type SD1ModelType = (
| StableDiffusion1ModelCheckpointConfig
| StableDiffusion1ModelDiffusersConfig
) & {
name: string;
};
export const sd1ModelsAdapter = createEntityAdapter<SD1ModelType>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const sd1InitialModelsState = sd1ModelsAdapter.getInitialState();
export type SD1ModelState = typeof sd1InitialModelsState;
export const sd1ModelsSlice = createSlice({
name: 'sd1models',
initialState: sd1InitialModelsState,
reducers: {
modelAdded: sd1ModelsAdapter.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(getModels.fulfilled, (state, action) => {
if (action.meta.arg.baseModel !== 'sd-1') return;
sd1ModelsAdapter.setAll(state, action.payload);
});
},
});
export const {
selectAll: selectAllSD1Models,
selectById: selectByIdSD1Models,
selectEntities: selectEntitiesSD1Models,
selectIds: selectIdsSD1Models,
selectTotal: selectTotalSD1Models,
} = sd1ModelsAdapter.getSelectors<RootState>((state) => state.sd1models);
export const { modelAdded } = sd1ModelsSlice.actions;
export default sd1ModelsSlice.reducer;

View File

@ -0,0 +1,53 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
StableDiffusion2ModelCheckpointConfig,
StableDiffusion2ModelDiffusersConfig,
} from 'services/api';
import { getModels } from 'services/thunks/model';
export type SD2ModelType = (
| StableDiffusion2ModelCheckpointConfig
| StableDiffusion2ModelDiffusersConfig
) & {
name: string;
};
export const sd2ModelsAdapater = createEntityAdapter<SD2ModelType>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const sd2InitialModelsState = sd2ModelsAdapater.getInitialState();
export type SD2ModelState = typeof sd2InitialModelsState;
export const sd2ModelsSlice = createSlice({
name: 'sd2models',
initialState: sd2InitialModelsState,
reducers: {
modelAdded: sd2ModelsAdapater.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(getModels.fulfilled, (state, action) => {
if (action.meta.arg.baseModel !== 'sd-2') return;
sd2ModelsAdapater.setAll(state, action.payload);
});
},
});
export const {
selectAll: selectAllSD2Models,
selectById: selectByIdSD2Models,
selectEntities: selectEntitiesSD2Models,
selectIds: selectIdsSD2Models,
selectTotal: selectTotalSD2Models,
} = sd2ModelsAdapater.getSelectors<RootState>((state) => state.sd2models);
export const { modelAdded } = sd2ModelsSlice.actions;
export default sd2ModelsSlice.reducer;

View File

@ -1,6 +1,9 @@
import { ModelsState } from './modelSlice'; import { SD1ModelState } from './models/sd1ModelSlice';
import { SD2ModelState } from './models/sd2ModelSlice';
/** /**
* Models slice persist denylist * Models slice persist denylist
*/ */
export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids']; export const modelsPersistDenylist:
| (keyof SD1ModelState)[]
| (keyof SD2ModelState)[] = ['entities', 'ids'];

View File

@ -1,20 +1,12 @@
import { UseToastOptions } from '@chakra-ui/react'; import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai'; import * as InvokeAI from 'app/types/invokeai';
import { ProgressImage } from 'services/events/types';
import { makeToast } from '../../../app/components/Toaster';
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
import { receivedModels } from 'services/thunks/model';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { LogLevelName } from 'roarr';
import { InvokeLogLevel } from 'app/logging/useLogger'; import { InvokeLogLevel } from 'app/logging/useLogger';
import { TFuncKey } from 'i18next';
import { t } from 'i18next';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { LANGUAGES } from '../components/LanguagePicker'; import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { imageUploaded } from 'services/thunks/image'; import { TFuncKey, t } from 'i18next';
import { LogLevelName } from 'roarr';
import { import {
appSocketConnected, appSocketConnected,
appSocketDisconnected, appSocketDisconnected,
@ -26,6 +18,12 @@ import {
appSocketSubscribed, appSocketSubscribed,
appSocketUnsubscribed, appSocketUnsubscribed,
} from 'services/events/actions'; } from 'services/events/actions';
import { ProgressImage } from 'services/events/types';
import { imageUploaded } from 'services/thunks/image';
import { getModels } from 'services/thunks/model';
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
import { makeToast } from '../../../app/components/Toaster';
import { LANGUAGES } from '../components/LanguagePicker';
export type CancelStrategy = 'immediate' | 'scheduled'; export type CancelStrategy = 'immediate' | 'scheduled';
@ -382,7 +380,7 @@ export const systemSlice = createSlice({
/** /**
* Received available models from the backend * Received available models from the backend
*/ */
builder.addCase(receivedModels.fulfilled, (state) => { builder.addCase(getModels.fulfilled, (state) => {
state.wereModelsReceived = true; state.wereModelsReceived = true;
}); });

View File

@ -1,31 +1,55 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { createAppAsyncThunk } from 'app/store/storeUtils'; import { createAppAsyncThunk } from 'app/store/storeUtils';
import { Model } from 'features/system/store/modelSlice'; import { SD1ModelType } from 'features/system/store/models/sd1ModelSlice';
import { reduce, size } from 'lodash-es'; import { reduce, size } from 'lodash-es';
import { ModelsService } from 'services/api'; import { BaseModelType, ModelType, ModelsService } from 'services/api';
const models = log.child({ namespace: 'model' }); const models = log.child({ namespace: 'model' });
export const IMAGES_PER_PAGE = 20; export const IMAGES_PER_PAGE = 20;
export const receivedModels = createAppAsyncThunk( type getModelsArg = {
'models/receivedModels', baseModel: BaseModelType | undefined;
async (_) => { modelType: ModelType | undefined;
const response = await ModelsService.listModels(); };
const deserializedModels = reduce( export const getModels = createAppAsyncThunk(
response.models['sd-1']['pipeline'], 'models/getModels',
async (arg: getModelsArg) => {
const response = await ModelsService.listModels(arg);
let deserializedModels = {};
if (arg.baseModel === undefined) return response.models;
if (arg.modelType === undefined) return response.models;
if (arg.baseModel === 'sd-1') {
deserializedModels = reduce(
response.models[arg.baseModel][arg.modelType],
(modelsAccumulator, model, modelName) => { (modelsAccumulator, model, modelName) => {
modelsAccumulator[modelName] = { ...model, name: modelName }; modelsAccumulator[modelName] = { ...model, name: modelName };
return modelsAccumulator; return modelsAccumulator;
}, },
{} as Record<string, Model> {} as Record<string, SD1ModelType>
); );
}
if (arg.baseModel === 'sd-2') {
deserializedModels = reduce(
response.models[arg.baseModel][arg.modelType],
(modelsAccumulator, model, modelName) => {
modelsAccumulator[modelName] = { ...model, name: modelName };
return modelsAccumulator;
},
{} as Record<string, SD1ModelType>
);
}
models.info( models.info(
{ response }, { response },
`Received ${size(response.models['sd-1']['pipeline'])} models` `Received ${size(response.models[arg.baseModel][arg.modelType])} ${[
arg.baseModel,
]} models`
); );
return deserializedModels; return deserializedModels;