fix: Model Manager Tab Issues

This commit is contained in:
blessedcoolant 2023-07-31 12:51:30 +12:00 committed by psychedelicious
parent 41d6a38690
commit ce687b28ef

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
@ -6,23 +6,23 @@ import { forEach } from 'lodash-es';
import type { ChangeEvent, PropsWithChildren } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import {
LoRAModelConfigEntity,
MainModelConfigEntity,
OnnxModelConfigEntity,
useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetOnnxModelsQuery,
useGetLoRAModelsQuery,
LoRAModelConfigEntity,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
type ModelListProps = {
selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void;
};
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
type ModelType = 'main' | 'lora' | 'onnx';
@ -33,35 +33,43 @@ const ModelList = (props: ModelListProps) => {
const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] =
useState<CombinedModelFormat>('images');
useState<CombinedModelFormat>('all');
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredDiffusersModels: modelsFilter(
data,
'main',
'diffusers',
nameFilter
),
}),
});
const { filteredDiffusersModels, isDiffusersModelLoading } =
useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data, isLoading }) => ({
filteredDiffusersModels: modelsFilter(
data,
'main',
'diffusers',
nameFilter
),
isDiffusersModelLoading: isLoading,
}),
});
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredCheckpointModels: modelsFilter(
data,
'main',
'checkpoint',
nameFilter
),
}),
});
const { filteredCheckpointModels, isCheckpointModelLoading } =
useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data, isLoading }) => ({
filteredCheckpointModels: modelsFilter(
data,
'main',
'checkpoint',
nameFilter
),
isCheckpointModelLoading: isLoading,
}),
});
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
}),
});
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(
undefined,
{
selectFromResult: ({ data, isLoading }) => ({
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
isLoadingLoraModels: isLoading,
}),
}
);
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
@ -79,13 +87,47 @@ const ModelList = (props: ModelListProps) => {
setNameFilter(e.target.value);
}, []);
const renderModelList = (
filterArray: Partial<CombinedModelFormat>[],
isLoading: boolean,
loadingMessage: string,
title: string,
modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[]
) => {
if (!filterArray.includes(modelFormatFilter)) return;
if (isLoading) {
return <FetchingModelsLoader loadingMessage={loadingMessage} />;
}
if (modelList.length === 0) return;
return (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
{title}
</Text>
{modelList.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
);
};
return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
<ButtonGroup isAttached>
<IAIButton
onClick={() => setModelFormatFilter('images')}
isChecked={modelFormatFilter === 'images'}
onClick={() => setModelFormatFilter('all')}
isChecked={modelFormatFilter === 'all'}
size="sm"
>
{t('modelManager.allModels')}
@ -287,3 +329,26 @@ const StyledModelContainer = (props: PropsWithChildren) => {
</Flex>
);
};
const FetchingModelsLoader = ({
loadingMessage,
}: {
loadingMessage?: string;
}) => {
return (
<StyledModelContainer>
<Flex
justifyContent="center"
alignItems="center"
flexDirection="column"
p={4}
gap={8}
>
<Spinner />
<Text variant="subtext">
{loadingMessage ? loadingMessage : 'Fetching...'}
</Text>
</Flex>
</StyledModelContainer>
);
};