Merge branch 'main' into fix/ui/fix-lora-sort

This commit is contained in:
psychedelicious 2023-08-10 15:32:40 +10:00
commit d4812bbc8d
2 changed files with 33 additions and 35 deletions

View File

@ -1,55 +1,58 @@
import { modelChanged } from 'features/parameters/store/generationSlice'; import { modelChanged } from 'features/parameters/store/generationSlice';
import { setActiveTab } from 'features/ui/store/uiSlice'; import { setActiveTab } from 'features/ui/store/uiSlice';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models';
MainModelConfigEntity,
modelsApi,
} from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
export const addTabChangedListener = () => { export const addTabChangedListener = () => {
startAppListening({ startAppListening({
actionCreator: setActiveTab, actionCreator: setActiveTab,
effect: (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
const activeTabName = action.payload; const activeTabName = action.payload;
if (activeTabName === 'unifiedCanvas') { if (activeTabName === 'unifiedCanvas') {
// grab the models from RTK Query cache const currentBaseModel = getState().generation.model?.base_model;
const { data } = modelsApi.endpoints.getMainModels.select(
NON_REFINER_BASE_MODELS
)(getState());
if (!data) { if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) {
// no models yet, so we can't do anything // if we're already on a valid model, no change needed
dispatch(modelChanged(null));
return; return;
} }
// need to filter out all the invalid canvas models (currently, this is just sdxl) try {
const validCanvasModels: MainModelConfigEntity[] = []; // just grab fresh models
const modelsRequest = dispatch(
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
);
const models = await modelsRequest.unwrap();
// cancel this cache subscription
modelsRequest.unsubscribe();
forEach(data.entities, (entity) => { if (!models.ids.length) {
if (!entity) { // no valid canvas models
dispatch(modelChanged(null));
return; return;
} }
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
validCanvasModels.push(entity); // need to filter out all the invalid canvas models (currently sdxl & refiner)
const validCanvasModels = mainModelsAdapter
.getSelectors()
.selectAll(models)
.filter((model) => ['sd-1', 'sd-2'].includes(model.base_model));
const firstValidCanvasModel = validCanvasModels[0];
if (!firstValidCanvasModel) {
// no valid canvas models
dispatch(modelChanged(null));
return;
} }
});
// this could still be undefined even tho TS doesn't say so const { base_model, model_name, model_type } = firstValidCanvasModel;
const firstValidCanvasModel = validCanvasModels[0];
if (!firstValidCanvasModel) { dispatch(modelChanged({ base_model, model_name, model_type }));
// uh oh, we have no models that are valid for canvas } catch {
// network request failed, bail
dispatch(modelChanged(null)); dispatch(modelChanged(null));
return;
} }
// only store the model name and base model in redux
const { base_model, model_name, model_type } = firstValidCanvasModel;
dispatch(modelChanged({ base_model, model_name, model_type }));
} }
}, },
}); });

View File

@ -54,11 +54,6 @@ const ParamLoRASelect = () => {
}); });
}); });
// Sort Alphabetically
data.sort((a, b) =>
a.label && b.label ? (a.label?.localeCompare(b.label) ? 1 : -1) : -1
);
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1)); return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [loras, loraModels, currentMainModel?.base_model]); }, [loras, loraModels, currentMainModel?.base_model]);