mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: Model Manager Tab Issues
This commit is contained in:
committed by
psychedelicious
parent
41d6a38690
commit
ce687b28ef
@ -1,4 +1,4 @@
|
|||||||
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
|
import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react';
|
||||||
import { EntityState } from '@reduxjs/toolkit';
|
import { EntityState } from '@reduxjs/toolkit';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
@ -6,23 +6,23 @@ import { forEach } from 'lodash-es';
|
|||||||
import type { ChangeEvent, PropsWithChildren } from 'react';
|
import type { ChangeEvent, PropsWithChildren } from 'react';
|
||||||
import { useCallback, useState } from 'react';
|
import { useCallback, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
import {
|
import {
|
||||||
|
LoRAModelConfigEntity,
|
||||||
MainModelConfigEntity,
|
MainModelConfigEntity,
|
||||||
OnnxModelConfigEntity,
|
OnnxModelConfigEntity,
|
||||||
|
useGetLoRAModelsQuery,
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
useGetOnnxModelsQuery,
|
useGetOnnxModelsQuery,
|
||||||
useGetLoRAModelsQuery,
|
|
||||||
LoRAModelConfigEntity,
|
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import ModelListItem from './ModelListItem';
|
import ModelListItem from './ModelListItem';
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
|
||||||
|
|
||||||
type ModelListProps = {
|
type ModelListProps = {
|
||||||
selectedModelId: string | undefined;
|
selectedModelId: string | undefined;
|
||||||
setSelectedModelId: (name: string | undefined) => void;
|
setSelectedModelId: (name: string | undefined) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
||||||
|
|
||||||
type ModelType = 'main' | 'lora' | 'onnx';
|
type ModelType = 'main' | 'lora' | 'onnx';
|
||||||
|
|
||||||
@ -33,35 +33,43 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [nameFilter, setNameFilter] = useState<string>('');
|
const [nameFilter, setNameFilter] = useState<string>('');
|
||||||
const [modelFormatFilter, setModelFormatFilter] =
|
const [modelFormatFilter, setModelFormatFilter] =
|
||||||
useState<CombinedModelFormat>('images');
|
useState<CombinedModelFormat>('all');
|
||||||
|
|
||||||
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
const { filteredDiffusersModels, isDiffusersModelLoading } =
|
||||||
selectFromResult: ({ data }) => ({
|
useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
|
selectFromResult: ({ data, isLoading }) => ({
|
||||||
filteredDiffusersModels: modelsFilter(
|
filteredDiffusersModels: modelsFilter(
|
||||||
data,
|
data,
|
||||||
'main',
|
'main',
|
||||||
'diffusers',
|
'diffusers',
|
||||||
nameFilter
|
nameFilter
|
||||||
),
|
),
|
||||||
|
isDiffusersModelLoading: isLoading,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
const { filteredCheckpointModels, isCheckpointModelLoading } =
|
||||||
selectFromResult: ({ data }) => ({
|
useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
|
selectFromResult: ({ data, isLoading }) => ({
|
||||||
filteredCheckpointModels: modelsFilter(
|
filteredCheckpointModels: modelsFilter(
|
||||||
data,
|
data,
|
||||||
'main',
|
'main',
|
||||||
'checkpoint',
|
'checkpoint',
|
||||||
nameFilter
|
nameFilter
|
||||||
),
|
),
|
||||||
|
isCheckpointModelLoading: isLoading,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
|
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(
|
||||||
selectFromResult: ({ data }) => ({
|
undefined,
|
||||||
|
{
|
||||||
|
selectFromResult: ({ data, isLoading }) => ({
|
||||||
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
||||||
|
isLoadingLoraModels: isLoading,
|
||||||
}),
|
}),
|
||||||
});
|
}
|
||||||
|
);
|
||||||
|
|
||||||
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||||
selectFromResult: ({ data }) => ({
|
selectFromResult: ({ data }) => ({
|
||||||
@ -79,13 +87,47 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
setNameFilter(e.target.value);
|
setNameFilter(e.target.value);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const renderModelList = (
|
||||||
|
filterArray: Partial<CombinedModelFormat>[],
|
||||||
|
isLoading: boolean,
|
||||||
|
loadingMessage: string,
|
||||||
|
title: string,
|
||||||
|
modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[]
|
||||||
|
) => {
|
||||||
|
if (!filterArray.includes(modelFormatFilter)) return;
|
||||||
|
|
||||||
|
if (isLoading) {
|
||||||
|
return <FetchingModelsLoader loadingMessage={loadingMessage} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (modelList.length === 0) return;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<StyledModelContainer>
|
||||||
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
|
<Text variant="subtext" fontSize="sm">
|
||||||
|
{title}
|
||||||
|
</Text>
|
||||||
|
{modelList.map((model) => (
|
||||||
|
<ModelListItem
|
||||||
|
key={model.id}
|
||||||
|
model={model}
|
||||||
|
isSelected={selectedModelId === model.id}
|
||||||
|
setSelectedModelId={setSelectedModelId}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</StyledModelContainer>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
||||||
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
|
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
|
||||||
<ButtonGroup isAttached>
|
<ButtonGroup isAttached>
|
||||||
<IAIButton
|
<IAIButton
|
||||||
onClick={() => setModelFormatFilter('images')}
|
onClick={() => setModelFormatFilter('all')}
|
||||||
isChecked={modelFormatFilter === 'images'}
|
isChecked={modelFormatFilter === 'all'}
|
||||||
size="sm"
|
size="sm"
|
||||||
>
|
>
|
||||||
{t('modelManager.allModels')}
|
{t('modelManager.allModels')}
|
||||||
@ -287,3 +329,26 @@ const StyledModelContainer = (props: PropsWithChildren) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const FetchingModelsLoader = ({
|
||||||
|
loadingMessage,
|
||||||
|
}: {
|
||||||
|
loadingMessage?: string;
|
||||||
|
}) => {
|
||||||
|
return (
|
||||||
|
<StyledModelContainer>
|
||||||
|
<Flex
|
||||||
|
justifyContent="center"
|
||||||
|
alignItems="center"
|
||||||
|
flexDirection="column"
|
||||||
|
p={4}
|
||||||
|
gap={8}
|
||||||
|
>
|
||||||
|
<Spinner />
|
||||||
|
<Text variant="subtext">
|
||||||
|
{loadingMessage ? loadingMessage : 'Fetching...'}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
</StyledModelContainer>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
Reference in New Issue
Block a user