cleanup: Updated model slice names to be more descriptive

Basically updated all slices to be more descriptive in their names. Did so in order to make sure theres good naming scheme available for secondary models.
This commit is contained in:
blessedcoolant 2023-06-18 17:36:23 +12:00 committed by psychedelicious
parent 604cc1adcd
commit 0c3616229e
12 changed files with 164 additions and 146 deletions

View File

@ -7,8 +7,8 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialConfigState } from 'features/system/store/configSlice';
import { sd1InitialModelsState } from 'features/system/store/models/sd1ModelSlice';
import { sd2InitialModelsState } from 'features/system/store/models/sd2ModelSlice';
import { sd1InitialPipelineModelsState } from 'features/system/store/models/sd1PipelineModelSlice';
import { sd2InitialPipelineModelsState } from 'features/system/store/models/sd2PipelineModelSlice';
import { initialSystemState } from 'features/system/store/systemSlice';
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
import { initialUIState } from 'features/ui/store/uiSlice';
@ -22,8 +22,8 @@ const initialStates: {
gallery: initialGalleryState,
generation: initialGenerationState,
lightbox: initialLightboxState,
sd1models: sd1InitialModelsState,
sd2models: sd2InitialModelsState,
sd1pipelinemodels: sd1InitialPipelineModelsState,
sd2pipelinemodels: sd2InitialPipelineModelsState,
nodes: initialNodesState,
postprocessing: initialPostprocessingState,
system: initialSystemState,

View File

@ -15,7 +15,8 @@ export const addSocketConnectedEventListener = () => {
moduleLog.debug({ timestamp }, 'Connected');
const { sd1models, sd2models, nodes, config, images } = getState();
const { sd1pipelinemodels, sd2pipelinemodels, nodes, config, images } =
getState();
const { disabledTabs } = config;
@ -28,11 +29,11 @@ export const addSocketConnectedEventListener = () => {
);
}
if (!sd1models.ids.length) {
if (!sd1pipelinemodels.ids.length) {
dispatch(getModels({ baseModel: 'sd-1', modelType: 'pipeline' }));
}
if (!sd2models.ids.length) {
if (!sd2pipelinemodels.ids.length) {
dispatch(getModels({ baseModel: 'sd-2', modelType: 'pipeline' }));
}

View File

@ -30,8 +30,8 @@ import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
// Model Reducers
import sd1ModelReducer from 'features/system/store/models/sd1ModelSlice';
import sd2ModelReducer from 'features/system/store/models/sd2ModelSlice';
import sd1PipelineModelReducer from 'features/system/store/models/sd1PipelineModelSlice';
import sd2PipelineModelReducer from 'features/system/store/models/sd2PipelineModelSlice';
import { LOCALSTORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize';
@ -43,8 +43,8 @@ const allReducers = {
gallery: galleryReducer,
generation: generationReducer,
lightbox: lightboxReducer,
sd1models: sd1ModelReducer,
sd2models: sd2ModelReducer,
sd1pipelinemodels: sd1PipelineModelReducer,
sd2pipelinemodels: sd2PipelineModelReducer,
nodes: nodesReducer,
postprocessing: postprocessingReducer,
system: systemReducer,

View File

@ -17,7 +17,7 @@ const ModelInputFieldComponent = (
const dispatch = useAppDispatch();
const { sd1ModelDropDownData, sd2ModelDropdownData } =
const { sd1PipelineModelDropDownData, sd2PipelineModelDropdownData } =
useAppSelector(modelSelector);
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
@ -33,8 +33,8 @@ const ModelInputFieldComponent = (
return (
<NativeSelect
onChange={handleValueChanged}
value={field.value || sd1ModelDropDownData[0].value}
data={sd1ModelDropDownData.concat(sd2ModelDropdownData)}
value={field.value || sd1PipelineModelDropDownData[0].value}
data={sd1PipelineModelDropDownData.concat(sd2PipelineModelDropdownData)}
></NativeSelect>
);
};

View File

@ -20,8 +20,11 @@ const MODEL_LOADER_MAP = {
const ModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { selectedModel, sd1ModelDropDownData, sd2ModelDropdownData } =
useAppSelector(modelSelector);
const {
selectedModel,
sd1PipelineModelDropDownData,
sd2PipelineModelDropdownData,
} = useAppSelector(modelSelector);
useEffect(() => {
if (selectedModel)
@ -48,7 +51,7 @@ const ModelSelect = () => {
label={t('modelManager.model')}
value={selectedModel?.name ?? ''}
placeholder="Pick one"
data={sd1ModelDropDownData.concat(sd2ModelDropdownData)}
data={sd1PipelineModelDropDownData.concat(sd2PipelineModelDropdownData)}
onChange={handleChangeModel}
/>
);

View File

@ -3,26 +3,30 @@ import { RootState } from 'app/store/store';
import { IAISelectDataType } from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { isEqual } from 'lodash-es';
import {
selectAllSD1Models,
selectByIdSD1Models,
} from './models/sd1ModelSlice';
selectAllSD1PipelineModels,
selectByIdSD1PipelineModels,
} from './models/sd1PipelineModelSlice';
import {
selectAllSD2Models,
selectByIdSD2Models,
} from './models/sd2ModelSlice';
selectAllSD2PipelineModels,
selectByIdSD2PipelineModels,
} from './models/sd2PipelineModelSlice';
export const modelSelector = createSelector(
[(state: RootState) => state, generationSelector],
(state, generation) => {
let selectedModel = selectByIdSD1Models(state, generation.model);
let selectedModel = selectByIdSD1PipelineModels(state, generation.model);
if (selectedModel === undefined)
selectedModel = selectByIdSD2Models(state, generation.model);
selectedModel = selectByIdSD2PipelineModels(state, generation.model);
const sd1Models = selectAllSD1Models(state);
const sd2Models = selectAllSD2Models(state);
const sd1PipelineModels = selectAllSD1PipelineModels(state);
const sd2PipelineModels = selectAllSD2PipelineModels(state);
const sd1ModelDropDownData = selectAllSD1Models(state)
const allPipelineModels = sd1PipelineModels.concat(sd2PipelineModels);
const sd1PipelineModelDropDownData = selectAllSD1PipelineModels(state)
.map<IAISelectDataType>((m) => ({
value: m.name,
label: m.name,
@ -30,7 +34,7 @@ export const modelSelector = createSelector(
}))
.sort((a, b) => a.label.localeCompare(b.label));
const sd2ModelDropdownData = selectAllSD2Models(state)
const sd2PipelineModelDropdownData = selectAllSD2PipelineModels(state)
.map<IAISelectDataType>((m) => ({
value: m.name,
label: m.name,
@ -40,10 +44,11 @@ export const modelSelector = createSelector(
return {
selectedModel,
sd1Models,
sd2Models,
sd1ModelDropDownData,
sd2ModelDropdownData,
allPipelineModels,
sd1PipelineModels,
sd2PipelineModels,
sd1PipelineModelDropDownData,
sd2PipelineModelDropdownData,
};
},
{

View File

@ -1,53 +0,0 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
StableDiffusion1ModelCheckpointConfig,
StableDiffusion1ModelDiffusersConfig,
} from 'services/api';
import { getModels } from 'services/thunks/model';
export type SD1ModelType = (
| StableDiffusion1ModelCheckpointConfig
| StableDiffusion1ModelDiffusersConfig
) & {
name: string;
};
export const sd1ModelsAdapter = createEntityAdapter<SD1ModelType>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const sd1InitialModelsState = sd1ModelsAdapter.getInitialState();
export type SD1ModelState = typeof sd1InitialModelsState;
export const sd1ModelsSlice = createSlice({
name: 'sd1models',
initialState: sd1InitialModelsState,
reducers: {
modelAdded: sd1ModelsAdapter.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(getModels.fulfilled, (state, action) => {
if (action.meta.arg.baseModel !== 'sd-1') return;
sd1ModelsAdapter.setAll(state, action.payload);
});
},
});
export const {
selectAll: selectAllSD1Models,
selectById: selectByIdSD1Models,
selectEntities: selectEntitiesSD1Models,
selectIds: selectIdsSD1Models,
selectTotal: selectTotalSD1Models,
} = sd1ModelsAdapter.getSelectors<RootState>((state) => state.sd1models);
export const { modelAdded } = sd1ModelsSlice.actions;
export default sd1ModelsSlice.reducer;

View File

@ -0,0 +1,57 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
StableDiffusion1ModelCheckpointConfig,
StableDiffusion1ModelDiffusersConfig,
} from 'services/api';
import { getModels } from 'services/thunks/model';
export type SD1PipelineModelType = (
| StableDiffusion1ModelCheckpointConfig
| StableDiffusion1ModelDiffusersConfig
) & {
name: string;
};
export const sd1PipelineModelsAdapter =
createEntityAdapter<SD1PipelineModelType>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const sd1InitialPipelineModelsState =
sd1PipelineModelsAdapter.getInitialState();
export type SD1PipelineModelState = typeof sd1InitialPipelineModelsState;
export const sd1PipelineModelsSlice = createSlice({
name: 'sd1models',
initialState: sd1InitialPipelineModelsState,
reducers: {
modelAdded: sd1PipelineModelsAdapter.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(getModels.fulfilled, (state, action) => {
if (action.meta.arg.baseModel !== 'sd-1') return;
sd1PipelineModelsAdapter.setAll(state, action.payload);
});
},
});
export const {
selectAll: selectAllSD1PipelineModels,
selectById: selectByIdSD1PipelineModels,
selectEntities: selectEntitiesSD1PipelineModels,
selectIds: selectIdsSD1PipelineModels,
selectTotal: selectTotalSD1PipelineModels,
} = sd1PipelineModelsAdapter.getSelectors<RootState>(
(state) => state.sd1pipelinemodels
);
export const { modelAdded } = sd1PipelineModelsSlice.actions;
export default sd1PipelineModelsSlice.reducer;

View File

@ -1,53 +0,0 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
StableDiffusion2ModelCheckpointConfig,
StableDiffusion2ModelDiffusersConfig,
} from 'services/api';
import { getModels } from 'services/thunks/model';
export type SD2ModelType = (
| StableDiffusion2ModelCheckpointConfig
| StableDiffusion2ModelDiffusersConfig
) & {
name: string;
};
export const sd2ModelsAdapater = createEntityAdapter<SD2ModelType>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const sd2InitialModelsState = sd2ModelsAdapater.getInitialState();
export type SD2ModelState = typeof sd2InitialModelsState;
export const sd2ModelsSlice = createSlice({
name: 'sd2models',
initialState: sd2InitialModelsState,
reducers: {
modelAdded: sd2ModelsAdapater.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(getModels.fulfilled, (state, action) => {
if (action.meta.arg.baseModel !== 'sd-2') return;
sd2ModelsAdapater.setAll(state, action.payload);
});
},
});
export const {
selectAll: selectAllSD2Models,
selectById: selectByIdSD2Models,
selectEntities: selectEntitiesSD2Models,
selectIds: selectIdsSD2Models,
selectTotal: selectTotalSD2Models,
} = sd2ModelsAdapater.getSelectors<RootState>((state) => state.sd2models);
export const { modelAdded } = sd2ModelsSlice.actions;
export default sd2ModelsSlice.reducer;

View File

@ -0,0 +1,57 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
StableDiffusion2ModelCheckpointConfig,
StableDiffusion2ModelDiffusersConfig,
} from 'services/api';
import { getModels } from 'services/thunks/model';
export type SD2PipelineModelType = (
| StableDiffusion2ModelCheckpointConfig
| StableDiffusion2ModelDiffusersConfig
) & {
name: string;
};
export const sd2PipelineModelsAdapater =
createEntityAdapter<SD2PipelineModelType>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const sd2InitialPipelineModelsState =
sd2PipelineModelsAdapater.getInitialState();
export type SD2PipelineModelState = typeof sd2InitialPipelineModelsState;
export const sd2PipelineModelsSlice = createSlice({
name: 'sd2models',
initialState: sd2InitialPipelineModelsState,
reducers: {
modelAdded: sd2PipelineModelsAdapater.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(getModels.fulfilled, (state, action) => {
if (action.meta.arg.baseModel !== 'sd-2') return;
sd2PipelineModelsAdapater.setAll(state, action.payload);
});
},
});
export const {
selectAll: selectAllSD2PipelineModels,
selectById: selectByIdSD2PipelineModels,
selectEntities: selectEntitiesSD2PipelineModels,
selectIds: selectIdsSD2PipelineModels,
selectTotal: selectTotalSD2PipelineModels,
} = sd2PipelineModelsAdapater.getSelectors<RootState>(
(state) => state.sd2pipelinemodels
);
export const { modelAdded } = sd2PipelineModelsSlice.actions;
export default sd2PipelineModelsSlice.reducer;

View File

@ -1,9 +1,9 @@
import { SD1ModelState } from './models/sd1ModelSlice';
import { SD2ModelState } from './models/sd2ModelSlice';
import { SD1PipelineModelState } from './models/sd1PipelineModelSlice';
import { SD2PipelineModelState } from './models/sd2PipelineModelSlice';
/**
* Models slice persist denylist
*/
export const modelsPersistDenylist:
| (keyof SD1ModelState)[]
| (keyof SD2ModelState)[] = ['entities', 'ids'];
| (keyof SD1PipelineModelState)[]
| (keyof SD2PipelineModelState)[] = ['entities', 'ids'];

View File

@ -1,6 +1,7 @@
import { log } from 'app/logging/useLogger';
import { createAppAsyncThunk } from 'app/store/storeUtils';
import { SD1ModelType } from 'features/system/store/models/sd1ModelSlice';
import { SD1PipelineModelType } from 'features/system/store/models/sd1PipelineModelSlice';
import { SD2PipelineModelType } from 'features/system/store/models/sd2PipelineModelSlice';
import { reduce, size } from 'lodash-es';
import { BaseModelType, ModelType, ModelsService } from 'services/api';
@ -30,7 +31,7 @@ export const getModels = createAppAsyncThunk(
modelsAccumulator[modelName] = { ...model, name: modelName };
return modelsAccumulator;
},
{} as Record<string, SD1ModelType>
{} as Record<string, SD1PipelineModelType>
);
}
@ -41,7 +42,7 @@ export const getModels = createAppAsyncThunk(
modelsAccumulator[modelName] = { ...model, name: modelName };
return modelsAccumulator;
},
{} as Record<string, SD1ModelType>
{} as Record<string, SD2PipelineModelType>
);
}