mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: Add 2.x Models to the Model List
This commit is contained in:
@ -10,22 +10,42 @@ import IAIMantineSelect, {
|
||||
} from 'common/components/IAIMantineSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
|
||||
import {
|
||||
selectAllSD1Models,
|
||||
selectByIdSD1Models,
|
||||
} from '../store/models/sd1ModelSlice';
|
||||
import {
|
||||
selectAllSD2Models,
|
||||
selectByIdSD2Models,
|
||||
} from '../store/models/sd2ModelSlice';
|
||||
|
||||
const selector = createSelector(
|
||||
export const modelSelector = createSelector(
|
||||
[(state: RootState) => state, generationSelector],
|
||||
(state, generation) => {
|
||||
const selectedModel = selectModelsById(state, generation.model);
|
||||
let selectedModel = selectByIdSD1Models(state, generation.model);
|
||||
if (selectedModel === undefined)
|
||||
selectedModel = selectByIdSD2Models(state, generation.model);
|
||||
|
||||
const modelData = selectModelsAll(state)
|
||||
const sd1ModelData = selectAllSD1Models(state)
|
||||
.map<IAISelectDataType>((m) => ({
|
||||
value: m.name,
|
||||
label: m.name,
|
||||
group: '1.x Models',
|
||||
}))
|
||||
.sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
const sd2ModelData = selectAllSD2Models(state)
|
||||
.map<IAISelectDataType>((m) => ({
|
||||
value: m.name,
|
||||
label: m.name,
|
||||
group: '2.x Models',
|
||||
}))
|
||||
.sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return {
|
||||
selectedModel,
|
||||
modelData,
|
||||
sd1ModelData,
|
||||
sd2ModelData,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -38,7 +58,9 @@ const selector = createSelector(
|
||||
const ModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { selectedModel, modelData } = useAppSelector(selector);
|
||||
const { selectedModel, sd1ModelData, sd2ModelData } =
|
||||
useAppSelector(modelSelector);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
@ -55,7 +77,7 @@ const ModelSelect = () => {
|
||||
label={t('modelManager.model')}
|
||||
value={selectedModel?.name ?? ''}
|
||||
placeholder="Pick one"
|
||||
data={modelData}
|
||||
data={sd1ModelData.concat(sd2ModelData)}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
);
|
||||
|
@ -1,47 +0,0 @@
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { CkptModelInfo, DiffusersModelInfo } from 'services/api';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
|
||||
export type Model = (CkptModelInfo | DiffusersModelInfo) & {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export const modelsAdapter = createEntityAdapter<Model>({
|
||||
selectId: (model) => model.name,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
|
||||
export const initialModelsState = modelsAdapter.getInitialState();
|
||||
|
||||
export type ModelsState = typeof initialModelsState;
|
||||
|
||||
export const modelsSlice = createSlice({
|
||||
name: 'models',
|
||||
initialState: initialModelsState,
|
||||
reducers: {
|
||||
modelAdded: modelsAdapter.upsertOne,
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Received Models - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
||||
const models = action.payload;
|
||||
modelsAdapter.setAll(state, models);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectModelsAll,
|
||||
selectById: selectModelsById,
|
||||
selectEntities: selectModelsEntities,
|
||||
selectIds: selectModelsIds,
|
||||
selectTotal: selectModelsTotal,
|
||||
} = modelsAdapter.getSelectors<RootState>((state) => state.models);
|
||||
|
||||
export const { modelAdded } = modelsSlice.actions;
|
||||
|
||||
export default modelsSlice.reducer;
|
@ -0,0 +1,53 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
StableDiffusion1ModelCheckpointConfig,
|
||||
StableDiffusion1ModelDiffusersConfig,
|
||||
} from 'services/api';
|
||||
|
||||
import { getModels } from 'services/thunks/model';
|
||||
|
||||
export type SD1ModelType = (
|
||||
| StableDiffusion1ModelCheckpointConfig
|
||||
| StableDiffusion1ModelDiffusersConfig
|
||||
) & {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export const sd1ModelsAdapter = createEntityAdapter<SD1ModelType>({
|
||||
selectId: (model) => model.name,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
|
||||
export const sd1InitialModelsState = sd1ModelsAdapter.getInitialState();
|
||||
|
||||
export type SD1ModelState = typeof sd1InitialModelsState;
|
||||
|
||||
export const sd1ModelsSlice = createSlice({
|
||||
name: 'sd1models',
|
||||
initialState: sd1InitialModelsState,
|
||||
reducers: {
|
||||
modelAdded: sd1ModelsAdapter.upsertOne,
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Received Models - FULFILLED
|
||||
*/
|
||||
builder.addCase(getModels.fulfilled, (state, action) => {
|
||||
if (action.meta.arg.baseModel !== 'sd-1') return;
|
||||
sd1ModelsAdapter.setAll(state, action.payload);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectAllSD1Models,
|
||||
selectById: selectByIdSD1Models,
|
||||
selectEntities: selectEntitiesSD1Models,
|
||||
selectIds: selectIdsSD1Models,
|
||||
selectTotal: selectTotalSD1Models,
|
||||
} = sd1ModelsAdapter.getSelectors<RootState>((state) => state.sd1models);
|
||||
|
||||
export const { modelAdded } = sd1ModelsSlice.actions;
|
||||
|
||||
export default sd1ModelsSlice.reducer;
|
@ -0,0 +1,53 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
StableDiffusion2ModelCheckpointConfig,
|
||||
StableDiffusion2ModelDiffusersConfig,
|
||||
} from 'services/api';
|
||||
|
||||
import { getModels } from 'services/thunks/model';
|
||||
|
||||
export type SD2ModelType = (
|
||||
| StableDiffusion2ModelCheckpointConfig
|
||||
| StableDiffusion2ModelDiffusersConfig
|
||||
) & {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export const sd2ModelsAdapater = createEntityAdapter<SD2ModelType>({
|
||||
selectId: (model) => model.name,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
|
||||
export const sd2InitialModelsState = sd2ModelsAdapater.getInitialState();
|
||||
|
||||
export type SD2ModelState = typeof sd2InitialModelsState;
|
||||
|
||||
export const sd2ModelsSlice = createSlice({
|
||||
name: 'sd2models',
|
||||
initialState: sd2InitialModelsState,
|
||||
reducers: {
|
||||
modelAdded: sd2ModelsAdapater.upsertOne,
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Received Models - FULFILLED
|
||||
*/
|
||||
builder.addCase(getModels.fulfilled, (state, action) => {
|
||||
if (action.meta.arg.baseModel !== 'sd-2') return;
|
||||
sd2ModelsAdapater.setAll(state, action.payload);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectAllSD2Models,
|
||||
selectById: selectByIdSD2Models,
|
||||
selectEntities: selectEntitiesSD2Models,
|
||||
selectIds: selectIdsSD2Models,
|
||||
selectTotal: selectTotalSD2Models,
|
||||
} = sd2ModelsAdapater.getSelectors<RootState>((state) => state.sd2models);
|
||||
|
||||
export const { modelAdded } = sd2ModelsSlice.actions;
|
||||
|
||||
export default sd2ModelsSlice.reducer;
|
@ -1,6 +1,9 @@
|
||||
import { ModelsState } from './modelSlice';
|
||||
import { SD1ModelState } from './models/sd1ModelSlice';
|
||||
import { SD2ModelState } from './models/sd2ModelSlice';
|
||||
|
||||
/**
|
||||
* Models slice persist denylist
|
||||
*/
|
||||
export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids'];
|
||||
export const modelsPersistDenylist:
|
||||
| (keyof SD1ModelState)[]
|
||||
| (keyof SD2ModelState)[] = ['entities', 'ids'];
|
||||
|
@ -1,20 +1,12 @@
|
||||
import { UseToastOptions } from '@chakra-ui/react';
|
||||
import { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/types/invokeai';
|
||||
|
||||
import { ProgressImage } from 'services/events/types';
|
||||
import { makeToast } from '../../../app/components/Toaster';
|
||||
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
||||
import { LogLevelName } from 'roarr';
|
||||
import { InvokeLogLevel } from 'app/logging/useLogger';
|
||||
import { TFuncKey } from 'i18next';
|
||||
import { t } from 'i18next';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { LANGUAGES } from '../components/LanguagePicker';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
||||
import { TFuncKey, t } from 'i18next';
|
||||
import { LogLevelName } from 'roarr';
|
||||
import {
|
||||
appSocketConnected,
|
||||
appSocketDisconnected,
|
||||
@ -26,6 +18,12 @@ import {
|
||||
appSocketSubscribed,
|
||||
appSocketUnsubscribed,
|
||||
} from 'services/events/actions';
|
||||
import { ProgressImage } from 'services/events/types';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { getModels } from 'services/thunks/model';
|
||||
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
|
||||
import { makeToast } from '../../../app/components/Toaster';
|
||||
import { LANGUAGES } from '../components/LanguagePicker';
|
||||
|
||||
export type CancelStrategy = 'immediate' | 'scheduled';
|
||||
|
||||
@ -379,7 +377,7 @@ export const systemSlice = createSlice({
|
||||
/**
|
||||
* Received available models from the backend
|
||||
*/
|
||||
builder.addCase(receivedModels.fulfilled, (state) => {
|
||||
builder.addCase(getModels.fulfilled, (state) => {
|
||||
state.wereModelsReceived = true;
|
||||
});
|
||||
|
||||
|
Reference in New Issue
Block a user