updates for defaultModel (#5866)

* move defaultModel logic to modelsLoaded and update to work for key instead of name/base/type string

* lint fix

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
This commit is contained in:
Mary Hipp Rogers 2024-03-05 09:55:22 -05:00 committed by GitHub
parent ba1f6bf926
commit e30cb4b52f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 23 deletions

View File

@ -7,8 +7,10 @@ import {
selectAllT2IAdapters,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features/parameters/store/generationSlice';
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
@ -24,7 +26,9 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
const log = logger('models');
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
const currentModel = getState().generation.model;
const state = getState();
const currentModel = state.generation.model;
const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) {
@ -39,6 +43,29 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
return;
}
const defaultModel = state.config.sd.defaultModel;
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.generation.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(widthChanged(width));
dispatch(heightChanged(height));
return;
}
}
const result = zParameterModel.safeParse(models[0]);
if (!result.success) {

View File

@ -16,7 +16,6 @@ import type {
ParameterScheduler,
ParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { configChanged } from 'features/system/store/configSlice';
import { clamp } from 'lodash-es';
@ -210,26 +209,6 @@ export const generationSlice = createSlice({
},
extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => {
const defaultModel = action.payload.sd?.defaultModel;
if (defaultModel && !state.model) {
const [base_model, model_type, model_name] = defaultModel.split('/');
const result = zParameterModel.safeParse({
model_name,
base_model,
model_type,
});
if (result.success) {
state.model = result.data;
const optimalDimension = getOptimalDimension(result.data);
state.width = optimalDimension;
state.height = optimalDimension;
}
}
if (action.payload.sd?.scheduler) {
state.scheduler = action.payload.sd.scheduler;
}