fix(ui): fix canvas model switching (#4221)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

## Description

There was no check at all to see if the canvas had a valid model already
selected. The first model in the list was selected every time.

Now, we check if its valid. If not, we go through the logic to try and
pick the first valid model.

If there are no valid models, or there was a problem listing models, the
model selection is cleared.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->


- Closes #4125

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

- Go to Canvas tab
- Select a model other than the first one in the list
- Go to a different tab
- Go back to Canvas tab
- The model should be the same as you selected
This commit is contained in:
blessedcoolant 2023-08-10 17:29:41 +12:00 committed by GitHub
commit 2564301aeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,55 +1,58 @@
import { modelChanged } from 'features/parameters/store/generationSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import {
MainModelConfigEntity,
modelsApi,
} from 'services/api/endpoints/models';
import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';
export const addTabChangedListener = () => {
startAppListening({
actionCreator: setActiveTab,
effect: (action, { getState, dispatch }) => {
effect: async (action, { getState, dispatch }) => {
const activeTabName = action.payload;
if (activeTabName === 'unifiedCanvas') {
// grab the models from RTK Query cache
const { data } = modelsApi.endpoints.getMainModels.select(
NON_REFINER_BASE_MODELS
)(getState());
const currentBaseModel = getState().generation.model?.base_model;
if (!data) {
// no models yet, so we can't do anything
dispatch(modelChanged(null));
if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) {
// if we're already on a valid model, no change needed
return;
}
// need to filter out all the invalid canvas models (currently, this is just sdxl)
const validCanvasModels: MainModelConfigEntity[] = [];
try {
// just grab fresh models
const modelsRequest = dispatch(
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
);
const models = await modelsRequest.unwrap();
// cancel this cache subscription
modelsRequest.unsubscribe();
forEach(data.entities, (entity) => {
if (!entity) {
if (!models.ids.length) {
// no valid canvas models
dispatch(modelChanged(null));
return;
}
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
validCanvasModels.push(entity);
// need to filter out all the invalid canvas models (currently sdxl & refiner)
const validCanvasModels = mainModelsAdapter
.getSelectors()
.selectAll(models)
.filter((model) => ['sd-1', 'sd-2'].includes(model.base_model));
const firstValidCanvasModel = validCanvasModels[0];
if (!firstValidCanvasModel) {
// no valid canvas models
dispatch(modelChanged(null));
return;
}
});
// this could still be undefined even tho TS doesn't say so
const firstValidCanvasModel = validCanvasModels[0];
const { base_model, model_name, model_type } = firstValidCanvasModel;
if (!firstValidCanvasModel) {
// uh oh, we have no models that are valid for canvas
dispatch(modelChanged({ base_model, model_name, model_type }));
} catch {
// network request failed, bail
dispatch(modelChanged(null));
return;
}
// only store the model name and base model in redux
const { base_model, model_name, model_type } = firstValidCanvasModel;
dispatch(modelChanged({ base_model, model_name, model_type }));
}
},
});