feat(ui): single getModelConfigs query (#5962)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [z] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

## Description

Single query, with simple wrapper hooks (type-safe). Updated everywhere
in frontend.

## QA Instructions, Screenshots, Recordings

Things that use models should work. All of this code is strictly
typechecked, so we can be confident in this change.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

This PR can be merged when approved

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->
This commit is contained in:
blessedcoolant 2024-03-14 18:20:38 +05:30 committed by GitHub
commit b07b7af710
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 447 additions and 790 deletions

View File

@ -1,10 +1,10 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import type { JSONObject } from 'common/types';
import { import {
controlAdapterModelCleared, controlAdapterModelCleared,
selectAllControlNets, selectControlAdapterAll,
selectAllIPAdapters,
selectAllT2IAdapters,
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice'; import { loraRemoved } from 'features/lora/store/loraSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize'; import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
@ -12,212 +12,162 @@ import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas'; import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension'; import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice'; import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es'; import { forEach } from 'lodash-es';
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models'; import type { Logger } from 'roarr';
import type { TypeGuardFor } from 'services/api/types'; import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { isNonRefinerMainModelConfig, isRefinerMainModelModelConfig, isVAEModelConfig } from 'services/api/types';
export const addModelsLoadedListener = (startAppListening: AppStartListening) => { export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> => predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one // models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models'); const log = logger('models');
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`); log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
const state = getState(); const state = getState();
const currentModel = state.generation.model; const models = modelConfigsAdapterSelectors.selectAll(action.payload);
const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) { handleMainModels(models, state, dispatch, log);
// No models loaded at all handleRefinerModels(models, state, dispatch, log);
dispatch(modelChanged(null)); handleVAEModels(models, state, dispatch, log);
return; handleLoRAModels(models, state, dispatch, log);
} handleControlAdapterModels(models, state, dispatch, log);
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (isCurrentModelAvailable) {
return;
}
const defaultModel = state.config.sd.defaultModel;
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.generation.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(widthChanged(width));
dispatch(heightChanged(height));
return;
}
}
const result = zParameterModel.safeParse(models[0]);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse main model');
return;
}
dispatch(modelChanged(result.data, currentModel));
},
});
startAppListening({
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) && action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info({ models: action.payload.entities }, `SDXL Refiner models loaded (${action.payload.ids.length})`);
const currentModel = getState().sdxl.refinerModel;
const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) {
// No models loaded at all
dispatch(refinerModelChanged(null));
return;
}
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (!isCurrentModelAvailable) {
dispatch(refinerModelChanged(null));
return;
}
},
});
startAppListening({
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// VAEs loaded, need to reset the VAE is it's no longer available
const log = logger('models');
log.info({ models: action.payload.entities }, `VAEs loaded (${action.payload.ids.length})`);
const currentVae = getState().generation.vae;
if (currentVae === null) {
// null is a valid VAE! it means "use the default with the main model"
return;
}
const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key);
if (isCurrentVAEAvailable) {
return;
}
const firstModel = vaeModelsAdapterSelectors.selectAll(action.payload)[0];
if (!firstModel) {
// No custom VAEs loaded at all; use the default
dispatch(vaeSelected(null));
return;
}
const result = zParameterVAEModel.safeParse(firstModel);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse VAE model');
return;
}
dispatch(vaeSelected(result.data));
},
});
startAppListening({
matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// LoRA models loaded - need to remove missing LoRAs from state
const log = logger('models');
log.info({ models: action.payload.entities }, `LoRAs loaded (${action.payload.ids.length})`);
const loras = getState().lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.model.key);
if (isLoRAAvailable) {
return;
}
dispatch(loraRemoved(id));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`);
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`);
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getIPAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`);
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled,
effect: async (action) => {
const log = logger('models');
log.info({ models: action.payload.entities }, `Embeddings loaded (${action.payload.ids.length})`);
}, },
}); });
}; };
type ModelHandler = (
models: AnyModelConfig[],
state: RootState,
dispatch: AppDispatch,
log: Logger<JSONObject>
) => undefined;
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
const currentModel = state.generation.model;
const mainModels = models.filter(isNonRefinerMainModelConfig);
if (mainModels.length === 0) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
const isCurrentMainModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (isCurrentMainModelAvailable) {
return;
}
const defaultModel = state.config.sd.defaultModel;
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.generation.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(widthChanged(width));
dispatch(heightChanged(height));
return;
}
}
const result = zParameterModel.safeParse(models[0]);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse main model');
return;
}
dispatch(modelChanged(result.data, currentModel));
};
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
const currentRefinerModel = state.sdxl.refinerModel;
const refinerModels = models.filter(isRefinerMainModelModelConfig);
if (models.length === 0) {
// No models loaded at all
dispatch(refinerModelChanged(null));
return;
}
const isCurrentRefinerModelAvailable = currentRefinerModel
? refinerModels.some((m) => m.key === currentRefinerModel.key)
: false;
if (!isCurrentRefinerModelAvailable) {
dispatch(refinerModelChanged(null));
return;
}
};
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
const currentVae = state.generation.vae;
if (currentVae === null) {
// null is a valid VAE! it means "use the default with the main model"
return;
}
const vaeModels = models.filter(isVAEModelConfig);
const isCurrentVAEAvailable = vaeModels.some((m) => m.key === currentVae.key);
if (isCurrentVAEAvailable) {
return;
}
const firstModel = vaeModels[0];
if (!firstModel) {
// No custom VAEs loaded at all; use the default
dispatch(vaeSelected(null));
return;
}
const result = zParameterVAEModel.safeParse(firstModel);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse VAE model');
return;
}
dispatch(vaeSelected(result.data));
};
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
const loras = state.lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = models.some((m) => m.key === lora.model.key);
if (isLoRAAvailable) {
return;
}
dispatch(loraRemoved(id));
});
};
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
const isModelAvailable = models.some((m) => m.key === ca.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
};

View File

@ -23,8 +23,7 @@ import {
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { map } from 'lodash-es'; import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types';
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => { export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
@ -39,7 +38,12 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
return; return;
} }
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap(); const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
const data = await request.unwrap();
request.unsubscribe();
const models = modelConfigsAdapterSelectors.selectAll(data);
const modelConfig = models.find((model) => model.key === currentModel.key);
if (!modelConfig) { if (!modelConfig) {
return; return;
@ -55,11 +59,8 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
if (vae === 'default') { if (vae === 'default') {
dispatch(vaeSelected(null)); dispatch(vaeSelected(null));
} else { } else {
const { data } = modelsApi.endpoints.getVaeModels.select()(state); const vaeModel = models.find((model) => model.key === vae);
const vaeArray = map(data?.entities); const result = zParameterVAEModel.safeParse(vaeModel);
const validVae = vaeArray.find((model) => model.key === vae);
const result = zParameterVAEModel.safeParse(validVae);
if (!result.success) { if (!result.success) {
return; return;
} }

View File

@ -30,11 +30,10 @@ export const addSocketConnectedEventListener = (startAppListening: AppStartListe
// Bail on the recovery logic if this is the first connection - we don't need to recover anything // Bail on the recovery logic if this is the first connection - we don't need to recover anything
if ($isFirstConnection.get()) { if ($isFirstConnection.get()) {
// The TI models are used in a component that is not always rendered, so when users open the prompt triggers // Populate the model configs on first connection. This query cache has a 24hr timeout, so we can immediately
// box has a delay while it does the initial fetch. We need to both pre-fetch the data and maintain an RTK // unsubscribe.
// Query subscription to it, so the cache doesn't clear itself when the user closes the prompt triggers box. const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
// So, we explicitly do not unsubscribe from this query! request.unsubscribe();
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
$isFirstConnection.set(false); $isFirstConnection.set(false);
return; return;

View File

@ -1,15 +1,14 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select'; import type { GroupBase } from 'chakra-react-select';
import type { ModelIdentifierField } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
import { groupBy, map, reduce } from 'lodash-es'; import { groupBy, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types'; import type { AnyModelConfig } from 'services/api/types';
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = { type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined; modelConfigs: T[];
selectedModel?: ModelIdentifierField | null; selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void; onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean; getIsDisabled?: (model: T) => boolean;
@ -29,13 +28,12 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
): UseGroupedModelComboboxReturn => { ): UseGroupedModelComboboxReturn => {
const { t } = useTranslation(); const { t } = useTranslation();
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl'); const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg; const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg;
const options = useMemo<GroupBase<ComboboxOption>[]>(() => { const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
if (!modelEntities) { if (!modelConfigs) {
return []; return [];
} }
const modelEntitiesArray = map(modelEntities.entities); const groupedModels = groupBy(modelConfigs, 'base');
const groupedModels = groupBy(modelEntitiesArray, 'base');
const _options = reduce( const _options = reduce(
groupedModels, groupedModels,
(acc, val, label) => { (acc, val, label) => {
@ -53,7 +51,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
); );
_options.sort((a) => (a.label === base_model ? -1 : 1)); _options.sort((a) => (a.label === base_model ? -1 : 1));
return _options; return _options;
}, [getIsDisabled, modelEntities, base_model]); }, [getIsDisabled, modelConfigs, base_model]);
const value = useMemo( const value = useMemo(
() => () =>
@ -67,14 +65,14 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
onChange(null); onChange(null);
return; return;
} }
const model = modelEntities?.entities[v.value]; const model = modelConfigs.find((m) => m.key === v.value);
if (!model) { if (!model) {
onChange(null); onChange(null);
return; return;
} }
onChange(model); onChange(model);
}, },
[modelEntities?.entities, onChange] [modelConfigs, onChange]
); );
const placeholder = useMemo(() => { const placeholder = useMemo(() => {

View File

@ -1,13 +1,11 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import type { ModelIdentifierField } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types'; import type { AnyModelConfig } from 'services/api/types';
type UseModelComboboxArg<T extends AnyModelConfig> = { type UseModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined; modelConfigs: T[];
selectedModel?: ModelIdentifierField | null; selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void; onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean; getIsDisabled?: (model: T) => boolean;
@ -25,19 +23,14 @@ type UseModelComboboxReturn = {
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => { export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
const { t } = useTranslation(); const { t } = useTranslation();
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg; const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
const options = useMemo<ComboboxOption[]>(() => { const options = useMemo<ComboboxOption[]>(() => {
if (!modelEntities) { return modelConfigs.filter(optionsFilter).map((model) => ({
return []; label: model.name,
} value: model.key,
return map(modelEntities.entities) isDisabled: getIsDisabled ? getIsDisabled(model) : false,
.filter(optionsFilter) }));
.map((model) => ({ }, [optionsFilter, getIsDisabled, modelConfigs]);
label: model.name,
value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
}));
}, [optionsFilter, getIsDisabled, modelEntities]);
const value = useMemo( const value = useMemo(
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)), () => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
@ -50,14 +43,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
onChange(null); onChange(null);
return; return;
} }
const model = modelEntities?.entities[v.value]; const model = modelConfigs.find((m) => m.key === v.value);
if (!model) { if (!model) {
onChange(null); onChange(null);
return; return;
} }
onChange(model); onChange(model);
}, },
[modelEntities?.entities, onChange] [modelConfigs, onChange]
); );
const placeholder = useMemo(() => { const placeholder = useMemo(() => {

View File

@ -1,15 +1,12 @@
import type { Item } from '@invoke-ai/ui-library'; import type { Item } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { ModelIdentifierField } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { filter } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types'; import type { AnyModelConfig } from 'services/api/types';
type UseModelCustomSelectArg<T extends AnyModelConfig> = { type UseModelCustomSelectArg<T extends AnyModelConfig> = {
data: EntityState<T, string> | undefined; modelConfigs: T[];
isLoading: boolean; isLoading: boolean;
selectedModel?: ModelIdentifierField | null; selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void; onChange: (value: T | null) => void;
@ -28,7 +25,7 @@ const modelFilterDefault = () => true;
const isModelDisabledDefault = () => false; const isModelDisabledDefault = () => false;
export const useModelCustomSelect = <T extends AnyModelConfig>({ export const useModelCustomSelect = <T extends AnyModelConfig>({
data, modelConfigs,
isLoading, isLoading,
selectedModel, selectedModel,
onChange, onChange,
@ -39,30 +36,28 @@ export const useModelCustomSelect = <T extends AnyModelConfig>({
const items: Item[] = useMemo( const items: Item[] = useMemo(
() => () =>
data modelConfigs.filter(modelFilter).map<Item>((m) => ({
? filter(data.entities, modelFilter).map<Item>((m) => ({ label: m.name,
label: m.name, value: m.key,
value: m.key, description: m.description,
description: m.description, group: MODEL_TYPE_SHORT_MAP[m.base],
group: MODEL_TYPE_SHORT_MAP[m.base], isDisabled: isModelDisabled(m),
isDisabled: isModelDisabled(m), })),
})) [modelConfigs, isModelDisabled, modelFilter]
: EMPTY_ARRAY,
[data, isModelDisabled, modelFilter]
); );
const _onChange = useCallback( const _onChange = useCallback(
(item: Item | null) => { (item: Item | null) => {
if (!item || !data) { if (!item || !modelConfigs) {
return; return;
} }
const model = data.entities[item.value]; const model = modelConfigs.find((m) => m.key === item.value);
if (!model) { if (!model) {
return; return;
} }
onChange(model); onChange(model);
}, },
[data, onChange] [modelConfigs, onChange]
); );
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]); const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);

View File

@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect'; import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled'; import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel'; import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery'; import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType'; import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
@ -20,7 +20,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType); const [modelConfigs, { isLoading }] = useControlAdapterModels(controlAdapterType);
const _onChange = useCallback( const _onChange = useCallback(
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => { (modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
@ -43,7 +43,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
); );
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({ const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
data, modelConfigs,
isLoading, isLoading,
selectedModel, selectedModel,
onChange: _onChange, onChange: _onChange,

View File

@ -1,17 +1,16 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice'; import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types'; import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { useControlAdapterModels } from './useControlAdapterModels';
export const useAddControlAdapter = (type: ControlAdapterType) => { export const useAddControlAdapter = (type: ControlAdapterType) => {
const baseModel = useAppSelector((s) => s.generation.model?.base); const baseModel = useAppSelector((s) => s.generation.model?.base);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const models = useControlAdapterModels(type); const [models] = useControlAdapterModels(type);
const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => { const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
// prefer to use a model that matches the base model // prefer to use a model that matches the base model

View File

@ -1,26 +0,0 @@
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import {
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
export const useControlAdapterModelQuery = (type: ControlAdapterType) => {
const controlNetModelsQuery = useGetControlNetModelsQuery();
const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery();
const ipAdapterModelsQuery = useGetIPAdapterModelsQuery();
if (type === 'controlnet') {
return controlNetModelsQuery;
}
if (type === 't2i_adapter') {
return t2iAdapterModelsQuery;
}
if (type === 'ip_adapter') {
return ipAdapterModelsQuery;
}
// Assert that the end of the function is not reachable.
const exhaustiveCheck: never = type;
return exhaustiveCheck;
};

View File

@ -1,31 +1,10 @@
import type { ControlAdapterType } from 'features/controlAdapters/store/types'; import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import { useMemo } from 'react'; import { useControlNetModels, useIPAdapterModels, useT2IAdapterModels } from 'services/api/hooks/modelsByType';
import {
controlNetModelsAdapterSelectors,
ipAdapterModelsAdapterSelectors,
t2iAdapterModelsAdapterSelectors,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
export const useControlAdapterModels = (type?: ControlAdapterType) => { export const useControlAdapterModels = (type: ControlAdapterType) => {
const { data: controlNetModelsData } = useGetControlNetModelsQuery(); const controlNetModels = useControlNetModels();
const controlNetModels = useMemo( const t2iAdapterModels = useT2IAdapterModels();
() => (controlNetModelsData ? controlNetModelsAdapterSelectors.selectAll(controlNetModelsData) : []), const ipAdapterModels = useIPAdapterModels();
[controlNetModelsData]
);
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
const t2iAdapterModels = useMemo(
() => (t2iAdapterModelsData ? t2iAdapterModelsAdapterSelectors.selectAll(t2iAdapterModelsData) : []),
[t2iAdapterModelsData]
);
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
const ipAdapterModels = useMemo(
() => (ipAdapterModelsData ? ipAdapterModelsAdapterSelectors.selectAll(ipAdapterModelsData) : []),
[ipAdapterModelsData]
);
if (type === 'controlnet') { if (type === 'controlnet') {
return controlNetModels; return controlNetModels;
@ -36,5 +15,8 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
if (type === 'ip_adapter') { if (type === 'ip_adapter') {
return ipAdapterModels; return ipAdapterModels;
} }
return [];
// Assert that the end of the function is not reachable.
const exhaustiveCheck: never = type;
return exhaustiveCheck;
}; };

View File

@ -7,14 +7,14 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice'; import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types'; import type { LoRAModelConfig } from 'services/api/types';
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras); const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
const LoRASelect = () => { const LoRASelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data, isLoading } = useGetLoRAModelsQuery(); const [modelConfigs, { isLoading }] = useLoRAModels();
const { t } = useTranslation(); const { t } = useTranslation();
const addedLoRAs = useAppSelector(selectAddedLoRAs); const addedLoRAs = useAppSelector(selectAddedLoRAs);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
@ -37,7 +37,7 @@ const LoRASelect = () => {
); );
const { options, onChange } = useGroupedModelCombobox({ const { options, onChange } = useGroupedModelCombobox({
modelEntities: data, modelConfigs,
getIsDisabled, getIsDisabled,
onChange: _onChange, onChange: _onChange,
}); });

View File

@ -1,13 +1,16 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig } from 'app/store/store'; import type { PersistConfig } from 'app/store/store';
import type { ModelType } from 'services/api/types';
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
type ModelManagerState = { type ModelManagerState = {
_version: 1; _version: 1;
selectedModelKey: string | null; selectedModelKey: string | null;
selectedModelMode: 'edit' | 'view'; selectedModelMode: 'edit' | 'view';
searchTerm: string; searchTerm: string;
filteredModelType: string | null; filteredModelType: FilterableModelType | null;
scanPath: string | undefined; scanPath: string | undefined;
}; };
@ -35,7 +38,7 @@ export const modelManagerV2Slice = createSlice({
state.searchTerm = action.payload; state.searchTerm = action.payload;
}, },
setFilteredModelType: (state, action: PayloadAction<string | null>) => { setFilteredModelType: (state, action: PayloadAction<FilterableModelType | null>) => {
state.filteredModelType = action.payload; state.filteredModelType = action.payload;
}, },
setScanPath: (state, action: PayloadAction<string | undefined>) => { setScanPath: (state, action: PayloadAction<string | undefined>) => {

View File

@ -1,122 +1,105 @@
import { Flex, Spinner, Text } from '@invoke-ai/ui-library'; import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { forEach } from 'lodash-es'; import { memo, useMemo } from 'react';
import { memo } from 'react';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { import {
useGetControlNetModelsQuery, useControlNetModels,
useGetIPAdapterModelsQuery, useEmbeddingModels,
useGetLoRAModelsQuery, useIPAdapterModels,
useGetMainModelsQuery, useLoRAModels,
useGetT2IAdapterModelsQuery, useMainModels,
useGetTextualInversionModelsQuery, useT2IAdapterModels,
useGetVaeModelsQuery, useVAEModels,
} from 'services/api/endpoints/models'; } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig } from 'services/api/types'; import type { AnyModelConfig, ModelType } from 'services/api/types';
import { ModelListWrapper } from './ModelListWrapper'; import { ModelListWrapper } from './ModelListWrapper';
const ModelList = () => { const ModelList = () => {
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2); const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
const { filteredMainModels, isLoadingMainModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
selectFromResult: ({ data, isLoading }) => ({ const filteredMainModels = useMemo(
filteredMainModels: modelsFilter(data, searchTerm, filteredModelType), () => modelsFilter(mainModels, searchTerm, filteredModelType),
isLoadingMainModels: isLoading, [mainModels, searchTerm, filteredModelType]
}),
});
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredLoraModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingLoraModels: isLoading,
}),
});
const { filteredTextualInversionModels, isLoadingTextualInversionModels } = useGetTextualInversionModelsQuery(
undefined,
{
selectFromResult: ({ data, isLoading }) => ({
filteredTextualInversionModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingTextualInversionModels: isLoading,
}),
}
); );
const { filteredControlnetModels, isLoadingControlnetModels } = useGetControlNetModelsQuery(undefined, { const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
selectFromResult: ({ data, isLoading }) => ({ const filteredLoRAModels = useMemo(
filteredControlnetModels: modelsFilter(data, searchTerm, filteredModelType), () => modelsFilter(loraModels, searchTerm, filteredModelType),
isLoadingControlnetModels: isLoading, [loraModels, searchTerm, filteredModelType]
}), );
});
const { filteredT2iAdapterModels, isLoadingT2IAdapterModels } = useGetT2IAdapterModelsQuery(undefined, { const [embeddingModels, { isLoading: isLoadingEmbeddingModels }] = useEmbeddingModels();
selectFromResult: ({ data, isLoading }) => ({ const filteredEmbeddingModels = useMemo(
filteredT2iAdapterModels: modelsFilter(data, searchTerm, filteredModelType), () => modelsFilter(embeddingModels, searchTerm, filteredModelType),
isLoadingT2IAdapterModels: isLoading, [embeddingModels, searchTerm, filteredModelType]
}), );
});
const { filteredIpAdapterModels, isLoadingIpAdapterModels } = useGetIPAdapterModelsQuery(undefined, { const [controlNetModels, { isLoading: isLoadingControlNetModels }] = useControlNetModels();
selectFromResult: ({ data, isLoading }) => ({ const filteredControlNetModels = useMemo(
filteredIpAdapterModels: modelsFilter(data, searchTerm, filteredModelType), () => modelsFilter(controlNetModels, searchTerm, filteredModelType),
isLoadingIpAdapterModels: isLoading, [controlNetModels, searchTerm, filteredModelType]
}), );
});
const { filteredVaeModels, isLoadingVaeModels } = useGetVaeModelsQuery(undefined, { const [t2iAdapterModels, { isLoading: isLoadingT2IAdapterModels }] = useT2IAdapterModels();
selectFromResult: ({ data, isLoading }) => ({ const filteredT2IAdapterModels = useMemo(
filteredVaeModels: modelsFilter(data, searchTerm, filteredModelType), () => modelsFilter(t2iAdapterModels, searchTerm, filteredModelType),
isLoadingVaeModels: isLoading, [t2iAdapterModels, searchTerm, filteredModelType]
}), );
});
const [ipAdapterModels, { isLoading: isLoadingIPAdapterModels }] = useIPAdapterModels();
const filteredIPAdapterModels = useMemo(
() => modelsFilter(ipAdapterModels, searchTerm, filteredModelType),
[ipAdapterModels, searchTerm, filteredModelType]
);
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels();
const filteredVAEModels = useMemo(
() => modelsFilter(vaeModels, searchTerm, filteredModelType),
[vaeModels, searchTerm, filteredModelType]
);
return ( return (
<ScrollableContent> <ScrollableContent>
<Flex flexDirection="column" w="full" h="full" gap={4}> <Flex flexDirection="column" w="full" h="full" gap={4}>
{/* Main Model List */} {/* Main Model List */}
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main..." />} {isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main Models..." />}
{!isLoadingMainModels && filteredMainModels.length > 0 && ( {!isLoadingMainModels && filteredMainModels.length > 0 && (
<ModelListWrapper title="Main" modelList={filteredMainModels} key="main" /> <ModelListWrapper title="Main" modelList={filteredMainModels} key="main" />
)} )}
{/* LoRAs List */} {/* LoRAs List */}
{isLoadingLoraModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />} {isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
{!isLoadingLoraModels && filteredLoraModels.length > 0 && ( {!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
<ModelListWrapper title="LoRAs" modelList={filteredLoraModels} key="loras" /> <ModelListWrapper title="LoRA" modelList={filteredLoRAModels} key="loras" />
)} )}
{/* TI List */} {/* TI List */}
{isLoadingTextualInversionModels && <FetchingModelsLoader loadingMessage="Loading Textual Inversions..." />} {isLoadingEmbeddingModels && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
{!isLoadingTextualInversionModels && filteredTextualInversionModels.length > 0 && ( {!isLoadingEmbeddingModels && filteredEmbeddingModels.length > 0 && (
<ModelListWrapper <ModelListWrapper title="Embedding" modelList={filteredEmbeddingModels} key="textual-inversions" />
title="Textual Inversions"
modelList={filteredTextualInversionModels}
key="textual-inversions"
/>
)} )}
{/* VAE List */} {/* VAE List */}
{isLoadingVaeModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />} {isLoadingVAEModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
{!isLoadingVaeModels && filteredVaeModels.length > 0 && ( {!isLoadingVAEModels && filteredVAEModels.length > 0 && (
<ModelListWrapper title="VAEs" modelList={filteredVaeModels} key="vae" /> <ModelListWrapper title="VAE" modelList={filteredVAEModels} key="vae" />
)} )}
{/* Controlnet List */} {/* Controlnet List */}
{isLoadingControlnetModels && <FetchingModelsLoader loadingMessage="Loading Controlnets..." />} {isLoadingControlNetModels && <FetchingModelsLoader loadingMessage="Loading ControlNets..." />}
{!isLoadingControlnetModels && filteredControlnetModels.length > 0 && ( {!isLoadingControlNetModels && filteredControlNetModels.length > 0 && (
<ModelListWrapper title="Controlnets" modelList={filteredControlnetModels} key="controlnets" /> <ModelListWrapper title="ControlNet" modelList={filteredControlNetModels} key="controlnets" />
)} )}
{/* IP Adapter List */} {/* IP Adapter List */}
{isLoadingIpAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />} {isLoadingIPAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
{!isLoadingIpAdapterModels && filteredIpAdapterModels.length > 0 && ( {!isLoadingIPAdapterModels && filteredIPAdapterModels.length > 0 && (
<ModelListWrapper title="IP Adapters" modelList={filteredIpAdapterModels} key="ip-adapters" /> <ModelListWrapper title="IP Adapter" modelList={filteredIPAdapterModels} key="ip-adapters" />
)} )}
{/* T2I Adapters List */} {/* T2I Adapters List */}
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />} {isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
{!isLoadingT2IAdapterModels && filteredT2iAdapterModels.length > 0 && ( {!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
<ModelListWrapper title="T2I Adapters" modelList={filteredT2iAdapterModels} key="t2i-adapters" /> <ModelListWrapper title="T2I Adapter" modelList={filteredT2IAdapterModels} key="t2i-adapters" />
)} )}
</Flex> </Flex>
</ScrollableContent> </ScrollableContent>
@ -126,25 +109,16 @@ const ModelList = () => {
export default memo(ModelList); export default memo(ModelList);
const modelsFilter = <T extends AnyModelConfig>( const modelsFilter = <T extends AnyModelConfig>(
data: EntityState<T, string> | undefined, data: T[],
nameFilter: string, nameFilter: string,
filteredModelType: string | null filteredModelType: ModelType | null
): T[] => { ): T[] => {
const filteredModels: T[] = []; return data.filter((model) => {
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase()); const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
const matchesType = filteredModelType ? model.type === filteredModelType : true; const matchesType = filteredModelType ? model.type === filteredModelType : true;
if (matchesFilter && matchesType) { return matchesFilter && matchesType;
filteredModels.push(model);
}
}); });
return filteredModels;
}; };
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => { const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {

View File

@ -1,11 +1,13 @@
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoFilter } from 'react-icons/io5'; import { PiFunnelBold } from 'react-icons/pi';
import { objectKeys } from 'tsafe';
const MODEL_TYPE_LABELS: { [key: string]: string } = { const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = {
main: 'Main', main: 'Main',
lora: 'LoRA', lora: 'LoRA',
embedding: 'Textual Inversion', embedding: 'Textual Inversion',
@ -13,7 +15,6 @@ const MODEL_TYPE_LABELS: { [key: string]: string } = {
vae: 'VAE', vae: 'VAE',
t2i_adapter: 'T2I Adapter', t2i_adapter: 'T2I Adapter',
ip_adapter: 'IP Adapter', ip_adapter: 'IP Adapter',
clip_vision: 'Clip Vision',
}; };
export const ModelTypeFilter = () => { export const ModelTypeFilter = () => {
@ -22,7 +23,7 @@ export const ModelTypeFilter = () => {
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType); const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
const selectModelType = useCallback( const selectModelType = useCallback(
(option: string) => { (option: FilterableModelType) => {
dispatch(setFilteredModelType(option)); dispatch(setFilteredModelType(option));
}, },
[dispatch] [dispatch]
@ -34,12 +35,12 @@ export const ModelTypeFilter = () => {
return ( return (
<Menu> <Menu>
<MenuButton as={Button} size="sm" leftIcon={<IoFilter />}> <MenuButton as={Button} size="sm" leftIcon={<PiFunnelBold />}>
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')} {filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
</MenuButton> </MenuButton>
<MenuList> <MenuList>
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem> <MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
{Object.keys(MODEL_TYPE_LABELS).map((option) => ( {objectKeys(MODEL_TYPE_LABELS).map((option) => (
<MenuItem <MenuItem
key={option} key={option}
bg={filteredModelType === option ? 'base.700' : 'transparent'} bg={filteredModelType === option ? 'base.700' : 'transparent'}

View File

@ -4,12 +4,12 @@ import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle'; import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form'; import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form'; import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings'; import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
@ -21,18 +21,16 @@ export function DefaultVae(props: UseControllerProps<MainModelDefaultSettingsFor
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken); const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { compatibleOptions } = useGetVaeModelsQuery(undefined, { const [vaeModels] = useVAEModels();
selectFromResult: ({ data }) => { const compatibleOptions = useMemo(() => {
const modelArray = map(data?.entities); const compatibleOptions = vaeModels
const compatibleOptions = modelArray .filter((vae) => vae.base === modelData?.base)
.filter((vae) => vae.base === modelData?.base) .map((vae) => ({ label: vae.name, value: vae.key }));
.map((vae) => ({ label: vae.name, value: vae.key }));
const defaultOption = { label: 'Default VAE', value: 'default' }; const defaultOption = { label: 'Default VAE', value: 'default' };
return { compatibleOptions: [defaultOption, ...compatibleOptions] }; return [defaultOption, ...compatibleOptions];
}, }, [modelData?.base, vaeModels]);
});
const onChange = useCallback<ComboboxOnChange>( const onChange = useCallback<ComboboxOnChange>(
(v) => { (v) => {

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field'; import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; import { useControlNetModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig } from 'services/api/types'; import type { ControlNetModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -14,7 +14,7 @@ type Props = FieldComponentProps<ControlNetModelFieldInputInstance, ControlNetMo
const ControlNetModelFieldInputComponent = (props: Props) => { const ControlNetModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data, isLoading } = useGetControlNetModelsQuery(); const [modelConfigs, { isLoading }] = useControlNetModels();
const _onChange = useCallback( const _onChange = useCallback(
(value: ControlNetModelConfig | null) => { (value: ControlNetModelConfig | null) => {
@ -33,7 +33,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
); );
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data, modelConfigs,
onChange: _onChange, onChange: _onChange,
selectedModel: field.value, selectedModel: field.value,
isLoading, isLoading,

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field'; import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models'; import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
import type { IPAdapterModelConfig } from 'services/api/types'; import type { IPAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -14,7 +14,7 @@ const IPAdapterModelFieldInputComponent = (
) => { ) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(); const [modelConfigs, { isLoading }] = useIPAdapterModels();
const _onChange = useCallback( const _onChange = useCallback(
(value: IPAdapterModelConfig | null) => { (value: IPAdapterModelConfig | null) => {
@ -33,9 +33,10 @@ const IPAdapterModelFieldInputComponent = (
); );
const { options, value, onChange } = useGroupedModelCombobox({ const { options, value, onChange } = useGroupedModelCombobox({
modelEntities: ipAdapterModels, modelConfigs,
onChange: _onChange, onChange: _onChange,
selectedModel: field.value, selectedModel: field.value,
isLoading,
}); });
return ( return (

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field'; import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types'; import type { LoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -14,7 +14,7 @@ type Props = FieldComponentProps<LoRAModelFieldInputInstance, LoRAModelFieldInpu
const LoRAModelFieldInputComponent = (props: Props) => { const LoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data, isLoading } = useGetLoRAModelsQuery(); const [modelConfigs, { isLoading }] = useLoRAModels();
const _onChange = useCallback( const _onChange = useCallback(
(value: LoRAModelConfig | null) => { (value: LoRAModelConfig | null) => {
if (!value) { if (!value) {
@ -32,7 +32,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
); );
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data, modelConfigs,
onChange: _onChange, onChange: _onChange,
selectedModel: field.value, selectedModel: field.value,
isLoading, isLoading,

View File

@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field'; import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { NON_SDXL_MAIN_MODELS } from 'services/api/constants'; import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/types'; import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -16,7 +15,7 @@ type Props = FieldComponentProps<MainModelFieldInputInstance, MainModelFieldInpu
const MainModelFieldInputComponent = (props: Props) => { const MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS); const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
const _onChange = useCallback( const _onChange = useCallback(
(value: MainModelConfig | null) => { (value: MainModelConfig | null) => {
if (!value) { if (!value) {
@ -33,7 +32,7 @@ const MainModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId] [dispatch, field.name, nodeId]
); );
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data, modelConfigs,
onChange: _onChange, onChange: _onChange,
isLoading, isLoading,
selectedModel: field.value, selectedModel: field.value,

View File

@ -8,8 +8,7 @@ import type {
SDXLRefinerModelFieldInputTemplate, SDXLRefinerModelFieldInputTemplate,
} from 'features/nodes/types/field'; } from 'features/nodes/types/field';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { REFINER_BASE_MODELS } from 'services/api/constants'; import { useRefinerModels } from 'services/api/hooks/modelsByType';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/types'; import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -19,7 +18,7 @@ type Props = FieldComponentProps<SDXLRefinerModelFieldInputInstance, SDXLRefiner
const RefinerModelFieldInputComponent = (props: Props) => { const RefinerModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS); const [modelConfigs, { isLoading }] = useRefinerModels();
const _onChange = useCallback( const _onChange = useCallback(
(value: MainModelConfig | null) => { (value: MainModelConfig | null) => {
if (!value) { if (!value) {
@ -36,7 +35,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId] [dispatch, field.name, nodeId]
); );
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data, modelConfigs,
onChange: _onChange, onChange: _onChange,
isLoading, isLoading,
selectedModel: field.value, selectedModel: field.value,

View File

@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field'; import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { SDXL_MAIN_MODELS } from 'services/api/constants'; import { useSDXLModels } from 'services/api/hooks/modelsByType';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/types'; import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -16,7 +15,7 @@ type Props = FieldComponentProps<SDXLMainModelFieldInputInstance, SDXLMainModelF
const SDXLMainModelFieldInputComponent = (props: Props) => { const SDXLMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS); const [modelConfigs, { isLoading }] = useSDXLModels();
const _onChange = useCallback( const _onChange = useCallback(
(value: MainModelConfig | null) => { (value: MainModelConfig | null) => {
if (!value) { if (!value) {
@ -33,7 +32,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId] [dispatch, field.name, nodeId]
); );
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data, modelConfigs,
onChange: _onChange, onChange: _onChange,
isLoading, isLoading,
selectedModel: field.value, selectedModel: field.value,

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field'; import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models'; import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
import type { T2IAdapterModelConfig } from 'services/api/types'; import type { T2IAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -15,7 +15,7 @@ const T2IAdapterModelFieldInputComponent = (
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(); const [modelConfigs, { isLoading }] = useT2IAdapterModels();
const _onChange = useCallback( const _onChange = useCallback(
(value: T2IAdapterModelConfig | null) => { (value: T2IAdapterModelConfig | null) => {
@ -34,9 +34,10 @@ const T2IAdapterModelFieldInputComponent = (
); );
const { options, value, onChange } = useGroupedModelCombobox({ const { options, value, onChange } = useGroupedModelCombobox({
modelEntities: t2iAdapterModels, modelConfigs,
onChange: _onChange, onChange: _onChange,
selectedModel: field.value, selectedModel: field.value,
isLoading,
}); });
return ( return (

View File

@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field'; import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types'; import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types'; import type { FieldComponentProps } from './types';
@ -15,7 +15,7 @@ type Props = FieldComponentProps<VAEModelFieldInputInstance, VAEModelFieldInputT
const VAEModelFieldInputComponent = (props: Props) => { const VAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data, isLoading } = useGetVaeModelsQuery(); const [modelConfigs, { isLoading }] = useVAEModels();
const _onChange = useCallback( const _onChange = useCallback(
(value: VAEModelConfig | null) => { (value: VAEModelConfig | null) => {
if (!value) { if (!value) {
@ -32,7 +32,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId] [dispatch, field.name, nodeId]
); );
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data, modelConfigs,
onChange: _onChange, onChange: _onChange,
selectedModel: field.value, selectedModel: field.value,
isLoading, isLoading,

View File

@ -8,8 +8,7 @@ import { modelSelected } from 'features/parameters/store/actions';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; import { useMainModels } from 'services/api/hooks/modelsByType';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/types'; import type { MainModelConfig } from 'services/api/types';
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
@ -18,7 +17,7 @@ const ParamMainModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const selectedModel = useAppSelector(selectModel); const selectedModel = useAppSelector(selectModel);
const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS); const [modelConfigs, { isLoading }] = useMainModels();
const _onChange = useCallback( const _onChange = useCallback(
(model: MainModelConfig | null) => { (model: MainModelConfig | null) => {
@ -35,7 +34,7 @@ const ParamMainModelSelect = () => {
); );
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({ const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
data, modelConfigs,
isLoading, isLoading,
selectedModel, selectedModel,
onChange: _onChange, onChange: _onChange,

View File

@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice'; import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types'; import type { VAEModelConfig } from 'services/api/types';
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => { const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
@ -19,7 +19,7 @@ const ParamVAEModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { model, vae } = useAppSelector(selector); const { model, vae } = useAppSelector(selector);
const { data, isLoading } = useGetVaeModelsQuery(); const [modelConfigs, { isLoading }] = useVAEModels();
const getIsDisabled = useCallback( const getIsDisabled = useCallback(
(vae: VAEModelConfig): boolean => { (vae: VAEModelConfig): boolean => {
const isCompatible = model?.base === vae.base; const isCompatible = model?.base === vae.base;
@ -35,7 +35,7 @@ const ParamVAEModelSelect = () => {
[dispatch] [dispatch]
); );
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({ const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data, modelConfigs,
onChange: _onChange, onChange: _onChange,
selectedModel: vae, selectedModel: vae,
isLoading, isLoading,

View File

@ -11,13 +11,8 @@ import { t } from 'i18next';
import { flatten, map } from 'lodash-es'; import { flatten, map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { import { useGetModelConfigQuery } from 'services/api/endpoints/models';
loraModelsAdapterSelectors, import { useEmbeddingModels, useLoRAModels } from 'services/api/hooks/modelsByType';
textualInversionModelsAdapterSelectors,
useGetLoRAModelsQuery,
useGetModelConfigQuery,
useGetTextualInversionModelsQuery,
} from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types';
const noOptionsMessage = () => t('prompt.noMatchingTriggers'); const noOptionsMessage = () => t('prompt.noMatchingTriggers');
@ -33,8 +28,8 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery( const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
mainModel?.key ?? skipToken mainModel?.key ?? skipToken
); );
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery(); const [loraModels, { isLoading: isLoadingLoRAs }] = useLoRAModels();
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery(); const [tiModels, { isLoading: isLoadingTIs }] = useEmbeddingModels();
const _onChange = useCallback<ComboboxOnChange>( const _onChange = useCallback<ComboboxOnChange>(
(v) => { (v) => {
@ -52,8 +47,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
const _options: GroupBase<ComboboxOption>[] = []; const _options: GroupBase<ComboboxOption>[] = [];
if (tiModels) { if (tiModels) {
const embeddingOptions = textualInversionModelsAdapterSelectors const embeddingOptions = tiModels
.selectAll(tiModels)
.filter((ti) => ti.base === mainModelConfig?.base) .filter((ti) => ti.base === mainModelConfig?.base)
.map((model) => ({ label: model.name, value: `<${model.name}>` })); .map((model) => ({ label: model.name, value: `<${model.name}>` }));
@ -66,8 +60,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
} }
if (loraModels) { if (loraModels) {
const triggerPhraseOptions = loraModelsAdapterSelectors const triggerPhraseOptions = loraModels
.selectAll(loraModels)
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key)) .filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
.map((lora) => { .map((lora) => {
if (lora.trigger_phrases) { if (lora.trigger_phrases) {

View File

@ -7,8 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice'; import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { REFINER_BASE_MODELS } from 'services/api/constants'; import { useRefinerModels } from 'services/api/hooks/modelsByType';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/types'; import type { MainModelConfig } from 'services/api/types';
const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel); const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel);
@ -19,7 +18,7 @@ const ParamSDXLRefinerModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const model = useAppSelector(selectModel); const model = useAppSelector(selectModel);
const { t } = useTranslation(); const { t } = useTranslation();
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS); const [modelConfigs, { isLoading }] = useRefinerModels();
const _onChange = useCallback( const _onChange = useCallback(
(model: MainModelConfig | null) => { (model: MainModelConfig | null) => {
if (!model) { if (!model) {
@ -31,7 +30,7 @@ const ParamSDXLRefinerModelSelect = () => {
[dispatch] [dispatch]
); );
const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({ const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({
modelEntities: data, modelConfigs,
onChange: _onChange, onChange: _onChange,
selectedModel: model, selectedModel: model,
isLoading, isLoading,

View File

@ -1,28 +1,11 @@
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit'; import type { EntityState } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import queryString from 'query-string'; import queryString from 'query-string';
import {
ALL_BASE_MODELS,
NON_REFINER_BASE_MODELS,
NON_SDXL_MAIN_MODELS,
REFINER_BASE_MODELS,
SDXL_MAIN_MODELS,
} from 'services/api/constants';
import type { operations, paths } from 'services/api/schema'; import type { operations, paths } from 'services/api/schema';
import type { import type { AnyModelConfig } from 'services/api/types';
AnyModelConfig,
BaseModelType,
ControlNetModelConfig,
IPAdapterModelConfig,
LoRAModelConfig,
MainModelConfig,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VAEModelConfig,
} from 'services/api/types';
import type { ApiTagDescription, tagTypes } from '..'; import type { ApiTagDescription } from '..';
import { api, buildV2Url, LIST_TAG } from '..'; import { api, buildV2Url, LIST_TAG } from '..';
export type UpdateModelArg = { export type UpdateModelArg = {
@ -40,8 +23,9 @@ type UpdateModelImageResponse =
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json']; paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json']; type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelConfigsResponse = NonNullable<
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>; paths['/api/v2/models/']['get']['responses']['200']['content']['application/json']
>;
type DeleteModelArg = { type DeleteModelArg = {
key: string; key: string;
@ -76,72 +60,11 @@ type GetHuggingFaceModelsResponse =
type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query']; type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({ const modelConfigsAdapter = createEntityAdapter<AnyModelConfig, string>({
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions); export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions);
const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const anyModelConfigAdapter = createEntityAdapter<AnyModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions);
const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: tagType,
id,
}))
);
}
return tags;
};
const buildTransformResponse =
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
(response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models);
};
/** /**
* Builds an endpoint URL for the models router * Builds an endpoint URL for the models router
@ -162,9 +85,27 @@ export const modelsApi = api.injectEndpoints({
}; };
}, },
onQueryStarted: async (_, { dispatch, queryFulfilled }) => { onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => { try {
upsertSingleModelConfig(data, dispatch); const { data } = await queryFulfilled;
});
// Update the individual model query caches
dispatch(modelsApi.util.upsertQueryData('getModelConfig', data.key, data));
const { base, name, type } = data;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, data));
// Update the list query cache
dispatch(
modelsApi.util.updateQueryData('getModelConfigs', undefined, (draft) => {
modelConfigsAdapter.updateOne(draft, {
id: data.key,
changes: data,
});
})
);
} catch {
// no-op
}
}, },
}), }),
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({ updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
@ -294,80 +235,27 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['ModelInstalls'], invalidatesTags: ['ModelInstalls'],
}), }),
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({ getModelConfigs: build.query<EntityState<AnyModelConfig, string>, void>({
query: (base_models) => { query: () => ({ url: buildModelsUrl() }),
const params: ListModelsArg = { providesTags: (result) => {
model_type: 'main', const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }];
base_models, if (result) {
}; const modelTags = result.ids.map((id) => ({ type: 'ModelConfig', id }) as const);
const query = queryString.stringify(params, { arrayFormat: 'none' }); tags.push(...modelTags);
return buildModelsUrl(`?${query}`); }
return tags;
},
keepUnusedDataFor: 60 * 60 * 1000 * 24, // 1 day (infinite)
transformResponse: (response: GetModelConfigsResponse) => {
return modelConfigsAdapter.setAll(modelConfigsAdapter.getInitialState(), response.models);
}, },
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => { onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => { queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch); modelConfigsAdapterSelectors.selectAll(data).forEach((modelConfig) => {
}); dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
}, const { base, name, type } = modelConfig;
}), dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({ });
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
}); });
}, },
}), }),
@ -375,14 +263,8 @@ export const modelsApi = api.injectEndpoints({
}); });
export const { export const {
useGetModelConfigsQuery,
useGetModelConfigQuery, useGetModelConfigQuery,
useGetMainModelsQuery,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useDeleteModelsMutation, useDeleteModelsMutation,
useDeleteModelImageMutation, useDeleteModelImageMutation,
useUpdateModelMutation, useUpdateModelMutation,
@ -396,127 +278,3 @@ export const {
useCancelModelInstallMutation, useCancelModelInstallMutation,
usePruneCompletedModelInstallsMutation, usePruneCompletedModelInstallsMutation,
} = modelsApi; } = modelsApi;
const upsertModelConfigs = (
modelConfigs: EntityState<AnyModelConfig, string>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
dispatch: ThunkDispatch<any, any, UnknownAction>
) => {
/**
* Once a list of models of a specific type is received, fetching any of those models individually is a waste of a
* network request. This function takes the received list of models and upserts them into the individual query caches
* for each model type.
*/
// Iterate over all the models and upsert them into the individual query caches for each model type.
anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => {
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
const { base, name, type } = modelConfig;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
});
};
const upsertSingleModelConfig = (
modelConfig: AnyModelConfig,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
dispatch: ThunkDispatch<any, any, UnknownAction>
) => {
/**
* When a model is updated, the individual query caches for each model type need to be updated, as well as the list
* query caches of models of that type.
*/
// Update the individual model query caches.
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
const { base, name, type } = modelConfig;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
// Update the list query caches for each model type.
if (modelConfig.type === 'main') {
[ALL_BASE_MODELS, NON_REFINER_BASE_MODELS, SDXL_MAIN_MODELS, NON_SDXL_MAIN_MODELS, REFINER_BASE_MODELS].forEach(
(queryArg) => {
dispatch(
modelsApi.util.updateQueryData('getMainModels', queryArg, (draft) => {
mainModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
}
);
return;
}
if (modelConfig.type === 'controlnet') {
dispatch(
modelsApi.util.updateQueryData('getControlNetModels', undefined, (draft) => {
controlNetModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'embedding') {
dispatch(
modelsApi.util.updateQueryData('getTextualInversionModels', undefined, (draft) => {
textualInversionModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'ip_adapter') {
dispatch(
modelsApi.util.updateQueryData('getIPAdapterModels', undefined, (draft) => {
ipAdapterModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'lora') {
dispatch(
modelsApi.util.updateQueryData('getLoRAModels', undefined, (draft) => {
loraModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 't2i_adapter') {
dispatch(
modelsApi.util.updateQueryData('getT2IAdapterModels', undefined, (draft) => {
t2iAdapterModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'vae') {
dispatch(
modelsApi.util.updateQueryData('getVaeModels', undefined, (draft) => {
vaeModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
};

View File

@ -0,0 +1,42 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { useMemo } from 'react';
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isControlNetModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig,
isSDXLMainModelModelConfig,
isT2IAdapterModelConfig,
isTIModelConfig,
isVAEModelConfig,
} from 'services/api/types';
const buildModelsHook =
<T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T) =>
() => {
const result = useGetModelConfigsQuery(undefined);
const modelConfigs = useMemo(() => {
if (!result.data) {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard);
}, [result]);
return [modelConfigs, result] as const;
};
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
export const useVAEModels = buildModelsHook(isVAEModelConfig);

View File

@ -1,12 +1,7 @@
import { REFINER_BASE_MODELS } from 'services/api/constants'; import { useRefinerModels } from 'services/api/hooks/modelsByType';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
export const useIsRefinerAvailable = () => { export const useIsRefinerAvailable = () => {
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, { const [refinerModels] = useRefinerModels();
selectFromResult: ({ data }) => ({
isRefinerAvailable: data ? data.ids.length > 0 : false,
}),
});
return isRefinerAvailable; return Boolean(refinerModels.length);
}; };

View File

@ -48,7 +48,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig']; export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterConfig']; export type IPAdapterModelConfig = S['IPAdapterConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig']; export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
export type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig']; type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
type DiffusersModelConfig = S['MainDiffusersConfig']; type DiffusersModelConfig = S['MainDiffusersConfig'];
type CheckpointModelConfig = S['MainCheckpointConfig']; type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig']; type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
@ -103,6 +103,18 @@ export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is
return config.type === 'main' && config.base === 'sdxl-refiner'; return config.type === 'main' && config.base === 'sdxl-refiner';
}; };
export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'sdxl';
};
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
};
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'embedding';
};
export type ModelInstallJob = S['ModelInstallJob']; export type ModelInstallJob = S['ModelInstallJob'];
export type ModelInstallStatus = S['InstallStatus']; export type ModelInstallStatus = S['InstallStatus'];
@ -200,10 +212,3 @@ export type PostUploadAction =
| CanvasInitialImageAction | CanvasInitialImageAction
| ToastAction | ToastAction
| AddToBatchAction; | AddToBatchAction;
type TypeGuard<T> = {
(input: unknown): input is T;
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type TypeGuardFor<T extends TypeGuard<any>> = T extends TypeGuard<infer U> ? U : never;