diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index 3f100d9072..44626c04c3 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -1,4 +1,4 @@ -import { ButtonGroup, Flex, Text } from '@chakra-ui/react'; +import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react'; import { EntityState } from '@reduxjs/toolkit'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; @@ -6,23 +6,23 @@ import { forEach } from 'lodash-es'; import type { ChangeEvent, PropsWithChildren } from 'react'; import { useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { ALL_BASE_MODELS } from 'services/api/constants'; import { + LoRAModelConfigEntity, MainModelConfigEntity, OnnxModelConfigEntity, + useGetLoRAModelsQuery, useGetMainModelsQuery, useGetOnnxModelsQuery, - useGetLoRAModelsQuery, - LoRAModelConfigEntity, } from 'services/api/endpoints/models'; import ModelListItem from './ModelListItem'; -import { ALL_BASE_MODELS } from 'services/api/constants'; type ModelListProps = { selectedModelId: string | undefined; setSelectedModelId: (name: string | undefined) => void; }; -type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx'; +type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx'; type ModelType = 'main' | 'lora' | 'onnx'; @@ -33,47 +33,63 @@ const ModelList = (props: ModelListProps) => { const { t } = useTranslation(); const [nameFilter, setNameFilter] = useState(''); const [modelFormatFilter, setModelFormatFilter] = - useState('images'); + useState('all'); - const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { - selectFromResult: ({ data }) => ({ - filteredDiffusersModels: modelsFilter( - data, - 'main', - 'diffusers', - nameFilter - ), - }), - }); + const { filteredDiffusersModels, isLoadingDiffusersModels } = + useGetMainModelsQuery(ALL_BASE_MODELS, { + selectFromResult: ({ data, isLoading }) => ({ + filteredDiffusersModels: modelsFilter( + data, + 'main', + 'diffusers', + nameFilter + ), + isLoadingDiffusersModels: isLoading, + }), + }); - const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { - selectFromResult: ({ data }) => ({ - filteredCheckpointModels: modelsFilter( - data, - 'main', - 'checkpoint', - nameFilter - ), - }), - }); + const { filteredCheckpointModels, isLoadingCheckpointModels } = + useGetMainModelsQuery(ALL_BASE_MODELS, { + selectFromResult: ({ data, isLoading }) => ({ + filteredCheckpointModels: modelsFilter( + data, + 'main', + 'checkpoint', + nameFilter + ), + isLoadingCheckpointModels: isLoading, + }), + }); - const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, { - selectFromResult: ({ data }) => ({ - filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter), - }), - }); + const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery( + undefined, + { + selectFromResult: ({ data, isLoading }) => ({ + filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter), + isLoadingLoraModels: isLoading, + }), + } + ); - const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, { - selectFromResult: ({ data }) => ({ - filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter), - }), - }); + const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery( + ALL_BASE_MODELS, + { + selectFromResult: ({ data, isLoading }) => ({ + filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter), + isLoadingOnnxModels: isLoading, + }), + } + ); - const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, { - selectFromResult: ({ data }) => ({ - filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter), - }), - }); + const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery( + ALL_BASE_MODELS, + { + selectFromResult: ({ data, isLoading }) => ({ + filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter), + isLoadingOliveModels: isLoading, + }), + } + ); const handleSearchFilter = useCallback((e: ChangeEvent) => { setNameFilter(e.target.value); @@ -84,8 +100,8 @@ const ModelList = (props: ModelListProps) => { setModelFormatFilter('images')} - isChecked={modelFormatFilter === 'images'} + onClick={() => setModelFormatFilter('all')} + isChecked={modelFormatFilter === 'all'} size="sm" > {t('modelManager.allModels')} @@ -139,95 +155,76 @@ const ModelList = (props: ModelListProps) => { maxHeight={window.innerHeight - 280} overflow="scroll" > - {['images', 'diffusers'].includes(modelFormatFilter) && + {/* Diffusers List */} + {isLoadingDiffusersModels && ( + + )} + {['all', 'diffusers'].includes(modelFormatFilter) && + !isLoadingDiffusersModels && filteredDiffusersModels.length > 0 && ( - - - - Diffusers - - {filteredDiffusersModels.map((model) => ( - - ))} - - + )} - {['images', 'checkpoint'].includes(modelFormatFilter) && + {/* Checkpoints List */} + {isLoadingCheckpointModels && ( + + )} + {['all', 'checkpoint'].includes(modelFormatFilter) && + !isLoadingCheckpointModels && filteredCheckpointModels.length > 0 && ( - - - - Checkpoints - - {filteredCheckpointModels.map((model) => ( - - ))} - - + )} - {['images', 'olive'].includes(modelFormatFilter) && - filteredOliveModels.length > 0 && ( - - - - Olives - - {filteredOliveModels.map((model) => ( - - ))} - - - )} - {['images', 'onnx'].includes(modelFormatFilter) && - filteredOnnxModels.length > 0 && ( - - - - Onnx - - {filteredOnnxModels.map((model) => ( - - ))} - - - )} - {['images', 'lora'].includes(modelFormatFilter) && + + {/* LoRAs List */} + {isLoadingLoraModels && ( + + )} + {['all', 'lora'].includes(modelFormatFilter) && + !isLoadingLoraModels && filteredLoraModels.length > 0 && ( - - - - LoRAs - - {filteredLoraModels.map((model) => ( - - ))} - - + + )} + {/* Olive List */} + {isLoadingOliveModels && ( + + )} + {['all', 'olive'].includes(modelFormatFilter) && + !isLoadingOliveModels && + filteredOliveModels.length > 0 && ( + + )} + {/* Onnx List */} + {isLoadingOnnxModels && ( + + )} + {['all', 'onnx'].includes(modelFormatFilter) && + !isLoadingOnnxModels && + filteredOnnxModels.length > 0 && ( + )} @@ -287,3 +284,52 @@ const StyledModelContainer = (props: PropsWithChildren) => { ); }; + +type ModelListWrapperProps = { + title: string; + modelList: + | MainModelConfigEntity[] + | LoRAModelConfigEntity[] + | OnnxModelConfigEntity[]; + selected: ModelListProps; +}; + +function ModelListWrapper(props: ModelListWrapperProps) { + const { title, modelList, selected } = props; + return ( + + + + {title} + + {modelList.map((model) => ( + + ))} + + + ); +} + +function FetchingModelsLoader({ loadingMessage }: { loadingMessage?: string }) { + return ( + + + + + {loadingMessage ? loadingMessage : 'Fetching...'} + + + + ); +} diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index e994d3fd3a..a7b1323f36 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -181,7 +181,7 @@ export const modelsApi = api.injectEndpoints({ }, providesTags: (result, error, arg) => { const tags: ApiFullTagDescription[] = [ - { id: 'OnnxModel', type: LIST_TAG }, + { type: 'OnnxModel', id: LIST_TAG }, ]; if (result) { @@ -266,6 +266,7 @@ export const modelsApi = api.injectEndpoints({ invalidatesTags: [ { type: 'MainModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG }, + { type: 'OnnxModel', id: LIST_TAG }, ], }), importMainModels: build.mutation< @@ -282,6 +283,7 @@ export const modelsApi = api.injectEndpoints({ invalidatesTags: [ { type: 'MainModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG }, + { type: 'OnnxModel', id: LIST_TAG }, ], }), addMainModels: build.mutation({ @@ -295,6 +297,7 @@ export const modelsApi = api.injectEndpoints({ invalidatesTags: [ { type: 'MainModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG }, + { type: 'OnnxModel', id: LIST_TAG }, ], }), deleteMainModels: build.mutation< @@ -310,6 +313,7 @@ export const modelsApi = api.injectEndpoints({ invalidatesTags: [ { type: 'MainModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG }, + { type: 'OnnxModel', id: LIST_TAG }, ], }), convertMainModels: build.mutation< @@ -326,6 +330,7 @@ export const modelsApi = api.injectEndpoints({ invalidatesTags: [ { type: 'MainModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG }, + { type: 'OnnxModel', id: LIST_TAG }, ], }), mergeMainModels: build.mutation({ @@ -339,6 +344,7 @@ export const modelsApi = api.injectEndpoints({ invalidatesTags: [ { type: 'MainModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG }, + { type: 'OnnxModel', id: LIST_TAG }, ], }), syncModels: build.mutation({ @@ -351,6 +357,7 @@ export const modelsApi = api.injectEndpoints({ invalidatesTags: [ { type: 'MainModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG }, + { type: 'OnnxModel', id: LIST_TAG }, ], }), getLoRAModels: build.query, void>({