From 19d66d5ec7187583c4df778053778179b96e2b23 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 14 Mar 2024 23:37:40 +1100 Subject: [PATCH] feat(ui): single getModelConfigs query Single query, with simple wrapper hooks (type-safe). Updated everywhere in frontend. --- .../listeners/modelsLoaded.ts | 348 ++++++++---------- .../listeners/setDefaultSettings.ts | 17 +- .../listeners/socketio/socketConnected.ts | 9 +- .../common/hooks/useGroupedModelCombobox.ts | 18 +- .../web/src/common/hooks/useModelCombobox.ts | 27 +- .../src/common/hooks/useModelCustomSelect.ts | 31 +- .../parameters/ParamControlAdapterModel.tsx | 6 +- .../hooks/useAddControlAdapter.ts | 5 +- .../hooks/useControlAdapterModelQuery.ts | 26 -- .../hooks/useControlAdapterModels.ts | 36 +- .../features/lora/components/LoRASelect.tsx | 6 +- .../store/modelManagerV2Slice.ts | 7 +- .../subpanels/ModelManagerPanel/ModelList.tsx | 162 ++++---- .../ModelManagerPanel/ModelTypeFilter.tsx | 13 +- .../MainModelDefaultSettings/DefaultVae.tsx | 22 +- .../ControlNetModelFieldInputComponent.tsx | 6 +- .../IPAdapterModelFieldInputComponent.tsx | 7 +- .../inputs/LoRAModelFieldInputComponent.tsx | 6 +- .../inputs/MainModelFieldInputComponent.tsx | 7 +- .../RefinerModelFieldInputComponent.tsx | 7 +- .../SDXLMainModelFieldInputComponent.tsx | 7 +- .../T2IAdapterModelFieldInputComponent.tsx | 7 +- .../inputs/VAEModelFieldInputComponent.tsx | 6 +- .../MainModel/ParamMainModelSelect.tsx | 7 +- .../VAEModel/ParamVAEModelSelect.tsx | 6 +- .../features/prompt/PromptTriggerSelect.tsx | 19 +- .../ParamSDXLRefinerModelSelect.tsx | 7 +- .../web/src/services/api/endpoints/models.ts | 338 +++-------------- .../src/services/api/hooks/modelsByType.ts | 42 +++ .../api/hooks/useIsRefinerAvailable.ts | 11 +- .../frontend/web/src/services/api/types.ts | 21 +- 31 files changed, 447 insertions(+), 790 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelQuery.ts create mode 100644 invokeai/frontend/web/src/services/api/hooks/modelsByType.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index c3903ed317..52fc2ede4a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import type { AppDispatch, RootState } from 'app/store/store'; +import type { JSONObject } from 'common/types'; import { controlAdapterModelCleared, - selectAllControlNets, - selectAllIPAdapters, - selectAllT2IAdapters, + selectControlAdapterAll, } from 'features/controlAdapters/store/controlAdaptersSlice'; import { loraRemoved } from 'features/lora/store/loraSlice'; import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize'; @@ -12,212 +12,162 @@ import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas'; import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension'; import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice'; -import { forEach, some } from 'lodash-es'; -import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models'; -import type { TypeGuardFor } from 'services/api/types'; +import { forEach } from 'lodash-es'; +import type { Logger } from 'roarr'; +import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; +import { isNonRefinerMainModelConfig, isRefinerMainModelModelConfig, isVAEModelConfig } from 'services/api/types'; export const addModelsLoadedListener = (startAppListening: AppStartListening) => { startAppListening({ - predicate: (action): action is TypeGuardFor => - modelsApi.endpoints.getMainModels.matchFulfilled(action) && - !action.meta.arg.originalArgs.includes('sdxl-refiner'), + predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled, effect: async (action, { getState, dispatch }) => { // models loaded, we need to ensure the selected model is available and if not, select the first one const log = logger('models'); - log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`); + log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`); const state = getState(); - const currentModel = state.generation.model; - const models = mainModelsAdapterSelectors.selectAll(action.payload); + const models = modelConfigsAdapterSelectors.selectAll(action.payload); - if (models.length === 0) { - // No models loaded at all - dispatch(modelChanged(null)); - return; - } - - const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false; - - if (isCurrentModelAvailable) { - return; - } - - const defaultModel = state.config.sd.defaultModel; - const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false; - - if (defaultModelInList) { - const result = zParameterModel.safeParse(defaultModelInList); - if (result.success) { - dispatch(modelChanged(defaultModelInList, currentModel)); - - const optimalDimension = getOptimalDimension(defaultModelInList); - if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) { - return; - } - const { width, height } = calculateNewSize( - state.generation.aspectRatio.value, - optimalDimension * optimalDimension - ); - - dispatch(widthChanged(width)); - dispatch(heightChanged(height)); - return; - } - } - - const result = zParameterModel.safeParse(models[0]); - - if (!result.success) { - log.error({ error: result.error.format() }, 'Failed to parse main model'); - return; - } - - dispatch(modelChanged(result.data, currentModel)); - }, - }); - startAppListening({ - predicate: (action): action is TypeGuardFor => - modelsApi.endpoints.getMainModels.matchFulfilled(action) && action.meta.arg.originalArgs.includes('sdxl-refiner'), - effect: async (action, { getState, dispatch }) => { - // models loaded, we need to ensure the selected model is available and if not, select the first one - const log = logger('models'); - log.info({ models: action.payload.entities }, `SDXL Refiner models loaded (${action.payload.ids.length})`); - - const currentModel = getState().sdxl.refinerModel; - const models = mainModelsAdapterSelectors.selectAll(action.payload); - - if (models.length === 0) { - // No models loaded at all - dispatch(refinerModelChanged(null)); - return; - } - - const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false; - - if (!isCurrentModelAvailable) { - dispatch(refinerModelChanged(null)); - return; - } - }, - }); - startAppListening({ - matcher: modelsApi.endpoints.getVaeModels.matchFulfilled, - effect: async (action, { getState, dispatch }) => { - // VAEs loaded, need to reset the VAE is it's no longer available - const log = logger('models'); - log.info({ models: action.payload.entities }, `VAEs loaded (${action.payload.ids.length})`); - - const currentVae = getState().generation.vae; - - if (currentVae === null) { - // null is a valid VAE! it means "use the default with the main model" - return; - } - - const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key); - - if (isCurrentVAEAvailable) { - return; - } - - const firstModel = vaeModelsAdapterSelectors.selectAll(action.payload)[0]; - - if (!firstModel) { - // No custom VAEs loaded at all; use the default - dispatch(vaeSelected(null)); - return; - } - - const result = zParameterVAEModel.safeParse(firstModel); - - if (!result.success) { - log.error({ error: result.error.format() }, 'Failed to parse VAE model'); - return; - } - - dispatch(vaeSelected(result.data)); - }, - }); - startAppListening({ - matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled, - effect: async (action, { getState, dispatch }) => { - // LoRA models loaded - need to remove missing LoRAs from state - const log = logger('models'); - log.info({ models: action.payload.entities }, `LoRAs loaded (${action.payload.ids.length})`); - - const loras = getState().lora.loras; - - forEach(loras, (lora, id) => { - const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.model.key); - - if (isLoRAAvailable) { - return; - } - - dispatch(loraRemoved(id)); - }); - }, - }); - startAppListening({ - matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled, - effect: async (action, { getState, dispatch }) => { - // ControlNet models loaded - need to remove missing ControlNets from state - const log = logger('models'); - log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`); - - selectAllControlNets(getState().controlAdapters).forEach((ca) => { - const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key); - - if (isModelAvailable) { - return; - } - - dispatch(controlAdapterModelCleared({ id: ca.id })); - }); - }, - }); - startAppListening({ - matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled, - effect: async (action, { getState, dispatch }) => { - // ControlNet models loaded - need to remove missing ControlNets from state - const log = logger('models'); - log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`); - - selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => { - const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key); - - if (isModelAvailable) { - return; - } - - dispatch(controlAdapterModelCleared({ id: ca.id })); - }); - }, - }); - startAppListening({ - matcher: modelsApi.endpoints.getIPAdapterModels.matchFulfilled, - effect: async (action, { getState, dispatch }) => { - // ControlNet models loaded - need to remove missing ControlNets from state - const log = logger('models'); - log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`); - - selectAllIPAdapters(getState().controlAdapters).forEach((ca) => { - const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key); - - if (isModelAvailable) { - return; - } - - dispatch(controlAdapterModelCleared({ id: ca.id })); - }); - }, - }); - startAppListening({ - matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled, - effect: async (action) => { - const log = logger('models'); - log.info({ models: action.payload.entities }, `Embeddings loaded (${action.payload.ids.length})`); + handleMainModels(models, state, dispatch, log); + handleRefinerModels(models, state, dispatch, log); + handleVAEModels(models, state, dispatch, log); + handleLoRAModels(models, state, dispatch, log); + handleControlAdapterModels(models, state, dispatch, log); }, }); }; + +type ModelHandler = ( + models: AnyModelConfig[], + state: RootState, + dispatch: AppDispatch, + log: Logger +) => undefined; + +const handleMainModels: ModelHandler = (models, state, dispatch, log) => { + const currentModel = state.generation.model; + const mainModels = models.filter(isNonRefinerMainModelConfig); + if (mainModels.length === 0) { + // No models loaded at all + dispatch(modelChanged(null)); + return; + } + + const isCurrentMainModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false; + + if (isCurrentMainModelAvailable) { + return; + } + + const defaultModel = state.config.sd.defaultModel; + const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false; + + if (defaultModelInList) { + const result = zParameterModel.safeParse(defaultModelInList); + if (result.success) { + dispatch(modelChanged(defaultModelInList, currentModel)); + + const optimalDimension = getOptimalDimension(defaultModelInList); + if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) { + return; + } + const { width, height } = calculateNewSize( + state.generation.aspectRatio.value, + optimalDimension * optimalDimension + ); + + dispatch(widthChanged(width)); + dispatch(heightChanged(height)); + return; + } + } + + const result = zParameterModel.safeParse(models[0]); + + if (!result.success) { + log.error({ error: result.error.format() }, 'Failed to parse main model'); + return; + } + + dispatch(modelChanged(result.data, currentModel)); +}; + +const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => { + const currentRefinerModel = state.sdxl.refinerModel; + const refinerModels = models.filter(isRefinerMainModelModelConfig); + if (models.length === 0) { + // No models loaded at all + dispatch(refinerModelChanged(null)); + return; + } + + const isCurrentRefinerModelAvailable = currentRefinerModel + ? refinerModels.some((m) => m.key === currentRefinerModel.key) + : false; + + if (!isCurrentRefinerModelAvailable) { + dispatch(refinerModelChanged(null)); + return; + } +}; + +const handleVAEModels: ModelHandler = (models, state, dispatch, log) => { + const currentVae = state.generation.vae; + + if (currentVae === null) { + // null is a valid VAE! it means "use the default with the main model" + return; + } + const vaeModels = models.filter(isVAEModelConfig); + + const isCurrentVAEAvailable = vaeModels.some((m) => m.key === currentVae.key); + + if (isCurrentVAEAvailable) { + return; + } + + const firstModel = vaeModels[0]; + + if (!firstModel) { + // No custom VAEs loaded at all; use the default + dispatch(vaeSelected(null)); + return; + } + + const result = zParameterVAEModel.safeParse(firstModel); + + if (!result.success) { + log.error({ error: result.error.format() }, 'Failed to parse VAE model'); + return; + } + + dispatch(vaeSelected(result.data)); +}; + +const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => { + const loras = state.lora.loras; + + forEach(loras, (lora, id) => { + const isLoRAAvailable = models.some((m) => m.key === lora.model.key); + + if (isLoRAAvailable) { + return; + } + + dispatch(loraRemoved(id)); + }); +}; + +const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => { + selectControlAdapterAll(state.controlAdapters).forEach((ca) => { + const isModelAvailable = models.some((m) => m.key === ca.model?.key); + + if (isModelAvailable) { + return; + } + + dispatch(controlAdapterModelCleared({ id: ca.id })); + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts index 06deb7c02f..dad01d3bdc 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts @@ -23,8 +23,7 @@ import { import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; import { t } from 'i18next'; -import { map } from 'lodash-es'; -import { modelsApi } from 'services/api/endpoints/models'; +import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models'; import { isNonRefinerMainModelConfig } from 'services/api/types'; export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => { @@ -39,7 +38,12 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni return; } - const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap(); + const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate()); + const data = await request.unwrap(); + request.unsubscribe(); + const models = modelConfigsAdapterSelectors.selectAll(data); + + const modelConfig = models.find((model) => model.key === currentModel.key); if (!modelConfig) { return; @@ -55,11 +59,8 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni if (vae === 'default') { dispatch(vaeSelected(null)); } else { - const { data } = modelsApi.endpoints.getVaeModels.select()(state); - const vaeArray = map(data?.entities); - const validVae = vaeArray.find((model) => model.key === vae); - - const result = zParameterVAEModel.safeParse(validVae); + const vaeModel = models.find((model) => model.key === vae); + const result = zParameterVAEModel.safeParse(vaeModel); if (!result.success) { return; } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index 2540362201..0b2644f124 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -30,11 +30,10 @@ export const addSocketConnectedEventListener = (startAppListening: AppStartListe // Bail on the recovery logic if this is the first connection - we don't need to recover anything if ($isFirstConnection.get()) { - // The TI models are used in a component that is not always rendered, so when users open the prompt triggers - // box has a delay while it does the initial fetch. We need to both pre-fetch the data and maintain an RTK - // Query subscription to it, so the cache doesn't clear itself when the user closes the prompt triggers box. - // So, we explicitly do not unsubscribe from this query! - dispatch(modelsApi.endpoints.getTextualInversionModels.initiate()); + // Populate the model configs on first connection. This query cache has a 24hr timeout, so we can immediately + // unsubscribe. + const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate()); + request.unsubscribe(); $isFirstConnection.set(false); return; diff --git a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts index ee9da8ea66..55887eb3be 100644 --- a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts @@ -1,15 +1,14 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; -import type { EntityState } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import type { GroupBase } from 'chakra-react-select'; import type { ModelIdentifierField } from 'features/nodes/types/common'; -import { groupBy, map, reduce } from 'lodash-es'; +import { groupBy, reduce } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import type { AnyModelConfig } from 'services/api/types'; type UseGroupedModelComboboxArg = { - modelEntities: EntityState | undefined; + modelConfigs: T[]; selectedModel?: ModelIdentifierField | null; onChange: (value: T | null) => void; getIsDisabled?: (model: T) => boolean; @@ -29,13 +28,12 @@ export const useGroupedModelCombobox = ( ): UseGroupedModelComboboxReturn => { const { t } = useTranslation(); const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl'); - const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg; + const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg; const options = useMemo[]>(() => { - if (!modelEntities) { + if (!modelConfigs) { return []; } - const modelEntitiesArray = map(modelEntities.entities); - const groupedModels = groupBy(modelEntitiesArray, 'base'); + const groupedModels = groupBy(modelConfigs, 'base'); const _options = reduce( groupedModels, (acc, val, label) => { @@ -53,7 +51,7 @@ export const useGroupedModelCombobox = ( ); _options.sort((a) => (a.label === base_model ? -1 : 1)); return _options; - }, [getIsDisabled, modelEntities, base_model]); + }, [getIsDisabled, modelConfigs, base_model]); const value = useMemo( () => @@ -67,14 +65,14 @@ export const useGroupedModelCombobox = ( onChange(null); return; } - const model = modelEntities?.entities[v.value]; + const model = modelConfigs.find((m) => m.key === v.value); if (!model) { onChange(null); return; } onChange(model); }, - [modelEntities?.entities, onChange] + [modelConfigs, onChange] ); const placeholder = useMemo(() => { diff --git a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts index 3d9109a5ef..d57ef48337 100644 --- a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts @@ -1,13 +1,11 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; -import type { EntityState } from '@reduxjs/toolkit'; import type { ModelIdentifierField } from 'features/nodes/types/common'; -import { map } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import type { AnyModelConfig } from 'services/api/types'; type UseModelComboboxArg = { - modelEntities: EntityState | undefined; + modelConfigs: T[]; selectedModel?: ModelIdentifierField | null; onChange: (value: T | null) => void; getIsDisabled?: (model: T) => boolean; @@ -25,19 +23,14 @@ type UseModelComboboxReturn = { export const useModelCombobox = (arg: UseModelComboboxArg): UseModelComboboxReturn => { const { t } = useTranslation(); - const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg; + const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg; const options = useMemo(() => { - if (!modelEntities) { - return []; - } - return map(modelEntities.entities) - .filter(optionsFilter) - .map((model) => ({ - label: model.name, - value: model.key, - isDisabled: getIsDisabled ? getIsDisabled(model) : false, - })); - }, [optionsFilter, getIsDisabled, modelEntities]); + return modelConfigs.filter(optionsFilter).map((model) => ({ + label: model.name, + value: model.key, + isDisabled: getIsDisabled ? getIsDisabled(model) : false, + })); + }, [optionsFilter, getIsDisabled, modelConfigs]); const value = useMemo( () => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)), @@ -50,14 +43,14 @@ export const useModelCombobox = (arg: UseModelCombobox onChange(null); return; } - const model = modelEntities?.entities[v.value]; + const model = modelConfigs.find((m) => m.key === v.value); if (!model) { onChange(null); return; } onChange(model); }, - [modelEntities?.entities, onChange] + [modelConfigs, onChange] ); const placeholder = useMemo(() => { diff --git a/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts b/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts index 60de28468c..d8e5d07f3a 100644 --- a/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts +++ b/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts @@ -1,15 +1,12 @@ import type { Item } from '@invoke-ai/ui-library'; -import type { EntityState } from '@reduxjs/toolkit'; -import { EMPTY_ARRAY } from 'app/store/constants'; import type { ModelIdentifierField } from 'features/nodes/types/common'; import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants'; -import { filter } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import type { AnyModelConfig } from 'services/api/types'; type UseModelCustomSelectArg = { - data: EntityState | undefined; + modelConfigs: T[]; isLoading: boolean; selectedModel?: ModelIdentifierField | null; onChange: (value: T | null) => void; @@ -28,7 +25,7 @@ const modelFilterDefault = () => true; const isModelDisabledDefault = () => false; export const useModelCustomSelect = ({ - data, + modelConfigs, isLoading, selectedModel, onChange, @@ -39,30 +36,28 @@ export const useModelCustomSelect = ({ const items: Item[] = useMemo( () => - data - ? filter(data.entities, modelFilter).map((m) => ({ - label: m.name, - value: m.key, - description: m.description, - group: MODEL_TYPE_SHORT_MAP[m.base], - isDisabled: isModelDisabled(m), - })) - : EMPTY_ARRAY, - [data, isModelDisabled, modelFilter] + modelConfigs.filter(modelFilter).map((m) => ({ + label: m.name, + value: m.key, + description: m.description, + group: MODEL_TYPE_SHORT_MAP[m.base], + isDisabled: isModelDisabled(m), + })), + [modelConfigs, isModelDisabled, modelFilter] ); const _onChange = useCallback( (item: Item | null) => { - if (!item || !data) { + if (!item || !modelConfigs) { return; } - const model = data.entities[item.value]; + const model = modelConfigs.find((m) => m.key === item.value); if (!model) { return; } onChange(model); }, - [data, onChange] + [modelConfigs, onChange] ); const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]); diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx index b8dddf56dc..ef0bb74736 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx +++ b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx @@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect'; import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled'; import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel'; -import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery'; +import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels'; import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType'; import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { memo, useCallback, useMemo } from 'react'; @@ -20,7 +20,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { const dispatch = useAppDispatch(); const currentBaseModel = useAppSelector((s) => s.generation.model?.base); - const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType); + const [modelConfigs, { isLoading }] = useControlAdapterModels(controlAdapterType); const _onChange = useCallback( (modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => { @@ -43,7 +43,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { ); const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({ - data, + modelConfigs, isLoading, selectedModel, onChange: _onChange, diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts index 82d6e8c5d6..43c567b319 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts @@ -1,17 +1,16 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels'; import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice'; import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types'; import { useCallback, useMemo } from 'react'; import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; -import { useControlAdapterModels } from './useControlAdapterModels'; - export const useAddControlAdapter = (type: ControlAdapterType) => { const baseModel = useAppSelector((s) => s.generation.model?.base); const dispatch = useAppDispatch(); - const models = useControlAdapterModels(type); + const [models] = useControlAdapterModels(type); const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => { // prefer to use a model that matches the base model diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelQuery.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelQuery.ts deleted file mode 100644 index 1d092497af..0000000000 --- a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelQuery.ts +++ /dev/null @@ -1,26 +0,0 @@ -import type { ControlAdapterType } from 'features/controlAdapters/store/types'; -import { - useGetControlNetModelsQuery, - useGetIPAdapterModelsQuery, - useGetT2IAdapterModelsQuery, -} from 'services/api/endpoints/models'; - -export const useControlAdapterModelQuery = (type: ControlAdapterType) => { - const controlNetModelsQuery = useGetControlNetModelsQuery(); - const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery(); - const ipAdapterModelsQuery = useGetIPAdapterModelsQuery(); - - if (type === 'controlnet') { - return controlNetModelsQuery; - } - if (type === 't2i_adapter') { - return t2iAdapterModelsQuery; - } - if (type === 'ip_adapter') { - return ipAdapterModelsQuery; - } - - // Assert that the end of the function is not reachable. - const exhaustiveCheck: never = type; - return exhaustiveCheck; -}; diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModels.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModels.ts index dd23211b9b..4fe5ae7811 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModels.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModels.ts @@ -1,31 +1,10 @@ import type { ControlAdapterType } from 'features/controlAdapters/store/types'; -import { useMemo } from 'react'; -import { - controlNetModelsAdapterSelectors, - ipAdapterModelsAdapterSelectors, - t2iAdapterModelsAdapterSelectors, - useGetControlNetModelsQuery, - useGetIPAdapterModelsQuery, - useGetT2IAdapterModelsQuery, -} from 'services/api/endpoints/models'; +import { useControlNetModels, useIPAdapterModels, useT2IAdapterModels } from 'services/api/hooks/modelsByType'; -export const useControlAdapterModels = (type?: ControlAdapterType) => { - const { data: controlNetModelsData } = useGetControlNetModelsQuery(); - const controlNetModels = useMemo( - () => (controlNetModelsData ? controlNetModelsAdapterSelectors.selectAll(controlNetModelsData) : []), - [controlNetModelsData] - ); - - const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery(); - const t2iAdapterModels = useMemo( - () => (t2iAdapterModelsData ? t2iAdapterModelsAdapterSelectors.selectAll(t2iAdapterModelsData) : []), - [t2iAdapterModelsData] - ); - const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery(); - const ipAdapterModels = useMemo( - () => (ipAdapterModelsData ? ipAdapterModelsAdapterSelectors.selectAll(ipAdapterModelsData) : []), - [ipAdapterModelsData] - ); +export const useControlAdapterModels = (type: ControlAdapterType) => { + const controlNetModels = useControlNetModels(); + const t2iAdapterModels = useT2IAdapterModels(); + const ipAdapterModels = useIPAdapterModels(); if (type === 'controlnet') { return controlNetModels; @@ -36,5 +15,8 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => { if (type === 'ip_adapter') { return ipAdapterModels; } - return []; + + // Assert that the end of the function is not reachable. + const exhaustiveCheck: never = type; + return exhaustiveCheck; }; diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index 851d098763..4209784725 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -7,14 +7,14 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; +import { useLoRAModels } from 'services/api/hooks/modelsByType'; import type { LoRAModelConfig } from 'services/api/types'; const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras); const LoRASelect = () => { const dispatch = useAppDispatch(); - const { data, isLoading } = useGetLoRAModelsQuery(); + const [modelConfigs, { isLoading }] = useLoRAModels(); const { t } = useTranslation(); const addedLoRAs = useAppSelector(selectAddedLoRAs); const currentBaseModel = useAppSelector((s) => s.generation.model?.base); @@ -37,7 +37,7 @@ const LoRASelect = () => { ); const { options, onChange } = useGroupedModelCombobox({ - modelEntities: data, + modelConfigs, getIsDisabled, onChange: _onChange, }); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts index 39a009216a..6bdd829bb1 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts @@ -1,13 +1,16 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig } from 'app/store/store'; +import type { ModelType } from 'services/api/types'; + +export type FilterableModelType = Exclude; type ModelManagerState = { _version: 1; selectedModelKey: string | null; selectedModelMode: 'edit' | 'view'; searchTerm: string; - filteredModelType: string | null; + filteredModelType: FilterableModelType | null; scanPath: string | undefined; }; @@ -35,7 +38,7 @@ export const modelManagerV2Slice = createSlice({ state.searchTerm = action.payload; }, - setFilteredModelType: (state, action: PayloadAction) => { + setFilteredModelType: (state, action: PayloadAction) => { state.filteredModelType = action.payload; }, setScanPath: (state, action: PayloadAction) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx index 2aceb484ee..d6c99e460d 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx @@ -1,122 +1,105 @@ import { Flex, Spinner, Text } from '@invoke-ai/ui-library'; -import type { EntityState } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; -import { forEach } from 'lodash-es'; -import { memo } from 'react'; -import { ALL_BASE_MODELS } from 'services/api/constants'; +import { memo, useMemo } from 'react'; import { - useGetControlNetModelsQuery, - useGetIPAdapterModelsQuery, - useGetLoRAModelsQuery, - useGetMainModelsQuery, - useGetT2IAdapterModelsQuery, - useGetTextualInversionModelsQuery, - useGetVaeModelsQuery, -} from 'services/api/endpoints/models'; -import type { AnyModelConfig } from 'services/api/types'; + useControlNetModels, + useEmbeddingModels, + useIPAdapterModels, + useLoRAModels, + useMainModels, + useT2IAdapterModels, + useVAEModels, +} from 'services/api/hooks/modelsByType'; +import type { AnyModelConfig, ModelType } from 'services/api/types'; import { ModelListWrapper } from './ModelListWrapper'; const ModelList = () => { const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2); - const { filteredMainModels, isLoadingMainModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { - selectFromResult: ({ data, isLoading }) => ({ - filteredMainModels: modelsFilter(data, searchTerm, filteredModelType), - isLoadingMainModels: isLoading, - }), - }); - - const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, { - selectFromResult: ({ data, isLoading }) => ({ - filteredLoraModels: modelsFilter(data, searchTerm, filteredModelType), - isLoadingLoraModels: isLoading, - }), - }); - - const { filteredTextualInversionModels, isLoadingTextualInversionModels } = useGetTextualInversionModelsQuery( - undefined, - { - selectFromResult: ({ data, isLoading }) => ({ - filteredTextualInversionModels: modelsFilter(data, searchTerm, filteredModelType), - isLoadingTextualInversionModels: isLoading, - }), - } + const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels(); + const filteredMainModels = useMemo( + () => modelsFilter(mainModels, searchTerm, filteredModelType), + [mainModels, searchTerm, filteredModelType] ); - const { filteredControlnetModels, isLoadingControlnetModels } = useGetControlNetModelsQuery(undefined, { - selectFromResult: ({ data, isLoading }) => ({ - filteredControlnetModels: modelsFilter(data, searchTerm, filteredModelType), - isLoadingControlnetModels: isLoading, - }), - }); + const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels(); + const filteredLoRAModels = useMemo( + () => modelsFilter(loraModels, searchTerm, filteredModelType), + [loraModels, searchTerm, filteredModelType] + ); - const { filteredT2iAdapterModels, isLoadingT2IAdapterModels } = useGetT2IAdapterModelsQuery(undefined, { - selectFromResult: ({ data, isLoading }) => ({ - filteredT2iAdapterModels: modelsFilter(data, searchTerm, filteredModelType), - isLoadingT2IAdapterModels: isLoading, - }), - }); + const [embeddingModels, { isLoading: isLoadingEmbeddingModels }] = useEmbeddingModels(); + const filteredEmbeddingModels = useMemo( + () => modelsFilter(embeddingModels, searchTerm, filteredModelType), + [embeddingModels, searchTerm, filteredModelType] + ); - const { filteredIpAdapterModels, isLoadingIpAdapterModels } = useGetIPAdapterModelsQuery(undefined, { - selectFromResult: ({ data, isLoading }) => ({ - filteredIpAdapterModels: modelsFilter(data, searchTerm, filteredModelType), - isLoadingIpAdapterModels: isLoading, - }), - }); + const [controlNetModels, { isLoading: isLoadingControlNetModels }] = useControlNetModels(); + const filteredControlNetModels = useMemo( + () => modelsFilter(controlNetModels, searchTerm, filteredModelType), + [controlNetModels, searchTerm, filteredModelType] + ); - const { filteredVaeModels, isLoadingVaeModels } = useGetVaeModelsQuery(undefined, { - selectFromResult: ({ data, isLoading }) => ({ - filteredVaeModels: modelsFilter(data, searchTerm, filteredModelType), - isLoadingVaeModels: isLoading, - }), - }); + const [t2iAdapterModels, { isLoading: isLoadingT2IAdapterModels }] = useT2IAdapterModels(); + const filteredT2IAdapterModels = useMemo( + () => modelsFilter(t2iAdapterModels, searchTerm, filteredModelType), + [t2iAdapterModels, searchTerm, filteredModelType] + ); + + const [ipAdapterModels, { isLoading: isLoadingIPAdapterModels }] = useIPAdapterModels(); + const filteredIPAdapterModels = useMemo( + () => modelsFilter(ipAdapterModels, searchTerm, filteredModelType), + [ipAdapterModels, searchTerm, filteredModelType] + ); + + const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels(); + const filteredVAEModels = useMemo( + () => modelsFilter(vaeModels, searchTerm, filteredModelType), + [vaeModels, searchTerm, filteredModelType] + ); return ( {/* Main Model List */} - {isLoadingMainModels && } + {isLoadingMainModels && } {!isLoadingMainModels && filteredMainModels.length > 0 && ( )} {/* LoRAs List */} - {isLoadingLoraModels && } - {!isLoadingLoraModels && filteredLoraModels.length > 0 && ( - + {isLoadingLoRAModels && } + {!isLoadingLoRAModels && filteredLoRAModels.length > 0 && ( + )} {/* TI List */} - {isLoadingTextualInversionModels && } - {!isLoadingTextualInversionModels && filteredTextualInversionModels.length > 0 && ( - + {isLoadingEmbeddingModels && } + {!isLoadingEmbeddingModels && filteredEmbeddingModels.length > 0 && ( + )} {/* VAE List */} - {isLoadingVaeModels && } - {!isLoadingVaeModels && filteredVaeModels.length > 0 && ( - + {isLoadingVAEModels && } + {!isLoadingVAEModels && filteredVAEModels.length > 0 && ( + )} {/* Controlnet List */} - {isLoadingControlnetModels && } - {!isLoadingControlnetModels && filteredControlnetModels.length > 0 && ( - + {isLoadingControlNetModels && } + {!isLoadingControlNetModels && filteredControlNetModels.length > 0 && ( + )} {/* IP Adapter List */} - {isLoadingIpAdapterModels && } - {!isLoadingIpAdapterModels && filteredIpAdapterModels.length > 0 && ( - + {isLoadingIPAdapterModels && } + {!isLoadingIPAdapterModels && filteredIPAdapterModels.length > 0 && ( + )} {/* T2I Adapters List */} {isLoadingT2IAdapterModels && } - {!isLoadingT2IAdapterModels && filteredT2iAdapterModels.length > 0 && ( - + {!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && ( + )} @@ -126,25 +109,16 @@ const ModelList = () => { export default memo(ModelList); const modelsFilter = ( - data: EntityState | undefined, + data: T[], nameFilter: string, - filteredModelType: string | null + filteredModelType: ModelType | null ): T[] => { - const filteredModels: T[] = []; - - forEach(data?.entities, (model) => { - if (!model) { - return; - } - + return data.filter((model) => { const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase()); const matchesType = filteredModelType ? model.type === filteredModelType : true; - if (matchesFilter && matchesType) { - filteredModels.push(model); - } + return matchesFilter && matchesType; }); - return filteredModels; }; const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx index b5b4aadbe9..94a06bf5d9 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx @@ -1,11 +1,13 @@ import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { IoFilter } from 'react-icons/io5'; +import { PiFunnelBold } from 'react-icons/pi'; +import { objectKeys } from 'tsafe'; -const MODEL_TYPE_LABELS: { [key: string]: string } = { +const MODEL_TYPE_LABELS: Record = { main: 'Main', lora: 'LoRA', embedding: 'Textual Inversion', @@ -13,7 +15,6 @@ const MODEL_TYPE_LABELS: { [key: string]: string } = { vae: 'VAE', t2i_adapter: 'T2I Adapter', ip_adapter: 'IP Adapter', - clip_vision: 'Clip Vision', }; export const ModelTypeFilter = () => { @@ -22,7 +23,7 @@ export const ModelTypeFilter = () => { const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType); const selectModelType = useCallback( - (option: string) => { + (option: FilterableModelType) => { dispatch(setFilteredModelType(option)); }, [dispatch] @@ -34,12 +35,12 @@ export const ModelTypeFilter = () => { return ( - }> + }> {filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')} {t('modelManager.allModels')} - {Object.keys(MODEL_TYPE_LABELS).map((option) => ( + {objectKeys(MODEL_TYPE_LABELS).map((option) => ( s.modelmanagerV2.selectedModelKey); const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken); - const { compatibleOptions } = useGetVaeModelsQuery(undefined, { - selectFromResult: ({ data }) => { - const modelArray = map(data?.entities); - const compatibleOptions = modelArray - .filter((vae) => vae.base === modelData?.base) - .map((vae) => ({ label: vae.name, value: vae.key })); + const [vaeModels] = useVAEModels(); + const compatibleOptions = useMemo(() => { + const compatibleOptions = vaeModels + .filter((vae) => vae.base === modelData?.base) + .map((vae) => ({ label: vae.name, value: vae.key })); - const defaultOption = { label: 'Default VAE', value: 'default' }; + const defaultOption = { label: 'Default VAE', value: 'default' }; - return { compatibleOptions: [defaultOption, ...compatibleOptions] }; - }, - }); + return [defaultOption, ...compatibleOptions]; + }, [modelData?.base, vaeModels]); const onChange = useCallback( (v) => { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx index 367b32030d..6cc3773a8d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; +import { useControlNetModels } from 'services/api/hooks/modelsByType'; import type { ControlNetModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -14,7 +14,7 @@ type Props = FieldComponentProps { const { nodeId, field } = props; const dispatch = useAppDispatch(); - const { data, isLoading } = useGetControlNetModelsQuery(); + const [modelConfigs, { isLoading }] = useControlNetModels(); const _onChange = useCallback( (value: ControlNetModelConfig | null) => { @@ -33,7 +33,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => { ); const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ - modelEntities: data, + modelConfigs, onChange: _onChange, selectedModel: field.value, isLoading, diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx index 62ce9d65a2..001d2d19ae 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models'; +import { useIPAdapterModels } from 'services/api/hooks/modelsByType'; import type { IPAdapterModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -14,7 +14,7 @@ const IPAdapterModelFieldInputComponent = ( ) => { const { nodeId, field } = props; const dispatch = useAppDispatch(); - const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(); + const [modelConfigs, { isLoading }] = useIPAdapterModels(); const _onChange = useCallback( (value: IPAdapterModelConfig | null) => { @@ -33,9 +33,10 @@ const IPAdapterModelFieldInputComponent = ( ); const { options, value, onChange } = useGroupedModelCombobox({ - modelEntities: ipAdapterModels, + modelConfigs, onChange: _onChange, selectedModel: field.value, + isLoading, }); return ( diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx index 9caadfc451..dc4f6e27af 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; +import { useLoRAModels } from 'services/api/hooks/modelsByType'; import type { LoRAModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -14,7 +14,7 @@ type Props = FieldComponentProps { const { nodeId, field } = props; const dispatch = useAppDispatch(); - const { data, isLoading } = useGetLoRAModelsQuery(); + const [modelConfigs, { isLoading }] = useLoRAModels(); const _onChange = useCallback( (value: LoRAModelConfig | null) => { if (!value) { @@ -32,7 +32,7 @@ const LoRAModelFieldInputComponent = (props: Props) => { ); const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ - modelEntities: data, + modelConfigs, onChange: _onChange, selectedModel: field.value, isLoading, diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx index c2b3d3d69a..2a84347b7e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx @@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import { NON_SDXL_MAIN_MODELS } from 'services/api/constants'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType'; import type { MainModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -16,7 +15,7 @@ type Props = FieldComponentProps { const { nodeId, field } = props; const dispatch = useAppDispatch(); - const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS); + const [modelConfigs, { isLoading }] = useNonSDXLMainModels(); const _onChange = useCallback( (value: MainModelConfig | null) => { if (!value) { @@ -33,7 +32,7 @@ const MainModelFieldInputComponent = (props: Props) => { [dispatch, field.name, nodeId] ); const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ - modelEntities: data, + modelConfigs, onChange: _onChange, isLoading, selectedModel: field.value, diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx index 0499eb262e..c9d42dad8e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx @@ -8,8 +8,7 @@ import type { SDXLRefinerModelFieldInputTemplate, } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import { REFINER_BASE_MODELS } from 'services/api/constants'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { useRefinerModels } from 'services/api/hooks/modelsByType'; import type { MainModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -19,7 +18,7 @@ type Props = FieldComponentProps { const { nodeId, field } = props; const dispatch = useAppDispatch(); - const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS); + const [modelConfigs, { isLoading }] = useRefinerModels(); const _onChange = useCallback( (value: MainModelConfig | null) => { if (!value) { @@ -36,7 +35,7 @@ const RefinerModelFieldInputComponent = (props: Props) => { [dispatch, field.name, nodeId] ); const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ - modelEntities: data, + modelConfigs, onChange: _onChange, isLoading, selectedModel: field.value, diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx index 0a48b0c917..5d3c584d4b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx @@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import { SDXL_MAIN_MODELS } from 'services/api/constants'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { useSDXLModels } from 'services/api/hooks/modelsByType'; import type { MainModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -16,7 +15,7 @@ type Props = FieldComponentProps { const { nodeId, field } = props; const dispatch = useAppDispatch(); - const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS); + const [modelConfigs, { isLoading }] = useSDXLModels(); const _onChange = useCallback( (value: MainModelConfig | null) => { if (!value) { @@ -33,7 +32,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => { [dispatch, field.name, nodeId] ); const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ - modelEntities: data, + modelConfigs, onChange: _onChange, isLoading, selectedModel: field.value, diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx index 9a374f1fd0..4016805675 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models'; +import { useT2IAdapterModels } from 'services/api/hooks/modelsByType'; import type { T2IAdapterModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -15,7 +15,7 @@ const T2IAdapterModelFieldInputComponent = ( const { nodeId, field } = props; const dispatch = useAppDispatch(); - const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(); + const [modelConfigs, { isLoading }] = useT2IAdapterModels(); const _onChange = useCallback( (value: T2IAdapterModelConfig | null) => { @@ -34,9 +34,10 @@ const T2IAdapterModelFieldInputComponent = ( ); const { options, value, onChange } = useGroupedModelCombobox({ - modelEntities: t2iAdapterModels, + modelConfigs, onChange: _onChange, selectedModel: field.value, + isLoading, }); return ( diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx index c1259486e6..d10712d15d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx @@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; +import { useVAEModels } from 'services/api/hooks/modelsByType'; import type { VAEModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -15,7 +15,7 @@ type Props = FieldComponentProps { const { nodeId, field } = props; const dispatch = useAppDispatch(); - const { data, isLoading } = useGetVaeModelsQuery(); + const [modelConfigs, { isLoading }] = useVAEModels(); const _onChange = useCallback( (value: VAEModelConfig | null) => { if (!value) { @@ -32,7 +32,7 @@ const VAEModelFieldInputComponent = (props: Props) => { [dispatch, field.name, nodeId] ); const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ - modelEntities: data, + modelConfigs, onChange: _onChange, selectedModel: field.value, isLoading, diff --git a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx index 8f8cc7ea5c..349a40a92f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx @@ -8,8 +8,7 @@ import { modelSelected } from 'features/parameters/store/actions'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { useMainModels } from 'services/api/hooks/modelsByType'; import type { MainModelConfig } from 'services/api/types'; const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); @@ -18,7 +17,7 @@ const ParamMainModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); const selectedModel = useAppSelector(selectModel); - const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS); + const [modelConfigs, { isLoading }] = useMainModels(); const _onChange = useCallback( (model: MainModelConfig | null) => { @@ -35,7 +34,7 @@ const ParamMainModelSelect = () => { ); const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({ - data, + modelConfigs, isLoading, selectedModel, onChange: _onChange, diff --git a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx index 282723b6bf..3552b09292 100644 --- a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx @@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common'; import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; +import { useVAEModels } from 'services/api/hooks/modelsByType'; import type { VAEModelConfig } from 'services/api/types'; const selector = createMemoizedSelector(selectGenerationSlice, (generation) => { @@ -19,7 +19,7 @@ const ParamVAEModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); const { model, vae } = useAppSelector(selector); - const { data, isLoading } = useGetVaeModelsQuery(); + const [modelConfigs, { isLoading }] = useVAEModels(); const getIsDisabled = useCallback( (vae: VAEModelConfig): boolean => { const isCompatible = model?.base === vae.base; @@ -35,7 +35,7 @@ const ParamVAEModelSelect = () => { [dispatch] ); const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({ - modelEntities: data, + modelConfigs, onChange: _onChange, selectedModel: vae, isLoading, diff --git a/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx b/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx index 9da7876e52..9e79a4dcae 100644 --- a/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx +++ b/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx @@ -11,13 +11,8 @@ import { t } from 'i18next'; import { flatten, map } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { - loraModelsAdapterSelectors, - textualInversionModelsAdapterSelectors, - useGetLoRAModelsQuery, - useGetModelConfigQuery, - useGetTextualInversionModelsQuery, -} from 'services/api/endpoints/models'; +import { useGetModelConfigQuery } from 'services/api/endpoints/models'; +import { useEmbeddingModels, useLoRAModels } from 'services/api/hooks/modelsByType'; import { isNonRefinerMainModelConfig } from 'services/api/types'; const noOptionsMessage = () => t('prompt.noMatchingTriggers'); @@ -33,8 +28,8 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery( mainModel?.key ?? skipToken ); - const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery(); - const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery(); + const [loraModels, { isLoading: isLoadingLoRAs }] = useLoRAModels(); + const [tiModels, { isLoading: isLoadingTIs }] = useEmbeddingModels(); const _onChange = useCallback( (v) => { @@ -52,8 +47,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel const _options: GroupBase[] = []; if (tiModels) { - const embeddingOptions = textualInversionModelsAdapterSelectors - .selectAll(tiModels) + const embeddingOptions = tiModels .filter((ti) => ti.base === mainModelConfig?.base) .map((model) => ({ label: model.name, value: `<${model.name}>` })); @@ -66,8 +60,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel } if (loraModels) { - const triggerPhraseOptions = loraModelsAdapterSelectors - .selectAll(loraModels) + const triggerPhraseOptions = loraModels .filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key)) .map((lora) => { if (lora.trigger_phrases) { diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx index c515108795..619f55f5f8 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx @@ -7,8 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common'; import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { REFINER_BASE_MODELS } from 'services/api/constants'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { useRefinerModels } from 'services/api/hooks/modelsByType'; import type { MainModelConfig } from 'services/api/types'; const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel); @@ -19,7 +18,7 @@ const ParamSDXLRefinerModelSelect = () => { const dispatch = useAppDispatch(); const model = useAppSelector(selectModel); const { t } = useTranslation(); - const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS); + const [modelConfigs, { isLoading }] = useRefinerModels(); const _onChange = useCallback( (model: MainModelConfig | null) => { if (!model) { @@ -31,7 +30,7 @@ const ParamSDXLRefinerModelSelect = () => { [dispatch] ); const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({ - modelEntities: data, + modelConfigs, onChange: _onChange, selectedModel: model, isLoading, diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 1f8bf96b31..603edb4c27 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,28 +1,11 @@ -import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit'; +import type { EntityState } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import queryString from 'query-string'; -import { - ALL_BASE_MODELS, - NON_REFINER_BASE_MODELS, - NON_SDXL_MAIN_MODELS, - REFINER_BASE_MODELS, - SDXL_MAIN_MODELS, -} from 'services/api/constants'; import type { operations, paths } from 'services/api/schema'; -import type { - AnyModelConfig, - BaseModelType, - ControlNetModelConfig, - IPAdapterModelConfig, - LoRAModelConfig, - MainModelConfig, - T2IAdapterModelConfig, - TextualInversionModelConfig, - VAEModelConfig, -} from 'services/api/types'; +import type { AnyModelConfig } from 'services/api/types'; -import type { ApiTagDescription, tagTypes } from '..'; +import type { ApiTagDescription } from '..'; import { api, buildV2Url, LIST_TAG } from '..'; export type UpdateModelArg = { @@ -40,8 +23,9 @@ type UpdateModelImageResponse = paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json']; type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json']; - -type ListModelsArg = NonNullable; +type GetModelConfigsResponse = NonNullable< + paths['/api/v2/models/']['get']['responses']['200']['content']['application/json'] +>; type DeleteModelArg = { key: string; @@ -76,72 +60,11 @@ type GetHuggingFaceModelsResponse = type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query']; -const mainModelsAdapter = createEntityAdapter({ +const modelConfigsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); -export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions); -const loraModelsAdapter = createEntityAdapter({ - selectId: (entity) => entity.key, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); -export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions); -const controlNetModelsAdapter = createEntityAdapter({ - selectId: (entity) => entity.key, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); -export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions); -const ipAdapterModelsAdapter = createEntityAdapter({ - selectId: (entity) => entity.key, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); -export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); -const t2iAdapterModelsAdapter = createEntityAdapter({ - selectId: (entity) => entity.key, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); -export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); -const textualInversionModelsAdapter = createEntityAdapter({ - selectId: (entity) => entity.key, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); -export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors( - undefined, - getSelectorsOptions -); -const vaeModelsAdapter = createEntityAdapter({ - selectId: (entity) => entity.key, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); -export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions); - -const anyModelConfigAdapter = createEntityAdapter({ - selectId: (entity) => entity.key, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); -const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions); - -const buildProvidesTags = - (tagType: (typeof tagTypes)[number]) => - (result: EntityState | undefined) => { - const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model']; - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: tagType, - id, - })) - ); - } - - return tags; - }; - -const buildTransformResponse = - (adapter: EntityAdapter) => - (response: { models: T[] }) => { - return adapter.setAll(adapter.getInitialState(), response.models); - }; +export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions); /** * Builds an endpoint URL for the models router @@ -162,9 +85,27 @@ export const modelsApi = api.injectEndpoints({ }; }, onQueryStarted: async (_, { dispatch, queryFulfilled }) => { - queryFulfilled.then(({ data }) => { - upsertSingleModelConfig(data, dispatch); - }); + try { + const { data } = await queryFulfilled; + + // Update the individual model query caches + dispatch(modelsApi.util.upsertQueryData('getModelConfig', data.key, data)); + + const { base, name, type } = data; + dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, data)); + + // Update the list query cache + dispatch( + modelsApi.util.updateQueryData('getModelConfigs', undefined, (draft) => { + modelConfigsAdapter.updateOne(draft, { + id: data.key, + changes: data, + }); + }) + ); + } catch { + // no-op + } }, }), updateModelImage: build.mutation({ @@ -294,80 +235,27 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['ModelInstalls'], }), - getMainModels: build.query, BaseModelType[]>({ - query: (base_models) => { - const params: ListModelsArg = { - model_type: 'main', - base_models, - }; - const query = queryString.stringify(params, { arrayFormat: 'none' }); - return buildModelsUrl(`?${query}`); + getModelConfigs: build.query, void>({ + query: () => ({ url: buildModelsUrl() }), + providesTags: (result) => { + const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }]; + if (result) { + const modelTags = result.ids.map((id) => ({ type: 'ModelConfig', id }) as const); + tags.push(...modelTags); + } + return tags; + }, + keepUnusedDataFor: 60 * 60 * 1000 * 24, // 1 day (infinite) + transformResponse: (response: GetModelConfigsResponse) => { + return modelConfigsAdapter.setAll(modelConfigsAdapter.getInitialState(), response.models); }, - providesTags: buildProvidesTags('MainModel'), - transformResponse: buildTransformResponse(mainModelsAdapter), onQueryStarted: async (_, { dispatch, queryFulfilled }) => { queryFulfilled.then(({ data }) => { - upsertModelConfigs(data, dispatch); - }); - }, - }), - getLoRAModels: build.query, void>({ - query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), - providesTags: buildProvidesTags('LoRAModel'), - transformResponse: buildTransformResponse(loraModelsAdapter), - onQueryStarted: async (_, { dispatch, queryFulfilled }) => { - queryFulfilled.then(({ data }) => { - upsertModelConfigs(data, dispatch); - }); - }, - }), - getControlNetModels: build.query, void>({ - query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), - providesTags: buildProvidesTags('ControlNetModel'), - transformResponse: buildTransformResponse(controlNetModelsAdapter), - onQueryStarted: async (_, { dispatch, queryFulfilled }) => { - queryFulfilled.then(({ data }) => { - upsertModelConfigs(data, dispatch); - }); - }, - }), - getIPAdapterModels: build.query, void>({ - query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }), - providesTags: buildProvidesTags('IPAdapterModel'), - transformResponse: buildTransformResponse(ipAdapterModelsAdapter), - onQueryStarted: async (_, { dispatch, queryFulfilled }) => { - queryFulfilled.then(({ data }) => { - upsertModelConfigs(data, dispatch); - }); - }, - }), - getT2IAdapterModels: build.query, void>({ - query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }), - providesTags: buildProvidesTags('T2IAdapterModel'), - transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), - onQueryStarted: async (_, { dispatch, queryFulfilled }) => { - queryFulfilled.then(({ data }) => { - upsertModelConfigs(data, dispatch); - }); - }, - }), - getVaeModels: build.query, void>({ - query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }), - providesTags: buildProvidesTags('VaeModel'), - transformResponse: buildTransformResponse(vaeModelsAdapter), - onQueryStarted: async (_, { dispatch, queryFulfilled }) => { - queryFulfilled.then(({ data }) => { - upsertModelConfigs(data, dispatch); - }); - }, - }), - getTextualInversionModels: build.query, void>({ - query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }), - providesTags: buildProvidesTags('TextualInversionModel'), - transformResponse: buildTransformResponse(textualInversionModelsAdapter), - onQueryStarted: async (_, { dispatch, queryFulfilled }) => { - queryFulfilled.then(({ data }) => { - upsertModelConfigs(data, dispatch); + modelConfigsAdapterSelectors.selectAll(data).forEach((modelConfig) => { + dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig)); + const { base, name, type } = modelConfig; + dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig)); + }); }); }, }), @@ -375,14 +263,8 @@ export const modelsApi = api.injectEndpoints({ }); export const { + useGetModelConfigsQuery, useGetModelConfigQuery, - useGetMainModelsQuery, - useGetControlNetModelsQuery, - useGetIPAdapterModelsQuery, - useGetT2IAdapterModelsQuery, - useGetLoRAModelsQuery, - useGetTextualInversionModelsQuery, - useGetVaeModelsQuery, useDeleteModelsMutation, useDeleteModelImageMutation, useUpdateModelMutation, @@ -396,127 +278,3 @@ export const { useCancelModelInstallMutation, usePruneCompletedModelInstallsMutation, } = modelsApi; - -const upsertModelConfigs = ( - modelConfigs: EntityState, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - dispatch: ThunkDispatch -) => { - /** - * Once a list of models of a specific type is received, fetching any of those models individually is a waste of a - * network request. This function takes the received list of models and upserts them into the individual query caches - * for each model type. - */ - - // Iterate over all the models and upsert them into the individual query caches for each model type. - anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => { - dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig)); - const { base, name, type } = modelConfig; - dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig)); - }); -}; - -const upsertSingleModelConfig = ( - modelConfig: AnyModelConfig, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - dispatch: ThunkDispatch -) => { - /** - * When a model is updated, the individual query caches for each model type need to be updated, as well as the list - * query caches of models of that type. - */ - - // Update the individual model query caches. - dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig)); - const { base, name, type } = modelConfig; - dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig)); - - // Update the list query caches for each model type. - if (modelConfig.type === 'main') { - [ALL_BASE_MODELS, NON_REFINER_BASE_MODELS, SDXL_MAIN_MODELS, NON_SDXL_MAIN_MODELS, REFINER_BASE_MODELS].forEach( - (queryArg) => { - dispatch( - modelsApi.util.updateQueryData('getMainModels', queryArg, (draft) => { - mainModelsAdapter.updateOne(draft, { - id: modelConfig.key, - changes: modelConfig, - }); - }) - ); - } - ); - return; - } - - if (modelConfig.type === 'controlnet') { - dispatch( - modelsApi.util.updateQueryData('getControlNetModels', undefined, (draft) => { - controlNetModelsAdapter.updateOne(draft, { - id: modelConfig.key, - changes: modelConfig, - }); - }) - ); - return; - } - - if (modelConfig.type === 'embedding') { - dispatch( - modelsApi.util.updateQueryData('getTextualInversionModels', undefined, (draft) => { - textualInversionModelsAdapter.updateOne(draft, { - id: modelConfig.key, - changes: modelConfig, - }); - }) - ); - return; - } - - if (modelConfig.type === 'ip_adapter') { - dispatch( - modelsApi.util.updateQueryData('getIPAdapterModels', undefined, (draft) => { - ipAdapterModelsAdapter.updateOne(draft, { - id: modelConfig.key, - changes: modelConfig, - }); - }) - ); - return; - } - - if (modelConfig.type === 'lora') { - dispatch( - modelsApi.util.updateQueryData('getLoRAModels', undefined, (draft) => { - loraModelsAdapter.updateOne(draft, { - id: modelConfig.key, - changes: modelConfig, - }); - }) - ); - return; - } - - if (modelConfig.type === 't2i_adapter') { - dispatch( - modelsApi.util.updateQueryData('getT2IAdapterModels', undefined, (draft) => { - t2iAdapterModelsAdapter.updateOne(draft, { - id: modelConfig.key, - changes: modelConfig, - }); - }) - ); - return; - } - - if (modelConfig.type === 'vae') { - dispatch( - modelsApi.util.updateQueryData('getVaeModels', undefined, (draft) => { - vaeModelsAdapter.updateOne(draft, { - id: modelConfig.key, - changes: modelConfig, - }); - }) - ); - return; - } -}; diff --git a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts new file mode 100644 index 0000000000..2d04b9dc46 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts @@ -0,0 +1,42 @@ +import { EMPTY_ARRAY } from 'app/store/constants'; +import { useMemo } from 'react'; +import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; +import { + isControlNetModelConfig, + isIPAdapterModelConfig, + isLoRAModelConfig, + isNonRefinerMainModelConfig, + isNonSDXLMainModelConfig, + isRefinerMainModelModelConfig, + isSDXLMainModelModelConfig, + isT2IAdapterModelConfig, + isTIModelConfig, + isVAEModelConfig, +} from 'services/api/types'; + +const buildModelsHook = + (typeGuard: (config: AnyModelConfig) => config is T) => + () => { + const result = useGetModelConfigsQuery(undefined); + const modelConfigs = useMemo(() => { + if (!result.data) { + return EMPTY_ARRAY; + } + + return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard); + }, [result]); + + return [modelConfigs, result] as const; + }; + +export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig); +export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig); +export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig); +export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig); +export const useLoRAModels = buildModelsHook(isLoRAModelConfig); +export const useControlNetModels = buildModelsHook(isControlNetModelConfig); +export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig); +export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig); +export const useEmbeddingModels = buildModelsHook(isTIModelConfig); +export const useVAEModels = buildModelsHook(isVAEModelConfig); diff --git a/invokeai/frontend/web/src/services/api/hooks/useIsRefinerAvailable.ts b/invokeai/frontend/web/src/services/api/hooks/useIsRefinerAvailable.ts index 4cb4891be4..3ac69e8c87 100644 --- a/invokeai/frontend/web/src/services/api/hooks/useIsRefinerAvailable.ts +++ b/invokeai/frontend/web/src/services/api/hooks/useIsRefinerAvailable.ts @@ -1,12 +1,7 @@ -import { REFINER_BASE_MODELS } from 'services/api/constants'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { useRefinerModels } from 'services/api/hooks/modelsByType'; export const useIsRefinerAvailable = () => { - const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, { - selectFromResult: ({ data }) => ({ - isRefinerAvailable: data ? data.ids.length > 0 : false, - }), - }); + const [refinerModels] = useRefinerModels(); - return isRefinerAvailable; + return Boolean(refinerModels.length); }; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index d42220f962..6a81b7b6dc 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -48,7 +48,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig']; export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig']; export type IPAdapterModelConfig = S['IPAdapterConfig']; export type T2IAdapterModelConfig = S['T2IAdapterConfig']; -export type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig']; +type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig']; type DiffusersModelConfig = S['MainDiffusersConfig']; type CheckpointModelConfig = S['MainCheckpointConfig']; type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig']; @@ -103,6 +103,18 @@ export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is return config.type === 'main' && config.base === 'sdxl-refiner'; }; +export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => { + return config.type === 'main' && config.base === 'sdxl'; +}; + +export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { + return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2'); +}; + +export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => { + return config.type === 'embedding'; +}; + export type ModelInstallJob = S['ModelInstallJob']; export type ModelInstallStatus = S['InstallStatus']; @@ -200,10 +212,3 @@ export type PostUploadAction = | CanvasInitialImageAction | ToastAction | AddToBatchAction; - -type TypeGuard = { - (input: unknown): input is T; -}; - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export type TypeGuardFor> = T extends TypeGuard ? U : never;