mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: Add 2.x Models to the Model List
This commit is contained in:
parent
e374211313
commit
f8d7477c7a
@ -18,7 +18,8 @@ const serializationDenylist: {
|
|||||||
gallery: galleryPersistDenylist,
|
gallery: galleryPersistDenylist,
|
||||||
generation: generationPersistDenylist,
|
generation: generationPersistDenylist,
|
||||||
lightbox: lightboxPersistDenylist,
|
lightbox: lightboxPersistDenylist,
|
||||||
models: modelsPersistDenylist,
|
sd1models: modelsPersistDenylist,
|
||||||
|
sd2models: modelsPersistDenylist,
|
||||||
nodes: nodesPersistDenylist,
|
nodes: nodesPersistDenylist,
|
||||||
postprocessing: postprocessingPersistDenylist,
|
postprocessing: postprocessingPersistDenylist,
|
||||||
system: systemPersistDenylist,
|
system: systemPersistDenylist,
|
||||||
|
@ -7,7 +7,8 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
|||||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||||
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
||||||
import { initialConfigState } from 'features/system/store/configSlice';
|
import { initialConfigState } from 'features/system/store/configSlice';
|
||||||
import { initialModelsState } from 'features/system/store/modelSlice';
|
import { sd1InitialModelsState } from 'features/system/store/models/sd1ModelSlice';
|
||||||
|
import { sd2InitialModelsState } from 'features/system/store/models/sd2ModelSlice';
|
||||||
import { initialSystemState } from 'features/system/store/systemSlice';
|
import { initialSystemState } from 'features/system/store/systemSlice';
|
||||||
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
||||||
import { initialUIState } from 'features/ui/store/uiSlice';
|
import { initialUIState } from 'features/ui/store/uiSlice';
|
||||||
@ -21,7 +22,8 @@ const initialStates: {
|
|||||||
gallery: initialGalleryState,
|
gallery: initialGalleryState,
|
||||||
generation: initialGenerationState,
|
generation: initialGenerationState,
|
||||||
lightbox: initialLightboxState,
|
lightbox: initialLightboxState,
|
||||||
models: initialModelsState,
|
sd1models: sd1InitialModelsState,
|
||||||
|
sd2models: sd2InitialModelsState,
|
||||||
nodes: initialNodesState,
|
nodes: initialNodesState,
|
||||||
postprocessing: initialPostprocessingState,
|
postprocessing: initialPostprocessingState,
|
||||||
system: initialSystemState,
|
system: initialSystemState,
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import { startAppListening } from '../..';
|
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||||
import { receivedPageOfImages } from 'services/thunks/image';
|
import { receivedPageOfImages } from 'services/thunks/image';
|
||||||
import { receivedModels } from 'services/thunks/model';
|
import { getModels } from 'services/thunks/model';
|
||||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||||
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'socketio' });
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ export const addSocketConnectedEventListener = () => {
|
|||||||
|
|
||||||
moduleLog.debug({ timestamp }, 'Connected');
|
moduleLog.debug({ timestamp }, 'Connected');
|
||||||
|
|
||||||
const { models, nodes, config, images } = getState();
|
const { sd1models, sd2models, nodes, config, images } = getState();
|
||||||
|
|
||||||
const { disabledTabs } = config;
|
const { disabledTabs } = config;
|
||||||
|
|
||||||
@ -28,8 +28,12 @@ export const addSocketConnectedEventListener = () => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!models.ids.length) {
|
if (!sd1models.ids.length) {
|
||||||
dispatch(receivedModels());
|
dispatch(getModels({ baseModel: 'sd-1', modelType: 'pipeline' }));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!sd2models.ids.length) {
|
||||||
|
dispatch(getModels({ baseModel: 'sd-2', modelType: 'pipeline' }));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||||
|
@ -5,34 +5,37 @@ import {
|
|||||||
configureStore,
|
configureStore,
|
||||||
} from '@reduxjs/toolkit';
|
} from '@reduxjs/toolkit';
|
||||||
|
|
||||||
import { rememberReducer, rememberEnhancer } from 'redux-remember';
|
|
||||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||||
|
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||||
|
|
||||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||||
|
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||||
import imagesReducer from 'features/gallery/store/imagesSlice';
|
import imagesReducer from 'features/gallery/store/imagesSlice';
|
||||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||||
import generationReducer from 'features/parameters/store/generationSlice';
|
import generationReducer from 'features/parameters/store/generationSlice';
|
||||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
|
||||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
// import sessionReducer from 'features/system/store/sessionSlice';
|
// import sessionReducer from 'features/system/store/sessionSlice';
|
||||||
import configReducer from 'features/system/store/configSlice';
|
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
|
||||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
|
||||||
import modelsReducer from 'features/system/store/modelSlice';
|
|
||||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||||
import boardsReducer from 'features/gallery/store/boardSlice';
|
import boardsReducer from 'features/gallery/store/boardSlice';
|
||||||
|
import configReducer from 'features/system/store/configSlice';
|
||||||
|
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||||
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
|
|
||||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||||
|
|
||||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
|
||||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||||
|
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||||
|
|
||||||
|
// Model Reducers
|
||||||
|
import sd1ModelReducer from 'features/system/store/models/sd1ModelSlice';
|
||||||
|
import sd2ModelReducer from 'features/system/store/models/sd2ModelSlice';
|
||||||
|
|
||||||
|
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||||
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
|
||||||
import { api } from 'services/apiSlice';
|
import { api } from 'services/apiSlice';
|
||||||
|
|
||||||
const allReducers = {
|
const allReducers = {
|
||||||
@ -40,7 +43,8 @@ const allReducers = {
|
|||||||
gallery: galleryReducer,
|
gallery: galleryReducer,
|
||||||
generation: generationReducer,
|
generation: generationReducer,
|
||||||
lightbox: lightboxReducer,
|
lightbox: lightboxReducer,
|
||||||
models: modelsReducer,
|
sd1models: sd1ModelReducer,
|
||||||
|
sd2models: sd2ModelReducer,
|
||||||
nodes: nodesReducer,
|
nodes: nodesReducer,
|
||||||
postprocessing: postprocessingReducer,
|
postprocessing: postprocessingReducer,
|
||||||
system: systemReducer,
|
system: systemReducer,
|
||||||
@ -63,7 +67,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
|||||||
'gallery',
|
'gallery',
|
||||||
'generation',
|
'generation',
|
||||||
'lightbox',
|
'lightbox',
|
||||||
// 'models',
|
|
||||||
'nodes',
|
'nodes',
|
||||||
'postprocessing',
|
'postprocessing',
|
||||||
'system',
|
'system',
|
||||||
|
@ -1,29 +1,14 @@
|
|||||||
import { Select } from '@chakra-ui/react';
|
import { NativeSelect } from '@mantine/core';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import {
|
||||||
ModelInputFieldTemplate,
|
ModelInputFieldTemplate,
|
||||||
ModelInputFieldValue,
|
ModelInputFieldValue,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { selectModelsIds } from 'features/system/store/modelSlice';
|
import { modelSelector } from 'features/system/components/ModelSelect';
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
import { ChangeEvent, memo } from 'react';
|
import { ChangeEvent, memo } from 'react';
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
const availableModelsSelector = createSelector(
|
|
||||||
[selectModelsIds],
|
|
||||||
(allModelNames) => {
|
|
||||||
return { allModelNames };
|
|
||||||
// return map(modelList, (_, name) => name);
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const ModelInputFieldComponent = (
|
const ModelInputFieldComponent = (
|
||||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||||
) => {
|
) => {
|
||||||
@ -31,7 +16,7 @@ const ModelInputFieldComponent = (
|
|||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { allModelNames } = useAppSelector(availableModelsSelector);
|
const { sd1ModelData, sd2ModelData } = useAppSelector(modelSelector);
|
||||||
|
|
||||||
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -44,14 +29,11 @@ const ModelInputFieldComponent = (
|
|||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Select
|
<NativeSelect
|
||||||
onChange={handleValueChanged}
|
onChange={handleValueChanged}
|
||||||
value={field.value || allModelNames[0]}
|
value={field.value || sd1ModelData[0].value}
|
||||||
>
|
data={sd1ModelData.concat(sd2ModelData)}
|
||||||
{allModelNames.map((option) => (
|
></NativeSelect>
|
||||||
<option key={option}>{option}</option>
|
|
||||||
))}
|
|
||||||
</Select>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { Scheduler } from 'app/constants';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { clamp, sortBy } from 'lodash-es';
|
import { clamp, sortBy } from 'lodash-es';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { imageUrlsReceived } from 'services/thunks/image';
|
import { imageUrlsReceived } from 'services/thunks/image';
|
||||||
import { receivedModels } from 'services/thunks/model';
|
import { getModels } from 'services/thunks/model';
|
||||||
import {
|
import {
|
||||||
CfgScaleParam,
|
CfgScaleParam,
|
||||||
HeightParam,
|
HeightParam,
|
||||||
@ -17,7 +18,6 @@ import {
|
|||||||
StrengthParam,
|
StrengthParam,
|
||||||
WidthParam,
|
WidthParam,
|
||||||
} from './parameterZodSchemas';
|
} from './parameterZodSchemas';
|
||||||
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
|
|
||||||
|
|
||||||
export interface GenerationState {
|
export interface GenerationState {
|
||||||
cfgScale: CfgScaleParam;
|
cfgScale: CfgScaleParam;
|
||||||
@ -220,7 +220,7 @@ export const generationSlice = createSlice({
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
builder.addCase(getModels.fulfilled, (state, action) => {
|
||||||
if (!state.model) {
|
if (!state.model) {
|
||||||
const firstModel = sortBy(action.payload, 'name')[0];
|
const firstModel = sortBy(action.payload, 'name')[0];
|
||||||
state.model = firstModel.name;
|
state.model = firstModel.name;
|
||||||
|
@ -10,22 +10,42 @@ import IAIMantineSelect, {
|
|||||||
} from 'common/components/IAIMantineSelect';
|
} from 'common/components/IAIMantineSelect';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||||
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
|
import {
|
||||||
|
selectAllSD1Models,
|
||||||
|
selectByIdSD1Models,
|
||||||
|
} from '../store/models/sd1ModelSlice';
|
||||||
|
import {
|
||||||
|
selectAllSD2Models,
|
||||||
|
selectByIdSD2Models,
|
||||||
|
} from '../store/models/sd2ModelSlice';
|
||||||
|
|
||||||
const selector = createSelector(
|
export const modelSelector = createSelector(
|
||||||
[(state: RootState) => state, generationSelector],
|
[(state: RootState) => state, generationSelector],
|
||||||
(state, generation) => {
|
(state, generation) => {
|
||||||
const selectedModel = selectModelsById(state, generation.model);
|
let selectedModel = selectByIdSD1Models(state, generation.model);
|
||||||
|
if (selectedModel === undefined)
|
||||||
|
selectedModel = selectByIdSD2Models(state, generation.model);
|
||||||
|
|
||||||
const modelData = selectModelsAll(state)
|
const sd1ModelData = selectAllSD1Models(state)
|
||||||
.map<IAISelectDataType>((m) => ({
|
.map<IAISelectDataType>((m) => ({
|
||||||
value: m.name,
|
value: m.name,
|
||||||
label: m.name,
|
label: m.name,
|
||||||
|
group: '1.x Models',
|
||||||
}))
|
}))
|
||||||
.sort((a, b) => a.label.localeCompare(b.label));
|
.sort((a, b) => a.label.localeCompare(b.label));
|
||||||
|
|
||||||
|
const sd2ModelData = selectAllSD2Models(state)
|
||||||
|
.map<IAISelectDataType>((m) => ({
|
||||||
|
value: m.name,
|
||||||
|
label: m.name,
|
||||||
|
group: '2.x Models',
|
||||||
|
}))
|
||||||
|
.sort((a, b) => a.label.localeCompare(b.label));
|
||||||
|
|
||||||
return {
|
return {
|
||||||
selectedModel,
|
selectedModel,
|
||||||
modelData,
|
sd1ModelData,
|
||||||
|
sd2ModelData,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -38,7 +58,9 @@ const selector = createSelector(
|
|||||||
const ModelSelect = () => {
|
const ModelSelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { selectedModel, modelData } = useAppSelector(selector);
|
const { selectedModel, sd1ModelData, sd2ModelData } =
|
||||||
|
useAppSelector(modelSelector);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
if (!v) {
|
if (!v) {
|
||||||
@ -55,7 +77,7 @@ const ModelSelect = () => {
|
|||||||
label={t('modelManager.model')}
|
label={t('modelManager.model')}
|
||||||
value={selectedModel?.name ?? ''}
|
value={selectedModel?.name ?? ''}
|
||||||
placeholder="Pick one"
|
placeholder="Pick one"
|
||||||
data={modelData}
|
data={sd1ModelData.concat(sd2ModelData)}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
@ -1,47 +0,0 @@
|
|||||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { CkptModelInfo, DiffusersModelInfo } from 'services/api';
|
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
|
|
||||||
export type Model = (CkptModelInfo | DiffusersModelInfo) & {
|
|
||||||
name: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const modelsAdapter = createEntityAdapter<Model>({
|
|
||||||
selectId: (model) => model.name,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
|
|
||||||
export const initialModelsState = modelsAdapter.getInitialState();
|
|
||||||
|
|
||||||
export type ModelsState = typeof initialModelsState;
|
|
||||||
|
|
||||||
export const modelsSlice = createSlice({
|
|
||||||
name: 'models',
|
|
||||||
initialState: initialModelsState,
|
|
||||||
reducers: {
|
|
||||||
modelAdded: modelsAdapter.upsertOne,
|
|
||||||
},
|
|
||||||
extraReducers(builder) {
|
|
||||||
/**
|
|
||||||
* Received Models - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
|
||||||
const models = action.payload;
|
|
||||||
modelsAdapter.setAll(state, models);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
export const {
|
|
||||||
selectAll: selectModelsAll,
|
|
||||||
selectById: selectModelsById,
|
|
||||||
selectEntities: selectModelsEntities,
|
|
||||||
selectIds: selectModelsIds,
|
|
||||||
selectTotal: selectModelsTotal,
|
|
||||||
} = modelsAdapter.getSelectors<RootState>((state) => state.models);
|
|
||||||
|
|
||||||
export const { modelAdded } = modelsSlice.actions;
|
|
||||||
|
|
||||||
export default modelsSlice.reducer;
|
|
@ -0,0 +1,53 @@
|
|||||||
|
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import {
|
||||||
|
StableDiffusion1ModelCheckpointConfig,
|
||||||
|
StableDiffusion1ModelDiffusersConfig,
|
||||||
|
} from 'services/api';
|
||||||
|
|
||||||
|
import { getModels } from 'services/thunks/model';
|
||||||
|
|
||||||
|
export type SD1ModelType = (
|
||||||
|
| StableDiffusion1ModelCheckpointConfig
|
||||||
|
| StableDiffusion1ModelDiffusersConfig
|
||||||
|
) & {
|
||||||
|
name: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const sd1ModelsAdapter = createEntityAdapter<SD1ModelType>({
|
||||||
|
selectId: (model) => model.name,
|
||||||
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const sd1InitialModelsState = sd1ModelsAdapter.getInitialState();
|
||||||
|
|
||||||
|
export type SD1ModelState = typeof sd1InitialModelsState;
|
||||||
|
|
||||||
|
export const sd1ModelsSlice = createSlice({
|
||||||
|
name: 'sd1models',
|
||||||
|
initialState: sd1InitialModelsState,
|
||||||
|
reducers: {
|
||||||
|
modelAdded: sd1ModelsAdapter.upsertOne,
|
||||||
|
},
|
||||||
|
extraReducers(builder) {
|
||||||
|
/**
|
||||||
|
* Received Models - FULFILLED
|
||||||
|
*/
|
||||||
|
builder.addCase(getModels.fulfilled, (state, action) => {
|
||||||
|
if (action.meta.arg.baseModel !== 'sd-1') return;
|
||||||
|
sd1ModelsAdapter.setAll(state, action.payload);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const {
|
||||||
|
selectAll: selectAllSD1Models,
|
||||||
|
selectById: selectByIdSD1Models,
|
||||||
|
selectEntities: selectEntitiesSD1Models,
|
||||||
|
selectIds: selectIdsSD1Models,
|
||||||
|
selectTotal: selectTotalSD1Models,
|
||||||
|
} = sd1ModelsAdapter.getSelectors<RootState>((state) => state.sd1models);
|
||||||
|
|
||||||
|
export const { modelAdded } = sd1ModelsSlice.actions;
|
||||||
|
|
||||||
|
export default sd1ModelsSlice.reducer;
|
@ -0,0 +1,53 @@
|
|||||||
|
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import {
|
||||||
|
StableDiffusion2ModelCheckpointConfig,
|
||||||
|
StableDiffusion2ModelDiffusersConfig,
|
||||||
|
} from 'services/api';
|
||||||
|
|
||||||
|
import { getModels } from 'services/thunks/model';
|
||||||
|
|
||||||
|
export type SD2ModelType = (
|
||||||
|
| StableDiffusion2ModelCheckpointConfig
|
||||||
|
| StableDiffusion2ModelDiffusersConfig
|
||||||
|
) & {
|
||||||
|
name: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const sd2ModelsAdapater = createEntityAdapter<SD2ModelType>({
|
||||||
|
selectId: (model) => model.name,
|
||||||
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const sd2InitialModelsState = sd2ModelsAdapater.getInitialState();
|
||||||
|
|
||||||
|
export type SD2ModelState = typeof sd2InitialModelsState;
|
||||||
|
|
||||||
|
export const sd2ModelsSlice = createSlice({
|
||||||
|
name: 'sd2models',
|
||||||
|
initialState: sd2InitialModelsState,
|
||||||
|
reducers: {
|
||||||
|
modelAdded: sd2ModelsAdapater.upsertOne,
|
||||||
|
},
|
||||||
|
extraReducers(builder) {
|
||||||
|
/**
|
||||||
|
* Received Models - FULFILLED
|
||||||
|
*/
|
||||||
|
builder.addCase(getModels.fulfilled, (state, action) => {
|
||||||
|
if (action.meta.arg.baseModel !== 'sd-2') return;
|
||||||
|
sd2ModelsAdapater.setAll(state, action.payload);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const {
|
||||||
|
selectAll: selectAllSD2Models,
|
||||||
|
selectById: selectByIdSD2Models,
|
||||||
|
selectEntities: selectEntitiesSD2Models,
|
||||||
|
selectIds: selectIdsSD2Models,
|
||||||
|
selectTotal: selectTotalSD2Models,
|
||||||
|
} = sd2ModelsAdapater.getSelectors<RootState>((state) => state.sd2models);
|
||||||
|
|
||||||
|
export const { modelAdded } = sd2ModelsSlice.actions;
|
||||||
|
|
||||||
|
export default sd2ModelsSlice.reducer;
|
@ -1,6 +1,9 @@
|
|||||||
import { ModelsState } from './modelSlice';
|
import { SD1ModelState } from './models/sd1ModelSlice';
|
||||||
|
import { SD2ModelState } from './models/sd2ModelSlice';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Models slice persist denylist
|
* Models slice persist denylist
|
||||||
*/
|
*/
|
||||||
export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids'];
|
export const modelsPersistDenylist:
|
||||||
|
| (keyof SD1ModelState)[]
|
||||||
|
| (keyof SD2ModelState)[] = ['entities', 'ids'];
|
||||||
|
@ -1,20 +1,12 @@
|
|||||||
import { UseToastOptions } from '@chakra-ui/react';
|
import { UseToastOptions } from '@chakra-ui/react';
|
||||||
import { PayloadAction } from '@reduxjs/toolkit';
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
|
||||||
import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/types/invokeai';
|
||||||
|
|
||||||
import { ProgressImage } from 'services/events/types';
|
|
||||||
import { makeToast } from '../../../app/components/Toaster';
|
|
||||||
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
|
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
|
||||||
import { LogLevelName } from 'roarr';
|
|
||||||
import { InvokeLogLevel } from 'app/logging/useLogger';
|
import { InvokeLogLevel } from 'app/logging/useLogger';
|
||||||
import { TFuncKey } from 'i18next';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
import { LANGUAGES } from '../components/LanguagePicker';
|
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { TFuncKey, t } from 'i18next';
|
||||||
|
import { LogLevelName } from 'roarr';
|
||||||
import {
|
import {
|
||||||
appSocketConnected,
|
appSocketConnected,
|
||||||
appSocketDisconnected,
|
appSocketDisconnected,
|
||||||
@ -26,6 +18,12 @@ import {
|
|||||||
appSocketSubscribed,
|
appSocketSubscribed,
|
||||||
appSocketUnsubscribed,
|
appSocketUnsubscribed,
|
||||||
} from 'services/events/actions';
|
} from 'services/events/actions';
|
||||||
|
import { ProgressImage } from 'services/events/types';
|
||||||
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
|
import { getModels } from 'services/thunks/model';
|
||||||
|
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
|
||||||
|
import { makeToast } from '../../../app/components/Toaster';
|
||||||
|
import { LANGUAGES } from '../components/LanguagePicker';
|
||||||
|
|
||||||
export type CancelStrategy = 'immediate' | 'scheduled';
|
export type CancelStrategy = 'immediate' | 'scheduled';
|
||||||
|
|
||||||
@ -382,7 +380,7 @@ export const systemSlice = createSlice({
|
|||||||
/**
|
/**
|
||||||
* Received available models from the backend
|
* Received available models from the backend
|
||||||
*/
|
*/
|
||||||
builder.addCase(receivedModels.fulfilled, (state) => {
|
builder.addCase(getModels.fulfilled, (state) => {
|
||||||
state.wereModelsReceived = true;
|
state.wereModelsReceived = true;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1,31 +1,55 @@
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { createAppAsyncThunk } from 'app/store/storeUtils';
|
import { createAppAsyncThunk } from 'app/store/storeUtils';
|
||||||
import { Model } from 'features/system/store/modelSlice';
|
import { SD1ModelType } from 'features/system/store/models/sd1ModelSlice';
|
||||||
import { reduce, size } from 'lodash-es';
|
import { reduce, size } from 'lodash-es';
|
||||||
import { ModelsService } from 'services/api';
|
import { BaseModelType, ModelType, ModelsService } from 'services/api';
|
||||||
|
|
||||||
const models = log.child({ namespace: 'model' });
|
const models = log.child({ namespace: 'model' });
|
||||||
|
|
||||||
export const IMAGES_PER_PAGE = 20;
|
export const IMAGES_PER_PAGE = 20;
|
||||||
|
|
||||||
export const receivedModels = createAppAsyncThunk(
|
type getModelsArg = {
|
||||||
'models/receivedModels',
|
baseModel: BaseModelType | undefined;
|
||||||
async (_) => {
|
modelType: ModelType | undefined;
|
||||||
const response = await ModelsService.listModels();
|
};
|
||||||
|
|
||||||
const deserializedModels = reduce(
|
export const getModels = createAppAsyncThunk(
|
||||||
response.models['sd-1']['pipeline'],
|
'models/getModels',
|
||||||
|
async (arg: getModelsArg) => {
|
||||||
|
const response = await ModelsService.listModels(arg);
|
||||||
|
|
||||||
|
let deserializedModels = {};
|
||||||
|
|
||||||
|
if (arg.baseModel === undefined) return response.models;
|
||||||
|
if (arg.modelType === undefined) return response.models;
|
||||||
|
|
||||||
|
if (arg.baseModel === 'sd-1') {
|
||||||
|
deserializedModels = reduce(
|
||||||
|
response.models[arg.baseModel][arg.modelType],
|
||||||
(modelsAccumulator, model, modelName) => {
|
(modelsAccumulator, model, modelName) => {
|
||||||
modelsAccumulator[modelName] = { ...model, name: modelName };
|
modelsAccumulator[modelName] = { ...model, name: modelName };
|
||||||
|
|
||||||
return modelsAccumulator;
|
return modelsAccumulator;
|
||||||
},
|
},
|
||||||
{} as Record<string, Model>
|
{} as Record<string, SD1ModelType>
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (arg.baseModel === 'sd-2') {
|
||||||
|
deserializedModels = reduce(
|
||||||
|
response.models[arg.baseModel][arg.modelType],
|
||||||
|
(modelsAccumulator, model, modelName) => {
|
||||||
|
modelsAccumulator[modelName] = { ...model, name: modelName };
|
||||||
|
return modelsAccumulator;
|
||||||
|
},
|
||||||
|
{} as Record<string, SD1ModelType>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
models.info(
|
models.info(
|
||||||
{ response },
|
{ response },
|
||||||
`Received ${size(response.models['sd-1']['pipeline'])} models`
|
`Received ${size(response.models[arg.baseModel][arg.modelType])} ${[
|
||||||
|
arg.baseModel,
|
||||||
|
]} models`
|
||||||
);
|
);
|
||||||
|
|
||||||
return deserializedModels;
|
return deserializedModels;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user