mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup: Updated model slice names to be more descriptive
Basically updated all slices to be more descriptive in their names. Did so in order to make sure theres good naming scheme available for secondary models.
This commit is contained in:
parent
604cc1adcd
commit
0c3616229e
@ -7,8 +7,8 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
||||
import { initialConfigState } from 'features/system/store/configSlice';
|
||||
import { sd1InitialModelsState } from 'features/system/store/models/sd1ModelSlice';
|
||||
import { sd2InitialModelsState } from 'features/system/store/models/sd2ModelSlice';
|
||||
import { sd1InitialPipelineModelsState } from 'features/system/store/models/sd1PipelineModelSlice';
|
||||
import { sd2InitialPipelineModelsState } from 'features/system/store/models/sd2PipelineModelSlice';
|
||||
import { initialSystemState } from 'features/system/store/systemSlice';
|
||||
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
||||
import { initialUIState } from 'features/ui/store/uiSlice';
|
||||
@ -22,8 +22,8 @@ const initialStates: {
|
||||
gallery: initialGalleryState,
|
||||
generation: initialGenerationState,
|
||||
lightbox: initialLightboxState,
|
||||
sd1models: sd1InitialModelsState,
|
||||
sd2models: sd2InitialModelsState,
|
||||
sd1pipelinemodels: sd1InitialPipelineModelsState,
|
||||
sd2pipelinemodels: sd2InitialPipelineModelsState,
|
||||
nodes: initialNodesState,
|
||||
postprocessing: initialPostprocessingState,
|
||||
system: initialSystemState,
|
||||
|
@ -15,7 +15,8 @@ export const addSocketConnectedEventListener = () => {
|
||||
|
||||
moduleLog.debug({ timestamp }, 'Connected');
|
||||
|
||||
const { sd1models, sd2models, nodes, config, images } = getState();
|
||||
const { sd1pipelinemodels, sd2pipelinemodels, nodes, config, images } =
|
||||
getState();
|
||||
|
||||
const { disabledTabs } = config;
|
||||
|
||||
@ -28,11 +29,11 @@ export const addSocketConnectedEventListener = () => {
|
||||
);
|
||||
}
|
||||
|
||||
if (!sd1models.ids.length) {
|
||||
if (!sd1pipelinemodels.ids.length) {
|
||||
dispatch(getModels({ baseModel: 'sd-1', modelType: 'pipeline' }));
|
||||
}
|
||||
|
||||
if (!sd2models.ids.length) {
|
||||
if (!sd2pipelinemodels.ids.length) {
|
||||
dispatch(getModels({ baseModel: 'sd-2', modelType: 'pipeline' }));
|
||||
}
|
||||
|
||||
|
@ -30,8 +30,8 @@ import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
|
||||
// Model Reducers
|
||||
import sd1ModelReducer from 'features/system/store/models/sd1ModelSlice';
|
||||
import sd2ModelReducer from 'features/system/store/models/sd2ModelSlice';
|
||||
import sd1PipelineModelReducer from 'features/system/store/models/sd1PipelineModelSlice';
|
||||
import sd2PipelineModelReducer from 'features/system/store/models/sd2PipelineModelSlice';
|
||||
|
||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||
@ -43,8 +43,8 @@ const allReducers = {
|
||||
gallery: galleryReducer,
|
||||
generation: generationReducer,
|
||||
lightbox: lightboxReducer,
|
||||
sd1models: sd1ModelReducer,
|
||||
sd2models: sd2ModelReducer,
|
||||
sd1pipelinemodels: sd1PipelineModelReducer,
|
||||
sd2pipelinemodels: sd2PipelineModelReducer,
|
||||
nodes: nodesReducer,
|
||||
postprocessing: postprocessingReducer,
|
||||
system: systemReducer,
|
||||
|
@ -17,7 +17,7 @@ const ModelInputFieldComponent = (
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { sd1ModelDropDownData, sd2ModelDropdownData } =
|
||||
const { sd1PipelineModelDropDownData, sd2PipelineModelDropdownData } =
|
||||
useAppSelector(modelSelector);
|
||||
|
||||
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
||||
@ -33,8 +33,8 @@ const ModelInputFieldComponent = (
|
||||
return (
|
||||
<NativeSelect
|
||||
onChange={handleValueChanged}
|
||||
value={field.value || sd1ModelDropDownData[0].value}
|
||||
data={sd1ModelDropDownData.concat(sd2ModelDropdownData)}
|
||||
value={field.value || sd1PipelineModelDropDownData[0].value}
|
||||
data={sd1PipelineModelDropDownData.concat(sd2PipelineModelDropdownData)}
|
||||
></NativeSelect>
|
||||
);
|
||||
};
|
||||
|
@ -20,8 +20,11 @@ const MODEL_LOADER_MAP = {
|
||||
const ModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { selectedModel, sd1ModelDropDownData, sd2ModelDropdownData } =
|
||||
useAppSelector(modelSelector);
|
||||
const {
|
||||
selectedModel,
|
||||
sd1PipelineModelDropDownData,
|
||||
sd2PipelineModelDropdownData,
|
||||
} = useAppSelector(modelSelector);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedModel)
|
||||
@ -48,7 +51,7 @@ const ModelSelect = () => {
|
||||
label={t('modelManager.model')}
|
||||
value={selectedModel?.name ?? ''}
|
||||
placeholder="Pick one"
|
||||
data={sd1ModelDropDownData.concat(sd2ModelDropdownData)}
|
||||
data={sd1PipelineModelDropDownData.concat(sd2PipelineModelDropdownData)}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
);
|
||||
|
@ -3,26 +3,30 @@ import { RootState } from 'app/store/store';
|
||||
import { IAISelectDataType } from 'common/components/IAIMantineSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import {
|
||||
selectAllSD1Models,
|
||||
selectByIdSD1Models,
|
||||
} from './models/sd1ModelSlice';
|
||||
selectAllSD1PipelineModels,
|
||||
selectByIdSD1PipelineModels,
|
||||
} from './models/sd1PipelineModelSlice';
|
||||
|
||||
import {
|
||||
selectAllSD2Models,
|
||||
selectByIdSD2Models,
|
||||
} from './models/sd2ModelSlice';
|
||||
selectAllSD2PipelineModels,
|
||||
selectByIdSD2PipelineModels,
|
||||
} from './models/sd2PipelineModelSlice';
|
||||
|
||||
export const modelSelector = createSelector(
|
||||
[(state: RootState) => state, generationSelector],
|
||||
(state, generation) => {
|
||||
let selectedModel = selectByIdSD1Models(state, generation.model);
|
||||
let selectedModel = selectByIdSD1PipelineModels(state, generation.model);
|
||||
if (selectedModel === undefined)
|
||||
selectedModel = selectByIdSD2Models(state, generation.model);
|
||||
selectedModel = selectByIdSD2PipelineModels(state, generation.model);
|
||||
|
||||
const sd1Models = selectAllSD1Models(state);
|
||||
const sd2Models = selectAllSD2Models(state);
|
||||
const sd1PipelineModels = selectAllSD1PipelineModels(state);
|
||||
const sd2PipelineModels = selectAllSD2PipelineModels(state);
|
||||
|
||||
const sd1ModelDropDownData = selectAllSD1Models(state)
|
||||
const allPipelineModels = sd1PipelineModels.concat(sd2PipelineModels);
|
||||
|
||||
const sd1PipelineModelDropDownData = selectAllSD1PipelineModels(state)
|
||||
.map<IAISelectDataType>((m) => ({
|
||||
value: m.name,
|
||||
label: m.name,
|
||||
@ -30,7 +34,7 @@ export const modelSelector = createSelector(
|
||||
}))
|
||||
.sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
const sd2ModelDropdownData = selectAllSD2Models(state)
|
||||
const sd2PipelineModelDropdownData = selectAllSD2PipelineModels(state)
|
||||
.map<IAISelectDataType>((m) => ({
|
||||
value: m.name,
|
||||
label: m.name,
|
||||
@ -40,10 +44,11 @@ export const modelSelector = createSelector(
|
||||
|
||||
return {
|
||||
selectedModel,
|
||||
sd1Models,
|
||||
sd2Models,
|
||||
sd1ModelDropDownData,
|
||||
sd2ModelDropdownData,
|
||||
allPipelineModels,
|
||||
sd1PipelineModels,
|
||||
sd2PipelineModels,
|
||||
sd1PipelineModelDropDownData,
|
||||
sd2PipelineModelDropdownData,
|
||||
};
|
||||
},
|
||||
{
|
||||
|
@ -1,53 +0,0 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
StableDiffusion1ModelCheckpointConfig,
|
||||
StableDiffusion1ModelDiffusersConfig,
|
||||
} from 'services/api';
|
||||
|
||||
import { getModels } from 'services/thunks/model';
|
||||
|
||||
export type SD1ModelType = (
|
||||
| StableDiffusion1ModelCheckpointConfig
|
||||
| StableDiffusion1ModelDiffusersConfig
|
||||
) & {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export const sd1ModelsAdapter = createEntityAdapter<SD1ModelType>({
|
||||
selectId: (model) => model.name,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
|
||||
export const sd1InitialModelsState = sd1ModelsAdapter.getInitialState();
|
||||
|
||||
export type SD1ModelState = typeof sd1InitialModelsState;
|
||||
|
||||
export const sd1ModelsSlice = createSlice({
|
||||
name: 'sd1models',
|
||||
initialState: sd1InitialModelsState,
|
||||
reducers: {
|
||||
modelAdded: sd1ModelsAdapter.upsertOne,
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Received Models - FULFILLED
|
||||
*/
|
||||
builder.addCase(getModels.fulfilled, (state, action) => {
|
||||
if (action.meta.arg.baseModel !== 'sd-1') return;
|
||||
sd1ModelsAdapter.setAll(state, action.payload);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectAllSD1Models,
|
||||
selectById: selectByIdSD1Models,
|
||||
selectEntities: selectEntitiesSD1Models,
|
||||
selectIds: selectIdsSD1Models,
|
||||
selectTotal: selectTotalSD1Models,
|
||||
} = sd1ModelsAdapter.getSelectors<RootState>((state) => state.sd1models);
|
||||
|
||||
export const { modelAdded } = sd1ModelsSlice.actions;
|
||||
|
||||
export default sd1ModelsSlice.reducer;
|
@ -0,0 +1,57 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
StableDiffusion1ModelCheckpointConfig,
|
||||
StableDiffusion1ModelDiffusersConfig,
|
||||
} from 'services/api';
|
||||
|
||||
import { getModels } from 'services/thunks/model';
|
||||
|
||||
export type SD1PipelineModelType = (
|
||||
| StableDiffusion1ModelCheckpointConfig
|
||||
| StableDiffusion1ModelDiffusersConfig
|
||||
) & {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export const sd1PipelineModelsAdapter =
|
||||
createEntityAdapter<SD1PipelineModelType>({
|
||||
selectId: (model) => model.name,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
|
||||
export const sd1InitialPipelineModelsState =
|
||||
sd1PipelineModelsAdapter.getInitialState();
|
||||
|
||||
export type SD1PipelineModelState = typeof sd1InitialPipelineModelsState;
|
||||
|
||||
export const sd1PipelineModelsSlice = createSlice({
|
||||
name: 'sd1models',
|
||||
initialState: sd1InitialPipelineModelsState,
|
||||
reducers: {
|
||||
modelAdded: sd1PipelineModelsAdapter.upsertOne,
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Received Models - FULFILLED
|
||||
*/
|
||||
builder.addCase(getModels.fulfilled, (state, action) => {
|
||||
if (action.meta.arg.baseModel !== 'sd-1') return;
|
||||
sd1PipelineModelsAdapter.setAll(state, action.payload);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectAllSD1PipelineModels,
|
||||
selectById: selectByIdSD1PipelineModels,
|
||||
selectEntities: selectEntitiesSD1PipelineModels,
|
||||
selectIds: selectIdsSD1PipelineModels,
|
||||
selectTotal: selectTotalSD1PipelineModels,
|
||||
} = sd1PipelineModelsAdapter.getSelectors<RootState>(
|
||||
(state) => state.sd1pipelinemodels
|
||||
);
|
||||
|
||||
export const { modelAdded } = sd1PipelineModelsSlice.actions;
|
||||
|
||||
export default sd1PipelineModelsSlice.reducer;
|
@ -1,53 +0,0 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
StableDiffusion2ModelCheckpointConfig,
|
||||
StableDiffusion2ModelDiffusersConfig,
|
||||
} from 'services/api';
|
||||
|
||||
import { getModels } from 'services/thunks/model';
|
||||
|
||||
export type SD2ModelType = (
|
||||
| StableDiffusion2ModelCheckpointConfig
|
||||
| StableDiffusion2ModelDiffusersConfig
|
||||
) & {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export const sd2ModelsAdapater = createEntityAdapter<SD2ModelType>({
|
||||
selectId: (model) => model.name,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
|
||||
export const sd2InitialModelsState = sd2ModelsAdapater.getInitialState();
|
||||
|
||||
export type SD2ModelState = typeof sd2InitialModelsState;
|
||||
|
||||
export const sd2ModelsSlice = createSlice({
|
||||
name: 'sd2models',
|
||||
initialState: sd2InitialModelsState,
|
||||
reducers: {
|
||||
modelAdded: sd2ModelsAdapater.upsertOne,
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Received Models - FULFILLED
|
||||
*/
|
||||
builder.addCase(getModels.fulfilled, (state, action) => {
|
||||
if (action.meta.arg.baseModel !== 'sd-2') return;
|
||||
sd2ModelsAdapater.setAll(state, action.payload);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectAllSD2Models,
|
||||
selectById: selectByIdSD2Models,
|
||||
selectEntities: selectEntitiesSD2Models,
|
||||
selectIds: selectIdsSD2Models,
|
||||
selectTotal: selectTotalSD2Models,
|
||||
} = sd2ModelsAdapater.getSelectors<RootState>((state) => state.sd2models);
|
||||
|
||||
export const { modelAdded } = sd2ModelsSlice.actions;
|
||||
|
||||
export default sd2ModelsSlice.reducer;
|
@ -0,0 +1,57 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
StableDiffusion2ModelCheckpointConfig,
|
||||
StableDiffusion2ModelDiffusersConfig,
|
||||
} from 'services/api';
|
||||
|
||||
import { getModels } from 'services/thunks/model';
|
||||
|
||||
export type SD2PipelineModelType = (
|
||||
| StableDiffusion2ModelCheckpointConfig
|
||||
| StableDiffusion2ModelDiffusersConfig
|
||||
) & {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export const sd2PipelineModelsAdapater =
|
||||
createEntityAdapter<SD2PipelineModelType>({
|
||||
selectId: (model) => model.name,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
|
||||
export const sd2InitialPipelineModelsState =
|
||||
sd2PipelineModelsAdapater.getInitialState();
|
||||
|
||||
export type SD2PipelineModelState = typeof sd2InitialPipelineModelsState;
|
||||
|
||||
export const sd2PipelineModelsSlice = createSlice({
|
||||
name: 'sd2models',
|
||||
initialState: sd2InitialPipelineModelsState,
|
||||
reducers: {
|
||||
modelAdded: sd2PipelineModelsAdapater.upsertOne,
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Received Models - FULFILLED
|
||||
*/
|
||||
builder.addCase(getModels.fulfilled, (state, action) => {
|
||||
if (action.meta.arg.baseModel !== 'sd-2') return;
|
||||
sd2PipelineModelsAdapater.setAll(state, action.payload);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectAllSD2PipelineModels,
|
||||
selectById: selectByIdSD2PipelineModels,
|
||||
selectEntities: selectEntitiesSD2PipelineModels,
|
||||
selectIds: selectIdsSD2PipelineModels,
|
||||
selectTotal: selectTotalSD2PipelineModels,
|
||||
} = sd2PipelineModelsAdapater.getSelectors<RootState>(
|
||||
(state) => state.sd2pipelinemodels
|
||||
);
|
||||
|
||||
export const { modelAdded } = sd2PipelineModelsSlice.actions;
|
||||
|
||||
export default sd2PipelineModelsSlice.reducer;
|
@ -1,9 +1,9 @@
|
||||
import { SD1ModelState } from './models/sd1ModelSlice';
|
||||
import { SD2ModelState } from './models/sd2ModelSlice';
|
||||
import { SD1PipelineModelState } from './models/sd1PipelineModelSlice';
|
||||
import { SD2PipelineModelState } from './models/sd2PipelineModelSlice';
|
||||
|
||||
/**
|
||||
* Models slice persist denylist
|
||||
*/
|
||||
export const modelsPersistDenylist:
|
||||
| (keyof SD1ModelState)[]
|
||||
| (keyof SD2ModelState)[] = ['entities', 'ids'];
|
||||
| (keyof SD1PipelineModelState)[]
|
||||
| (keyof SD2PipelineModelState)[] = ['entities', 'ids'];
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { createAppAsyncThunk } from 'app/store/storeUtils';
|
||||
import { SD1ModelType } from 'features/system/store/models/sd1ModelSlice';
|
||||
import { SD1PipelineModelType } from 'features/system/store/models/sd1PipelineModelSlice';
|
||||
import { SD2PipelineModelType } from 'features/system/store/models/sd2PipelineModelSlice';
|
||||
import { reduce, size } from 'lodash-es';
|
||||
import { BaseModelType, ModelType, ModelsService } from 'services/api';
|
||||
|
||||
@ -30,7 +31,7 @@ export const getModels = createAppAsyncThunk(
|
||||
modelsAccumulator[modelName] = { ...model, name: modelName };
|
||||
return modelsAccumulator;
|
||||
},
|
||||
{} as Record<string, SD1ModelType>
|
||||
{} as Record<string, SD1PipelineModelType>
|
||||
);
|
||||
}
|
||||
|
||||
@ -41,7 +42,7 @@ export const getModels = createAppAsyncThunk(
|
||||
modelsAccumulator[modelName] = { ...model, name: modelName };
|
||||
return modelsAccumulator;
|
||||
},
|
||||
{} as Record<string, SD1ModelType>
|
||||
{} as Record<string, SD2PipelineModelType>
|
||||
);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user