feat(ui): single getModelConfigs query

Single query, with simple wrapper hooks (type-safe). Updated everywhere in frontend.
This commit is contained in:
psychedelicious 2024-03-14 23:37:40 +11:00
parent ed20255abf
commit 19d66d5ec7
31 changed files with 447 additions and 790 deletions

View File

@ -1,10 +1,10 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import type { JSONObject } from 'common/types';
import {
controlAdapterModelCleared,
selectAllControlNets,
selectAllIPAdapters,
selectAllT2IAdapters,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
@ -12,212 +12,162 @@ import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
import type { TypeGuardFor } from 'services/api/types';
import { forEach } from 'lodash-es';
import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { isNonRefinerMainModelConfig, isRefinerMainModelModelConfig, isVAEModelConfig } from 'services/api/types';
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
const state = getState();
const currentModel = state.generation.model;
const models = mainModelsAdapterSelectors.selectAll(action.payload);
const models = modelConfigsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (isCurrentModelAvailable) {
return;
}
const defaultModel = state.config.sd.defaultModel;
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.generation.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(widthChanged(width));
dispatch(heightChanged(height));
return;
}
}
const result = zParameterModel.safeParse(models[0]);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse main model');
return;
}
dispatch(modelChanged(result.data, currentModel));
},
});
startAppListening({
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) && action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info({ models: action.payload.entities }, `SDXL Refiner models loaded (${action.payload.ids.length})`);
const currentModel = getState().sdxl.refinerModel;
const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) {
// No models loaded at all
dispatch(refinerModelChanged(null));
return;
}
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (!isCurrentModelAvailable) {
dispatch(refinerModelChanged(null));
return;
}
},
});
startAppListening({
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// VAEs loaded, need to reset the VAE is it's no longer available
const log = logger('models');
log.info({ models: action.payload.entities }, `VAEs loaded (${action.payload.ids.length})`);
const currentVae = getState().generation.vae;
if (currentVae === null) {
// null is a valid VAE! it means "use the default with the main model"
return;
}
const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key);
if (isCurrentVAEAvailable) {
return;
}
const firstModel = vaeModelsAdapterSelectors.selectAll(action.payload)[0];
if (!firstModel) {
// No custom VAEs loaded at all; use the default
dispatch(vaeSelected(null));
return;
}
const result = zParameterVAEModel.safeParse(firstModel);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse VAE model');
return;
}
dispatch(vaeSelected(result.data));
},
});
startAppListening({
matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// LoRA models loaded - need to remove missing LoRAs from state
const log = logger('models');
log.info({ models: action.payload.entities }, `LoRAs loaded (${action.payload.ids.length})`);
const loras = getState().lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.model.key);
if (isLoRAAvailable) {
return;
}
dispatch(loraRemoved(id));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`);
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`);
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getIPAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`);
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled,
effect: async (action) => {
const log = logger('models');
log.info({ models: action.payload.entities }, `Embeddings loaded (${action.payload.ids.length})`);
handleMainModels(models, state, dispatch, log);
handleRefinerModels(models, state, dispatch, log);
handleVAEModels(models, state, dispatch, log);
handleLoRAModels(models, state, dispatch, log);
handleControlAdapterModels(models, state, dispatch, log);
},
});
};
type ModelHandler = (
models: AnyModelConfig[],
state: RootState,
dispatch: AppDispatch,
log: Logger<JSONObject>
) => undefined;
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
const currentModel = state.generation.model;
const mainModels = models.filter(isNonRefinerMainModelConfig);
if (mainModels.length === 0) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
const isCurrentMainModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (isCurrentMainModelAvailable) {
return;
}
const defaultModel = state.config.sd.defaultModel;
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.generation.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(widthChanged(width));
dispatch(heightChanged(height));
return;
}
}
const result = zParameterModel.safeParse(models[0]);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse main model');
return;
}
dispatch(modelChanged(result.data, currentModel));
};
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
const currentRefinerModel = state.sdxl.refinerModel;
const refinerModels = models.filter(isRefinerMainModelModelConfig);
if (models.length === 0) {
// No models loaded at all
dispatch(refinerModelChanged(null));
return;
}
const isCurrentRefinerModelAvailable = currentRefinerModel
? refinerModels.some((m) => m.key === currentRefinerModel.key)
: false;
if (!isCurrentRefinerModelAvailable) {
dispatch(refinerModelChanged(null));
return;
}
};
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
const currentVae = state.generation.vae;
if (currentVae === null) {
// null is a valid VAE! it means "use the default with the main model"
return;
}
const vaeModels = models.filter(isVAEModelConfig);
const isCurrentVAEAvailable = vaeModels.some((m) => m.key === currentVae.key);
if (isCurrentVAEAvailable) {
return;
}
const firstModel = vaeModels[0];
if (!firstModel) {
// No custom VAEs loaded at all; use the default
dispatch(vaeSelected(null));
return;
}
const result = zParameterVAEModel.safeParse(firstModel);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse VAE model');
return;
}
dispatch(vaeSelected(result.data));
};
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
const loras = state.lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = models.some((m) => m.key === lora.model.key);
if (isLoRAAvailable) {
return;
}
dispatch(loraRemoved(id));
});
};
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
const isModelAvailable = models.some((m) => m.key === ca.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
};

View File

@ -23,8 +23,7 @@ import {
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { map } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
@ -39,7 +38,12 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
return;
}
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
const data = await request.unwrap();
request.unsubscribe();
const models = modelConfigsAdapterSelectors.selectAll(data);
const modelConfig = models.find((model) => model.key === currentModel.key);
if (!modelConfig) {
return;
@ -55,11 +59,8 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
if (vae === 'default') {
dispatch(vaeSelected(null));
} else {
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
const vaeArray = map(data?.entities);
const validVae = vaeArray.find((model) => model.key === vae);
const result = zParameterVAEModel.safeParse(validVae);
const vaeModel = models.find((model) => model.key === vae);
const result = zParameterVAEModel.safeParse(vaeModel);
if (!result.success) {
return;
}

View File

@ -30,11 +30,10 @@ export const addSocketConnectedEventListener = (startAppListening: AppStartListe
// Bail on the recovery logic if this is the first connection - we don't need to recover anything
if ($isFirstConnection.get()) {
// The TI models are used in a component that is not always rendered, so when users open the prompt triggers
// box has a delay while it does the initial fetch. We need to both pre-fetch the data and maintain an RTK
// Query subscription to it, so the cache doesn't clear itself when the user closes the prompt triggers box.
// So, we explicitly do not unsubscribe from this query!
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
// Populate the model configs on first connection. This query cache has a 24hr timeout, so we can immediately
// unsubscribe.
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
request.unsubscribe();
$isFirstConnection.set(false);
return;

View File

@ -1,15 +1,14 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { groupBy, map, reduce } from 'lodash-es';
import { groupBy, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
modelConfigs: T[];
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
@ -29,13 +28,12 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
): UseGroupedModelComboboxReturn => {
const { t } = useTranslation();
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg;
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
if (!modelEntities) {
if (!modelConfigs) {
return [];
}
const modelEntitiesArray = map(modelEntities.entities);
const groupedModels = groupBy(modelEntitiesArray, 'base');
const groupedModels = groupBy(modelConfigs, 'base');
const _options = reduce(
groupedModels,
(acc, val, label) => {
@ -53,7 +51,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
);
_options.sort((a) => (a.label === base_model ? -1 : 1));
return _options;
}, [getIsDisabled, modelEntities, base_model]);
}, [getIsDisabled, modelConfigs, base_model]);
const value = useMemo(
() =>
@ -67,14 +65,14 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
onChange(null);
return;
}
const model = modelEntities?.entities[v.value];
const model = modelConfigs.find((m) => m.key === v.value);
if (!model) {
onChange(null);
return;
}
onChange(model);
},
[modelEntities?.entities, onChange]
[modelConfigs, onChange]
);
const placeholder = useMemo(() => {

View File

@ -1,13 +1,11 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
modelConfigs: T[];
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
@ -25,19 +23,14 @@ type UseModelComboboxReturn = {
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
const { t } = useTranslation();
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
const options = useMemo<ComboboxOption[]>(() => {
if (!modelEntities) {
return [];
}
return map(modelEntities.entities)
.filter(optionsFilter)
.map((model) => ({
label: model.name,
value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
}));
}, [optionsFilter, getIsDisabled, modelEntities]);
return modelConfigs.filter(optionsFilter).map((model) => ({
label: model.name,
value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
}));
}, [optionsFilter, getIsDisabled, modelConfigs]);
const value = useMemo(
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
@ -50,14 +43,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
onChange(null);
return;
}
const model = modelEntities?.entities[v.value];
const model = modelConfigs.find((m) => m.key === v.value);
if (!model) {
onChange(null);
return;
}
onChange(model);
},
[modelEntities?.entities, onChange]
[modelConfigs, onChange]
);
const placeholder = useMemo(() => {

View File

@ -1,15 +1,12 @@
import type { Item } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { filter } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
data: EntityState<T, string> | undefined;
modelConfigs: T[];
isLoading: boolean;
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
@ -28,7 +25,7 @@ const modelFilterDefault = () => true;
const isModelDisabledDefault = () => false;
export const useModelCustomSelect = <T extends AnyModelConfig>({
data,
modelConfigs,
isLoading,
selectedModel,
onChange,
@ -39,30 +36,28 @@ export const useModelCustomSelect = <T extends AnyModelConfig>({
const items: Item[] = useMemo(
() =>
data
? filter(data.entities, modelFilter).map<Item>((m) => ({
label: m.name,
value: m.key,
description: m.description,
group: MODEL_TYPE_SHORT_MAP[m.base],
isDisabled: isModelDisabled(m),
}))
: EMPTY_ARRAY,
[data, isModelDisabled, modelFilter]
modelConfigs.filter(modelFilter).map<Item>((m) => ({
label: m.name,
value: m.key,
description: m.description,
group: MODEL_TYPE_SHORT_MAP[m.base],
isDisabled: isModelDisabled(m),
})),
[modelConfigs, isModelDisabled, modelFilter]
);
const _onChange = useCallback(
(item: Item | null) => {
if (!item || !data) {
if (!item || !modelConfigs) {
return;
}
const model = data.entities[item.value];
const model = modelConfigs.find((m) => m.key === item.value);
if (!model) {
return;
}
onChange(model);
},
[data, onChange]
[modelConfigs, onChange]
);
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);

View File

@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { memo, useCallback, useMemo } from 'react';
@ -20,7 +20,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
const [modelConfigs, { isLoading }] = useControlAdapterModels(controlAdapterType);
const _onChange = useCallback(
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
@ -43,7 +43,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
);
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
data,
modelConfigs,
isLoading,
selectedModel,
onChange: _onChange,

View File

@ -1,17 +1,16 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
import { useCallback, useMemo } from 'react';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { useControlAdapterModels } from './useControlAdapterModels';
export const useAddControlAdapter = (type: ControlAdapterType) => {
const baseModel = useAppSelector((s) => s.generation.model?.base);
const dispatch = useAppDispatch();
const models = useControlAdapterModels(type);
const [models] = useControlAdapterModels(type);
const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
// prefer to use a model that matches the base model

View File

@ -1,26 +0,0 @@
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import {
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
export const useControlAdapterModelQuery = (type: ControlAdapterType) => {
const controlNetModelsQuery = useGetControlNetModelsQuery();
const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery();
const ipAdapterModelsQuery = useGetIPAdapterModelsQuery();
if (type === 'controlnet') {
return controlNetModelsQuery;
}
if (type === 't2i_adapter') {
return t2iAdapterModelsQuery;
}
if (type === 'ip_adapter') {
return ipAdapterModelsQuery;
}
// Assert that the end of the function is not reachable.
const exhaustiveCheck: never = type;
return exhaustiveCheck;
};

View File

@ -1,31 +1,10 @@
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import { useMemo } from 'react';
import {
controlNetModelsAdapterSelectors,
ipAdapterModelsAdapterSelectors,
t2iAdapterModelsAdapterSelectors,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
import { useControlNetModels, useIPAdapterModels, useT2IAdapterModels } from 'services/api/hooks/modelsByType';
export const useControlAdapterModels = (type?: ControlAdapterType) => {
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
const controlNetModels = useMemo(
() => (controlNetModelsData ? controlNetModelsAdapterSelectors.selectAll(controlNetModelsData) : []),
[controlNetModelsData]
);
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
const t2iAdapterModels = useMemo(
() => (t2iAdapterModelsData ? t2iAdapterModelsAdapterSelectors.selectAll(t2iAdapterModelsData) : []),
[t2iAdapterModelsData]
);
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
const ipAdapterModels = useMemo(
() => (ipAdapterModelsData ? ipAdapterModelsAdapterSelectors.selectAll(ipAdapterModelsData) : []),
[ipAdapterModelsData]
);
export const useControlAdapterModels = (type: ControlAdapterType) => {
const controlNetModels = useControlNetModels();
const t2iAdapterModels = useT2IAdapterModels();
const ipAdapterModels = useIPAdapterModels();
if (type === 'controlnet') {
return controlNetModels;
@ -36,5 +15,8 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
if (type === 'ip_adapter') {
return ipAdapterModels;
}
return [];
// Assert that the end of the function is not reachable.
const exhaustiveCheck: never = type;
return exhaustiveCheck;
};

View File

@ -7,14 +7,14 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types';
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
const LoRASelect = () => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetLoRAModelsQuery();
const [modelConfigs, { isLoading }] = useLoRAModels();
const { t } = useTranslation();
const addedLoRAs = useAppSelector(selectAddedLoRAs);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
@ -37,7 +37,7 @@ const LoRASelect = () => {
);
const { options, onChange } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
getIsDisabled,
onChange: _onChange,
});

View File

@ -1,13 +1,16 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig } from 'app/store/store';
import type { ModelType } from 'services/api/types';
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
type ModelManagerState = {
_version: 1;
selectedModelKey: string | null;
selectedModelMode: 'edit' | 'view';
searchTerm: string;
filteredModelType: string | null;
filteredModelType: FilterableModelType | null;
scanPath: string | undefined;
};
@ -35,7 +38,7 @@ export const modelManagerV2Slice = createSlice({
state.searchTerm = action.payload;
},
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
setFilteredModelType: (state, action: PayloadAction<FilterableModelType | null>) => {
state.filteredModelType = action.payload;
},
setScanPath: (state, action: PayloadAction<string | undefined>) => {

View File

@ -1,122 +1,105 @@
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { forEach } from 'lodash-es';
import { memo } from 'react';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { memo, useMemo } from 'react';
import {
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetT2IAdapterModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
useControlNetModels,
useEmbeddingModels,
useIPAdapterModels,
useLoRAModels,
useMainModels,
useT2IAdapterModels,
useVAEModels,
} from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ModelType } from 'services/api/types';
import { ModelListWrapper } from './ModelListWrapper';
const ModelList = () => {
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
const { filteredMainModels, isLoadingMainModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data, isLoading }) => ({
filteredMainModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingMainModels: isLoading,
}),
});
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredLoraModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingLoraModels: isLoading,
}),
});
const { filteredTextualInversionModels, isLoadingTextualInversionModels } = useGetTextualInversionModelsQuery(
undefined,
{
selectFromResult: ({ data, isLoading }) => ({
filteredTextualInversionModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingTextualInversionModels: isLoading,
}),
}
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
const filteredMainModels = useMemo(
() => modelsFilter(mainModels, searchTerm, filteredModelType),
[mainModels, searchTerm, filteredModelType]
);
const { filteredControlnetModels, isLoadingControlnetModels } = useGetControlNetModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredControlnetModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingControlnetModels: isLoading,
}),
});
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
const filteredLoRAModels = useMemo(
() => modelsFilter(loraModels, searchTerm, filteredModelType),
[loraModels, searchTerm, filteredModelType]
);
const { filteredT2iAdapterModels, isLoadingT2IAdapterModels } = useGetT2IAdapterModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredT2iAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingT2IAdapterModels: isLoading,
}),
});
const [embeddingModels, { isLoading: isLoadingEmbeddingModels }] = useEmbeddingModels();
const filteredEmbeddingModels = useMemo(
() => modelsFilter(embeddingModels, searchTerm, filteredModelType),
[embeddingModels, searchTerm, filteredModelType]
);
const { filteredIpAdapterModels, isLoadingIpAdapterModels } = useGetIPAdapterModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredIpAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingIpAdapterModels: isLoading,
}),
});
const [controlNetModels, { isLoading: isLoadingControlNetModels }] = useControlNetModels();
const filteredControlNetModels = useMemo(
() => modelsFilter(controlNetModels, searchTerm, filteredModelType),
[controlNetModels, searchTerm, filteredModelType]
);
const { filteredVaeModels, isLoadingVaeModels } = useGetVaeModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredVaeModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingVaeModels: isLoading,
}),
});
const [t2iAdapterModels, { isLoading: isLoadingT2IAdapterModels }] = useT2IAdapterModels();
const filteredT2IAdapterModels = useMemo(
() => modelsFilter(t2iAdapterModels, searchTerm, filteredModelType),
[t2iAdapterModels, searchTerm, filteredModelType]
);
const [ipAdapterModels, { isLoading: isLoadingIPAdapterModels }] = useIPAdapterModels();
const filteredIPAdapterModels = useMemo(
() => modelsFilter(ipAdapterModels, searchTerm, filteredModelType),
[ipAdapterModels, searchTerm, filteredModelType]
);
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels();
const filteredVAEModels = useMemo(
() => modelsFilter(vaeModels, searchTerm, filteredModelType),
[vaeModels, searchTerm, filteredModelType]
);
return (
<ScrollableContent>
<Flex flexDirection="column" w="full" h="full" gap={4}>
{/* Main Model List */}
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main..." />}
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main Models..." />}
{!isLoadingMainModels && filteredMainModels.length > 0 && (
<ModelListWrapper title="Main" modelList={filteredMainModels} key="main" />
)}
{/* LoRAs List */}
{isLoadingLoraModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
{!isLoadingLoraModels && filteredLoraModels.length > 0 && (
<ModelListWrapper title="LoRAs" modelList={filteredLoraModels} key="loras" />
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
<ModelListWrapper title="LoRA" modelList={filteredLoRAModels} key="loras" />
)}
{/* TI List */}
{isLoadingTextualInversionModels && <FetchingModelsLoader loadingMessage="Loading Textual Inversions..." />}
{!isLoadingTextualInversionModels && filteredTextualInversionModels.length > 0 && (
<ModelListWrapper
title="Textual Inversions"
modelList={filteredTextualInversionModels}
key="textual-inversions"
/>
{isLoadingEmbeddingModels && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
{!isLoadingEmbeddingModels && filteredEmbeddingModels.length > 0 && (
<ModelListWrapper title="Embedding" modelList={filteredEmbeddingModels} key="textual-inversions" />
)}
{/* VAE List */}
{isLoadingVaeModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
{!isLoadingVaeModels && filteredVaeModels.length > 0 && (
<ModelListWrapper title="VAEs" modelList={filteredVaeModels} key="vae" />
{isLoadingVAEModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
{!isLoadingVAEModels && filteredVAEModels.length > 0 && (
<ModelListWrapper title="VAE" modelList={filteredVAEModels} key="vae" />
)}
{/* Controlnet List */}
{isLoadingControlnetModels && <FetchingModelsLoader loadingMessage="Loading Controlnets..." />}
{!isLoadingControlnetModels && filteredControlnetModels.length > 0 && (
<ModelListWrapper title="Controlnets" modelList={filteredControlnetModels} key="controlnets" />
{isLoadingControlNetModels && <FetchingModelsLoader loadingMessage="Loading ControlNets..." />}
{!isLoadingControlNetModels && filteredControlNetModels.length > 0 && (
<ModelListWrapper title="ControlNet" modelList={filteredControlNetModels} key="controlnets" />
)}
{/* IP Adapter List */}
{isLoadingIpAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
{!isLoadingIpAdapterModels && filteredIpAdapterModels.length > 0 && (
<ModelListWrapper title="IP Adapters" modelList={filteredIpAdapterModels} key="ip-adapters" />
{isLoadingIPAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
{!isLoadingIPAdapterModels && filteredIPAdapterModels.length > 0 && (
<ModelListWrapper title="IP Adapter" modelList={filteredIPAdapterModels} key="ip-adapters" />
)}
{/* T2I Adapters List */}
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
{!isLoadingT2IAdapterModels && filteredT2iAdapterModels.length > 0 && (
<ModelListWrapper title="T2I Adapters" modelList={filteredT2iAdapterModels} key="t2i-adapters" />
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
<ModelListWrapper title="T2I Adapter" modelList={filteredT2IAdapterModels} key="t2i-adapters" />
)}
</Flex>
</ScrollableContent>
@ -126,25 +109,16 @@ const ModelList = () => {
export default memo(ModelList);
const modelsFilter = <T extends AnyModelConfig>(
data: EntityState<T, string> | undefined,
data: T[],
nameFilter: string,
filteredModelType: string | null
filteredModelType: ModelType | null
): T[] => {
const filteredModels: T[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
return data.filter((model) => {
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
const matchesType = filteredModelType ? model.type === filteredModelType : true;
if (matchesFilter && matchesType) {
filteredModels.push(model);
}
return matchesFilter && matchesType;
});
return filteredModels;
};
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {

View File

@ -1,11 +1,13 @@
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { IoFilter } from 'react-icons/io5';
import { PiFunnelBold } from 'react-icons/pi';
import { objectKeys } from 'tsafe';
const MODEL_TYPE_LABELS: { [key: string]: string } = {
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = {
main: 'Main',
lora: 'LoRA',
embedding: 'Textual Inversion',
@ -13,7 +15,6 @@ const MODEL_TYPE_LABELS: { [key: string]: string } = {
vae: 'VAE',
t2i_adapter: 'T2I Adapter',
ip_adapter: 'IP Adapter',
clip_vision: 'Clip Vision',
};
export const ModelTypeFilter = () => {
@ -22,7 +23,7 @@ export const ModelTypeFilter = () => {
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
const selectModelType = useCallback(
(option: string) => {
(option: FilterableModelType) => {
dispatch(setFilteredModelType(option));
},
[dispatch]
@ -34,12 +35,12 @@ export const ModelTypeFilter = () => {
return (
<Menu>
<MenuButton as={Button} size="sm" leftIcon={<IoFilter />}>
<MenuButton as={Button} size="sm" leftIcon={<PiFunnelBold />}>
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
</MenuButton>
<MenuList>
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
{Object.keys(MODEL_TYPE_LABELS).map((option) => (
{objectKeys(MODEL_TYPE_LABELS).map((option) => (
<MenuItem
key={option}
bg={filteredModelType === option ? 'base.700' : 'transparent'}

View File

@ -4,12 +4,12 @@ import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
@ -21,18 +21,16 @@ export function DefaultVae(props: UseControllerProps<MainModelDefaultSettingsFor
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
selectFromResult: ({ data }) => {
const modelArray = map(data?.entities);
const compatibleOptions = modelArray
.filter((vae) => vae.base === modelData?.base)
.map((vae) => ({ label: vae.name, value: vae.key }));
const [vaeModels] = useVAEModels();
const compatibleOptions = useMemo(() => {
const compatibleOptions = vaeModels
.filter((vae) => vae.base === modelData?.base)
.map((vae) => ({ label: vae.name, value: vae.key }));
const defaultOption = { label: 'Default VAE', value: 'default' };
const defaultOption = { label: 'Default VAE', value: 'default' };
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
},
});
return [defaultOption, ...compatibleOptions];
}, [modelData?.base, vaeModels]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { useControlNetModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -14,7 +14,7 @@ type Props = FieldComponentProps<ControlNetModelFieldInputInstance, ControlNetMo
const ControlNetModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetControlNetModelsQuery();
const [modelConfigs, { isLoading }] = useControlNetModels();
const _onChange = useCallback(
(value: ControlNetModelConfig | null) => {
@ -33,7 +33,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
import type { IPAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -14,7 +14,7 @@ const IPAdapterModelFieldInputComponent = (
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
const [modelConfigs, { isLoading }] = useIPAdapterModels();
const _onChange = useCallback(
(value: IPAdapterModelConfig | null) => {
@ -33,9 +33,10 @@ const IPAdapterModelFieldInputComponent = (
);
const { options, value, onChange } = useGroupedModelCombobox({
modelEntities: ipAdapterModels,
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -14,7 +14,7 @@ type Props = FieldComponentProps<LoRAModelFieldInputInstance, LoRAModelFieldInpu
const LoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetLoRAModelsQuery();
const [modelConfigs, { isLoading }] = useLoRAModels();
const _onChange = useCallback(
(value: LoRAModelConfig | null) => {
if (!value) {
@ -32,7 +32,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,

View File

@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { NON_SDXL_MAIN_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -16,7 +15,7 @@ type Props = FieldComponentProps<MainModelFieldInputInstance, MainModelFieldInpu
const MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS);
const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
@ -33,7 +32,7 @@ const MainModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,

View File

@ -8,8 +8,7 @@ import type {
SDXLRefinerModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useRefinerModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -19,7 +18,7 @@ type Props = FieldComponentProps<SDXLRefinerModelFieldInputInstance, SDXLRefiner
const RefinerModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
const [modelConfigs, { isLoading }] = useRefinerModels();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
@ -36,7 +35,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,

View File

@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { SDXL_MAIN_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useSDXLModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -16,7 +15,7 @@ type Props = FieldComponentProps<SDXLMainModelFieldInputInstance, SDXLMainModelF
const SDXLMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS);
const [modelConfigs, { isLoading }] = useSDXLModels();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
@ -33,7 +32,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
import type { T2IAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -15,7 +15,7 @@ const T2IAdapterModelFieldInputComponent = (
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
const [modelConfigs, { isLoading }] = useT2IAdapterModels();
const _onChange = useCallback(
(value: T2IAdapterModelConfig | null) => {
@ -34,9 +34,10 @@ const T2IAdapterModelFieldInputComponent = (
);
const { options, value, onChange } = useGroupedModelCombobox({
modelEntities: t2iAdapterModels,
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (

View File

@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -15,7 +15,7 @@ type Props = FieldComponentProps<VAEModelFieldInputInstance, VAEModelFieldInputT
const VAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetVaeModelsQuery();
const [modelConfigs, { isLoading }] = useVAEModels();
const _onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
@ -32,7 +32,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,

View File

@ -8,8 +8,7 @@ import { modelSelected } from 'features/parameters/store/actions';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useMainModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
@ -18,7 +17,7 @@ const ParamMainModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const selectedModel = useAppSelector(selectModel);
const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS);
const [modelConfigs, { isLoading }] = useMainModels();
const _onChange = useCallback(
(model: MainModelConfig | null) => {
@ -35,7 +34,7 @@ const ParamMainModelSelect = () => {
);
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
data,
modelConfigs,
isLoading,
selectedModel,
onChange: _onChange,

View File

@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
@ -19,7 +19,7 @@ const ParamVAEModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { model, vae } = useAppSelector(selector);
const { data, isLoading } = useGetVaeModelsQuery();
const [modelConfigs, { isLoading }] = useVAEModels();
const getIsDisabled = useCallback(
(vae: VAEModelConfig): boolean => {
const isCompatible = model?.base === vae.base;
@ -35,7 +35,7 @@ const ParamVAEModelSelect = () => {
[dispatch]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
selectedModel: vae,
isLoading,

View File

@ -11,13 +11,8 @@ import { t } from 'i18next';
import { flatten, map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
loraModelsAdapterSelectors,
textualInversionModelsAdapterSelectors,
useGetLoRAModelsQuery,
useGetModelConfigQuery,
useGetTextualInversionModelsQuery,
} from 'services/api/endpoints/models';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { useEmbeddingModels, useLoRAModels } from 'services/api/hooks/modelsByType';
import { isNonRefinerMainModelConfig } from 'services/api/types';
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
@ -33,8 +28,8 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
mainModel?.key ?? skipToken
);
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
const [loraModels, { isLoading: isLoadingLoRAs }] = useLoRAModels();
const [tiModels, { isLoading: isLoadingTIs }] = useEmbeddingModels();
const _onChange = useCallback<ComboboxOnChange>(
(v) => {
@ -52,8 +47,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
const _options: GroupBase<ComboboxOption>[] = [];
if (tiModels) {
const embeddingOptions = textualInversionModelsAdapterSelectors
.selectAll(tiModels)
const embeddingOptions = tiModels
.filter((ti) => ti.base === mainModelConfig?.base)
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
@ -66,8 +60,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
}
if (loraModels) {
const triggerPhraseOptions = loraModelsAdapterSelectors
.selectAll(loraModels)
const triggerPhraseOptions = loraModels
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
.map((lora) => {
if (lora.trigger_phrases) {

View File

@ -7,8 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useRefinerModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel);
@ -19,7 +18,7 @@ const ParamSDXLRefinerModelSelect = () => {
const dispatch = useAppDispatch();
const model = useAppSelector(selectModel);
const { t } = useTranslation();
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
const [modelConfigs, { isLoading }] = useRefinerModels();
const _onChange = useCallback(
(model: MainModelConfig | null) => {
if (!model) {
@ -31,7 +30,7 @@ const ParamSDXLRefinerModelSelect = () => {
[dispatch]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
selectedModel: model,
isLoading,

View File

@ -1,28 +1,11 @@
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import type { EntityState } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import queryString from 'query-string';
import {
ALL_BASE_MODELS,
NON_REFINER_BASE_MODELS,
NON_SDXL_MAIN_MODELS,
REFINER_BASE_MODELS,
SDXL_MAIN_MODELS,
} from 'services/api/constants';
import type { operations, paths } from 'services/api/schema';
import type {
AnyModelConfig,
BaseModelType,
ControlNetModelConfig,
IPAdapterModelConfig,
LoRAModelConfig,
MainModelConfig,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VAEModelConfig,
} from 'services/api/types';
import type { AnyModelConfig } from 'services/api/types';
import type { ApiTagDescription, tagTypes } from '..';
import type { ApiTagDescription } from '..';
import { api, buildV2Url, LIST_TAG } from '..';
export type UpdateModelArg = {
@ -40,8 +23,9 @@ type UpdateModelImageResponse =
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
type GetModelConfigsResponse = NonNullable<
paths['/api/v2/models/']['get']['responses']['200']['content']['application/json']
>;
type DeleteModelArg = {
key: string;
@ -76,72 +60,11 @@ type GetHuggingFaceModelsResponse =
type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
const modelConfigsAdapter = createEntityAdapter<AnyModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const anyModelConfigAdapter = createEntityAdapter<AnyModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions);
const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: tagType,
id,
}))
);
}
return tags;
};
const buildTransformResponse =
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
(response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models);
};
export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions);
/**
* Builds an endpoint URL for the models router
@ -162,9 +85,27 @@ export const modelsApi = api.injectEndpoints({
};
},
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertSingleModelConfig(data, dispatch);
});
try {
const { data } = await queryFulfilled;
// Update the individual model query caches
dispatch(modelsApi.util.upsertQueryData('getModelConfig', data.key, data));
const { base, name, type } = data;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, data));
// Update the list query cache
dispatch(
modelsApi.util.updateQueryData('getModelConfigs', undefined, (draft) => {
modelConfigsAdapter.updateOne(draft, {
id: data.key,
changes: data,
});
})
);
} catch {
// no-op
}
},
}),
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
@ -294,80 +235,27 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['ModelInstalls'],
}),
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
query: (base_models) => {
const params: ListModelsArg = {
model_type: 'main',
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
return buildModelsUrl(`?${query}`);
getModelConfigs: build.query<EntityState<AnyModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl() }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }];
if (result) {
const modelTags = result.ids.map((id) => ({ type: 'ModelConfig', id }) as const);
tags.push(...modelTags);
}
return tags;
},
keepUnusedDataFor: 60 * 60 * 1000 * 24, // 1 day (infinite)
transformResponse: (response: GetModelConfigsResponse) => {
return modelConfigsAdapter.setAll(modelConfigsAdapter.getInitialState(), response.models);
},
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
modelConfigsAdapterSelectors.selectAll(data).forEach((modelConfig) => {
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
const { base, name, type } = modelConfig;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
});
});
},
}),
@ -375,14 +263,8 @@ export const modelsApi = api.injectEndpoints({
});
export const {
useGetModelConfigsQuery,
useGetModelConfigQuery,
useGetMainModelsQuery,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useDeleteModelsMutation,
useDeleteModelImageMutation,
useUpdateModelMutation,
@ -396,127 +278,3 @@ export const {
useCancelModelInstallMutation,
usePruneCompletedModelInstallsMutation,
} = modelsApi;
const upsertModelConfigs = (
modelConfigs: EntityState<AnyModelConfig, string>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
dispatch: ThunkDispatch<any, any, UnknownAction>
) => {
/**
* Once a list of models of a specific type is received, fetching any of those models individually is a waste of a
* network request. This function takes the received list of models and upserts them into the individual query caches
* for each model type.
*/
// Iterate over all the models and upsert them into the individual query caches for each model type.
anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => {
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
const { base, name, type } = modelConfig;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
});
};
const upsertSingleModelConfig = (
modelConfig: AnyModelConfig,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
dispatch: ThunkDispatch<any, any, UnknownAction>
) => {
/**
* When a model is updated, the individual query caches for each model type need to be updated, as well as the list
* query caches of models of that type.
*/
// Update the individual model query caches.
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
const { base, name, type } = modelConfig;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
// Update the list query caches for each model type.
if (modelConfig.type === 'main') {
[ALL_BASE_MODELS, NON_REFINER_BASE_MODELS, SDXL_MAIN_MODELS, NON_SDXL_MAIN_MODELS, REFINER_BASE_MODELS].forEach(
(queryArg) => {
dispatch(
modelsApi.util.updateQueryData('getMainModels', queryArg, (draft) => {
mainModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
}
);
return;
}
if (modelConfig.type === 'controlnet') {
dispatch(
modelsApi.util.updateQueryData('getControlNetModels', undefined, (draft) => {
controlNetModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'embedding') {
dispatch(
modelsApi.util.updateQueryData('getTextualInversionModels', undefined, (draft) => {
textualInversionModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'ip_adapter') {
dispatch(
modelsApi.util.updateQueryData('getIPAdapterModels', undefined, (draft) => {
ipAdapterModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'lora') {
dispatch(
modelsApi.util.updateQueryData('getLoRAModels', undefined, (draft) => {
loraModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 't2i_adapter') {
dispatch(
modelsApi.util.updateQueryData('getT2IAdapterModels', undefined, (draft) => {
t2iAdapterModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'vae') {
dispatch(
modelsApi.util.updateQueryData('getVaeModels', undefined, (draft) => {
vaeModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
};

View File

@ -0,0 +1,42 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { useMemo } from 'react';
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isControlNetModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig,
isSDXLMainModelModelConfig,
isT2IAdapterModelConfig,
isTIModelConfig,
isVAEModelConfig,
} from 'services/api/types';
const buildModelsHook =
<T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T) =>
() => {
const result = useGetModelConfigsQuery(undefined);
const modelConfigs = useMemo(() => {
if (!result.data) {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard);
}, [result]);
return [modelConfigs, result] as const;
};
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
export const useVAEModels = buildModelsHook(isVAEModelConfig);

View File

@ -1,12 +1,7 @@
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useRefinerModels } from 'services/api/hooks/modelsByType';
export const useIsRefinerAvailable = () => {
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
selectFromResult: ({ data }) => ({
isRefinerAvailable: data ? data.ids.length > 0 : false,
}),
});
const [refinerModels] = useRefinerModels();
return isRefinerAvailable;
return Boolean(refinerModels.length);
};

View File

@ -48,7 +48,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
export type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
type DiffusersModelConfig = S['MainDiffusersConfig'];
type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
@ -103,6 +103,18 @@ export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is
return config.type === 'main' && config.base === 'sdxl-refiner';
};
export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'sdxl';
};
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
};
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'embedding';
};
export type ModelInstallJob = S['ModelInstallJob'];
export type ModelInstallStatus = S['InstallStatus'];
@ -200,10 +212,3 @@ export type PostUploadAction =
| CanvasInitialImageAction
| ToastAction
| AddToBatchAction;
type TypeGuard<T> = {
(input: unknown): input is T;
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type TypeGuardFor<T extends TypeGuard<any>> = T extends TypeGuard<infer U> ? U : never;