fix(ui): fix canvas model switching

There was no check at all to see if the canvas had a valid model already selected. The first model in the list was selected every time.

Now, we check if its valid. If not, we go through the logic to try and pick the first valid model.

If there are no valid models, or there was a problem listing models, the model selection is cleared.
This commit is contained in:
psychedelicious 2023-08-10 15:20:37 +10:00
parent 49cce1eec6
commit da0efeaa7f

View File

@ -1,55 +1,58 @@
import { modelChanged } from 'features/parameters/store/generationSlice'; import { modelChanged } from 'features/parameters/store/generationSlice';
import { setActiveTab } from 'features/ui/store/uiSlice'; import { setActiveTab } from 'features/ui/store/uiSlice';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models';
MainModelConfigEntity,
modelsApi,
} from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
export const addTabChangedListener = () => { export const addTabChangedListener = () => {
startAppListening({ startAppListening({
actionCreator: setActiveTab, actionCreator: setActiveTab,
effect: (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
const activeTabName = action.payload; const activeTabName = action.payload;
if (activeTabName === 'unifiedCanvas') { if (activeTabName === 'unifiedCanvas') {
// grab the models from RTK Query cache const currentBaseModel = getState().generation.model?.base_model;
const { data } = modelsApi.endpoints.getMainModels.select(
NON_REFINER_BASE_MODELS
)(getState());
if (!data) { if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) {
// no models yet, so we can't do anything // if we're already on a valid model, no change needed
return;
}
try {
// just grab fresh models
const modelsRequest = dispatch(
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
);
const models = await modelsRequest.unwrap();
// cancel this cache subscription
modelsRequest.unsubscribe();
if (!models.ids.length) {
// no valid canvas models
dispatch(modelChanged(null)); dispatch(modelChanged(null));
return; return;
} }
// need to filter out all the invalid canvas models (currently, this is just sdxl) // need to filter out all the invalid canvas models (currently sdxl & refiner)
const validCanvasModels: MainModelConfigEntity[] = []; const validCanvasModels = mainModelsAdapter
.getSelectors()
.selectAll(models)
.filter((model) => ['sd-1', 'sd-2'].includes(model.base_model));
forEach(data.entities, (entity) => {
if (!entity) {
return;
}
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
validCanvasModels.push(entity);
}
});
// this could still be undefined even tho TS doesn't say so
const firstValidCanvasModel = validCanvasModels[0]; const firstValidCanvasModel = validCanvasModels[0];
if (!firstValidCanvasModel) { if (!firstValidCanvasModel) {
// uh oh, we have no models that are valid for canvas // no valid canvas models
dispatch(modelChanged(null)); dispatch(modelChanged(null));
return; return;
} }
// only store the model name and base model in redux
const { base_model, model_name, model_type } = firstValidCanvasModel; const { base_model, model_name, model_type } = firstValidCanvasModel;
dispatch(modelChanged({ base_model, model_name, model_type })); dispatch(modelChanged({ base_model, model_name, model_type }));
} catch {
// network request failed, bail
dispatch(modelChanged(null));
}
} }
}, },
}); });