From 3006285d1358bbb97325b7d76ffdd332484263c6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 5 Apr 2024 09:21:13 +1100 Subject: [PATCH] fix(ui): display refiner models in mm --- invokeai/frontend/web/public/locales/en.json | 1 + .../store/modelManagerV2Slice.ts | 2 +- .../subpanels/ModelManagerPanel/ModelList.tsx | 60 +++++++++++++++++-- .../ModelManagerPanel/ModelTypeFilter.tsx | 1 + 4 files changed, 59 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 96fb8e5748..0cf98289e4 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -684,6 +684,7 @@ "noModelsInstalled": "No Models Installed", "noModelsInstalledDesc1": "Install models with the", "noModelSelected": "No Model Selected", + "noMatchingModels": "No matching Models", "none": "none", "path": "Path", "pathToConfig": "Path To Config", diff --git a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts index 6bdd829bb1..c637d30fd8 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts @@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig } from 'app/store/store'; import type { ModelType } from 'services/api/types'; -export type FilterableModelType = Exclude; +export type FilterableModelType = Exclude | 'refiner'; type ModelManagerState = { _version: 1; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx index 033841ec79..67e65dbfb6 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx @@ -1,6 +1,7 @@ -import { Flex } from '@invoke-ai/ui-library'; +import { Flex, Text } from '@invoke-ai/ui-library'; import { useAppSelector } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; +import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { @@ -9,10 +10,11 @@ import { useIPAdapterModels, useLoRAModels, useMainModels, + useRefinerModels, useT2IAdapterModels, useVAEModels, } from 'services/api/hooks/modelsByType'; -import type { AnyModelConfig, ModelType } from 'services/api/types'; +import type { AnyModelConfig } from 'services/api/types'; import { FetchingModelsLoader } from './FetchingModelsLoader'; import { ModelListWrapper } from './ModelListWrapper'; @@ -27,6 +29,12 @@ const ModelList = () => { [mainModels, searchTerm, filteredModelType] ); + const [refinerModels, { isLoading: isLoadingRefinerModels }] = useRefinerModels(); + const filteredRefinerModels = useMemo( + () => modelsFilter(refinerModels, searchTerm, filteredModelType), + [refinerModels, searchTerm, filteredModelType] + ); + const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels(); const filteredLoRAModels = useMemo( () => modelsFilter(loraModels, searchTerm, filteredModelType), @@ -63,6 +71,28 @@ const ModelList = () => { [vaeModels, searchTerm, filteredModelType] ); + const totalFilteredModels = useMemo(() => { + return ( + filteredMainModels.length + + filteredRefinerModels.length + + filteredLoRAModels.length + + filteredEmbeddingModels.length + + filteredControlNetModels.length + + filteredT2IAdapterModels.length + + filteredIPAdapterModels.length + + filteredVAEModels.length + ); + }, [ + filteredControlNetModels.length, + filteredEmbeddingModels.length, + filteredIPAdapterModels.length, + filteredLoRAModels.length, + filteredMainModels.length, + filteredRefinerModels.length, + filteredT2IAdapterModels.length, + filteredVAEModels.length, + ]); + return ( @@ -71,6 +101,11 @@ const ModelList = () => { {!isLoadingMainModels && filteredMainModels.length > 0 && ( )} + {/* Refiner Model List */} + {isLoadingRefinerModels && } + {!isLoadingRefinerModels && filteredRefinerModels.length > 0 && ( + + )} {/* LoRAs List */} {isLoadingLoRAModels && } {!isLoadingLoRAModels && filteredLoRAModels.length > 0 && ( @@ -108,6 +143,11 @@ const ModelList = () => { {!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && ( )} + {totalFilteredModels === 0 && ( + + {t('modelManager.noMatchingModels')} + + )} ); @@ -118,12 +158,24 @@ export default memo(ModelList); const modelsFilter = ( data: T[], nameFilter: string, - filteredModelType: ModelType | null + filteredModelType: FilterableModelType | null ): T[] => { return data.filter((model) => { const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase()); - const matchesType = filteredModelType ? model.type === filteredModelType : true; + const matchesType = getMatchesType(model, filteredModelType); return matchesFilter && matchesType; }); }; + +const getMatchesType = (modelConfig: AnyModelConfig, filteredModelType: FilterableModelType | null): boolean => { + if (filteredModelType === 'refiner') { + return modelConfig.base === 'sdxl-refiner'; + } + + if (filteredModelType === 'main' && modelConfig.base === 'sdxl-refiner') { + return false; + } + + return filteredModelType ? modelConfig.type === filteredModelType : true; +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx index 0b8ad3f600..76802b36e7 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx @@ -13,6 +13,7 @@ export const ModelTypeFilter = () => { const MODEL_TYPE_LABELS: Record = useMemo( () => ({ main: t('modelManager.main'), + refiner: t('sdxl.refiner'), lora: 'LoRA', embedding: t('modelManager.textualInversions'), controlnet: 'ControlNet',