feat(ui): move initial IP adapter model selection to listener

This commit is contained in:
psychedelicious 2023-10-04 08:37:51 +11:00
parent 24d73d484a
commit 069d8b5812
3 changed files with 59 additions and 27 deletions

View File

@ -697,7 +697,7 @@
"noLoRAsAvailable": "No LoRAs available",
"noMatchingLoRAs": "No matching LoRAs",
"noMatchingModels": "No matching Models",
"noModelsAvailable": "No Modelss available",
"noModelsAvailable": "No models available",
"selectLoRA": "Select a LoRA",
"selectModel": "Select a Model"
},

View File

@ -1,5 +1,8 @@
import { logger } from 'app/logging/logger';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
import {
controlNetRemoved,
ipAdapterModelChanged,
} from 'features/controlNet/store/controlNetSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import {
modelChanged,
@ -16,12 +19,14 @@ import {
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import {
ipAdapterModelsAdapter,
mainModelsAdapter,
modelsApi,
vaeModelsAdapter,
} from 'services/api/endpoints/models';
import { TypeGuardFor } from 'services/api/types';
import { startAppListening } from '..';
import { zIPAdapterModel } from 'features/nodes/types/types';
export const addModelsLoadedListener = () => {
startAppListening({
@ -234,6 +239,50 @@ export const addModelsLoadedListener = () => {
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getIPAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info(
{ models: action.payload.entities },
`IP Adapter models loaded (${action.payload.ids.length})`
);
const { model } = getState().controlNet.ipAdapterInfo;
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === model?.model_name &&
m?.base_model === model?.base_model
);
if (isModelAvailable) {
return;
}
const firstModel = ipAdapterModelsAdapter
.getSelectors()
.selectAll(action.payload)[0];
if (!firstModel) {
dispatch(ipAdapterModelChanged(null));
}
const result = zIPAdapterModel.safeParse(firstModel);
if (!result.success) {
log.error(
{ error: result.error.format() },
'Failed to parse IP Adapter model'
);
return;
}
dispatch(ipAdapterModelChanged(result.data));
},
});
startAppListening({
matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled,
effect: async (action) => {

View File

@ -6,7 +6,7 @@ import { ipAdapterModelChanged } from 'features/controlNet/store/controlNetSlice
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
import { forEach } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
@ -24,27 +24,6 @@ const ParamIPAdapterModelSelect = () => {
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
const firstModel = useMemo(() => {
if (!ipAdapterModels || !Object.keys(ipAdapterModels.entities).length) {
return undefined;
}
const firstModelId = Object.keys(ipAdapterModels.entities)[0];
if (!firstModelId) {
return undefined;
}
const firstModel = ipAdapterModels.entities[firstModelId];
return firstModel ? firstModel : undefined;
}, [ipAdapterModels]);
useEffect(() => {
if (firstModel) {
dispatch(ipAdapterModelChanged(firstModel));
}
}, [firstModel, dispatch]);
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
@ -109,12 +88,16 @@ const ParamIPAdapterModelSelect = () => {
className="nowheel nodrag"
tooltip={selectedModel?.description}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
placeholder={
data.length > 0
? t('models.selectModel')
: t('models.noModelsAvailable')
}
error={!selectedModel && data.length > 0}
data={data}
onChange={handleValueChanged}
sx={{ width: '100%' }}
disabled={!isEnabled}
disabled={!isEnabled || data.length === 0}
/>
);
};