mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): single getModelConfigs query
Single query, with simple wrapper hooks (type-safe). Updated everywhere in frontend.
This commit is contained in:
parent
ed20255abf
commit
19d66d5ec7
@ -1,10 +1,10 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import type { AppDispatch, RootState } from 'app/store/store';
|
||||||
|
import type { JSONObject } from 'common/types';
|
||||||
import {
|
import {
|
||||||
controlAdapterModelCleared,
|
controlAdapterModelCleared,
|
||||||
selectAllControlNets,
|
selectControlAdapterAll,
|
||||||
selectAllIPAdapters,
|
|
||||||
selectAllT2IAdapters,
|
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
||||||
@ -12,34 +12,52 @@ import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features
|
|||||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { forEach, some } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
|
import type { Logger } from 'roarr';
|
||||||
import type { TypeGuardFor } from 'services/api/types';
|
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
import { isNonRefinerMainModelConfig, isRefinerMainModelModelConfig, isVAEModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
|
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
|
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
|
||||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
|
||||||
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||||
const log = logger('models');
|
const log = logger('models');
|
||||||
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
|
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
|
||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const currentModel = state.generation.model;
|
const models = modelConfigsAdapterSelectors.selectAll(action.payload);
|
||||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
|
||||||
|
|
||||||
if (models.length === 0) {
|
handleMainModels(models, state, dispatch, log);
|
||||||
|
handleRefinerModels(models, state, dispatch, log);
|
||||||
|
handleVAEModels(models, state, dispatch, log);
|
||||||
|
handleLoRAModels(models, state, dispatch, log);
|
||||||
|
handleControlAdapterModels(models, state, dispatch, log);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
type ModelHandler = (
|
||||||
|
models: AnyModelConfig[],
|
||||||
|
state: RootState,
|
||||||
|
dispatch: AppDispatch,
|
||||||
|
log: Logger<JSONObject>
|
||||||
|
) => undefined;
|
||||||
|
|
||||||
|
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
||||||
|
const currentModel = state.generation.model;
|
||||||
|
const mainModels = models.filter(isNonRefinerMainModelConfig);
|
||||||
|
if (mainModels.length === 0) {
|
||||||
// No models loaded at all
|
// No models loaded at all
|
||||||
dispatch(modelChanged(null));
|
dispatch(modelChanged(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
const isCurrentMainModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
||||||
|
|
||||||
if (isCurrentModelAvailable) {
|
if (isCurrentMainModelAvailable) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,54 +92,43 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
|||||||
}
|
}
|
||||||
|
|
||||||
dispatch(modelChanged(result.data, currentModel));
|
dispatch(modelChanged(result.data, currentModel));
|
||||||
},
|
};
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
|
|
||||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) && action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `SDXL Refiner models loaded (${action.payload.ids.length})`);
|
|
||||||
|
|
||||||
const currentModel = getState().sdxl.refinerModel;
|
|
||||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
|
||||||
|
|
||||||
|
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||||
|
const currentRefinerModel = state.sdxl.refinerModel;
|
||||||
|
const refinerModels = models.filter(isRefinerMainModelModelConfig);
|
||||||
if (models.length === 0) {
|
if (models.length === 0) {
|
||||||
// No models loaded at all
|
// No models loaded at all
|
||||||
dispatch(refinerModelChanged(null));
|
dispatch(refinerModelChanged(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
const isCurrentRefinerModelAvailable = currentRefinerModel
|
||||||
|
? refinerModels.some((m) => m.key === currentRefinerModel.key)
|
||||||
|
: false;
|
||||||
|
|
||||||
if (!isCurrentModelAvailable) {
|
if (!isCurrentRefinerModelAvailable) {
|
||||||
dispatch(refinerModelChanged(null));
|
dispatch(refinerModelChanged(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
},
|
};
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
// VAEs loaded, need to reset the VAE is it's no longer available
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `VAEs loaded (${action.payload.ids.length})`);
|
|
||||||
|
|
||||||
const currentVae = getState().generation.vae;
|
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||||
|
const currentVae = state.generation.vae;
|
||||||
|
|
||||||
if (currentVae === null) {
|
if (currentVae === null) {
|
||||||
// null is a valid VAE! it means "use the default with the main model"
|
// null is a valid VAE! it means "use the default with the main model"
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
const vaeModels = models.filter(isVAEModelConfig);
|
||||||
|
|
||||||
const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key);
|
const isCurrentVAEAvailable = vaeModels.some((m) => m.key === currentVae.key);
|
||||||
|
|
||||||
if (isCurrentVAEAvailable) {
|
if (isCurrentVAEAvailable) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const firstModel = vaeModelsAdapterSelectors.selectAll(action.payload)[0];
|
const firstModel = vaeModels[0];
|
||||||
|
|
||||||
if (!firstModel) {
|
if (!firstModel) {
|
||||||
// No custom VAEs loaded at all; use the default
|
// No custom VAEs loaded at all; use the default
|
||||||
@ -137,19 +144,13 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
|||||||
}
|
}
|
||||||
|
|
||||||
dispatch(vaeSelected(result.data));
|
dispatch(vaeSelected(result.data));
|
||||||
},
|
};
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled,
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
// LoRA models loaded - need to remove missing LoRAs from state
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `LoRAs loaded (${action.payload.ids.length})`);
|
|
||||||
|
|
||||||
const loras = getState().lora.loras;
|
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||||
|
const loras = state.lora.loras;
|
||||||
|
|
||||||
forEach(loras, (lora, id) => {
|
forEach(loras, (lora, id) => {
|
||||||
const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.model.key);
|
const isLoRAAvailable = models.some((m) => m.key === lora.model.key);
|
||||||
|
|
||||||
if (isLoRAAvailable) {
|
if (isLoRAAvailable) {
|
||||||
return;
|
return;
|
||||||
@ -157,17 +158,11 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
|||||||
|
|
||||||
dispatch(loraRemoved(id));
|
dispatch(loraRemoved(id));
|
||||||
});
|
});
|
||||||
},
|
};
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`);
|
|
||||||
|
|
||||||
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
|
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
|
||||||
|
const isModelAvailable = models.some((m) => m.key === ca.model?.key);
|
||||||
|
|
||||||
if (isModelAvailable) {
|
if (isModelAvailable) {
|
||||||
return;
|
return;
|
||||||
@ -175,49 +170,4 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
|||||||
|
|
||||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
||||||
});
|
});
|
||||||
},
|
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled,
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`);
|
|
||||||
|
|
||||||
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
|
|
||||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
|
||||||
|
|
||||||
if (isModelAvailable) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
matcher: modelsApi.endpoints.getIPAdapterModels.matchFulfilled,
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`);
|
|
||||||
|
|
||||||
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
|
|
||||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
|
||||||
|
|
||||||
if (isModelAvailable) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled,
|
|
||||||
effect: async (action) => {
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `Embeddings loaded (${action.payload.ids.length})`);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
@ -23,8 +23,7 @@ import {
|
|||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { map } from 'lodash-es';
|
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||||
@ -39,7 +38,12 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
|
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
|
||||||
|
const data = await request.unwrap();
|
||||||
|
request.unsubscribe();
|
||||||
|
const models = modelConfigsAdapterSelectors.selectAll(data);
|
||||||
|
|
||||||
|
const modelConfig = models.find((model) => model.key === currentModel.key);
|
||||||
|
|
||||||
if (!modelConfig) {
|
if (!modelConfig) {
|
||||||
return;
|
return;
|
||||||
@ -55,11 +59,8 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
|||||||
if (vae === 'default') {
|
if (vae === 'default') {
|
||||||
dispatch(vaeSelected(null));
|
dispatch(vaeSelected(null));
|
||||||
} else {
|
} else {
|
||||||
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
|
const vaeModel = models.find((model) => model.key === vae);
|
||||||
const vaeArray = map(data?.entities);
|
const result = zParameterVAEModel.safeParse(vaeModel);
|
||||||
const validVae = vaeArray.find((model) => model.key === vae);
|
|
||||||
|
|
||||||
const result = zParameterVAEModel.safeParse(validVae);
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -30,11 +30,10 @@ export const addSocketConnectedEventListener = (startAppListening: AppStartListe
|
|||||||
|
|
||||||
// Bail on the recovery logic if this is the first connection - we don't need to recover anything
|
// Bail on the recovery logic if this is the first connection - we don't need to recover anything
|
||||||
if ($isFirstConnection.get()) {
|
if ($isFirstConnection.get()) {
|
||||||
// The TI models are used in a component that is not always rendered, so when users open the prompt triggers
|
// Populate the model configs on first connection. This query cache has a 24hr timeout, so we can immediately
|
||||||
// box has a delay while it does the initial fetch. We need to both pre-fetch the data and maintain an RTK
|
// unsubscribe.
|
||||||
// Query subscription to it, so the cache doesn't clear itself when the user closes the prompt triggers box.
|
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
|
||||||
// So, we explicitly do not unsubscribe from this query!
|
request.unsubscribe();
|
||||||
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
|
|
||||||
|
|
||||||
$isFirstConnection.set(false);
|
$isFirstConnection.set(false);
|
||||||
return;
|
return;
|
||||||
|
@ -1,15 +1,14 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { GroupBase } from 'chakra-react-select';
|
import type { GroupBase } from 'chakra-react-select';
|
||||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { groupBy, map, reduce } from 'lodash-es';
|
import { groupBy, reduce } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||||
modelEntities: EntityState<T, string> | undefined;
|
modelConfigs: T[];
|
||||||
selectedModel?: ModelIdentifierField | null;
|
selectedModel?: ModelIdentifierField | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
getIsDisabled?: (model: T) => boolean;
|
getIsDisabled?: (model: T) => boolean;
|
||||||
@ -29,13 +28,12 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
|||||||
): UseGroupedModelComboboxReturn => {
|
): UseGroupedModelComboboxReturn => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
|
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
|
||||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
||||||
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
||||||
if (!modelEntities) {
|
if (!modelConfigs) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
const modelEntitiesArray = map(modelEntities.entities);
|
const groupedModels = groupBy(modelConfigs, 'base');
|
||||||
const groupedModels = groupBy(modelEntitiesArray, 'base');
|
|
||||||
const _options = reduce(
|
const _options = reduce(
|
||||||
groupedModels,
|
groupedModels,
|
||||||
(acc, val, label) => {
|
(acc, val, label) => {
|
||||||
@ -53,7 +51,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
|||||||
);
|
);
|
||||||
_options.sort((a) => (a.label === base_model ? -1 : 1));
|
_options.sort((a) => (a.label === base_model ? -1 : 1));
|
||||||
return _options;
|
return _options;
|
||||||
}, [getIsDisabled, modelEntities, base_model]);
|
}, [getIsDisabled, modelConfigs, base_model]);
|
||||||
|
|
||||||
const value = useMemo(
|
const value = useMemo(
|
||||||
() =>
|
() =>
|
||||||
@ -67,14 +65,14 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
|||||||
onChange(null);
|
onChange(null);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const model = modelEntities?.entities[v.value];
|
const model = modelConfigs.find((m) => m.key === v.value);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
onChange(null);
|
onChange(null);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
onChange(model);
|
onChange(model);
|
||||||
},
|
},
|
||||||
[modelEntities?.entities, onChange]
|
[modelConfigs, onChange]
|
||||||
);
|
);
|
||||||
|
|
||||||
const placeholder = useMemo(() => {
|
const placeholder = useMemo(() => {
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
|
||||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { map } from 'lodash-es';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
||||||
modelEntities: EntityState<T, string> | undefined;
|
modelConfigs: T[];
|
||||||
selectedModel?: ModelIdentifierField | null;
|
selectedModel?: ModelIdentifierField | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
getIsDisabled?: (model: T) => boolean;
|
getIsDisabled?: (model: T) => boolean;
|
||||||
@ -25,19 +23,14 @@ type UseModelComboboxReturn = {
|
|||||||
|
|
||||||
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
|
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
|
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
|
||||||
const options = useMemo<ComboboxOption[]>(() => {
|
const options = useMemo<ComboboxOption[]>(() => {
|
||||||
if (!modelEntities) {
|
return modelConfigs.filter(optionsFilter).map((model) => ({
|
||||||
return [];
|
|
||||||
}
|
|
||||||
return map(modelEntities.entities)
|
|
||||||
.filter(optionsFilter)
|
|
||||||
.map((model) => ({
|
|
||||||
label: model.name,
|
label: model.name,
|
||||||
value: model.key,
|
value: model.key,
|
||||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||||
}));
|
}));
|
||||||
}, [optionsFilter, getIsDisabled, modelEntities]);
|
}, [optionsFilter, getIsDisabled, modelConfigs]);
|
||||||
|
|
||||||
const value = useMemo(
|
const value = useMemo(
|
||||||
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
|
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
|
||||||
@ -50,14 +43,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
|
|||||||
onChange(null);
|
onChange(null);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const model = modelEntities?.entities[v.value];
|
const model = modelConfigs.find((m) => m.key === v.value);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
onChange(null);
|
onChange(null);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
onChange(model);
|
onChange(model);
|
||||||
},
|
},
|
||||||
[modelEntities?.entities, onChange]
|
[modelConfigs, onChange]
|
||||||
);
|
);
|
||||||
|
|
||||||
const placeholder = useMemo(() => {
|
const placeholder = useMemo(() => {
|
||||||
|
@ -1,15 +1,12 @@
|
|||||||
import type { Item } from '@invoke-ai/ui-library';
|
import type { Item } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
|
||||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
|
||||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||||
import { filter } from 'lodash-es';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
||||||
data: EntityState<T, string> | undefined;
|
modelConfigs: T[];
|
||||||
isLoading: boolean;
|
isLoading: boolean;
|
||||||
selectedModel?: ModelIdentifierField | null;
|
selectedModel?: ModelIdentifierField | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
@ -28,7 +25,7 @@ const modelFilterDefault = () => true;
|
|||||||
const isModelDisabledDefault = () => false;
|
const isModelDisabledDefault = () => false;
|
||||||
|
|
||||||
export const useModelCustomSelect = <T extends AnyModelConfig>({
|
export const useModelCustomSelect = <T extends AnyModelConfig>({
|
||||||
data,
|
modelConfigs,
|
||||||
isLoading,
|
isLoading,
|
||||||
selectedModel,
|
selectedModel,
|
||||||
onChange,
|
onChange,
|
||||||
@ -39,30 +36,28 @@ export const useModelCustomSelect = <T extends AnyModelConfig>({
|
|||||||
|
|
||||||
const items: Item[] = useMemo(
|
const items: Item[] = useMemo(
|
||||||
() =>
|
() =>
|
||||||
data
|
modelConfigs.filter(modelFilter).map<Item>((m) => ({
|
||||||
? filter(data.entities, modelFilter).map<Item>((m) => ({
|
|
||||||
label: m.name,
|
label: m.name,
|
||||||
value: m.key,
|
value: m.key,
|
||||||
description: m.description,
|
description: m.description,
|
||||||
group: MODEL_TYPE_SHORT_MAP[m.base],
|
group: MODEL_TYPE_SHORT_MAP[m.base],
|
||||||
isDisabled: isModelDisabled(m),
|
isDisabled: isModelDisabled(m),
|
||||||
}))
|
})),
|
||||||
: EMPTY_ARRAY,
|
[modelConfigs, isModelDisabled, modelFilter]
|
||||||
[data, isModelDisabled, modelFilter]
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(item: Item | null) => {
|
(item: Item | null) => {
|
||||||
if (!item || !data) {
|
if (!item || !modelConfigs) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const model = data.entities[item.value];
|
const model = modelConfigs.find((m) => m.key === item.value);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
onChange(model);
|
onChange(model);
|
||||||
},
|
},
|
||||||
[data, onChange]
|
[modelConfigs, onChange]
|
||||||
);
|
);
|
||||||
|
|
||||||
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);
|
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);
|
||||||
|
@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
||||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||||
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
@ -20,7 +20,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
|
|
||||||
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
const [modelConfigs, { isLoading }] = useControlAdapterModels(controlAdapterType);
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
||||||
@ -43,7 +43,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||||
data,
|
modelConfigs,
|
||||||
isLoading,
|
isLoading,
|
||||||
selectedModel,
|
selectedModel,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||||
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { useControlAdapterModels } from './useControlAdapterModels';
|
|
||||||
|
|
||||||
export const useAddControlAdapter = (type: ControlAdapterType) => {
|
export const useAddControlAdapter = (type: ControlAdapterType) => {
|
||||||
const baseModel = useAppSelector((s) => s.generation.model?.base);
|
const baseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const models = useControlAdapterModels(type);
|
const [models] = useControlAdapterModels(type);
|
||||||
|
|
||||||
const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
|
const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
|
||||||
// prefer to use a model that matches the base model
|
// prefer to use a model that matches the base model
|
||||||
|
@ -1,26 +0,0 @@
|
|||||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
|
||||||
import {
|
|
||||||
useGetControlNetModelsQuery,
|
|
||||||
useGetIPAdapterModelsQuery,
|
|
||||||
useGetT2IAdapterModelsQuery,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
export const useControlAdapterModelQuery = (type: ControlAdapterType) => {
|
|
||||||
const controlNetModelsQuery = useGetControlNetModelsQuery();
|
|
||||||
const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery();
|
|
||||||
const ipAdapterModelsQuery = useGetIPAdapterModelsQuery();
|
|
||||||
|
|
||||||
if (type === 'controlnet') {
|
|
||||||
return controlNetModelsQuery;
|
|
||||||
}
|
|
||||||
if (type === 't2i_adapter') {
|
|
||||||
return t2iAdapterModelsQuery;
|
|
||||||
}
|
|
||||||
if (type === 'ip_adapter') {
|
|
||||||
return ipAdapterModelsQuery;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that the end of the function is not reachable.
|
|
||||||
const exhaustiveCheck: never = type;
|
|
||||||
return exhaustiveCheck;
|
|
||||||
};
|
|
@ -1,31 +1,10 @@
|
|||||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||||
import { useMemo } from 'react';
|
import { useControlNetModels, useIPAdapterModels, useT2IAdapterModels } from 'services/api/hooks/modelsByType';
|
||||||
import {
|
|
||||||
controlNetModelsAdapterSelectors,
|
|
||||||
ipAdapterModelsAdapterSelectors,
|
|
||||||
t2iAdapterModelsAdapterSelectors,
|
|
||||||
useGetControlNetModelsQuery,
|
|
||||||
useGetIPAdapterModelsQuery,
|
|
||||||
useGetT2IAdapterModelsQuery,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
export const useControlAdapterModels = (type?: ControlAdapterType) => {
|
export const useControlAdapterModels = (type: ControlAdapterType) => {
|
||||||
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
|
const controlNetModels = useControlNetModels();
|
||||||
const controlNetModels = useMemo(
|
const t2iAdapterModels = useT2IAdapterModels();
|
||||||
() => (controlNetModelsData ? controlNetModelsAdapterSelectors.selectAll(controlNetModelsData) : []),
|
const ipAdapterModels = useIPAdapterModels();
|
||||||
[controlNetModelsData]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
|
|
||||||
const t2iAdapterModels = useMemo(
|
|
||||||
() => (t2iAdapterModelsData ? t2iAdapterModelsAdapterSelectors.selectAll(t2iAdapterModelsData) : []),
|
|
||||||
[t2iAdapterModelsData]
|
|
||||||
);
|
|
||||||
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
|
|
||||||
const ipAdapterModels = useMemo(
|
|
||||||
() => (ipAdapterModelsData ? ipAdapterModelsAdapterSelectors.selectAll(ipAdapterModelsData) : []),
|
|
||||||
[ipAdapterModelsData]
|
|
||||||
);
|
|
||||||
|
|
||||||
if (type === 'controlnet') {
|
if (type === 'controlnet') {
|
||||||
return controlNetModels;
|
return controlNetModels;
|
||||||
@ -36,5 +15,8 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
|
|||||||
if (type === 'ip_adapter') {
|
if (type === 'ip_adapter') {
|
||||||
return ipAdapterModels;
|
return ipAdapterModels;
|
||||||
}
|
}
|
||||||
return [];
|
|
||||||
|
// Assert that the end of the function is not reachable.
|
||||||
|
const exhaustiveCheck: never = type;
|
||||||
|
return exhaustiveCheck;
|
||||||
};
|
};
|
||||||
|
@ -7,14 +7,14 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|||||||
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { LoRAModelConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
||||||
|
|
||||||
const LoRASelect = () => {
|
const LoRASelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetLoRAModelsQuery();
|
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
@ -37,7 +37,7 @@ const LoRASelect = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { options, onChange } = useGroupedModelCombobox({
|
const { options, onChange } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
getIsDisabled,
|
getIsDisabled,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
});
|
});
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig } from 'app/store/store';
|
import type { PersistConfig } from 'app/store/store';
|
||||||
|
import type { ModelType } from 'services/api/types';
|
||||||
|
|
||||||
|
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
|
||||||
|
|
||||||
type ModelManagerState = {
|
type ModelManagerState = {
|
||||||
_version: 1;
|
_version: 1;
|
||||||
selectedModelKey: string | null;
|
selectedModelKey: string | null;
|
||||||
selectedModelMode: 'edit' | 'view';
|
selectedModelMode: 'edit' | 'view';
|
||||||
searchTerm: string;
|
searchTerm: string;
|
||||||
filteredModelType: string | null;
|
filteredModelType: FilterableModelType | null;
|
||||||
scanPath: string | undefined;
|
scanPath: string | undefined;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -35,7 +38,7 @@ export const modelManagerV2Slice = createSlice({
|
|||||||
state.searchTerm = action.payload;
|
state.searchTerm = action.payload;
|
||||||
},
|
},
|
||||||
|
|
||||||
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
|
setFilteredModelType: (state, action: PayloadAction<FilterableModelType | null>) => {
|
||||||
state.filteredModelType = action.payload;
|
state.filteredModelType = action.payload;
|
||||||
},
|
},
|
||||||
setScanPath: (state, action: PayloadAction<string | undefined>) => {
|
setScanPath: (state, action: PayloadAction<string | undefined>) => {
|
||||||
|
@ -1,122 +1,105 @@
|
|||||||
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
|
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import { forEach } from 'lodash-es';
|
import { memo, useMemo } from 'react';
|
||||||
import { memo } from 'react';
|
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
|
||||||
import {
|
import {
|
||||||
useGetControlNetModelsQuery,
|
useControlNetModels,
|
||||||
useGetIPAdapterModelsQuery,
|
useEmbeddingModels,
|
||||||
useGetLoRAModelsQuery,
|
useIPAdapterModels,
|
||||||
useGetMainModelsQuery,
|
useLoRAModels,
|
||||||
useGetT2IAdapterModelsQuery,
|
useMainModels,
|
||||||
useGetTextualInversionModelsQuery,
|
useT2IAdapterModels,
|
||||||
useGetVaeModelsQuery,
|
useVAEModels,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/hooks/modelsByType';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig, ModelType } from 'services/api/types';
|
||||||
|
|
||||||
import { ModelListWrapper } from './ModelListWrapper';
|
import { ModelListWrapper } from './ModelListWrapper';
|
||||||
|
|
||||||
const ModelList = () => {
|
const ModelList = () => {
|
||||||
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
||||||
|
|
||||||
const { filteredMainModels, isLoadingMainModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
const filteredMainModels = useMemo(
|
||||||
filteredMainModels: modelsFilter(data, searchTerm, filteredModelType),
|
() => modelsFilter(mainModels, searchTerm, filteredModelType),
|
||||||
isLoadingMainModels: isLoading,
|
[mainModels, searchTerm, filteredModelType]
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, {
|
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
|
||||||
filteredLoraModels: modelsFilter(data, searchTerm, filteredModelType),
|
|
||||||
isLoadingLoraModels: isLoading,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredTextualInversionModels, isLoadingTextualInversionModels } = useGetTextualInversionModelsQuery(
|
|
||||||
undefined,
|
|
||||||
{
|
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
|
||||||
filteredTextualInversionModels: modelsFilter(data, searchTerm, filteredModelType),
|
|
||||||
isLoadingTextualInversionModels: isLoading,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const { filteredControlnetModels, isLoadingControlnetModels } = useGetControlNetModelsQuery(undefined, {
|
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
const filteredLoRAModels = useMemo(
|
||||||
filteredControlnetModels: modelsFilter(data, searchTerm, filteredModelType),
|
() => modelsFilter(loraModels, searchTerm, filteredModelType),
|
||||||
isLoadingControlnetModels: isLoading,
|
[loraModels, searchTerm, filteredModelType]
|
||||||
}),
|
);
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredT2iAdapterModels, isLoadingT2IAdapterModels } = useGetT2IAdapterModelsQuery(undefined, {
|
const [embeddingModels, { isLoading: isLoadingEmbeddingModels }] = useEmbeddingModels();
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
const filteredEmbeddingModels = useMemo(
|
||||||
filteredT2iAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
|
() => modelsFilter(embeddingModels, searchTerm, filteredModelType),
|
||||||
isLoadingT2IAdapterModels: isLoading,
|
[embeddingModels, searchTerm, filteredModelType]
|
||||||
}),
|
);
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredIpAdapterModels, isLoadingIpAdapterModels } = useGetIPAdapterModelsQuery(undefined, {
|
const [controlNetModels, { isLoading: isLoadingControlNetModels }] = useControlNetModels();
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
const filteredControlNetModels = useMemo(
|
||||||
filteredIpAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
|
() => modelsFilter(controlNetModels, searchTerm, filteredModelType),
|
||||||
isLoadingIpAdapterModels: isLoading,
|
[controlNetModels, searchTerm, filteredModelType]
|
||||||
}),
|
);
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredVaeModels, isLoadingVaeModels } = useGetVaeModelsQuery(undefined, {
|
const [t2iAdapterModels, { isLoading: isLoadingT2IAdapterModels }] = useT2IAdapterModels();
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
const filteredT2IAdapterModels = useMemo(
|
||||||
filteredVaeModels: modelsFilter(data, searchTerm, filteredModelType),
|
() => modelsFilter(t2iAdapterModels, searchTerm, filteredModelType),
|
||||||
isLoadingVaeModels: isLoading,
|
[t2iAdapterModels, searchTerm, filteredModelType]
|
||||||
}),
|
);
|
||||||
});
|
|
||||||
|
const [ipAdapterModels, { isLoading: isLoadingIPAdapterModels }] = useIPAdapterModels();
|
||||||
|
const filteredIPAdapterModels = useMemo(
|
||||||
|
() => modelsFilter(ipAdapterModels, searchTerm, filteredModelType),
|
||||||
|
[ipAdapterModels, searchTerm, filteredModelType]
|
||||||
|
);
|
||||||
|
|
||||||
|
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels();
|
||||||
|
const filteredVAEModels = useMemo(
|
||||||
|
() => modelsFilter(vaeModels, searchTerm, filteredModelType),
|
||||||
|
[vaeModels, searchTerm, filteredModelType]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
||||||
{/* Main Model List */}
|
{/* Main Model List */}
|
||||||
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main..." />}
|
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main Models..." />}
|
||||||
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
||||||
<ModelListWrapper title="Main" modelList={filteredMainModels} key="main" />
|
<ModelListWrapper title="Main" modelList={filteredMainModels} key="main" />
|
||||||
)}
|
)}
|
||||||
{/* LoRAs List */}
|
{/* LoRAs List */}
|
||||||
{isLoadingLoraModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||||
{!isLoadingLoraModels && filteredLoraModels.length > 0 && (
|
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
|
||||||
<ModelListWrapper title="LoRAs" modelList={filteredLoraModels} key="loras" />
|
<ModelListWrapper title="LoRA" modelList={filteredLoRAModels} key="loras" />
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* TI List */}
|
{/* TI List */}
|
||||||
{isLoadingTextualInversionModels && <FetchingModelsLoader loadingMessage="Loading Textual Inversions..." />}
|
{isLoadingEmbeddingModels && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
|
||||||
{!isLoadingTextualInversionModels && filteredTextualInversionModels.length > 0 && (
|
{!isLoadingEmbeddingModels && filteredEmbeddingModels.length > 0 && (
|
||||||
<ModelListWrapper
|
<ModelListWrapper title="Embedding" modelList={filteredEmbeddingModels} key="textual-inversions" />
|
||||||
title="Textual Inversions"
|
|
||||||
modelList={filteredTextualInversionModels}
|
|
||||||
key="textual-inversions"
|
|
||||||
/>
|
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* VAE List */}
|
{/* VAE List */}
|
||||||
{isLoadingVaeModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
|
{isLoadingVAEModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
|
||||||
{!isLoadingVaeModels && filteredVaeModels.length > 0 && (
|
{!isLoadingVAEModels && filteredVAEModels.length > 0 && (
|
||||||
<ModelListWrapper title="VAEs" modelList={filteredVaeModels} key="vae" />
|
<ModelListWrapper title="VAE" modelList={filteredVAEModels} key="vae" />
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Controlnet List */}
|
{/* Controlnet List */}
|
||||||
{isLoadingControlnetModels && <FetchingModelsLoader loadingMessage="Loading Controlnets..." />}
|
{isLoadingControlNetModels && <FetchingModelsLoader loadingMessage="Loading ControlNets..." />}
|
||||||
{!isLoadingControlnetModels && filteredControlnetModels.length > 0 && (
|
{!isLoadingControlNetModels && filteredControlNetModels.length > 0 && (
|
||||||
<ModelListWrapper title="Controlnets" modelList={filteredControlnetModels} key="controlnets" />
|
<ModelListWrapper title="ControlNet" modelList={filteredControlNetModels} key="controlnets" />
|
||||||
)}
|
)}
|
||||||
{/* IP Adapter List */}
|
{/* IP Adapter List */}
|
||||||
{isLoadingIpAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
|
{isLoadingIPAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
|
||||||
{!isLoadingIpAdapterModels && filteredIpAdapterModels.length > 0 && (
|
{!isLoadingIPAdapterModels && filteredIPAdapterModels.length > 0 && (
|
||||||
<ModelListWrapper title="IP Adapters" modelList={filteredIpAdapterModels} key="ip-adapters" />
|
<ModelListWrapper title="IP Adapter" modelList={filteredIPAdapterModels} key="ip-adapters" />
|
||||||
)}
|
)}
|
||||||
{/* T2I Adapters List */}
|
{/* T2I Adapters List */}
|
||||||
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
|
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
|
||||||
{!isLoadingT2IAdapterModels && filteredT2iAdapterModels.length > 0 && (
|
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
||||||
<ModelListWrapper title="T2I Adapters" modelList={filteredT2iAdapterModels} key="t2i-adapters" />
|
<ModelListWrapper title="T2I Adapter" modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
@ -126,25 +109,16 @@ const ModelList = () => {
|
|||||||
export default memo(ModelList);
|
export default memo(ModelList);
|
||||||
|
|
||||||
const modelsFilter = <T extends AnyModelConfig>(
|
const modelsFilter = <T extends AnyModelConfig>(
|
||||||
data: EntityState<T, string> | undefined,
|
data: T[],
|
||||||
nameFilter: string,
|
nameFilter: string,
|
||||||
filteredModelType: string | null
|
filteredModelType: ModelType | null
|
||||||
): T[] => {
|
): T[] => {
|
||||||
const filteredModels: T[] = [];
|
return data.filter((model) => {
|
||||||
|
|
||||||
forEach(data?.entities, (model) => {
|
|
||||||
if (!model) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
||||||
const matchesType = filteredModelType ? model.type === filteredModelType : true;
|
const matchesType = filteredModelType ? model.type === filteredModelType : true;
|
||||||
|
|
||||||
if (matchesFilter && matchesType) {
|
return matchesFilter && matchesType;
|
||||||
filteredModels.push(model);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
return filteredModels;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
|
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { IoFilter } from 'react-icons/io5';
|
import { PiFunnelBold } from 'react-icons/pi';
|
||||||
|
import { objectKeys } from 'tsafe';
|
||||||
|
|
||||||
const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = {
|
||||||
main: 'Main',
|
main: 'Main',
|
||||||
lora: 'LoRA',
|
lora: 'LoRA',
|
||||||
embedding: 'Textual Inversion',
|
embedding: 'Textual Inversion',
|
||||||
@ -13,7 +15,6 @@ const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
|||||||
vae: 'VAE',
|
vae: 'VAE',
|
||||||
t2i_adapter: 'T2I Adapter',
|
t2i_adapter: 'T2I Adapter',
|
||||||
ip_adapter: 'IP Adapter',
|
ip_adapter: 'IP Adapter',
|
||||||
clip_vision: 'Clip Vision',
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ModelTypeFilter = () => {
|
export const ModelTypeFilter = () => {
|
||||||
@ -22,7 +23,7 @@ export const ModelTypeFilter = () => {
|
|||||||
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
||||||
|
|
||||||
const selectModelType = useCallback(
|
const selectModelType = useCallback(
|
||||||
(option: string) => {
|
(option: FilterableModelType) => {
|
||||||
dispatch(setFilteredModelType(option));
|
dispatch(setFilteredModelType(option));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
@ -34,12 +35,12 @@ export const ModelTypeFilter = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Menu>
|
<Menu>
|
||||||
<MenuButton as={Button} size="sm" leftIcon={<IoFilter />}>
|
<MenuButton as={Button} size="sm" leftIcon={<PiFunnelBold />}>
|
||||||
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
|
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
|
||||||
</MenuButton>
|
</MenuButton>
|
||||||
<MenuList>
|
<MenuList>
|
||||||
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
|
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
|
||||||
{Object.keys(MODEL_TYPE_LABELS).map((option) => (
|
{objectKeys(MODEL_TYPE_LABELS).map((option) => (
|
||||||
<MenuItem
|
<MenuItem
|
||||||
key={option}
|
key={option}
|
||||||
bg={filteredModelType === option ? 'base.700' : 'transparent'}
|
bg={filteredModelType === option ? 'base.700' : 'transparent'}
|
||||||
|
@ -4,12 +4,12 @@ import { skipToken } from '@reduxjs/toolkit/query';
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
import { map } from 'lodash-es';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||||
|
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||||
|
|
||||||
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
|
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
|
||||||
|
|
||||||
@ -21,18 +21,16 @@ export function DefaultVae(props: UseControllerProps<MainModelDefaultSettingsFor
|
|||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||||
|
|
||||||
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
|
const [vaeModels] = useVAEModels();
|
||||||
selectFromResult: ({ data }) => {
|
const compatibleOptions = useMemo(() => {
|
||||||
const modelArray = map(data?.entities);
|
const compatibleOptions = vaeModels
|
||||||
const compatibleOptions = modelArray
|
|
||||||
.filter((vae) => vae.base === modelData?.base)
|
.filter((vae) => vae.base === modelData?.base)
|
||||||
.map((vae) => ({ label: vae.name, value: vae.key }));
|
.map((vae) => ({ label: vae.name, value: vae.key }));
|
||||||
|
|
||||||
const defaultOption = { label: 'Default VAE', value: 'default' };
|
const defaultOption = { label: 'Default VAE', value: 'default' };
|
||||||
|
|
||||||
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
|
return [defaultOption, ...compatibleOptions];
|
||||||
},
|
}, [modelData?.base, vaeModels]);
|
||||||
});
|
|
||||||
|
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
(v) => {
|
(v) => {
|
||||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|||||||
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
import { useControlNetModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { ControlNetModelConfig } from 'services/api/types';
|
import type { ControlNetModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -14,7 +14,7 @@ type Props = FieldComponentProps<ControlNetModelFieldInputInstance, ControlNetMo
|
|||||||
const ControlNetModelFieldInputComponent = (props: Props) => {
|
const ControlNetModelFieldInputComponent = (props: Props) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetControlNetModelsQuery();
|
const [modelConfigs, { isLoading }] = useControlNetModels();
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: ControlNetModelConfig | null) => {
|
(value: ControlNetModelConfig | null) => {
|
||||||
@ -33,7 +33,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|||||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { IPAdapterModelConfig } from 'services/api/types';
|
import type { IPAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -14,7 +14,7 @@ const IPAdapterModelFieldInputComponent = (
|
|||||||
) => {
|
) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
|
const [modelConfigs, { isLoading }] = useIPAdapterModels();
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: IPAdapterModelConfig | null) => {
|
(value: IPAdapterModelConfig | null) => {
|
||||||
@ -33,9 +33,10 @@ const IPAdapterModelFieldInputComponent = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { options, value, onChange } = useGroupedModelCombobox({
|
const { options, value, onChange } = useGroupedModelCombobox({
|
||||||
modelEntities: ipAdapterModels,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
|
isLoading,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|||||||
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { LoRAModelConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -14,7 +14,7 @@ type Props = FieldComponentProps<LoRAModelFieldInputInstance, LoRAModelFieldInpu
|
|||||||
const LoRAModelFieldInputComponent = (props: Props) => {
|
const LoRAModelFieldInputComponent = (props: Props) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetLoRAModelsQuery();
|
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: LoRAModelConfig | null) => {
|
(value: LoRAModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
@ -32,7 +32,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
|
|||||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { NON_SDXL_MAIN_MODELS } from 'services/api/constants';
|
import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { MainModelConfig } from 'services/api/types';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -16,7 +15,7 @@ type Props = FieldComponentProps<MainModelFieldInputInstance, MainModelFieldInpu
|
|||||||
const MainModelFieldInputComponent = (props: Props) => {
|
const MainModelFieldInputComponent = (props: Props) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS);
|
const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: MainModelConfig | null) => {
|
(value: MainModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
@ -33,7 +32,7 @@ const MainModelFieldInputComponent = (props: Props) => {
|
|||||||
[dispatch, field.name, nodeId]
|
[dispatch, field.name, nodeId]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
isLoading,
|
isLoading,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
|
@ -8,8 +8,7 @@ import type {
|
|||||||
SDXLRefinerModelFieldInputTemplate,
|
SDXLRefinerModelFieldInputTemplate,
|
||||||
} from 'features/nodes/types/field';
|
} from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
import { useRefinerModels } from 'services/api/hooks/modelsByType';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { MainModelConfig } from 'services/api/types';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -19,7 +18,7 @@ type Props = FieldComponentProps<SDXLRefinerModelFieldInputInstance, SDXLRefiner
|
|||||||
const RefinerModelFieldInputComponent = (props: Props) => {
|
const RefinerModelFieldInputComponent = (props: Props) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
|
const [modelConfigs, { isLoading }] = useRefinerModels();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: MainModelConfig | null) => {
|
(value: MainModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
@ -36,7 +35,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
|
|||||||
[dispatch, field.name, nodeId]
|
[dispatch, field.name, nodeId]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
isLoading,
|
isLoading,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
|
@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
|
|||||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { SDXL_MAIN_MODELS } from 'services/api/constants';
|
import { useSDXLModels } from 'services/api/hooks/modelsByType';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { MainModelConfig } from 'services/api/types';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -16,7 +15,7 @@ type Props = FieldComponentProps<SDXLMainModelFieldInputInstance, SDXLMainModelF
|
|||||||
const SDXLMainModelFieldInputComponent = (props: Props) => {
|
const SDXLMainModelFieldInputComponent = (props: Props) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS);
|
const [modelConfigs, { isLoading }] = useSDXLModels();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: MainModelConfig | null) => {
|
(value: MainModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
@ -33,7 +32,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
|
|||||||
[dispatch, field.name, nodeId]
|
[dispatch, field.name, nodeId]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
isLoading,
|
isLoading,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|||||||
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { T2IAdapterModelConfig } from 'services/api/types';
|
import type { T2IAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -15,7 +15,7 @@ const T2IAdapterModelFieldInputComponent = (
|
|||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
|
const [modelConfigs, { isLoading }] = useT2IAdapterModels();
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: T2IAdapterModelConfig | null) => {
|
(value: T2IAdapterModelConfig | null) => {
|
||||||
@ -34,9 +34,10 @@ const T2IAdapterModelFieldInputComponent = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { options, value, onChange } = useGroupedModelCombobox({
|
const { options, value, onChange } = useGroupedModelCombobox({
|
||||||
modelEntities: t2iAdapterModels,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
|
isLoading,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
|
|||||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { VAEModelConfig } from 'services/api/types';
|
import type { VAEModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -15,7 +15,7 @@ type Props = FieldComponentProps<VAEModelFieldInputInstance, VAEModelFieldInputT
|
|||||||
const VAEModelFieldInputComponent = (props: Props) => {
|
const VAEModelFieldInputComponent = (props: Props) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetVaeModelsQuery();
|
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: VAEModelConfig | null) => {
|
(value: VAEModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
@ -32,7 +32,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
|||||||
[dispatch, field.name, nodeId]
|
[dispatch, field.name, nodeId]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
@ -8,8 +8,7 @@ import { modelSelected } from 'features/parameters/store/actions';
|
|||||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { MainModelConfig } from 'services/api/types';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
||||||
@ -18,7 +17,7 @@ const ParamMainModelSelect = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const selectedModel = useAppSelector(selectModel);
|
const selectedModel = useAppSelector(selectModel);
|
||||||
const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS);
|
const [modelConfigs, { isLoading }] = useMainModels();
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(model: MainModelConfig | null) => {
|
(model: MainModelConfig | null) => {
|
||||||
@ -35,7 +34,7 @@ const ParamMainModelSelect = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||||
data,
|
modelConfigs,
|
||||||
isLoading,
|
isLoading,
|
||||||
selectedModel,
|
selectedModel,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
|
@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
|||||||
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { VAEModelConfig } from 'services/api/types';
|
import type { VAEModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||||
@ -19,7 +19,7 @@ const ParamVAEModelSelect = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { model, vae } = useAppSelector(selector);
|
const { model, vae } = useAppSelector(selector);
|
||||||
const { data, isLoading } = useGetVaeModelsQuery();
|
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||||
const getIsDisabled = useCallback(
|
const getIsDisabled = useCallback(
|
||||||
(vae: VAEModelConfig): boolean => {
|
(vae: VAEModelConfig): boolean => {
|
||||||
const isCompatible = model?.base === vae.base;
|
const isCompatible = model?.base === vae.base;
|
||||||
@ -35,7 +35,7 @@ const ParamVAEModelSelect = () => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: vae,
|
selectedModel: vae,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
@ -11,13 +11,8 @@ import { t } from 'i18next';
|
|||||||
import { flatten, map } from 'lodash-es';
|
import { flatten, map } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||||
loraModelsAdapterSelectors,
|
import { useEmbeddingModels, useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||||
textualInversionModelsAdapterSelectors,
|
|
||||||
useGetLoRAModelsQuery,
|
|
||||||
useGetModelConfigQuery,
|
|
||||||
useGetTextualInversionModelsQuery,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
||||||
@ -33,8 +28,8 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
|||||||
const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
|
const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
|
||||||
mainModel?.key ?? skipToken
|
mainModel?.key ?? skipToken
|
||||||
);
|
);
|
||||||
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
|
const [loraModels, { isLoading: isLoadingLoRAs }] = useLoRAModels();
|
||||||
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
|
const [tiModels, { isLoading: isLoadingTIs }] = useEmbeddingModels();
|
||||||
|
|
||||||
const _onChange = useCallback<ComboboxOnChange>(
|
const _onChange = useCallback<ComboboxOnChange>(
|
||||||
(v) => {
|
(v) => {
|
||||||
@ -52,8 +47,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
|||||||
const _options: GroupBase<ComboboxOption>[] = [];
|
const _options: GroupBase<ComboboxOption>[] = [];
|
||||||
|
|
||||||
if (tiModels) {
|
if (tiModels) {
|
||||||
const embeddingOptions = textualInversionModelsAdapterSelectors
|
const embeddingOptions = tiModels
|
||||||
.selectAll(tiModels)
|
|
||||||
.filter((ti) => ti.base === mainModelConfig?.base)
|
.filter((ti) => ti.base === mainModelConfig?.base)
|
||||||
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
|
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
|
||||||
|
|
||||||
@ -66,8 +60,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (loraModels) {
|
if (loraModels) {
|
||||||
const triggerPhraseOptions = loraModelsAdapterSelectors
|
const triggerPhraseOptions = loraModels
|
||||||
.selectAll(loraModels)
|
|
||||||
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
|
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
|
||||||
.map((lora) => {
|
.map((lora) => {
|
||||||
if (lora.trigger_phrases) {
|
if (lora.trigger_phrases) {
|
||||||
|
@ -7,8 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
|||||||
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
import { useRefinerModels } from 'services/api/hooks/modelsByType';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { MainModelConfig } from 'services/api/types';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel);
|
const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel);
|
||||||
@ -19,7 +18,7 @@ const ParamSDXLRefinerModelSelect = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const model = useAppSelector(selectModel);
|
const model = useAppSelector(selectModel);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
|
const [modelConfigs, { isLoading }] = useRefinerModels();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(model: MainModelConfig | null) => {
|
(model: MainModelConfig | null) => {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
@ -31,7 +30,7 @@ const ParamSDXLRefinerModelSelect = () => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: model,
|
selectedModel: model,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
@ -1,28 +1,11 @@
|
|||||||
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
import type { EntityState } from '@reduxjs/toolkit';
|
||||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||||
import queryString from 'query-string';
|
import queryString from 'query-string';
|
||||||
import {
|
|
||||||
ALL_BASE_MODELS,
|
|
||||||
NON_REFINER_BASE_MODELS,
|
|
||||||
NON_SDXL_MAIN_MODELS,
|
|
||||||
REFINER_BASE_MODELS,
|
|
||||||
SDXL_MAIN_MODELS,
|
|
||||||
} from 'services/api/constants';
|
|
||||||
import type { operations, paths } from 'services/api/schema';
|
import type { operations, paths } from 'services/api/schema';
|
||||||
import type {
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
AnyModelConfig,
|
|
||||||
BaseModelType,
|
|
||||||
ControlNetModelConfig,
|
|
||||||
IPAdapterModelConfig,
|
|
||||||
LoRAModelConfig,
|
|
||||||
MainModelConfig,
|
|
||||||
T2IAdapterModelConfig,
|
|
||||||
TextualInversionModelConfig,
|
|
||||||
VAEModelConfig,
|
|
||||||
} from 'services/api/types';
|
|
||||||
|
|
||||||
import type { ApiTagDescription, tagTypes } from '..';
|
import type { ApiTagDescription } from '..';
|
||||||
import { api, buildV2Url, LIST_TAG } from '..';
|
import { api, buildV2Url, LIST_TAG } from '..';
|
||||||
|
|
||||||
export type UpdateModelArg = {
|
export type UpdateModelArg = {
|
||||||
@ -40,8 +23,9 @@ type UpdateModelImageResponse =
|
|||||||
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||||
|
type GetModelConfigsResponse = NonNullable<
|
||||||
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
paths['/api/v2/models/']['get']['responses']['200']['content']['application/json']
|
||||||
|
>;
|
||||||
|
|
||||||
type DeleteModelArg = {
|
type DeleteModelArg = {
|
||||||
key: string;
|
key: string;
|
||||||
@ -76,72 +60,11 @@ type GetHuggingFaceModelsResponse =
|
|||||||
|
|
||||||
type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query'];
|
type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query'];
|
||||||
|
|
||||||
const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
const modelConfigsAdapter = createEntityAdapter<AnyModelConfig, string>({
|
||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
|
|
||||||
undefined,
|
|
||||||
getSelectorsOptions
|
|
||||||
);
|
|
||||||
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
|
|
||||||
const anyModelConfigAdapter = createEntityAdapter<AnyModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
|
|
||||||
const buildProvidesTags =
|
|
||||||
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
|
|
||||||
(result: EntityState<TEntity, string> | undefined) => {
|
|
||||||
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
|
|
||||||
if (result) {
|
|
||||||
tags.push(
|
|
||||||
...result.ids.map((id) => ({
|
|
||||||
type: tagType,
|
|
||||||
id,
|
|
||||||
}))
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return tags;
|
|
||||||
};
|
|
||||||
|
|
||||||
const buildTransformResponse =
|
|
||||||
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
|
|
||||||
(response: { models: T[] }) => {
|
|
||||||
return adapter.setAll(adapter.getInitialState(), response.models);
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an endpoint URL for the models router
|
* Builds an endpoint URL for the models router
|
||||||
@ -162,9 +85,27 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
};
|
};
|
||||||
},
|
},
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
queryFulfilled.then(({ data }) => {
|
try {
|
||||||
upsertSingleModelConfig(data, dispatch);
|
const { data } = await queryFulfilled;
|
||||||
|
|
||||||
|
// Update the individual model query caches
|
||||||
|
dispatch(modelsApi.util.upsertQueryData('getModelConfig', data.key, data));
|
||||||
|
|
||||||
|
const { base, name, type } = data;
|
||||||
|
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, data));
|
||||||
|
|
||||||
|
// Update the list query cache
|
||||||
|
dispatch(
|
||||||
|
modelsApi.util.updateQueryData('getModelConfigs', undefined, (draft) => {
|
||||||
|
modelConfigsAdapter.updateOne(draft, {
|
||||||
|
id: data.key,
|
||||||
|
changes: data,
|
||||||
});
|
});
|
||||||
|
})
|
||||||
|
);
|
||||||
|
} catch {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
|
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
|
||||||
@ -294,80 +235,27 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['ModelInstalls'],
|
invalidatesTags: ['ModelInstalls'],
|
||||||
}),
|
}),
|
||||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
getModelConfigs: build.query<EntityState<AnyModelConfig, string>, void>({
|
||||||
query: (base_models) => {
|
query: () => ({ url: buildModelsUrl() }),
|
||||||
const params: ListModelsArg = {
|
providesTags: (result) => {
|
||||||
model_type: 'main',
|
const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }];
|
||||||
base_models,
|
if (result) {
|
||||||
};
|
const modelTags = result.ids.map((id) => ({ type: 'ModelConfig', id }) as const);
|
||||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
tags.push(...modelTags);
|
||||||
return buildModelsUrl(`?${query}`);
|
}
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
keepUnusedDataFor: 60 * 60 * 1000 * 24, // 1 day (infinite)
|
||||||
|
transformResponse: (response: GetModelConfigsResponse) => {
|
||||||
|
return modelConfigsAdapter.setAll(modelConfigsAdapter.getInitialState(), response.models);
|
||||||
},
|
},
|
||||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
|
||||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
queryFulfilled.then(({ data }) => {
|
queryFulfilled.then(({ data }) => {
|
||||||
upsertModelConfigs(data, dispatch);
|
modelConfigsAdapterSelectors.selectAll(data).forEach((modelConfig) => {
|
||||||
|
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
|
||||||
|
const { base, name, type } = modelConfig;
|
||||||
|
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
|
||||||
});
|
});
|
||||||
},
|
|
||||||
}),
|
|
||||||
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
|
||||||
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
|
||||||
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
|
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
|
|
||||||
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
|
|
||||||
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
|
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
|
|
||||||
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
|
|
||||||
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
|
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
|
|
||||||
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
|
|
||||||
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
|
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
|
|
||||||
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
|
|
||||||
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
|
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
|
|
||||||
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
|
|
||||||
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
@ -375,14 +263,8 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
});
|
});
|
||||||
|
|
||||||
export const {
|
export const {
|
||||||
|
useGetModelConfigsQuery,
|
||||||
useGetModelConfigQuery,
|
useGetModelConfigQuery,
|
||||||
useGetMainModelsQuery,
|
|
||||||
useGetControlNetModelsQuery,
|
|
||||||
useGetIPAdapterModelsQuery,
|
|
||||||
useGetT2IAdapterModelsQuery,
|
|
||||||
useGetLoRAModelsQuery,
|
|
||||||
useGetTextualInversionModelsQuery,
|
|
||||||
useGetVaeModelsQuery,
|
|
||||||
useDeleteModelsMutation,
|
useDeleteModelsMutation,
|
||||||
useDeleteModelImageMutation,
|
useDeleteModelImageMutation,
|
||||||
useUpdateModelMutation,
|
useUpdateModelMutation,
|
||||||
@ -396,127 +278,3 @@ export const {
|
|||||||
useCancelModelInstallMutation,
|
useCancelModelInstallMutation,
|
||||||
usePruneCompletedModelInstallsMutation,
|
usePruneCompletedModelInstallsMutation,
|
||||||
} = modelsApi;
|
} = modelsApi;
|
||||||
|
|
||||||
const upsertModelConfigs = (
|
|
||||||
modelConfigs: EntityState<AnyModelConfig, string>,
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
dispatch: ThunkDispatch<any, any, UnknownAction>
|
|
||||||
) => {
|
|
||||||
/**
|
|
||||||
* Once a list of models of a specific type is received, fetching any of those models individually is a waste of a
|
|
||||||
* network request. This function takes the received list of models and upserts them into the individual query caches
|
|
||||||
* for each model type.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Iterate over all the models and upsert them into the individual query caches for each model type.
|
|
||||||
anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => {
|
|
||||||
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
|
|
||||||
const { base, name, type } = modelConfig;
|
|
||||||
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
const upsertSingleModelConfig = (
|
|
||||||
modelConfig: AnyModelConfig,
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
dispatch: ThunkDispatch<any, any, UnknownAction>
|
|
||||||
) => {
|
|
||||||
/**
|
|
||||||
* When a model is updated, the individual query caches for each model type need to be updated, as well as the list
|
|
||||||
* query caches of models of that type.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Update the individual model query caches.
|
|
||||||
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
|
|
||||||
const { base, name, type } = modelConfig;
|
|
||||||
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
|
|
||||||
|
|
||||||
// Update the list query caches for each model type.
|
|
||||||
if (modelConfig.type === 'main') {
|
|
||||||
[ALL_BASE_MODELS, NON_REFINER_BASE_MODELS, SDXL_MAIN_MODELS, NON_SDXL_MAIN_MODELS, REFINER_BASE_MODELS].forEach(
|
|
||||||
(queryArg) => {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getMainModels', queryArg, (draft) => {
|
|
||||||
mainModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelConfig.type === 'controlnet') {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getControlNetModels', undefined, (draft) => {
|
|
||||||
controlNetModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelConfig.type === 'embedding') {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getTextualInversionModels', undefined, (draft) => {
|
|
||||||
textualInversionModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelConfig.type === 'ip_adapter') {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getIPAdapterModels', undefined, (draft) => {
|
|
||||||
ipAdapterModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelConfig.type === 'lora') {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getLoRAModels', undefined, (draft) => {
|
|
||||||
loraModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelConfig.type === 't2i_adapter') {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getT2IAdapterModels', undefined, (draft) => {
|
|
||||||
t2iAdapterModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelConfig.type === 'vae') {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getVaeModels', undefined, (draft) => {
|
|
||||||
vaeModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
42
invokeai/frontend/web/src/services/api/hooks/modelsByType.ts
Normal file
42
invokeai/frontend/web/src/services/api/hooks/modelsByType.ts
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
import {
|
||||||
|
isControlNetModelConfig,
|
||||||
|
isIPAdapterModelConfig,
|
||||||
|
isLoRAModelConfig,
|
||||||
|
isNonRefinerMainModelConfig,
|
||||||
|
isNonSDXLMainModelConfig,
|
||||||
|
isRefinerMainModelModelConfig,
|
||||||
|
isSDXLMainModelModelConfig,
|
||||||
|
isT2IAdapterModelConfig,
|
||||||
|
isTIModelConfig,
|
||||||
|
isVAEModelConfig,
|
||||||
|
} from 'services/api/types';
|
||||||
|
|
||||||
|
const buildModelsHook =
|
||||||
|
<T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T) =>
|
||||||
|
() => {
|
||||||
|
const result = useGetModelConfigsQuery(undefined);
|
||||||
|
const modelConfigs = useMemo(() => {
|
||||||
|
if (!result.data) {
|
||||||
|
return EMPTY_ARRAY;
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard);
|
||||||
|
}, [result]);
|
||||||
|
|
||||||
|
return [modelConfigs, result] as const;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||||
|
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||||
|
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||||
|
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||||
|
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||||
|
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||||
|
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||||
|
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||||
|
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||||
|
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
@ -1,12 +1,7 @@
|
|||||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
import { useRefinerModels } from 'services/api/hooks/modelsByType';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
export const useIsRefinerAvailable = () => {
|
export const useIsRefinerAvailable = () => {
|
||||||
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
|
const [refinerModels] = useRefinerModels();
|
||||||
selectFromResult: ({ data }) => ({
|
|
||||||
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
return isRefinerAvailable;
|
return Boolean(refinerModels.length);
|
||||||
};
|
};
|
||||||
|
@ -48,7 +48,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
|||||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||||
export type IPAdapterModelConfig = S['IPAdapterConfig'];
|
export type IPAdapterModelConfig = S['IPAdapterConfig'];
|
||||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||||
export type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||||
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||||
type CheckpointModelConfig = S['MainCheckpointConfig'];
|
type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||||
@ -103,6 +103,18 @@ export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is
|
|||||||
return config.type === 'main' && config.base === 'sdxl-refiner';
|
return config.type === 'main' && config.base === 'sdxl-refiner';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||||
|
return config.type === 'main' && config.base === 'sdxl';
|
||||||
|
};
|
||||||
|
|
||||||
|
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||||
|
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
|
||||||
|
};
|
||||||
|
|
||||||
|
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||||
|
return config.type === 'embedding';
|
||||||
|
};
|
||||||
|
|
||||||
export type ModelInstallJob = S['ModelInstallJob'];
|
export type ModelInstallJob = S['ModelInstallJob'];
|
||||||
export type ModelInstallStatus = S['InstallStatus'];
|
export type ModelInstallStatus = S['InstallStatus'];
|
||||||
|
|
||||||
@ -200,10 +212,3 @@ export type PostUploadAction =
|
|||||||
| CanvasInitialImageAction
|
| CanvasInitialImageAction
|
||||||
| ToastAction
|
| ToastAction
|
||||||
| AddToBatchAction;
|
| AddToBatchAction;
|
||||||
|
|
||||||
type TypeGuard<T> = {
|
|
||||||
(input: unknown): input is T;
|
|
||||||
};
|
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
export type TypeGuardFor<T extends TypeGuard<any>> = T extends TypeGuard<infer U> ? U : never;
|
|
||||||
|
Loading…
Reference in New Issue
Block a user