fix(ui): display refiner models in mm

This commit is contained in:
psychedelicious 2024-04-05 09:21:13 +11:00
parent 5d4a571778
commit 3006285d13
4 changed files with 59 additions and 5 deletions

View File

@ -684,6 +684,7 @@
"noModelsInstalled": "No Models Installed", "noModelsInstalled": "No Models Installed",
"noModelsInstalledDesc1": "Install models with the", "noModelsInstalledDesc1": "Install models with the",
"noModelSelected": "No Model Selected", "noModelSelected": "No Model Selected",
"noMatchingModels": "No matching Models",
"none": "none", "none": "none",
"path": "Path", "path": "Path",
"pathToConfig": "Path To Config", "pathToConfig": "Path To Config",

View File

@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig } from 'app/store/store'; import type { PersistConfig } from 'app/store/store';
import type { ModelType } from 'services/api/types'; import type { ModelType } from 'services/api/types';
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>; export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
type ModelManagerState = { type ModelManagerState = {
_version: 1; _version: 1;

View File

@ -1,6 +1,7 @@
import { Flex } from '@invoke-ai/ui-library'; import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { import {
@ -9,10 +10,11 @@ import {
useIPAdapterModels, useIPAdapterModels,
useLoRAModels, useLoRAModels,
useMainModels, useMainModels,
useRefinerModels,
useT2IAdapterModels, useT2IAdapterModels,
useVAEModels, useVAEModels,
} from 'services/api/hooks/modelsByType'; } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ModelType } from 'services/api/types'; import type { AnyModelConfig } from 'services/api/types';
import { FetchingModelsLoader } from './FetchingModelsLoader'; import { FetchingModelsLoader } from './FetchingModelsLoader';
import { ModelListWrapper } from './ModelListWrapper'; import { ModelListWrapper } from './ModelListWrapper';
@ -27,6 +29,12 @@ const ModelList = () => {
[mainModels, searchTerm, filteredModelType] [mainModels, searchTerm, filteredModelType]
); );
const [refinerModels, { isLoading: isLoadingRefinerModels }] = useRefinerModels();
const filteredRefinerModels = useMemo(
() => modelsFilter(refinerModels, searchTerm, filteredModelType),
[refinerModels, searchTerm, filteredModelType]
);
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels(); const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
const filteredLoRAModels = useMemo( const filteredLoRAModels = useMemo(
() => modelsFilter(loraModels, searchTerm, filteredModelType), () => modelsFilter(loraModels, searchTerm, filteredModelType),
@ -63,6 +71,28 @@ const ModelList = () => {
[vaeModels, searchTerm, filteredModelType] [vaeModels, searchTerm, filteredModelType]
); );
const totalFilteredModels = useMemo(() => {
return (
filteredMainModels.length +
filteredRefinerModels.length +
filteredLoRAModels.length +
filteredEmbeddingModels.length +
filteredControlNetModels.length +
filteredT2IAdapterModels.length +
filteredIPAdapterModels.length +
filteredVAEModels.length
);
}, [
filteredControlNetModels.length,
filteredEmbeddingModels.length,
filteredIPAdapterModels.length,
filteredLoRAModels.length,
filteredMainModels.length,
filteredRefinerModels.length,
filteredT2IAdapterModels.length,
filteredVAEModels.length,
]);
return ( return (
<ScrollableContent> <ScrollableContent>
<Flex flexDirection="column" w="full" h="full" gap={4}> <Flex flexDirection="column" w="full" h="full" gap={4}>
@ -71,6 +101,11 @@ const ModelList = () => {
{!isLoadingMainModels && filteredMainModels.length > 0 && ( {!isLoadingMainModels && filteredMainModels.length > 0 && (
<ModelListWrapper title={t('modelManager.main')} modelList={filteredMainModels} key="main" /> <ModelListWrapper title={t('modelManager.main')} modelList={filteredMainModels} key="main" />
)} )}
{/* Refiner Model List */}
{isLoadingRefinerModels && <FetchingModelsLoader loadingMessage="Loading Refiner Models..." />}
{!isLoadingRefinerModels && filteredRefinerModels.length > 0 && (
<ModelListWrapper title={t('sdxl.refiner')} modelList={filteredRefinerModels} key="refiner" />
)}
{/* LoRAs List */} {/* LoRAs List */}
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />} {isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && ( {!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
@ -108,6 +143,11 @@ const ModelList = () => {
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && ( {!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" /> <ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
)} )}
{totalFilteredModels === 0 && (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Text>{t('modelManager.noMatchingModels')}</Text>
</Flex>
)}
</Flex> </Flex>
</ScrollableContent> </ScrollableContent>
); );
@ -118,12 +158,24 @@ export default memo(ModelList);
const modelsFilter = <T extends AnyModelConfig>( const modelsFilter = <T extends AnyModelConfig>(
data: T[], data: T[],
nameFilter: string, nameFilter: string,
filteredModelType: ModelType | null filteredModelType: FilterableModelType | null
): T[] => { ): T[] => {
return data.filter((model) => { return data.filter((model) => {
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase()); const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
const matchesType = filteredModelType ? model.type === filteredModelType : true; const matchesType = getMatchesType(model, filteredModelType);
return matchesFilter && matchesType; return matchesFilter && matchesType;
}); });
}; };
const getMatchesType = (modelConfig: AnyModelConfig, filteredModelType: FilterableModelType | null): boolean => {
if (filteredModelType === 'refiner') {
return modelConfig.base === 'sdxl-refiner';
}
if (filteredModelType === 'main' && modelConfig.base === 'sdxl-refiner') {
return false;
}
return filteredModelType ? modelConfig.type === filteredModelType : true;
};

View File

@ -13,6 +13,7 @@ export const ModelTypeFilter = () => {
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo( const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
() => ({ () => ({
main: t('modelManager.main'), main: t('modelManager.main'),
refiner: t('sdxl.refiner'),
lora: 'LoRA', lora: 'LoRA',
embedding: t('modelManager.textualInversions'), embedding: t('modelManager.textualInversions'),
controlnet: 'ControlNet', controlnet: 'ControlNet',