mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): fix canvas model switching
There was no check at all to see if the canvas had a valid model already selected. The first model in the list was selected every time. Now, we check if its valid. If not, we go through the logic to try and pick the first valid model. If there are no valid models, or there was a problem listing models, the model selection is cleared.
This commit is contained in:
parent
49cce1eec6
commit
da0efeaa7f
@ -1,55 +1,58 @@
|
|||||||
import { modelChanged } from 'features/parameters/store/generationSlice';
|
import { modelChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
import { forEach } from 'lodash-es';
|
|
||||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
import {
|
import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models';
|
||||||
MainModelConfigEntity,
|
|
||||||
modelsApi,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
export const addTabChangedListener = () => {
|
export const addTabChangedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: setActiveTab,
|
actionCreator: setActiveTab,
|
||||||
effect: (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
const activeTabName = action.payload;
|
const activeTabName = action.payload;
|
||||||
if (activeTabName === 'unifiedCanvas') {
|
if (activeTabName === 'unifiedCanvas') {
|
||||||
// grab the models from RTK Query cache
|
const currentBaseModel = getState().generation.model?.base_model;
|
||||||
const { data } = modelsApi.endpoints.getMainModels.select(
|
|
||||||
NON_REFINER_BASE_MODELS
|
|
||||||
)(getState());
|
|
||||||
|
|
||||||
if (!data) {
|
if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) {
|
||||||
// no models yet, so we can't do anything
|
// if we're already on a valid model, no change needed
|
||||||
dispatch(modelChanged(null));
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// need to filter out all the invalid canvas models (currently, this is just sdxl)
|
try {
|
||||||
const validCanvasModels: MainModelConfigEntity[] = [];
|
// just grab fresh models
|
||||||
|
const modelsRequest = dispatch(
|
||||||
|
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
|
||||||
|
);
|
||||||
|
const models = await modelsRequest.unwrap();
|
||||||
|
// cancel this cache subscription
|
||||||
|
modelsRequest.unsubscribe();
|
||||||
|
|
||||||
forEach(data.entities, (entity) => {
|
if (!models.ids.length) {
|
||||||
if (!entity) {
|
// no valid canvas models
|
||||||
|
dispatch(modelChanged(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
|
|
||||||
validCanvasModels.push(entity);
|
// need to filter out all the invalid canvas models (currently sdxl & refiner)
|
||||||
|
const validCanvasModels = mainModelsAdapter
|
||||||
|
.getSelectors()
|
||||||
|
.selectAll(models)
|
||||||
|
.filter((model) => ['sd-1', 'sd-2'].includes(model.base_model));
|
||||||
|
|
||||||
|
const firstValidCanvasModel = validCanvasModels[0];
|
||||||
|
|
||||||
|
if (!firstValidCanvasModel) {
|
||||||
|
// no valid canvas models
|
||||||
|
dispatch(modelChanged(null));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
});
|
|
||||||
|
|
||||||
// this could still be undefined even tho TS doesn't say so
|
const { base_model, model_name, model_type } = firstValidCanvasModel;
|
||||||
const firstValidCanvasModel = validCanvasModels[0];
|
|
||||||
|
|
||||||
if (!firstValidCanvasModel) {
|
dispatch(modelChanged({ base_model, model_name, model_type }));
|
||||||
// uh oh, we have no models that are valid for canvas
|
} catch {
|
||||||
|
// network request failed, bail
|
||||||
dispatch(modelChanged(null));
|
dispatch(modelChanged(null));
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// only store the model name and base model in redux
|
|
||||||
const { base_model, model_name, model_type } = firstValidCanvasModel;
|
|
||||||
|
|
||||||
dispatch(modelChanged({ base_model, model_name, model_type }));
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user