From cbcd416b70f6d3ce45ece0538339b2a5365edea7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 26 Jul 2023 11:04:02 +1000 Subject: [PATCH] fix(ui): fix refiner missing from model manager Rolled back the earlier split of the refiner model query. Now, when you use `useGetMainModelsQuery()`, you must provide it an array of base model types. They are provided as constants for simplicity: - ALL_BASE_MODELS - NON_REFINER_BASE_MODELS - REFINER_BASE_MODELS Opted to just use args for the hook instead of wrapping the hook in another hook, we can tidy this up later if desired. --- .../listeners/modelsLoaded.ts | 8 ++- .../fields/ModelInputFieldComponent.tsx | 5 +- .../RefinerModelInputFieldComponent.tsx | 6 ++- .../MainModel/ParamMainModelSelect.tsx | 5 +- .../ParamSDXLRefinerAestheticScore.tsx | 2 +- .../SDXLRefiner/ParamSDXLRefinerCFGScale.tsx | 2 +- .../ParamSDXLRefinerModelSelect.tsx | 6 ++- .../SDXLRefiner/ParamSDXLRefinerScheduler.tsx | 2 +- .../SDXLRefiner/ParamSDXLRefinerStart.tsx | 2 +- .../SDXLRefiner/ParamSDXLRefinerSteps.tsx | 2 +- .../SDXLRefiner/ParamUseSDXLRefiner.tsx | 2 +- .../sdxl/hooks/useIsRefinerAvailable.ts | 11 ---- .../AddModelsPanel/FoundModelsList.tsx | 3 +- .../subpanels/MergeModelsPanel.tsx | 5 +- .../subpanels/ModelManagerPanel.tsx | 3 +- .../subpanels/ModelManagerPanel/ModelList.tsx | 5 +- .../web/src/services/api/constants.ts | 16 ++++++ .../web/src/services/api/endpoints/models.ts | 50 +++---------------- .../api/hooks/useIsRefinerAvailable.ts | 12 +++++ 19 files changed, 72 insertions(+), 75 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/sdxl/hooks/useIsRefinerAvailable.ts create mode 100644 invokeai/frontend/web/src/services/api/constants.ts create mode 100644 invokeai/frontend/web/src/services/api/hooks/useIsRefinerAvailable.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index a2c622ee63..1e0b3dbc61 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -19,7 +19,9 @@ import { startAppListening } from '..'; export const addModelsLoadedListener = () => { startAppListening({ - matcher: modelsApi.endpoints.getMainModels.matchFulfilled, + predicate: (state, action) => + modelsApi.endpoints.getMainModels.matchFulfilled(action) && + !action.meta.arg.originalArgs.includes('sdxl-refiner'), effect: async (action, { getState, dispatch }) => { // models loaded, we need to ensure the selected model is available and if not, select the first one const log = logger('models'); @@ -64,7 +66,9 @@ export const addModelsLoadedListener = () => { }, }); startAppListening({ - matcher: modelsApi.endpoints.getSDXLRefinerModels.matchFulfilled, + predicate: (state, action) => + modelsApi.endpoints.getMainModels.matchFulfilled(action) && + action.meta.arg.originalArgs.includes('sdxl-refiner'), effect: async (action, { getState, dispatch }) => { // models loaded, we need to ensure the selected model is available and if not, select the first one const log = logger('models'); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index 625ce0d5ca..273ba3be51 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -14,6 +14,7 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { FieldComponentProps } from './types'; import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus'; @@ -27,7 +28,9 @@ const ModelInputFieldComponent = ( const { t } = useTranslation(); const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; - const { data: mainModels, isLoading } = useGetMainModelsQuery(); + const { data: mainModels, isLoading } = useGetMainModelsQuery( + NON_REFINER_BASE_MODELS + ); const data = useMemo(() => { if (!mainModels) { diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/RefinerModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/RefinerModelInputFieldComponent.tsx index 1c791c8704..2a7531b59d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/RefinerModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/RefinerModelInputFieldComponent.tsx @@ -13,7 +13,8 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models'; +import { REFINER_BASE_MODELS } from 'services/api/constants'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { FieldComponentProps } from './types'; const RefinerModelInputFieldComponent = ( @@ -27,7 +28,8 @@ const RefinerModelInputFieldComponent = ( const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery(); + const { data: refinerModels, isLoading } = + useGetMainModelsQuery(REFINER_BASE_MODELS); const data = useMemo(() => { if (!refinerModels) { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx index 372778b2fa..4f799dc330 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx @@ -14,6 +14,7 @@ import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; import { forEach } from 'lodash-es'; +import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus'; @@ -29,8 +30,10 @@ const ParamMainModelSelect = () => { const { model } = useAppSelector(selector); - const { data: mainModels, isLoading } = useGetMainModelsQuery(); const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; + const { data: mainModels, isLoading } = useGetMainModelsQuery( + NON_REFINER_BASE_MODELS + ); const data = useMemo(() => { if (!mainModels) { diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerAestheticScore.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerAestheticScore.tsx index 8a3f655261..9c9c4b2f89 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerAestheticScore.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerAestheticScore.tsx @@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; -import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable'; import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice'; import { memo, useCallback } from 'react'; +import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable'; const selector = createSelector( [stateSelector], diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerCFGScale.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerCFGScale.tsx index 4d0a8cef7f..dd678ac0f7 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerCFGScale.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerCFGScale.tsx @@ -4,10 +4,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; -import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable'; import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; +import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable'; const selector = createSelector( [stateSelector], diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx index 63f72ecacf..4984e74964 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx @@ -11,7 +11,8 @@ import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice'; import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; -import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models'; +import { REFINER_BASE_MODELS } from 'services/api/constants'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; const selector = createSelector( stateSelector, @@ -24,7 +25,8 @@ const ParamSDXLRefinerModelSelect = () => { const { model } = useAppSelector(selector); - const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery(); + const { data: refinerModels, isLoading } = + useGetMainModelsQuery(REFINER_BASE_MODELS); const data = useMemo(() => { if (!refinerModels) { diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx index b92dcf03b2..e14eb0b5f8 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx @@ -7,11 +7,11 @@ import { SCHEDULER_LABEL_MAP, SchedulerParam, } from 'features/parameters/types/parameterSchemas'; -import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable'; import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice'; import { map } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; +import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable'; const selector = createSelector( stateSelector, diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerStart.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerStart.tsx index 719dfda949..987fb0aed7 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerStart.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerStart.tsx @@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; -import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable'; import { setRefinerStart } from 'features/sdxl/store/sdxlSlice'; import { memo, useCallback } from 'react'; +import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable'; const selector = createSelector( [stateSelector], diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx index 2def8aabe7..456cbb5d3a 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx @@ -4,10 +4,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; -import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable'; import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; +import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable'; const selector = createSelector( [stateSelector], diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx index b35221f5cf..1649f95e9a 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx @@ -1,9 +1,9 @@ import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAISwitch from 'common/components/IAISwitch'; -import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable'; import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice'; import { ChangeEvent } from 'react'; +import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable'; export default function ParamUseSDXLRefiner() { const shouldUseSDXLRefiner = useAppSelector( diff --git a/invokeai/frontend/web/src/features/sdxl/hooks/useIsRefinerAvailable.ts b/invokeai/frontend/web/src/features/sdxl/hooks/useIsRefinerAvailable.ts deleted file mode 100644 index 3a8cd269a4..0000000000 --- a/invokeai/frontend/web/src/features/sdxl/hooks/useIsRefinerAvailable.ts +++ /dev/null @@ -1,11 +0,0 @@ -import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models'; - -export const useIsRefinerAvailable = () => { - const { isRefinerAvailable } = useGetSDXLRefinerModelsQuery(undefined, { - selectFromResult: ({ data }) => ({ - isRefinerAvailable: data ? data.ids.length > 0 : false, - }), - }); - - return isRefinerAvailable; -}; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/FoundModelsList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/FoundModelsList.tsx index 10f297ce07..a44a747438 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/FoundModelsList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/FoundModelsList.tsx @@ -16,6 +16,7 @@ import { useImportMainModelsMutation, } from 'services/api/endpoints/models'; import { setAdvancedAddScanModel } from '../../store/modelManagerSlice'; +import { ALL_BASE_MODELS } from 'services/api/constants'; export default function FoundModelsList() { const searchFolder = useAppSelector( @@ -24,7 +25,7 @@ export default function FoundModelsList() { const [nameFilter, setNameFilter] = useState(''); // Get paths of models that are already installed - const { data: installedModels } = useGetMainModelsQuery(); + const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS); // Get all model paths from a given directory const { foundModels, alreadyInstalled, filteredModels } = diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx index 19ca10e240..4ad8fbaba6 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx @@ -1,5 +1,4 @@ import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react'; -import { makeToast } from 'features/system/util/makeToast'; import { useAppDispatch } from 'app/store/storeHooks'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; @@ -8,9 +7,11 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAISlider from 'common/components/IAISlider'; import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; import { pickBy } from 'lodash-es'; import { useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { ALL_BASE_MODELS } from 'services/api/constants'; import { useGetMainModelsQuery, useMergeMainModelsMutation, @@ -32,7 +33,7 @@ export default function MergeModelsPanel() { const { t } = useTranslation(); const dispatch = useAppDispatch(); - const { data } = useGetMainModelsQuery(); + const { data } = useGetMainModelsQuery(ALL_BASE_MODELS); const [mergeModels, { isLoading }] = useMergeMainModelsMutation(); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx index f49294cfb0..87eb918564 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -8,10 +8,11 @@ import { import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import ModelList from './ModelManagerPanel/ModelList'; +import { ALL_BASE_MODELS } from 'services/api/constants'; export default function ModelManagerPanel() { const [selectedModelId, setSelectedModelId] = useState(); - const { model } = useGetMainModelsQuery(undefined, { + const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({ model: selectedModelId ? data?.entities[selectedModelId] : undefined, }), diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index 722bd83b6e..f3d0eae495 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -11,6 +11,7 @@ import { useGetMainModelsQuery, } from 'services/api/endpoints/models'; import ModelListItem from './ModelListItem'; +import { ALL_BASE_MODELS } from 'services/api/constants'; type ModelListProps = { selectedModelId: string | undefined; @@ -26,13 +27,13 @@ const ModelList = (props: ModelListProps) => { const [modelFormatFilter, setModelFormatFilter] = useState('images'); - const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, { + const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({ filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter), }), }); - const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, { + const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({ filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter), }), diff --git a/invokeai/frontend/web/src/services/api/constants.ts b/invokeai/frontend/web/src/services/api/constants.ts new file mode 100644 index 0000000000..8bf35d0198 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/constants.ts @@ -0,0 +1,16 @@ +import { BaseModelType } from './types'; + +export const ALL_BASE_MODELS: BaseModelType[] = [ + 'sd-1', + 'sd-2', + 'sdxl', + 'sdxl-refiner', +]; + +export const NON_REFINER_BASE_MODELS: BaseModelType[] = [ + 'sd-1', + 'sd-2', + 'sdxl', +]; + +export const REFINER_BASE_MODELS: BaseModelType[] = ['sdxl-refiner']; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index d4490eb7b5..3d0013a62c 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -107,9 +107,6 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query']; const mainModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); -const sdxlRefinerModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), -}); const loraModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); @@ -147,11 +144,14 @@ const createModelEntities = ( export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ - getMainModels: build.query, void>({ - query: () => { + getMainModels: build.query< + EntityState, + BaseModelType[] + >({ + query: (base_models) => { const params = { model_type: 'main', - base_models: ['sd-1', 'sd-2', 'sdxl'], + base_models, }; const query = queryString.stringify(params, { arrayFormat: 'none' }); @@ -187,43 +187,6 @@ export const modelsApi = api.injectEndpoints({ ); }, }), - getSDXLRefinerModels: build.query, void>( - { - query: () => ({ - url: 'models/', - params: { model_type: 'main', base_models: ['sdxl-refiner'] }, - }), - providesTags: (result, error, arg) => { - const tags: ApiFullTagDescription[] = [ - { type: 'SDXLRefinerModel', id: LIST_TAG }, - ]; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'SDXLRefinerModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: ( - response: { models: MainModelConfig[] }, - meta, - arg - ) => { - const entities = createModelEntities( - response.models - ); - return sdxlRefinerModelsAdapter.setAll( - sdxlRefinerModelsAdapter.getInitialState(), - entities - ); - }, - } - ), updateMainModels: build.mutation< UpdateMainModelResponse, UpdateMainModelArg @@ -494,7 +457,6 @@ export const modelsApi = api.injectEndpoints({ export const { useGetMainModelsQuery, - useGetSDXLRefinerModelsQuery, useGetControlNetModelsQuery, useGetLoRAModelsQuery, useGetTextualInversionModelsQuery, diff --git a/invokeai/frontend/web/src/services/api/hooks/useIsRefinerAvailable.ts b/invokeai/frontend/web/src/services/api/hooks/useIsRefinerAvailable.ts new file mode 100644 index 0000000000..4cb4891be4 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/hooks/useIsRefinerAvailable.ts @@ -0,0 +1,12 @@ +import { REFINER_BASE_MODELS } from 'services/api/constants'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; + +export const useIsRefinerAvailable = () => { + const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, { + selectFromResult: ({ data }) => ({ + isRefinerAvailable: data ? data.ids.length > 0 : false, + }), + }); + + return isRefinerAvailable; +};