mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: Model Manager Tab Issues (#4087)
## What type of PR is this? (check all applicable) - [x] Refactor - [x] Feature - [x] Bug Fix - [?] Optimization ## Have you discussed this change with the InvokeAI team? - [x] No ## Description - Fixed filter type select using `images` instead of `all` -- Probably some merge conflict. - Added loading state for model lists. Handy when the model list takes longer than a second for any reason to fetch. Better to show this than an empty screen. - Refactored the render model list function so we modify the display component in one area rather than have repeated code. ### Other Issues - The list can get a bit laggy on initial load when you have a hundreds of models / loras. This needs to be fixed. Will make another PR for this.
This commit is contained in:
commit
c9d452b9d4
@ -1,4 +1,4 @@
|
||||
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
|
||||
import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react';
|
||||
import { EntityState } from '@reduxjs/toolkit';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
@ -6,23 +6,23 @@ import { forEach } from 'lodash-es';
|
||||
import type { ChangeEvent, PropsWithChildren } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import {
|
||||
LoRAModelConfigEntity,
|
||||
MainModelConfigEntity,
|
||||
OnnxModelConfigEntity,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
LoRAModelConfigEntity,
|
||||
} from 'services/api/endpoints/models';
|
||||
import ModelListItem from './ModelListItem';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
type ModelListProps = {
|
||||
selectedModelId: string | undefined;
|
||||
setSelectedModelId: (name: string | undefined) => void;
|
||||
};
|
||||
|
||||
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
||||
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
||||
|
||||
type ModelType = 'main' | 'lora' | 'onnx';
|
||||
|
||||
@ -33,47 +33,63 @@ const ModelList = (props: ModelListProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [nameFilter, setNameFilter] = useState<string>('');
|
||||
const [modelFormatFilter, setModelFormatFilter] =
|
||||
useState<CombinedModelFormat>('images');
|
||||
useState<CombinedModelFormat>('all');
|
||||
|
||||
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredDiffusersModels: modelsFilter(
|
||||
data,
|
||||
'main',
|
||||
'diffusers',
|
||||
nameFilter
|
||||
),
|
||||
}),
|
||||
});
|
||||
const { filteredDiffusersModels, isLoadingDiffusersModels } =
|
||||
useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredDiffusersModels: modelsFilter(
|
||||
data,
|
||||
'main',
|
||||
'diffusers',
|
||||
nameFilter
|
||||
),
|
||||
isLoadingDiffusersModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredCheckpointModels: modelsFilter(
|
||||
data,
|
||||
'main',
|
||||
'checkpoint',
|
||||
nameFilter
|
||||
),
|
||||
}),
|
||||
});
|
||||
const { filteredCheckpointModels, isLoadingCheckpointModels } =
|
||||
useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredCheckpointModels: modelsFilter(
|
||||
data,
|
||||
'main',
|
||||
'checkpoint',
|
||||
nameFilter
|
||||
),
|
||||
isLoadingCheckpointModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
||||
}),
|
||||
});
|
||||
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(
|
||||
undefined,
|
||||
{
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
||||
isLoadingLoraModels: isLoading,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
|
||||
}),
|
||||
});
|
||||
const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery(
|
||||
ALL_BASE_MODELS,
|
||||
{
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
|
||||
isLoadingOnnxModels: isLoading,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
|
||||
}),
|
||||
});
|
||||
const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery(
|
||||
ALL_BASE_MODELS,
|
||||
{
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
|
||||
isLoadingOliveModels: isLoading,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setNameFilter(e.target.value);
|
||||
@ -84,8 +100,8 @@ const ModelList = (props: ModelListProps) => {
|
||||
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
|
||||
<ButtonGroup isAttached>
|
||||
<IAIButton
|
||||
onClick={() => setModelFormatFilter('images')}
|
||||
isChecked={modelFormatFilter === 'images'}
|
||||
onClick={() => setModelFormatFilter('all')}
|
||||
isChecked={modelFormatFilter === 'all'}
|
||||
size="sm"
|
||||
>
|
||||
{t('modelManager.allModels')}
|
||||
@ -139,95 +155,76 @@ const ModelList = (props: ModelListProps) => {
|
||||
maxHeight={window.innerHeight - 280}
|
||||
overflow="scroll"
|
||||
>
|
||||
{['images', 'diffusers'].includes(modelFormatFilter) &&
|
||||
{/* Diffusers List */}
|
||||
{isLoadingDiffusersModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading Diffusers..." />
|
||||
)}
|
||||
{['all', 'diffusers'].includes(modelFormatFilter) &&
|
||||
!isLoadingDiffusersModels &&
|
||||
filteredDiffusersModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
Diffusers
|
||||
</Text>
|
||||
{filteredDiffusersModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
<ModelListWrapper
|
||||
title="Diffusers"
|
||||
modelList={filteredDiffusersModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="diffusers"
|
||||
/>
|
||||
)}
|
||||
{['images', 'checkpoint'].includes(modelFormatFilter) &&
|
||||
{/* Checkpoints List */}
|
||||
{isLoadingCheckpointModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading Checkpoints..." />
|
||||
)}
|
||||
{['all', 'checkpoint'].includes(modelFormatFilter) &&
|
||||
!isLoadingCheckpointModels &&
|
||||
filteredCheckpointModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
Checkpoints
|
||||
</Text>
|
||||
{filteredCheckpointModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
<ModelListWrapper
|
||||
title="Checkpoints"
|
||||
modelList={filteredCheckpointModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="checkpoints"
|
||||
/>
|
||||
)}
|
||||
{['images', 'olive'].includes(modelFormatFilter) &&
|
||||
filteredOliveModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
Olives
|
||||
</Text>
|
||||
{filteredOliveModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
{['images', 'onnx'].includes(modelFormatFilter) &&
|
||||
filteredOnnxModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
Onnx
|
||||
</Text>
|
||||
{filteredOnnxModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
{['images', 'lora'].includes(modelFormatFilter) &&
|
||||
|
||||
{/* LoRAs List */}
|
||||
{isLoadingLoraModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading LoRAs..." />
|
||||
)}
|
||||
{['all', 'lora'].includes(modelFormatFilter) &&
|
||||
!isLoadingLoraModels &&
|
||||
filteredLoraModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
LoRAs
|
||||
</Text>
|
||||
{filteredLoraModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
<ModelListWrapper
|
||||
title="LoRAs"
|
||||
modelList={filteredLoraModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="loras"
|
||||
/>
|
||||
)}
|
||||
{/* Olive List */}
|
||||
{isLoadingOliveModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading Olives..." />
|
||||
)}
|
||||
{['all', 'olive'].includes(modelFormatFilter) &&
|
||||
!isLoadingOliveModels &&
|
||||
filteredOliveModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Olives"
|
||||
modelList={filteredOliveModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="olive"
|
||||
/>
|
||||
)}
|
||||
{/* Onnx List */}
|
||||
{isLoadingOnnxModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading ONNX..." />
|
||||
)}
|
||||
{['all', 'onnx'].includes(modelFormatFilter) &&
|
||||
!isLoadingOnnxModels &&
|
||||
filteredOnnxModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="ONNX"
|
||||
modelList={filteredOnnxModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="onnx"
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
@ -287,3 +284,52 @@ const StyledModelContainer = (props: PropsWithChildren) => {
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
type ModelListWrapperProps = {
|
||||
title: string;
|
||||
modelList:
|
||||
| MainModelConfigEntity[]
|
||||
| LoRAModelConfigEntity[]
|
||||
| OnnxModelConfigEntity[];
|
||||
selected: ModelListProps;
|
||||
};
|
||||
|
||||
function ModelListWrapper(props: ModelListWrapperProps) {
|
||||
const { title, modelList, selected } = props;
|
||||
return (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
{title}
|
||||
</Text>
|
||||
{modelList.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selected.selectedModelId === model.id}
|
||||
setSelectedModelId={selected.setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
);
|
||||
}
|
||||
|
||||
function FetchingModelsLoader({ loadingMessage }: { loadingMessage?: string }) {
|
||||
return (
|
||||
<StyledModelContainer>
|
||||
<Flex
|
||||
justifyContent="center"
|
||||
alignItems="center"
|
||||
flexDirection="column"
|
||||
p={4}
|
||||
gap={8}
|
||||
>
|
||||
<Spinner />
|
||||
<Text variant="subtext">
|
||||
{loadingMessage ? loadingMessage : 'Fetching...'}
|
||||
</Text>
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
);
|
||||
}
|
||||
|
@ -181,7 +181,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ id: 'OnnxModel', type: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
@ -266,6 +266,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
importMainModels: build.mutation<
|
||||
@ -282,6 +283,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
||||
@ -295,6 +297,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
deleteMainModels: build.mutation<
|
||||
@ -310,6 +313,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
convertMainModels: build.mutation<
|
||||
@ -326,6 +330,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
||||
@ -339,6 +344,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
syncModels: build.mutation<SyncModelsResponse, void>({
|
||||
@ -351,6 +357,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
{ type: 'OnnxModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||
|
Loading…
Reference in New Issue
Block a user