mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): display refiner models in mm
This commit is contained in:
parent
5d4a571778
commit
3006285d13
@ -684,6 +684,7 @@
|
|||||||
"noModelsInstalled": "No Models Installed",
|
"noModelsInstalled": "No Models Installed",
|
||||||
"noModelsInstalledDesc1": "Install models with the",
|
"noModelsInstalledDesc1": "Install models with the",
|
||||||
"noModelSelected": "No Model Selected",
|
"noModelSelected": "No Model Selected",
|
||||||
|
"noMatchingModels": "No matching Models",
|
||||||
"none": "none",
|
"none": "none",
|
||||||
"path": "Path",
|
"path": "Path",
|
||||||
"pathToConfig": "Path To Config",
|
"pathToConfig": "Path To Config",
|
||||||
|
@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
|||||||
import type { PersistConfig } from 'app/store/store';
|
import type { PersistConfig } from 'app/store/store';
|
||||||
import type { ModelType } from 'services/api/types';
|
import type { ModelType } from 'services/api/types';
|
||||||
|
|
||||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
|
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
||||||
|
|
||||||
type ModelManagerState = {
|
type ModelManagerState = {
|
||||||
_version: 1;
|
_version: 1;
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { Flex } from '@invoke-ai/ui-library';
|
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
|
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import {
|
||||||
@ -9,10 +10,11 @@ import {
|
|||||||
useIPAdapterModels,
|
useIPAdapterModels,
|
||||||
useLoRAModels,
|
useLoRAModels,
|
||||||
useMainModels,
|
useMainModels,
|
||||||
|
useRefinerModels,
|
||||||
useT2IAdapterModels,
|
useT2IAdapterModels,
|
||||||
useVAEModels,
|
useVAEModels,
|
||||||
} from 'services/api/hooks/modelsByType';
|
} from 'services/api/hooks/modelsByType';
|
||||||
import type { AnyModelConfig, ModelType } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { FetchingModelsLoader } from './FetchingModelsLoader';
|
import { FetchingModelsLoader } from './FetchingModelsLoader';
|
||||||
import { ModelListWrapper } from './ModelListWrapper';
|
import { ModelListWrapper } from './ModelListWrapper';
|
||||||
@ -27,6 +29,12 @@ const ModelList = () => {
|
|||||||
[mainModels, searchTerm, filteredModelType]
|
[mainModels, searchTerm, filteredModelType]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const [refinerModels, { isLoading: isLoadingRefinerModels }] = useRefinerModels();
|
||||||
|
const filteredRefinerModels = useMemo(
|
||||||
|
() => modelsFilter(refinerModels, searchTerm, filteredModelType),
|
||||||
|
[refinerModels, searchTerm, filteredModelType]
|
||||||
|
);
|
||||||
|
|
||||||
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
|
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
|
||||||
const filteredLoRAModels = useMemo(
|
const filteredLoRAModels = useMemo(
|
||||||
() => modelsFilter(loraModels, searchTerm, filteredModelType),
|
() => modelsFilter(loraModels, searchTerm, filteredModelType),
|
||||||
@ -63,6 +71,28 @@ const ModelList = () => {
|
|||||||
[vaeModels, searchTerm, filteredModelType]
|
[vaeModels, searchTerm, filteredModelType]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const totalFilteredModels = useMemo(() => {
|
||||||
|
return (
|
||||||
|
filteredMainModels.length +
|
||||||
|
filteredRefinerModels.length +
|
||||||
|
filteredLoRAModels.length +
|
||||||
|
filteredEmbeddingModels.length +
|
||||||
|
filteredControlNetModels.length +
|
||||||
|
filteredT2IAdapterModels.length +
|
||||||
|
filteredIPAdapterModels.length +
|
||||||
|
filteredVAEModels.length
|
||||||
|
);
|
||||||
|
}, [
|
||||||
|
filteredControlNetModels.length,
|
||||||
|
filteredEmbeddingModels.length,
|
||||||
|
filteredIPAdapterModels.length,
|
||||||
|
filteredLoRAModels.length,
|
||||||
|
filteredMainModels.length,
|
||||||
|
filteredRefinerModels.length,
|
||||||
|
filteredT2IAdapterModels.length,
|
||||||
|
filteredVAEModels.length,
|
||||||
|
]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
||||||
@ -71,6 +101,11 @@ const ModelList = () => {
|
|||||||
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
||||||
<ModelListWrapper title={t('modelManager.main')} modelList={filteredMainModels} key="main" />
|
<ModelListWrapper title={t('modelManager.main')} modelList={filteredMainModels} key="main" />
|
||||||
)}
|
)}
|
||||||
|
{/* Refiner Model List */}
|
||||||
|
{isLoadingRefinerModels && <FetchingModelsLoader loadingMessage="Loading Refiner Models..." />}
|
||||||
|
{!isLoadingRefinerModels && filteredRefinerModels.length > 0 && (
|
||||||
|
<ModelListWrapper title={t('sdxl.refiner')} modelList={filteredRefinerModels} key="refiner" />
|
||||||
|
)}
|
||||||
{/* LoRAs List */}
|
{/* LoRAs List */}
|
||||||
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||||
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
|
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
|
||||||
@ -108,6 +143,11 @@ const ModelList = () => {
|
|||||||
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
||||||
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
||||||
)}
|
)}
|
||||||
|
{totalFilteredModels === 0 && (
|
||||||
|
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||||
|
<Text>{t('modelManager.noMatchingModels')}</Text>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
);
|
);
|
||||||
@ -118,12 +158,24 @@ export default memo(ModelList);
|
|||||||
const modelsFilter = <T extends AnyModelConfig>(
|
const modelsFilter = <T extends AnyModelConfig>(
|
||||||
data: T[],
|
data: T[],
|
||||||
nameFilter: string,
|
nameFilter: string,
|
||||||
filteredModelType: ModelType | null
|
filteredModelType: FilterableModelType | null
|
||||||
): T[] => {
|
): T[] => {
|
||||||
return data.filter((model) => {
|
return data.filter((model) => {
|
||||||
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
||||||
const matchesType = filteredModelType ? model.type === filteredModelType : true;
|
const matchesType = getMatchesType(model, filteredModelType);
|
||||||
|
|
||||||
return matchesFilter && matchesType;
|
return matchesFilter && matchesType;
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const getMatchesType = (modelConfig: AnyModelConfig, filteredModelType: FilterableModelType | null): boolean => {
|
||||||
|
if (filteredModelType === 'refiner') {
|
||||||
|
return modelConfig.base === 'sdxl-refiner';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (filteredModelType === 'main' && modelConfig.base === 'sdxl-refiner') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredModelType ? modelConfig.type === filteredModelType : true;
|
||||||
|
};
|
||||||
|
@ -13,6 +13,7 @@ export const ModelTypeFilter = () => {
|
|||||||
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
||||||
() => ({
|
() => ({
|
||||||
main: t('modelManager.main'),
|
main: t('modelManager.main'),
|
||||||
|
refiner: t('sdxl.refiner'),
|
||||||
lora: 'LoRA',
|
lora: 'LoRA',
|
||||||
embedding: t('modelManager.textualInversions'),
|
embedding: t('modelManager.textualInversions'),
|
||||||
controlnet: 'ControlNet',
|
controlnet: 'ControlNet',
|
||||||
|
Loading…
Reference in New Issue
Block a user