fix: Model Manager Tab Issues (#4087)

## What type of PR is this? (check all applicable)

- [x] Refactor
- [x] Feature
- [x] Bug Fix
- [?] Optimization


## Have you discussed this change with the InvokeAI team?
- [x] No

     
## Description

- Fixed filter type select using `images` instead of `all` -- Probably
some merge conflict.
- Added loading state for model lists. Handy when the model list takes
longer than a second for any reason to fetch. Better to show this than
an empty screen.
- Refactored the render model list function so we modify the display
component in one area rather than have repeated code.

### Other Issues

- The list can get a bit laggy on initial load when you have a hundreds
of models / loras. This needs to be fixed. Will make another PR for
this.
This commit is contained in:
blessedcoolant 2023-08-02 01:06:53 +12:00 committed by GitHub
commit c9d452b9d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 181 additions and 128 deletions

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex, Text } from '@chakra-ui/react'; import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit'; import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
@ -6,23 +6,23 @@ import { forEach } from 'lodash-es';
import type { ChangeEvent, PropsWithChildren } from 'react'; import type { ChangeEvent, PropsWithChildren } from 'react';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { import {
LoRAModelConfigEntity,
MainModelConfigEntity, MainModelConfigEntity,
OnnxModelConfigEntity, OnnxModelConfigEntity,
useGetLoRAModelsQuery,
useGetMainModelsQuery, useGetMainModelsQuery,
useGetOnnxModelsQuery, useGetOnnxModelsQuery,
useGetLoRAModelsQuery,
LoRAModelConfigEntity,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
type ModelListProps = { type ModelListProps = {
selectedModelId: string | undefined; selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void; setSelectedModelId: (name: string | undefined) => void;
}; };
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx'; type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
type ModelType = 'main' | 'lora' | 'onnx'; type ModelType = 'main' | 'lora' | 'onnx';
@ -33,47 +33,63 @@ const ModelList = (props: ModelListProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>(''); const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] = const [modelFormatFilter, setModelFormatFilter] =
useState<CombinedModelFormat>('images'); useState<CombinedModelFormat>('all');
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { const { filteredDiffusersModels, isLoadingDiffusersModels } =
selectFromResult: ({ data }) => ({ useGetMainModelsQuery(ALL_BASE_MODELS, {
filteredDiffusersModels: modelsFilter( selectFromResult: ({ data, isLoading }) => ({
data, filteredDiffusersModels: modelsFilter(
'main', data,
'diffusers', 'main',
nameFilter 'diffusers',
), nameFilter
}), ),
}); isLoadingDiffusersModels: isLoading,
}),
});
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { const { filteredCheckpointModels, isLoadingCheckpointModels } =
selectFromResult: ({ data }) => ({ useGetMainModelsQuery(ALL_BASE_MODELS, {
filteredCheckpointModels: modelsFilter( selectFromResult: ({ data, isLoading }) => ({
data, filteredCheckpointModels: modelsFilter(
'main', data,
'checkpoint', 'main',
nameFilter 'checkpoint',
), nameFilter
}), ),
}); isLoadingCheckpointModels: isLoading,
}),
});
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, { const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(
selectFromResult: ({ data }) => ({ undefined,
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter), {
}), selectFromResult: ({ data, isLoading }) => ({
}); filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
isLoadingLoraModels: isLoading,
}),
}
);
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, { const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery(
selectFromResult: ({ data }) => ({ ALL_BASE_MODELS,
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter), {
}), selectFromResult: ({ data, isLoading }) => ({
}); filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
isLoadingOnnxModels: isLoading,
}),
}
);
const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, { const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery(
selectFromResult: ({ data }) => ({ ALL_BASE_MODELS,
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter), {
}), selectFromResult: ({ data, isLoading }) => ({
}); filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
isLoadingOliveModels: isLoading,
}),
}
);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => { const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value); setNameFilter(e.target.value);
@ -84,8 +100,8 @@ const ModelList = (props: ModelListProps) => {
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}> <Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
<ButtonGroup isAttached> <ButtonGroup isAttached>
<IAIButton <IAIButton
onClick={() => setModelFormatFilter('images')} onClick={() => setModelFormatFilter('all')}
isChecked={modelFormatFilter === 'images'} isChecked={modelFormatFilter === 'all'}
size="sm" size="sm"
> >
{t('modelManager.allModels')} {t('modelManager.allModels')}
@ -139,95 +155,76 @@ const ModelList = (props: ModelListProps) => {
maxHeight={window.innerHeight - 280} maxHeight={window.innerHeight - 280}
overflow="scroll" overflow="scroll"
> >
{['images', 'diffusers'].includes(modelFormatFilter) && {/* Diffusers List */}
{isLoadingDiffusersModels && (
<FetchingModelsLoader loadingMessage="Loading Diffusers..." />
)}
{['all', 'diffusers'].includes(modelFormatFilter) &&
!isLoadingDiffusersModels &&
filteredDiffusersModels.length > 0 && ( filteredDiffusersModels.length > 0 && (
<StyledModelContainer> <ModelListWrapper
<Flex sx={{ gap: 2, flexDir: 'column' }}> title="Diffusers"
<Text variant="subtext" fontSize="sm"> modelList={filteredDiffusersModels}
Diffusers selected={{ selectedModelId, setSelectedModelId }}
</Text> key="diffusers"
{filteredDiffusersModels.map((model) => ( />
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)} )}
{['images', 'checkpoint'].includes(modelFormatFilter) && {/* Checkpoints List */}
{isLoadingCheckpointModels && (
<FetchingModelsLoader loadingMessage="Loading Checkpoints..." />
)}
{['all', 'checkpoint'].includes(modelFormatFilter) &&
!isLoadingCheckpointModels &&
filteredCheckpointModels.length > 0 && ( filteredCheckpointModels.length > 0 && (
<StyledModelContainer> <ModelListWrapper
<Flex sx={{ gap: 2, flexDir: 'column' }}> title="Checkpoints"
<Text variant="subtext" fontSize="sm"> modelList={filteredCheckpointModels}
Checkpoints selected={{ selectedModelId, setSelectedModelId }}
</Text> key="checkpoints"
{filteredCheckpointModels.map((model) => ( />
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)} )}
{['images', 'olive'].includes(modelFormatFilter) &&
filteredOliveModels.length > 0 && ( {/* LoRAs List */}
<StyledModelContainer> {isLoadingLoraModels && (
<Flex sx={{ gap: 2, flexDir: 'column' }}> <FetchingModelsLoader loadingMessage="Loading LoRAs..." />
<Text variant="subtext" fontSize="sm"> )}
Olives {['all', 'lora'].includes(modelFormatFilter) &&
</Text> !isLoadingLoraModels &&
{filteredOliveModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
{['images', 'onnx'].includes(modelFormatFilter) &&
filteredOnnxModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Onnx
</Text>
{filteredOnnxModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
{['images', 'lora'].includes(modelFormatFilter) &&
filteredLoraModels.length > 0 && ( filteredLoraModels.length > 0 && (
<StyledModelContainer> <ModelListWrapper
<Flex sx={{ gap: 2, flexDir: 'column' }}> title="LoRAs"
<Text variant="subtext" fontSize="sm"> modelList={filteredLoraModels}
LoRAs selected={{ selectedModelId, setSelectedModelId }}
</Text> key="loras"
{filteredLoraModels.map((model) => ( />
<ModelListItem )}
key={model.id} {/* Olive List */}
model={model} {isLoadingOliveModels && (
isSelected={selectedModelId === model.id} <FetchingModelsLoader loadingMessage="Loading Olives..." />
setSelectedModelId={setSelectedModelId} )}
/> {['all', 'olive'].includes(modelFormatFilter) &&
))} !isLoadingOliveModels &&
</Flex> filteredOliveModels.length > 0 && (
</StyledModelContainer> <ModelListWrapper
title="Olives"
modelList={filteredOliveModels}
selected={{ selectedModelId, setSelectedModelId }}
key="olive"
/>
)}
{/* Onnx List */}
{isLoadingOnnxModels && (
<FetchingModelsLoader loadingMessage="Loading ONNX..." />
)}
{['all', 'onnx'].includes(modelFormatFilter) &&
!isLoadingOnnxModels &&
filteredOnnxModels.length > 0 && (
<ModelListWrapper
title="ONNX"
modelList={filteredOnnxModels}
selected={{ selectedModelId, setSelectedModelId }}
key="onnx"
/>
)} )}
</Flex> </Flex>
</Flex> </Flex>
@ -287,3 +284,52 @@ const StyledModelContainer = (props: PropsWithChildren) => {
</Flex> </Flex>
); );
}; };
type ModelListWrapperProps = {
title: string;
modelList:
| MainModelConfigEntity[]
| LoRAModelConfigEntity[]
| OnnxModelConfigEntity[];
selected: ModelListProps;
};
function ModelListWrapper(props: ModelListWrapperProps) {
const { title, modelList, selected } = props;
return (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
{title}
</Text>
{modelList.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selected.selectedModelId === model.id}
setSelectedModelId={selected.setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
);
}
function FetchingModelsLoader({ loadingMessage }: { loadingMessage?: string }) {
return (
<StyledModelContainer>
<Flex
justifyContent="center"
alignItems="center"
flexDirection="column"
p={4}
gap={8}
>
<Spinner />
<Text variant="subtext">
{loadingMessage ? loadingMessage : 'Fetching...'}
</Text>
</Flex>
</StyledModelContainer>
);
}

View File

@ -181,7 +181,7 @@ export const modelsApi = api.injectEndpoints({
}, },
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'OnnxModel', type: LIST_TAG }, { type: 'OnnxModel', id: LIST_TAG },
]; ];
if (result) { if (result) {
@ -266,6 +266,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
importMainModels: build.mutation< importMainModels: build.mutation<
@ -282,6 +283,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({ addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
@ -295,6 +297,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
deleteMainModels: build.mutation< deleteMainModels: build.mutation<
@ -310,6 +313,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
convertMainModels: build.mutation< convertMainModels: build.mutation<
@ -326,6 +330,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({ mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
@ -339,6 +344,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
syncModels: build.mutation<SyncModelsResponse, void>({ syncModels: build.mutation<SyncModelsResponse, void>({
@ -351,6 +357,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({