mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): single getModelConfigs query
Single query, with simple wrapper hooks (type-safe). Updated everywhere in frontend.
This commit is contained in:
parent
ed20255abf
commit
19d66d5ec7
@ -1,10 +1,10 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import type { JSONObject } from 'common/types';
|
||||
import {
|
||||
controlAdapterModelCleared,
|
||||
selectAllControlNets,
|
||||
selectAllIPAdapters,
|
||||
selectAllT2IAdapters,
|
||||
selectControlAdapterAll,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
||||
@ -12,212 +12,162 @@ import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features
|
||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
|
||||
import type { TypeGuardFor } from 'services/api/types';
|
||||
import { forEach } from 'lodash-es';
|
||||
import type { Logger } from 'roarr';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig, isRefinerMainModelModelConfig, isVAEModelConfig } from 'services/api/types';
|
||||
|
||||
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
|
||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
|
||||
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
|
||||
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
||||
const models = modelConfigsAdapterSelectors.selectAll(action.payload);
|
||||
|
||||
if (models.length === 0) {
|
||||
// No models loaded at all
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
||||
|
||||
if (isCurrentModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const defaultModel = state.config.sd.defaultModel;
|
||||
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
|
||||
|
||||
if (defaultModelInList) {
|
||||
const result = zParameterModel.safeParse(defaultModelInList);
|
||||
if (result.success) {
|
||||
dispatch(modelChanged(defaultModelInList, currentModel));
|
||||
|
||||
const optimalDimension = getOptimalDimension(defaultModelInList);
|
||||
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
|
||||
return;
|
||||
}
|
||||
const { width, height } = calculateNewSize(
|
||||
state.generation.aspectRatio.value,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
|
||||
dispatch(widthChanged(width));
|
||||
dispatch(heightChanged(height));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const result = zParameterModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
log.error({ error: result.error.format() }, 'Failed to parse main model');
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(modelChanged(result.data, currentModel));
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
|
||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) && action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `SDXL Refiner models loaded (${action.payload.ids.length})`);
|
||||
|
||||
const currentModel = getState().sdxl.refinerModel;
|
||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
||||
|
||||
if (models.length === 0) {
|
||||
// No models loaded at all
|
||||
dispatch(refinerModelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
||||
|
||||
if (!isCurrentModelAvailable) {
|
||||
dispatch(refinerModelChanged(null));
|
||||
return;
|
||||
}
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// VAEs loaded, need to reset the VAE is it's no longer available
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `VAEs loaded (${action.payload.ids.length})`);
|
||||
|
||||
const currentVae = getState().generation.vae;
|
||||
|
||||
if (currentVae === null) {
|
||||
// null is a valid VAE! it means "use the default with the main model"
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key);
|
||||
|
||||
if (isCurrentVAEAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = vaeModelsAdapterSelectors.selectAll(action.payload)[0];
|
||||
|
||||
if (!firstModel) {
|
||||
// No custom VAEs loaded at all; use the default
|
||||
dispatch(vaeSelected(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zParameterVAEModel.safeParse(firstModel);
|
||||
|
||||
if (!result.success) {
|
||||
log.error({ error: result.error.format() }, 'Failed to parse VAE model');
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(vaeSelected(result.data));
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// LoRA models loaded - need to remove missing LoRAs from state
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `LoRAs loaded (${action.payload.ids.length})`);
|
||||
|
||||
const loras = getState().lora.loras;
|
||||
|
||||
forEach(loras, (lora, id) => {
|
||||
const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.model.key);
|
||||
|
||||
if (isLoRAAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(loraRemoved(id));
|
||||
});
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`);
|
||||
|
||||
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
|
||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
||||
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
||||
});
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`);
|
||||
|
||||
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
|
||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
||||
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
||||
});
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getIPAdapterModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`);
|
||||
|
||||
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
|
||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
||||
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
||||
});
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled,
|
||||
effect: async (action) => {
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `Embeddings loaded (${action.payload.ids.length})`);
|
||||
handleMainModels(models, state, dispatch, log);
|
||||
handleRefinerModels(models, state, dispatch, log);
|
||||
handleVAEModels(models, state, dispatch, log);
|
||||
handleLoRAModels(models, state, dispatch, log);
|
||||
handleControlAdapterModels(models, state, dispatch, log);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
type ModelHandler = (
|
||||
models: AnyModelConfig[],
|
||||
state: RootState,
|
||||
dispatch: AppDispatch,
|
||||
log: Logger<JSONObject>
|
||||
) => undefined;
|
||||
|
||||
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const currentModel = state.generation.model;
|
||||
const mainModels = models.filter(isNonRefinerMainModelConfig);
|
||||
if (mainModels.length === 0) {
|
||||
// No models loaded at all
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentMainModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
||||
|
||||
if (isCurrentMainModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const defaultModel = state.config.sd.defaultModel;
|
||||
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
|
||||
|
||||
if (defaultModelInList) {
|
||||
const result = zParameterModel.safeParse(defaultModelInList);
|
||||
if (result.success) {
|
||||
dispatch(modelChanged(defaultModelInList, currentModel));
|
||||
|
||||
const optimalDimension = getOptimalDimension(defaultModelInList);
|
||||
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
|
||||
return;
|
||||
}
|
||||
const { width, height } = calculateNewSize(
|
||||
state.generation.aspectRatio.value,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
|
||||
dispatch(widthChanged(width));
|
||||
dispatch(heightChanged(height));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const result = zParameterModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
log.error({ error: result.error.format() }, 'Failed to parse main model');
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(modelChanged(result.data, currentModel));
|
||||
};
|
||||
|
||||
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const currentRefinerModel = state.sdxl.refinerModel;
|
||||
const refinerModels = models.filter(isRefinerMainModelModelConfig);
|
||||
if (models.length === 0) {
|
||||
// No models loaded at all
|
||||
dispatch(refinerModelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentRefinerModelAvailable = currentRefinerModel
|
||||
? refinerModels.some((m) => m.key === currentRefinerModel.key)
|
||||
: false;
|
||||
|
||||
if (!isCurrentRefinerModelAvailable) {
|
||||
dispatch(refinerModelChanged(null));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const currentVae = state.generation.vae;
|
||||
|
||||
if (currentVae === null) {
|
||||
// null is a valid VAE! it means "use the default with the main model"
|
||||
return;
|
||||
}
|
||||
const vaeModels = models.filter(isVAEModelConfig);
|
||||
|
||||
const isCurrentVAEAvailable = vaeModels.some((m) => m.key === currentVae.key);
|
||||
|
||||
if (isCurrentVAEAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = vaeModels[0];
|
||||
|
||||
if (!firstModel) {
|
||||
// No custom VAEs loaded at all; use the default
|
||||
dispatch(vaeSelected(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zParameterVAEModel.safeParse(firstModel);
|
||||
|
||||
if (!result.success) {
|
||||
log.error({ error: result.error.format() }, 'Failed to parse VAE model');
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(vaeSelected(result.data));
|
||||
};
|
||||
|
||||
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const loras = state.lora.loras;
|
||||
|
||||
forEach(loras, (lora, id) => {
|
||||
const isLoRAAvailable = models.some((m) => m.key === lora.model.key);
|
||||
|
||||
if (isLoRAAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(loraRemoved(id));
|
||||
});
|
||||
};
|
||||
|
||||
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
|
||||
const isModelAvailable = models.some((m) => m.key === ca.model?.key);
|
||||
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
||||
});
|
||||
};
|
||||
|
@ -23,8 +23,7 @@ import {
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||
@ -39,7 +38,12 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
||||
return;
|
||||
}
|
||||
|
||||
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
|
||||
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
|
||||
const data = await request.unwrap();
|
||||
request.unsubscribe();
|
||||
const models = modelConfigsAdapterSelectors.selectAll(data);
|
||||
|
||||
const modelConfig = models.find((model) => model.key === currentModel.key);
|
||||
|
||||
if (!modelConfig) {
|
||||
return;
|
||||
@ -55,11 +59,8 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
||||
if (vae === 'default') {
|
||||
dispatch(vaeSelected(null));
|
||||
} else {
|
||||
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
|
||||
const vaeArray = map(data?.entities);
|
||||
const validVae = vaeArray.find((model) => model.key === vae);
|
||||
|
||||
const result = zParameterVAEModel.safeParse(validVae);
|
||||
const vaeModel = models.find((model) => model.key === vae);
|
||||
const result = zParameterVAEModel.safeParse(vaeModel);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
|
@ -30,11 +30,10 @@ export const addSocketConnectedEventListener = (startAppListening: AppStartListe
|
||||
|
||||
// Bail on the recovery logic if this is the first connection - we don't need to recover anything
|
||||
if ($isFirstConnection.get()) {
|
||||
// The TI models are used in a component that is not always rendered, so when users open the prompt triggers
|
||||
// box has a delay while it does the initial fetch. We need to both pre-fetch the data and maintain an RTK
|
||||
// Query subscription to it, so the cache doesn't clear itself when the user closes the prompt triggers box.
|
||||
// So, we explicitly do not unsubscribe from this query!
|
||||
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
|
||||
// Populate the model configs on first connection. This query cache has a 24hr timeout, so we can immediately
|
||||
// unsubscribe.
|
||||
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
|
||||
request.unsubscribe();
|
||||
|
||||
$isFirstConnection.set(false);
|
||||
return;
|
||||
|
@ -1,15 +1,14 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase } from 'chakra-react-select';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { groupBy, map, reduce } from 'lodash-es';
|
||||
import { groupBy, reduce } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
modelConfigs: T[];
|
||||
selectedModel?: ModelIdentifierField | null;
|
||||
onChange: (value: T | null) => void;
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
@ -29,13 +28,12 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
): UseGroupedModelComboboxReturn => {
|
||||
const { t } = useTranslation();
|
||||
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
|
||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
||||
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
||||
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
||||
if (!modelEntities) {
|
||||
if (!modelConfigs) {
|
||||
return [];
|
||||
}
|
||||
const modelEntitiesArray = map(modelEntities.entities);
|
||||
const groupedModels = groupBy(modelEntitiesArray, 'base');
|
||||
const groupedModels = groupBy(modelConfigs, 'base');
|
||||
const _options = reduce(
|
||||
groupedModels,
|
||||
(acc, val, label) => {
|
||||
@ -53,7 +51,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
);
|
||||
_options.sort((a) => (a.label === base_model ? -1 : 1));
|
||||
return _options;
|
||||
}, [getIsDisabled, modelEntities, base_model]);
|
||||
}, [getIsDisabled, modelConfigs, base_model]);
|
||||
|
||||
const value = useMemo(
|
||||
() =>
|
||||
@ -67,14 +65,14 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
const model = modelEntities?.entities[v.value];
|
||||
const model = modelConfigs.find((m) => m.key === v.value);
|
||||
if (!model) {
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
onChange(model);
|
||||
},
|
||||
[modelEntities?.entities, onChange]
|
||||
[modelConfigs, onChange]
|
||||
);
|
||||
|
||||
const placeholder = useMemo(() => {
|
||||
|
@ -1,13 +1,11 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
modelConfigs: T[];
|
||||
selectedModel?: ModelIdentifierField | null;
|
||||
onChange: (value: T | null) => void;
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
@ -25,19 +23,14 @@ type UseModelComboboxReturn = {
|
||||
|
||||
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
|
||||
const { t } = useTranslation();
|
||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
|
||||
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
|
||||
const options = useMemo<ComboboxOption[]>(() => {
|
||||
if (!modelEntities) {
|
||||
return [];
|
||||
}
|
||||
return map(modelEntities.entities)
|
||||
.filter(optionsFilter)
|
||||
.map((model) => ({
|
||||
label: model.name,
|
||||
value: model.key,
|
||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||
}));
|
||||
}, [optionsFilter, getIsDisabled, modelEntities]);
|
||||
return modelConfigs.filter(optionsFilter).map((model) => ({
|
||||
label: model.name,
|
||||
value: model.key,
|
||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||
}));
|
||||
}, [optionsFilter, getIsDisabled, modelConfigs]);
|
||||
|
||||
const value = useMemo(
|
||||
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
|
||||
@ -50,14 +43,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
const model = modelEntities?.entities[v.value];
|
||||
const model = modelConfigs.find((m) => m.key === v.value);
|
||||
if (!model) {
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
onChange(model);
|
||||
},
|
||||
[modelEntities?.entities, onChange]
|
||||
[modelConfigs, onChange]
|
||||
);
|
||||
|
||||
const placeholder = useMemo(() => {
|
||||
|
@ -1,15 +1,12 @@
|
||||
import type { Item } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
import { filter } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
||||
data: EntityState<T, string> | undefined;
|
||||
modelConfigs: T[];
|
||||
isLoading: boolean;
|
||||
selectedModel?: ModelIdentifierField | null;
|
||||
onChange: (value: T | null) => void;
|
||||
@ -28,7 +25,7 @@ const modelFilterDefault = () => true;
|
||||
const isModelDisabledDefault = () => false;
|
||||
|
||||
export const useModelCustomSelect = <T extends AnyModelConfig>({
|
||||
data,
|
||||
modelConfigs,
|
||||
isLoading,
|
||||
selectedModel,
|
||||
onChange,
|
||||
@ -39,30 +36,28 @@ export const useModelCustomSelect = <T extends AnyModelConfig>({
|
||||
|
||||
const items: Item[] = useMemo(
|
||||
() =>
|
||||
data
|
||||
? filter(data.entities, modelFilter).map<Item>((m) => ({
|
||||
label: m.name,
|
||||
value: m.key,
|
||||
description: m.description,
|
||||
group: MODEL_TYPE_SHORT_MAP[m.base],
|
||||
isDisabled: isModelDisabled(m),
|
||||
}))
|
||||
: EMPTY_ARRAY,
|
||||
[data, isModelDisabled, modelFilter]
|
||||
modelConfigs.filter(modelFilter).map<Item>((m) => ({
|
||||
label: m.name,
|
||||
value: m.key,
|
||||
description: m.description,
|
||||
group: MODEL_TYPE_SHORT_MAP[m.base],
|
||||
isDisabled: isModelDisabled(m),
|
||||
})),
|
||||
[modelConfigs, isModelDisabled, modelFilter]
|
||||
);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(item: Item | null) => {
|
||||
if (!item || !data) {
|
||||
if (!item || !modelConfigs) {
|
||||
return;
|
||||
}
|
||||
const model = data.entities[item.value];
|
||||
const model = modelConfigs.find((m) => m.key === item.value);
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
onChange(model);
|
||||
},
|
||||
[data, onChange]
|
||||
[modelConfigs, onChange]
|
||||
);
|
||||
|
||||
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);
|
||||
|
@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
||||
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@ -20,7 +20,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
||||
const [modelConfigs, { isLoading }] = useControlAdapterModels(controlAdapterType);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
||||
@ -43,7 +43,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
);
|
||||
|
||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||
data,
|
||||
modelConfigs,
|
||||
isLoading,
|
||||
selectedModel,
|
||||
onChange: _onChange,
|
||||
|
@ -1,17 +1,16 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import { useControlAdapterModels } from './useControlAdapterModels';
|
||||
|
||||
export const useAddControlAdapter = (type: ControlAdapterType) => {
|
||||
const baseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const models = useControlAdapterModels(type);
|
||||
const [models] = useControlAdapterModels(type);
|
||||
|
||||
const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
|
||||
// prefer to use a model that matches the base model
|
||||
|
@ -1,26 +0,0 @@
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
export const useControlAdapterModelQuery = (type: ControlAdapterType) => {
|
||||
const controlNetModelsQuery = useGetControlNetModelsQuery();
|
||||
const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery();
|
||||
const ipAdapterModelsQuery = useGetIPAdapterModelsQuery();
|
||||
|
||||
if (type === 'controlnet') {
|
||||
return controlNetModelsQuery;
|
||||
}
|
||||
if (type === 't2i_adapter') {
|
||||
return t2iAdapterModelsQuery;
|
||||
}
|
||||
if (type === 'ip_adapter') {
|
||||
return ipAdapterModelsQuery;
|
||||
}
|
||||
|
||||
// Assert that the end of the function is not reachable.
|
||||
const exhaustiveCheck: never = type;
|
||||
return exhaustiveCheck;
|
||||
};
|
@ -1,31 +1,10 @@
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import { useMemo } from 'react';
|
||||
import {
|
||||
controlNetModelsAdapterSelectors,
|
||||
ipAdapterModelsAdapterSelectors,
|
||||
t2iAdapterModelsAdapterSelectors,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { useControlNetModels, useIPAdapterModels, useT2IAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
|
||||
export const useControlAdapterModels = (type?: ControlAdapterType) => {
|
||||
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
|
||||
const controlNetModels = useMemo(
|
||||
() => (controlNetModelsData ? controlNetModelsAdapterSelectors.selectAll(controlNetModelsData) : []),
|
||||
[controlNetModelsData]
|
||||
);
|
||||
|
||||
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
|
||||
const t2iAdapterModels = useMemo(
|
||||
() => (t2iAdapterModelsData ? t2iAdapterModelsAdapterSelectors.selectAll(t2iAdapterModelsData) : []),
|
||||
[t2iAdapterModelsData]
|
||||
);
|
||||
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
|
||||
const ipAdapterModels = useMemo(
|
||||
() => (ipAdapterModelsData ? ipAdapterModelsAdapterSelectors.selectAll(ipAdapterModelsData) : []),
|
||||
[ipAdapterModelsData]
|
||||
);
|
||||
export const useControlAdapterModels = (type: ControlAdapterType) => {
|
||||
const controlNetModels = useControlNetModels();
|
||||
const t2iAdapterModels = useT2IAdapterModels();
|
||||
const ipAdapterModels = useIPAdapterModels();
|
||||
|
||||
if (type === 'controlnet') {
|
||||
return controlNetModels;
|
||||
@ -36,5 +15,8 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
|
||||
if (type === 'ip_adapter') {
|
||||
return ipAdapterModels;
|
||||
}
|
||||
return [];
|
||||
|
||||
// Assert that the end of the function is not reachable.
|
||||
const exhaustiveCheck: never = type;
|
||||
return exhaustiveCheck;
|
||||
};
|
||||
|
@ -7,14 +7,14 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
||||
|
||||
const LoRASelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetLoRAModelsQuery();
|
||||
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||
const { t } = useTranslation();
|
||||
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
@ -37,7 +37,7 @@ const LoRASelect = () => {
|
||||
);
|
||||
|
||||
const { options, onChange } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
modelConfigs,
|
||||
getIsDisabled,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
@ -1,13 +1,16 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig } from 'app/store/store';
|
||||
import type { ModelType } from 'services/api/types';
|
||||
|
||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
|
||||
|
||||
type ModelManagerState = {
|
||||
_version: 1;
|
||||
selectedModelKey: string | null;
|
||||
selectedModelMode: 'edit' | 'view';
|
||||
searchTerm: string;
|
||||
filteredModelType: string | null;
|
||||
filteredModelType: FilterableModelType | null;
|
||||
scanPath: string | undefined;
|
||||
};
|
||||
|
||||
@ -35,7 +38,7 @@ export const modelManagerV2Slice = createSlice({
|
||||
state.searchTerm = action.payload;
|
||||
},
|
||||
|
||||
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
|
||||
setFilteredModelType: (state, action: PayloadAction<FilterableModelType | null>) => {
|
||||
state.filteredModelType = action.payload;
|
||||
},
|
||||
setScanPath: (state, action: PayloadAction<string | undefined>) => {
|
||||
|
@ -1,122 +1,105 @@
|
||||
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import { memo, useMemo } from 'react';
|
||||
import {
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetMainModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
useControlNetModels,
|
||||
useEmbeddingModels,
|
||||
useIPAdapterModels,
|
||||
useLoRAModels,
|
||||
useMainModels,
|
||||
useT2IAdapterModels,
|
||||
useVAEModels,
|
||||
} from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, ModelType } from 'services/api/types';
|
||||
|
||||
import { ModelListWrapper } from './ModelListWrapper';
|
||||
|
||||
const ModelList = () => {
|
||||
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
||||
|
||||
const { filteredMainModels, isLoadingMainModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredMainModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingMainModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredLoraModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingLoraModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredTextualInversionModels, isLoadingTextualInversionModels } = useGetTextualInversionModelsQuery(
|
||||
undefined,
|
||||
{
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredTextualInversionModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingTextualInversionModels: isLoading,
|
||||
}),
|
||||
}
|
||||
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
|
||||
const filteredMainModels = useMemo(
|
||||
() => modelsFilter(mainModels, searchTerm, filteredModelType),
|
||||
[mainModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const { filteredControlnetModels, isLoadingControlnetModels } = useGetControlNetModelsQuery(undefined, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredControlnetModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingControlnetModels: isLoading,
|
||||
}),
|
||||
});
|
||||
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
|
||||
const filteredLoRAModels = useMemo(
|
||||
() => modelsFilter(loraModels, searchTerm, filteredModelType),
|
||||
[loraModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const { filteredT2iAdapterModels, isLoadingT2IAdapterModels } = useGetT2IAdapterModelsQuery(undefined, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredT2iAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingT2IAdapterModels: isLoading,
|
||||
}),
|
||||
});
|
||||
const [embeddingModels, { isLoading: isLoadingEmbeddingModels }] = useEmbeddingModels();
|
||||
const filteredEmbeddingModels = useMemo(
|
||||
() => modelsFilter(embeddingModels, searchTerm, filteredModelType),
|
||||
[embeddingModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const { filteredIpAdapterModels, isLoadingIpAdapterModels } = useGetIPAdapterModelsQuery(undefined, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredIpAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingIpAdapterModels: isLoading,
|
||||
}),
|
||||
});
|
||||
const [controlNetModels, { isLoading: isLoadingControlNetModels }] = useControlNetModels();
|
||||
const filteredControlNetModels = useMemo(
|
||||
() => modelsFilter(controlNetModels, searchTerm, filteredModelType),
|
||||
[controlNetModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const { filteredVaeModels, isLoadingVaeModels } = useGetVaeModelsQuery(undefined, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredVaeModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingVaeModels: isLoading,
|
||||
}),
|
||||
});
|
||||
const [t2iAdapterModels, { isLoading: isLoadingT2IAdapterModels }] = useT2IAdapterModels();
|
||||
const filteredT2IAdapterModels = useMemo(
|
||||
() => modelsFilter(t2iAdapterModels, searchTerm, filteredModelType),
|
||||
[t2iAdapterModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [ipAdapterModels, { isLoading: isLoadingIPAdapterModels }] = useIPAdapterModels();
|
||||
const filteredIPAdapterModels = useMemo(
|
||||
() => modelsFilter(ipAdapterModels, searchTerm, filteredModelType),
|
||||
[ipAdapterModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels();
|
||||
const filteredVAEModels = useMemo(
|
||||
() => modelsFilter(vaeModels, searchTerm, filteredModelType),
|
||||
[vaeModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
return (
|
||||
<ScrollableContent>
|
||||
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
||||
{/* Main Model List */}
|
||||
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main..." />}
|
||||
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main Models..." />}
|
||||
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
||||
<ModelListWrapper title="Main" modelList={filteredMainModels} key="main" />
|
||||
)}
|
||||
{/* LoRAs List */}
|
||||
{isLoadingLoraModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||
{!isLoadingLoraModels && filteredLoraModels.length > 0 && (
|
||||
<ModelListWrapper title="LoRAs" modelList={filteredLoraModels} key="loras" />
|
||||
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
|
||||
<ModelListWrapper title="LoRA" modelList={filteredLoRAModels} key="loras" />
|
||||
)}
|
||||
|
||||
{/* TI List */}
|
||||
{isLoadingTextualInversionModels && <FetchingModelsLoader loadingMessage="Loading Textual Inversions..." />}
|
||||
{!isLoadingTextualInversionModels && filteredTextualInversionModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Textual Inversions"
|
||||
modelList={filteredTextualInversionModels}
|
||||
key="textual-inversions"
|
||||
/>
|
||||
{isLoadingEmbeddingModels && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
|
||||
{!isLoadingEmbeddingModels && filteredEmbeddingModels.length > 0 && (
|
||||
<ModelListWrapper title="Embedding" modelList={filteredEmbeddingModels} key="textual-inversions" />
|
||||
)}
|
||||
|
||||
{/* VAE List */}
|
||||
{isLoadingVaeModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
|
||||
{!isLoadingVaeModels && filteredVaeModels.length > 0 && (
|
||||
<ModelListWrapper title="VAEs" modelList={filteredVaeModels} key="vae" />
|
||||
{isLoadingVAEModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
|
||||
{!isLoadingVAEModels && filteredVAEModels.length > 0 && (
|
||||
<ModelListWrapper title="VAE" modelList={filteredVAEModels} key="vae" />
|
||||
)}
|
||||
|
||||
{/* Controlnet List */}
|
||||
{isLoadingControlnetModels && <FetchingModelsLoader loadingMessage="Loading Controlnets..." />}
|
||||
{!isLoadingControlnetModels && filteredControlnetModels.length > 0 && (
|
||||
<ModelListWrapper title="Controlnets" modelList={filteredControlnetModels} key="controlnets" />
|
||||
{isLoadingControlNetModels && <FetchingModelsLoader loadingMessage="Loading ControlNets..." />}
|
||||
{!isLoadingControlNetModels && filteredControlNetModels.length > 0 && (
|
||||
<ModelListWrapper title="ControlNet" modelList={filteredControlNetModels} key="controlnets" />
|
||||
)}
|
||||
{/* IP Adapter List */}
|
||||
{isLoadingIpAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
|
||||
{!isLoadingIpAdapterModels && filteredIpAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title="IP Adapters" modelList={filteredIpAdapterModels} key="ip-adapters" />
|
||||
{isLoadingIPAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
|
||||
{!isLoadingIPAdapterModels && filteredIPAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title="IP Adapter" modelList={filteredIPAdapterModels} key="ip-adapters" />
|
||||
)}
|
||||
{/* T2I Adapters List */}
|
||||
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
|
||||
{!isLoadingT2IAdapterModels && filteredT2iAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title="T2I Adapters" modelList={filteredT2iAdapterModels} key="t2i-adapters" />
|
||||
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title="T2I Adapter" modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
||||
)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
@ -126,25 +109,16 @@ const ModelList = () => {
|
||||
export default memo(ModelList);
|
||||
|
||||
const modelsFilter = <T extends AnyModelConfig>(
|
||||
data: EntityState<T, string> | undefined,
|
||||
data: T[],
|
||||
nameFilter: string,
|
||||
filteredModelType: string | null
|
||||
filteredModelType: ModelType | null
|
||||
): T[] => {
|
||||
const filteredModels: T[] = [];
|
||||
|
||||
forEach(data?.entities, (model) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
return data.filter((model) => {
|
||||
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
||||
const matchesType = filteredModelType ? model.type === filteredModelType : true;
|
||||
|
||||
if (matchesFilter && matchesType) {
|
||||
filteredModels.push(model);
|
||||
}
|
||||
return matchesFilter && matchesType;
|
||||
});
|
||||
return filteredModels;
|
||||
};
|
||||
|
||||
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
|
||||
|
@ -1,11 +1,13 @@
|
||||
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoFilter } from 'react-icons/io5';
|
||||
import { PiFunnelBold } from 'react-icons/pi';
|
||||
import { objectKeys } from 'tsafe';
|
||||
|
||||
const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
||||
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = {
|
||||
main: 'Main',
|
||||
lora: 'LoRA',
|
||||
embedding: 'Textual Inversion',
|
||||
@ -13,7 +15,6 @@ const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
||||
vae: 'VAE',
|
||||
t2i_adapter: 'T2I Adapter',
|
||||
ip_adapter: 'IP Adapter',
|
||||
clip_vision: 'Clip Vision',
|
||||
};
|
||||
|
||||
export const ModelTypeFilter = () => {
|
||||
@ -22,7 +23,7 @@ export const ModelTypeFilter = () => {
|
||||
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
||||
|
||||
const selectModelType = useCallback(
|
||||
(option: string) => {
|
||||
(option: FilterableModelType) => {
|
||||
dispatch(setFilteredModelType(option));
|
||||
},
|
||||
[dispatch]
|
||||
@ -34,12 +35,12 @@ export const ModelTypeFilter = () => {
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton as={Button} size="sm" leftIcon={<IoFilter />}>
|
||||
<MenuButton as={Button} size="sm" leftIcon={<PiFunnelBold />}>
|
||||
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
|
||||
</MenuButton>
|
||||
<MenuList>
|
||||
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
|
||||
{Object.keys(MODEL_TYPE_LABELS).map((option) => (
|
||||
{objectKeys(MODEL_TYPE_LABELS).map((option) => (
|
||||
<MenuItem
|
||||
key={option}
|
||||
bg={filteredModelType === option ? 'base.700' : 'transparent'}
|
||||
|
@ -4,12 +4,12 @@ import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||
|
||||
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
|
||||
|
||||
@ -21,18 +21,16 @@ export function DefaultVae(props: UseControllerProps<MainModelDefaultSettingsFor
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => {
|
||||
const modelArray = map(data?.entities);
|
||||
const compatibleOptions = modelArray
|
||||
.filter((vae) => vae.base === modelData?.base)
|
||||
.map((vae) => ({ label: vae.name, value: vae.key }));
|
||||
const [vaeModels] = useVAEModels();
|
||||
const compatibleOptions = useMemo(() => {
|
||||
const compatibleOptions = vaeModels
|
||||
.filter((vae) => vae.base === modelData?.base)
|
||||
.map((vae) => ({ label: vae.name, value: vae.key }));
|
||||
|
||||
const defaultOption = { label: 'Default VAE', value: 'default' };
|
||||
const defaultOption = { label: 'Default VAE', value: 'default' };
|
||||
|
||||
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
|
||||
},
|
||||
});
|
||||
return [defaultOption, ...compatibleOptions];
|
||||
}, [modelData?.base, vaeModels]);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useControlNetModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ControlNetModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -14,7 +14,7 @@ type Props = FieldComponentProps<ControlNetModelFieldInputInstance, ControlNetMo
|
||||
const ControlNetModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetControlNetModelsQuery();
|
||||
const [modelConfigs, { isLoading }] = useControlNetModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: ControlNetModelConfig | null) => {
|
||||
@ -33,7 +33,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
|
||||
);
|
||||
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
import type { IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -14,7 +14,7 @@ const IPAdapterModelFieldInputComponent = (
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
|
||||
const [modelConfigs, { isLoading }] = useIPAdapterModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: IPAdapterModelConfig | null) => {
|
||||
@ -33,9 +33,10 @@ const IPAdapterModelFieldInputComponent = (
|
||||
);
|
||||
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelEntities: ipAdapterModels,
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -14,7 +14,7 @@ type Props = FieldComponentProps<LoRAModelFieldInputInstance, LoRAModelFieldInpu
|
||||
const LoRAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetLoRAModelsQuery();
|
||||
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||
const _onChange = useCallback(
|
||||
(value: LoRAModelConfig | null) => {
|
||||
if (!value) {
|
||||
@ -32,7 +32,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
|
||||
);
|
||||
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
|
@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { NON_SDXL_MAIN_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -16,7 +15,7 @@ type Props = FieldComponentProps<MainModelFieldInputInstance, MainModelFieldInpu
|
||||
const MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS);
|
||||
const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
@ -33,7 +32,7 @@ const MainModelFieldInputComponent = (props: Props) => {
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
|
@ -8,8 +8,7 @@ import type {
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useRefinerModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -19,7 +18,7 @@ type Props = FieldComponentProps<SDXLRefinerModelFieldInputInstance, SDXLRefiner
|
||||
const RefinerModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||
const [modelConfigs, { isLoading }] = useRefinerModels();
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
@ -36,7 +35,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
|
@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { SDXL_MAIN_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useSDXLModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -16,7 +15,7 @@ type Props = FieldComponentProps<SDXLMainModelFieldInputInstance, SDXLMainModelF
|
||||
const SDXLMainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS);
|
||||
const [modelConfigs, { isLoading }] = useSDXLModels();
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
@ -33,7 +32,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
import type { T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -15,7 +15,7 @@ const T2IAdapterModelFieldInputComponent = (
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
|
||||
const [modelConfigs, { isLoading }] = useT2IAdapterModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: T2IAdapterModelConfig | null) => {
|
||||
@ -34,9 +34,10 @@ const T2IAdapterModelFieldInputComponent = (
|
||||
);
|
||||
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelEntities: t2iAdapterModels,
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
|
@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
|
||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
@ -15,7 +15,7 @@ type Props = FieldComponentProps<VAEModelFieldInputInstance, VAEModelFieldInputT
|
||||
const VAEModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetVaeModelsQuery();
|
||||
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||
const _onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
if (!value) {
|
||||
@ -32,7 +32,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
|
@ -8,8 +8,7 @@ import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
||||
@ -18,7 +17,7 @@ const ParamMainModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const selectedModel = useAppSelector(selectModel);
|
||||
const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS);
|
||||
const [modelConfigs, { isLoading }] = useMainModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(model: MainModelConfig | null) => {
|
||||
@ -35,7 +34,7 @@ const ParamMainModelSelect = () => {
|
||||
);
|
||||
|
||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||
data,
|
||||
modelConfigs,
|
||||
isLoading,
|
||||
selectedModel,
|
||||
onChange: _onChange,
|
||||
|
@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||
@ -19,7 +19,7 @@ const ParamVAEModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { model, vae } = useAppSelector(selector);
|
||||
const { data, isLoading } = useGetVaeModelsQuery();
|
||||
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||
const getIsDisabled = useCallback(
|
||||
(vae: VAEModelConfig): boolean => {
|
||||
const isCompatible = model?.base === vae.base;
|
||||
@ -35,7 +35,7 @@ const ParamVAEModelSelect = () => {
|
||||
[dispatch]
|
||||
);
|
||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: vae,
|
||||
isLoading,
|
||||
|
@ -11,13 +11,8 @@ import { t } from 'i18next';
|
||||
import { flatten, map } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
loraModelsAdapterSelectors,
|
||||
textualInversionModelsAdapterSelectors,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetModelConfigQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import { useEmbeddingModels, useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
||||
@ -33,8 +28,8 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
|
||||
mainModel?.key ?? skipToken
|
||||
);
|
||||
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
|
||||
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
|
||||
const [loraModels, { isLoading: isLoadingLoRAs }] = useLoRAModels();
|
||||
const [tiModels, { isLoading: isLoadingTIs }] = useEmbeddingModels();
|
||||
|
||||
const _onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
@ -52,8 +47,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
const _options: GroupBase<ComboboxOption>[] = [];
|
||||
|
||||
if (tiModels) {
|
||||
const embeddingOptions = textualInversionModelsAdapterSelectors
|
||||
.selectAll(tiModels)
|
||||
const embeddingOptions = tiModels
|
||||
.filter((ti) => ti.base === mainModelConfig?.base)
|
||||
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
|
||||
|
||||
@ -66,8 +60,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
}
|
||||
|
||||
if (loraModels) {
|
||||
const triggerPhraseOptions = loraModelsAdapterSelectors
|
||||
.selectAll(loraModels)
|
||||
const triggerPhraseOptions = loraModels
|
||||
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
|
||||
.map((lora) => {
|
||||
if (lora.trigger_phrases) {
|
||||
|
@ -7,8 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useRefinerModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel);
|
||||
@ -19,7 +18,7 @@ const ParamSDXLRefinerModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const model = useAppSelector(selectModel);
|
||||
const { t } = useTranslation();
|
||||
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||
const [modelConfigs, { isLoading }] = useRefinerModels();
|
||||
const _onChange = useCallback(
|
||||
(model: MainModelConfig | null) => {
|
||||
if (!model) {
|
||||
@ -31,7 +30,7 @@ const ParamSDXLRefinerModelSelect = () => {
|
||||
[dispatch]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({
|
||||
modelEntities: data,
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: model,
|
||||
isLoading,
|
||||
|
@ -1,28 +1,11 @@
|
||||
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import queryString from 'query-string';
|
||||
import {
|
||||
ALL_BASE_MODELS,
|
||||
NON_REFINER_BASE_MODELS,
|
||||
NON_SDXL_MAIN_MODELS,
|
||||
REFINER_BASE_MODELS,
|
||||
SDXL_MAIN_MODELS,
|
||||
} from 'services/api/constants';
|
||||
import type { operations, paths } from 'services/api/schema';
|
||||
import type {
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ControlNetModelConfig,
|
||||
IPAdapterModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
T2IAdapterModelConfig,
|
||||
TextualInversionModelConfig,
|
||||
VAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import type { ApiTagDescription, tagTypes } from '..';
|
||||
import type { ApiTagDescription } from '..';
|
||||
import { api, buildV2Url, LIST_TAG } from '..';
|
||||
|
||||
export type UpdateModelArg = {
|
||||
@ -40,8 +23,9 @@ type UpdateModelImageResponse =
|
||||
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
||||
type GetModelConfigsResponse = NonNullable<
|
||||
paths['/api/v2/models/']['get']['responses']['200']['content']['application/json']
|
||||
>;
|
||||
|
||||
type DeleteModelArg = {
|
||||
key: string;
|
||||
@ -76,72 +60,11 @@ type GetHuggingFaceModelsResponse =
|
||||
|
||||
type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query'];
|
||||
|
||||
const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
||||
const modelConfigsAdapter = createEntityAdapter<AnyModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
|
||||
const anyModelConfigAdapter = createEntityAdapter<AnyModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
|
||||
const buildProvidesTags =
|
||||
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
|
||||
(result: EntityState<TEntity, string> | undefined) => {
|
||||
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: tagType,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
};
|
||||
|
||||
const buildTransformResponse =
|
||||
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
|
||||
(response: { models: T[] }) => {
|
||||
return adapter.setAll(adapter.getInitialState(), response.models);
|
||||
};
|
||||
export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the models router
|
||||
@ -162,9 +85,27 @@ export const modelsApi = api.injectEndpoints({
|
||||
};
|
||||
},
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
queryFulfilled.then(({ data }) => {
|
||||
upsertSingleModelConfig(data, dispatch);
|
||||
});
|
||||
try {
|
||||
const { data } = await queryFulfilled;
|
||||
|
||||
// Update the individual model query caches
|
||||
dispatch(modelsApi.util.upsertQueryData('getModelConfig', data.key, data));
|
||||
|
||||
const { base, name, type } = data;
|
||||
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, data));
|
||||
|
||||
// Update the list query cache
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getModelConfigs', undefined, (draft) => {
|
||||
modelConfigsAdapter.updateOne(draft, {
|
||||
id: data.key,
|
||||
changes: data,
|
||||
});
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
}),
|
||||
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
|
||||
@ -294,80 +235,27 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['ModelInstalls'],
|
||||
}),
|
||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||
query: (base_models) => {
|
||||
const params: ListModelsArg = {
|
||||
model_type: 'main',
|
||||
base_models,
|
||||
};
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
return buildModelsUrl(`?${query}`);
|
||||
getModelConfigs: build.query<EntityState<AnyModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl() }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }];
|
||||
if (result) {
|
||||
const modelTags = result.ids.map((id) => ({ type: 'ModelConfig', id }) as const);
|
||||
tags.push(...modelTags);
|
||||
}
|
||||
return tags;
|
||||
},
|
||||
keepUnusedDataFor: 60 * 60 * 1000 * 24, // 1 day (infinite)
|
||||
transformResponse: (response: GetModelConfigsResponse) => {
|
||||
return modelConfigsAdapter.setAll(modelConfigsAdapter.getInitialState(), response.models);
|
||||
},
|
||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
queryFulfilled.then(({ data }) => {
|
||||
upsertModelConfigs(data, dispatch);
|
||||
});
|
||||
},
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
||||
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
||||
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
queryFulfilled.then(({ data }) => {
|
||||
upsertModelConfigs(data, dispatch);
|
||||
});
|
||||
},
|
||||
}),
|
||||
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
|
||||
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
|
||||
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
queryFulfilled.then(({ data }) => {
|
||||
upsertModelConfigs(data, dispatch);
|
||||
});
|
||||
},
|
||||
}),
|
||||
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
|
||||
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
|
||||
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
queryFulfilled.then(({ data }) => {
|
||||
upsertModelConfigs(data, dispatch);
|
||||
});
|
||||
},
|
||||
}),
|
||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
|
||||
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
|
||||
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
queryFulfilled.then(({ data }) => {
|
||||
upsertModelConfigs(data, dispatch);
|
||||
});
|
||||
},
|
||||
}),
|
||||
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
|
||||
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
|
||||
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
queryFulfilled.then(({ data }) => {
|
||||
upsertModelConfigs(data, dispatch);
|
||||
});
|
||||
},
|
||||
}),
|
||||
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
|
||||
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
|
||||
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
queryFulfilled.then(({ data }) => {
|
||||
upsertModelConfigs(data, dispatch);
|
||||
modelConfigsAdapterSelectors.selectAll(data).forEach((modelConfig) => {
|
||||
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
|
||||
const { base, name, type } = modelConfig;
|
||||
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
|
||||
});
|
||||
});
|
||||
},
|
||||
}),
|
||||
@ -375,14 +263,8 @@ export const modelsApi = api.injectEndpoints({
|
||||
});
|
||||
|
||||
export const {
|
||||
useGetModelConfigsQuery,
|
||||
useGetModelConfigQuery,
|
||||
useGetMainModelsQuery,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
useDeleteModelsMutation,
|
||||
useDeleteModelImageMutation,
|
||||
useUpdateModelMutation,
|
||||
@ -396,127 +278,3 @@ export const {
|
||||
useCancelModelInstallMutation,
|
||||
usePruneCompletedModelInstallsMutation,
|
||||
} = modelsApi;
|
||||
|
||||
const upsertModelConfigs = (
|
||||
modelConfigs: EntityState<AnyModelConfig, string>,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
dispatch: ThunkDispatch<any, any, UnknownAction>
|
||||
) => {
|
||||
/**
|
||||
* Once a list of models of a specific type is received, fetching any of those models individually is a waste of a
|
||||
* network request. This function takes the received list of models and upserts them into the individual query caches
|
||||
* for each model type.
|
||||
*/
|
||||
|
||||
// Iterate over all the models and upsert them into the individual query caches for each model type.
|
||||
anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => {
|
||||
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
|
||||
const { base, name, type } = modelConfig;
|
||||
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
|
||||
});
|
||||
};
|
||||
|
||||
const upsertSingleModelConfig = (
|
||||
modelConfig: AnyModelConfig,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
dispatch: ThunkDispatch<any, any, UnknownAction>
|
||||
) => {
|
||||
/**
|
||||
* When a model is updated, the individual query caches for each model type need to be updated, as well as the list
|
||||
* query caches of models of that type.
|
||||
*/
|
||||
|
||||
// Update the individual model query caches.
|
||||
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
|
||||
const { base, name, type } = modelConfig;
|
||||
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
|
||||
|
||||
// Update the list query caches for each model type.
|
||||
if (modelConfig.type === 'main') {
|
||||
[ALL_BASE_MODELS, NON_REFINER_BASE_MODELS, SDXL_MAIN_MODELS, NON_SDXL_MAIN_MODELS, REFINER_BASE_MODELS].forEach(
|
||||
(queryArg) => {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getMainModels', queryArg, (draft) => {
|
||||
mainModelsAdapter.updateOne(draft, {
|
||||
id: modelConfig.key,
|
||||
changes: modelConfig,
|
||||
});
|
||||
})
|
||||
);
|
||||
}
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (modelConfig.type === 'controlnet') {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getControlNetModels', undefined, (draft) => {
|
||||
controlNetModelsAdapter.updateOne(draft, {
|
||||
id: modelConfig.key,
|
||||
changes: modelConfig,
|
||||
});
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (modelConfig.type === 'embedding') {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getTextualInversionModels', undefined, (draft) => {
|
||||
textualInversionModelsAdapter.updateOne(draft, {
|
||||
id: modelConfig.key,
|
||||
changes: modelConfig,
|
||||
});
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (modelConfig.type === 'ip_adapter') {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getIPAdapterModels', undefined, (draft) => {
|
||||
ipAdapterModelsAdapter.updateOne(draft, {
|
||||
id: modelConfig.key,
|
||||
changes: modelConfig,
|
||||
});
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (modelConfig.type === 'lora') {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getLoRAModels', undefined, (draft) => {
|
||||
loraModelsAdapter.updateOne(draft, {
|
||||
id: modelConfig.key,
|
||||
changes: modelConfig,
|
||||
});
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (modelConfig.type === 't2i_adapter') {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getT2IAdapterModels', undefined, (draft) => {
|
||||
t2iAdapterModelsAdapter.updateOne(draft, {
|
||||
id: modelConfig.key,
|
||||
changes: modelConfig,
|
||||
});
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (modelConfig.type === 'vae') {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getVaeModels', undefined, (draft) => {
|
||||
vaeModelsAdapter.updateOne(draft, {
|
||||
id: modelConfig.key,
|
||||
changes: modelConfig,
|
||||
});
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
42
invokeai/frontend/web/src/services/api/hooks/modelsByType.ts
Normal file
42
invokeai/frontend/web/src/services/api/hooks/modelsByType.ts
Normal file
@ -0,0 +1,42 @@
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useMemo } from 'react';
|
||||
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isControlNetModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isNonSDXLMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSDXLMainModelModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isTIModelConfig,
|
||||
isVAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
const buildModelsHook =
|
||||
<T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T) =>
|
||||
() => {
|
||||
const result = useGetModelConfigsQuery(undefined);
|
||||
const modelConfigs = useMemo(() => {
|
||||
if (!result.data) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard);
|
||||
}, [result]);
|
||||
|
||||
return [modelConfigs, result] as const;
|
||||
};
|
||||
|
||||
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
@ -1,12 +1,7 @@
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useRefinerModels } from 'services/api/hooks/modelsByType';
|
||||
|
||||
export const useIsRefinerAvailable = () => {
|
||||
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
||||
}),
|
||||
});
|
||||
const [refinerModels] = useRefinerModels();
|
||||
|
||||
return isRefinerAvailable;
|
||||
return Boolean(refinerModels.length);
|
||||
};
|
||||
|
@ -48,7 +48,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||
export type IPAdapterModelConfig = S['IPAdapterConfig'];
|
||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||
export type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||
type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||
@ -103,6 +103,18 @@ export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is
|
||||
return config.type === 'main' && config.base === 'sdxl-refiner';
|
||||
};
|
||||
|
||||
export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'sdxl';
|
||||
};
|
||||
|
||||
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
|
||||
};
|
||||
|
||||
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'embedding';
|
||||
};
|
||||
|
||||
export type ModelInstallJob = S['ModelInstallJob'];
|
||||
export type ModelInstallStatus = S['InstallStatus'];
|
||||
|
||||
@ -200,10 +212,3 @@ export type PostUploadAction =
|
||||
| CanvasInitialImageAction
|
||||
| ToastAction
|
||||
| AddToBatchAction;
|
||||
|
||||
type TypeGuard<T> = {
|
||||
(input: unknown): input is T;
|
||||
};
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
export type TypeGuardFor<T extends TypeGuard<any>> = T extends TypeGuard<infer U> ? U : never;
|
||||
|
Loading…
Reference in New Issue
Block a user