diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts index 605aa8b162..5fd23f4c1d 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts @@ -3,13 +3,8 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { validateSeedWeights } from 'common/util/seedWeightPairs'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; -import { systemSelector } from 'features/system/store/systemSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { - modelsApi, - useGetMainModelsQuery, -} from '../../services/api/endpoints/models'; +import { modelsApi } from '../../services/api/endpoints/models'; const readinessSelector = createSelector( [stateSelector, activeTabNameSelector], @@ -38,7 +33,10 @@ const readinessSelector = createSelector( } const { isSuccess: mainModelsSuccessfullyLoaded } = - modelsApi.endpoints.getMainModels.select()(state); + modelsApi.endpoints.getMainModels.select({ + model_type: 'main', + base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], + })(state); if (!mainModelsSuccessfullyLoaded) { isReady = false; reasonsWhyNotReady.push('Models are not loaded'); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index ee739e1002..3f03c76d50 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -22,7 +22,10 @@ const ModelInputFieldComponent = ( const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: mainModels } = useGetMainModelsQuery(); + const { data: mainModels } = useGetMainModelsQuery({ + model_type: 'main', + base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], + }); const data = useMemo(() => { if (!mainModels) { diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index 40a6a1203b..3c3349469d 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -25,7 +25,10 @@ const ModelSelect = () => { (state: RootState) => state.generation.model ); - const { data: mainModels, isLoading } = useGetMainModelsQuery(); + const { data: mainModels, isLoading } = useGetMainModelsQuery({ + model_type: 'main', + base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], + }); const data = useMemo(() => { if (!mainModels) { diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx index b71b5636b4..989821dbce 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx @@ -16,7 +16,10 @@ export default function MergeModelsPanel() { const dispatch = useAppDispatch(); - const { data } = useGetMainModelsQuery(); + const { data } = useGetMainModelsQuery({ + model_type: 'main', + base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], + }); const diffusersModels = pickBy( data?.entities, diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx index b22a303571..c2d6740941 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -8,7 +8,10 @@ import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import ModelList from './ModelManagerPanel/ModelList'; export default function ModelManagerPanel() { - const { data: mainModels } = useGetMainModelsQuery(); + const { data: mainModels } = useGetMainModelsQuery({ + model_type: 'main', + base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], + }); const openModel = useAppSelector( (state: RootState) => state.system.openModel diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index eb05e70357..6f0a6ab659 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -36,7 +36,10 @@ function ModelFilterButton({ } const ModelList = () => { - const { data: mainModels } = useGetMainModelsQuery(); + const { data: mainModels } = useGetMainModelsQuery({ + model_type: 'main', + base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], + }); const [renderModelList, setRenderModelList] = React.useState(false); diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index f358c97e6b..233424a2d8 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -2,9 +2,11 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; import { cloneDeep } from 'lodash-es'; import { AnyModelConfig, + BaseModelType, ControlNetModelConfig, LoRAModelConfig, MainModelConfig, + ModelType, TextualInversionModelConfig, VaeModelConfig, } from 'services/api/types'; @@ -68,21 +70,19 @@ const createModelEntities = ( return entityArray; }; +type MainModelQueryArg = { + model_type: ModelType; + base_models: BaseModelType[]; +}; + export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ - getMainModels: build.query, void>({ - query: () => { - const baseModels = { - base_models: ['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], - }; - const baseModelsQueryStr = queryString.stringify(baseModels, {}); - return { - url: `models/?${baseModelsQueryStr}`, - params: { - model_type: 'main', - }, - }; - }, + getMainModels: build.query< + EntityState, + MainModelQueryArg + >({ + query: (arg: MainModelQueryArg) => + `models/?${queryString.stringify(arg)}`, providesTags: (result, error, arg) => { const tags: ApiFullTagDescription[] = [ { id: 'MainModel', type: LIST_TAG },