fix: Model Manager Tab Issues

This commit is contained in:
blessedcoolant
2023-07-31 12:51:30 +12:00
committed by psychedelicious
parent 41d6a38690
commit ce687b28ef

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex, Text } from '@chakra-ui/react'; import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit'; import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
@ -6,23 +6,23 @@ import { forEach } from 'lodash-es';
import type { ChangeEvent, PropsWithChildren } from 'react'; import type { ChangeEvent, PropsWithChildren } from 'react';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { import {
LoRAModelConfigEntity,
MainModelConfigEntity, MainModelConfigEntity,
OnnxModelConfigEntity, OnnxModelConfigEntity,
useGetLoRAModelsQuery,
useGetMainModelsQuery, useGetMainModelsQuery,
useGetOnnxModelsQuery, useGetOnnxModelsQuery,
useGetLoRAModelsQuery,
LoRAModelConfigEntity,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
type ModelListProps = { type ModelListProps = {
selectedModelId: string | undefined; selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void; setSelectedModelId: (name: string | undefined) => void;
}; };
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx'; type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
type ModelType = 'main' | 'lora' | 'onnx'; type ModelType = 'main' | 'lora' | 'onnx';
@ -33,35 +33,43 @@ const ModelList = (props: ModelListProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>(''); const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] = const [modelFormatFilter, setModelFormatFilter] =
useState<CombinedModelFormat>('images'); useState<CombinedModelFormat>('all');
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { const { filteredDiffusersModels, isDiffusersModelLoading } =
selectFromResult: ({ data }) => ({ useGetMainModelsQuery(ALL_BASE_MODELS, {
filteredDiffusersModels: modelsFilter( selectFromResult: ({ data, isLoading }) => ({
data, filteredDiffusersModels: modelsFilter(
'main', data,
'diffusers', 'main',
nameFilter 'diffusers',
), nameFilter
}), ),
}); isDiffusersModelLoading: isLoading,
}),
});
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { const { filteredCheckpointModels, isCheckpointModelLoading } =
selectFromResult: ({ data }) => ({ useGetMainModelsQuery(ALL_BASE_MODELS, {
filteredCheckpointModels: modelsFilter( selectFromResult: ({ data, isLoading }) => ({
data, filteredCheckpointModels: modelsFilter(
'main', data,
'checkpoint', 'main',
nameFilter 'checkpoint',
), nameFilter
}), ),
}); isCheckpointModelLoading: isLoading,
}),
});
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, { const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(
selectFromResult: ({ data }) => ({ undefined,
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter), {
}), selectFromResult: ({ data, isLoading }) => ({
}); filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
isLoadingLoraModels: isLoading,
}),
}
);
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, { const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
@ -79,13 +87,47 @@ const ModelList = (props: ModelListProps) => {
setNameFilter(e.target.value); setNameFilter(e.target.value);
}, []); }, []);
const renderModelList = (
filterArray: Partial<CombinedModelFormat>[],
isLoading: boolean,
loadingMessage: string,
title: string,
modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[]
) => {
if (!filterArray.includes(modelFormatFilter)) return;
if (isLoading) {
return <FetchingModelsLoader loadingMessage={loadingMessage} />;
}
if (modelList.length === 0) return;
return (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
{title}
</Text>
{modelList.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
);
};
return ( return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%"> <Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}> <Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
<ButtonGroup isAttached> <ButtonGroup isAttached>
<IAIButton <IAIButton
onClick={() => setModelFormatFilter('images')} onClick={() => setModelFormatFilter('all')}
isChecked={modelFormatFilter === 'images'} isChecked={modelFormatFilter === 'all'}
size="sm" size="sm"
> >
{t('modelManager.allModels')} {t('modelManager.allModels')}
@ -287,3 +329,26 @@ const StyledModelContainer = (props: PropsWithChildren) => {
</Flex> </Flex>
); );
}; };
const FetchingModelsLoader = ({
loadingMessage,
}: {
loadingMessage?: string;
}) => {
return (
<StyledModelContainer>
<Flex
justifyContent="center"
alignItems="center"
flexDirection="column"
p={4}
gap={8}
>
<Spinner />
<Text variant="subtext">
{loadingMessage ? loadingMessage : 'Fetching...'}
</Text>
</Flex>
</StyledModelContainer>
);
};