mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Make ModelListWrapper instead of rendering conditionally
This commit is contained in:
parent
f404669831
commit
dcc274a2b9
@ -71,56 +71,30 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery(
|
||||||
selectFromResult: ({ data }) => ({
|
ALL_BASE_MODELS,
|
||||||
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
|
{
|
||||||
}),
|
selectFromResult: ({ data, isLoading }) => ({
|
||||||
});
|
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
|
||||||
|
isLoadingOnnxModels: isLoading,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery(
|
||||||
selectFromResult: ({ data }) => ({
|
ALL_BASE_MODELS,
|
||||||
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
|
{
|
||||||
}),
|
selectFromResult: ({ data, isLoading }) => ({
|
||||||
});
|
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
|
||||||
|
isLoadingOliveModels: isLoading,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||||
setNameFilter(e.target.value);
|
setNameFilter(e.target.value);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const renderModelList = (
|
|
||||||
filterArray: Partial<CombinedModelFormat>[],
|
|
||||||
isLoading: boolean,
|
|
||||||
loadingMessage: string,
|
|
||||||
title: string,
|
|
||||||
modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[]
|
|
||||||
) => {
|
|
||||||
if (!filterArray.includes(modelFormatFilter)) return;
|
|
||||||
|
|
||||||
if (isLoading) {
|
|
||||||
return <FetchingModelsLoader loadingMessage={loadingMessage} />;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelList.length === 0) return;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<StyledModelContainer>
|
|
||||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
|
||||||
<Text variant="subtext" fontSize="sm">
|
|
||||||
{title}
|
|
||||||
</Text>
|
|
||||||
{modelList.map((model) => (
|
|
||||||
<ModelListItem
|
|
||||||
key={model.id}
|
|
||||||
model={model}
|
|
||||||
isSelected={selectedModelId === model.id}
|
|
||||||
setSelectedModelId={setSelectedModelId}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</Flex>
|
|
||||||
</StyledModelContainer>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
||||||
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
|
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
|
||||||
@ -181,95 +155,76 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
maxHeight={window.innerHeight - 280}
|
maxHeight={window.innerHeight - 280}
|
||||||
overflow="scroll"
|
overflow="scroll"
|
||||||
>
|
>
|
||||||
{['images', 'diffusers'].includes(modelFormatFilter) &&
|
{/* Diffusers List */}
|
||||||
|
{isLoadingDiffusersModels && (
|
||||||
|
<FetchingModelsLoader loadingMessage="Loading Diffusers..." />
|
||||||
|
)}
|
||||||
|
{['all', 'diffusers'].includes(modelFormatFilter) &&
|
||||||
|
!isLoadingDiffusersModels &&
|
||||||
filteredDiffusersModels.length > 0 && (
|
filteredDiffusersModels.length > 0 && (
|
||||||
<StyledModelContainer>
|
<ModelListWrapper
|
||||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
title="Diffusers"
|
||||||
<Text variant="subtext" fontSize="sm">
|
modelList={filteredDiffusersModels}
|
||||||
Diffusers
|
selected={{ selectedModelId, setSelectedModelId }}
|
||||||
</Text>
|
key="diffusers"
|
||||||
{filteredDiffusersModels.map((model) => (
|
/>
|
||||||
<ModelListItem
|
|
||||||
key={model.id}
|
|
||||||
model={model}
|
|
||||||
isSelected={selectedModelId === model.id}
|
|
||||||
setSelectedModelId={setSelectedModelId}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</Flex>
|
|
||||||
</StyledModelContainer>
|
|
||||||
)}
|
)}
|
||||||
{['images', 'checkpoint'].includes(modelFormatFilter) &&
|
{/* Checkpoints List */}
|
||||||
|
{isLoadingCheckpointModels && (
|
||||||
|
<FetchingModelsLoader loadingMessage="Loading Checkpoints..." />
|
||||||
|
)}
|
||||||
|
{['all', 'checkpoint'].includes(modelFormatFilter) &&
|
||||||
|
!isLoadingCheckpointModels &&
|
||||||
filteredCheckpointModels.length > 0 && (
|
filteredCheckpointModels.length > 0 && (
|
||||||
<StyledModelContainer>
|
<ModelListWrapper
|
||||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
title="Checkpoints"
|
||||||
<Text variant="subtext" fontSize="sm">
|
modelList={filteredCheckpointModels}
|
||||||
Checkpoints
|
selected={{ selectedModelId, setSelectedModelId }}
|
||||||
</Text>
|
key="checkpoints"
|
||||||
{filteredCheckpointModels.map((model) => (
|
/>
|
||||||
<ModelListItem
|
|
||||||
key={model.id}
|
|
||||||
model={model}
|
|
||||||
isSelected={selectedModelId === model.id}
|
|
||||||
setSelectedModelId={setSelectedModelId}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</Flex>
|
|
||||||
</StyledModelContainer>
|
|
||||||
)}
|
)}
|
||||||
{['images', 'olive'].includes(modelFormatFilter) &&
|
|
||||||
filteredOliveModels.length > 0 && (
|
{/* LoRAs List */}
|
||||||
<StyledModelContainer>
|
{isLoadingLoraModels && (
|
||||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
<FetchingModelsLoader loadingMessage="Loading LoRAs..." />
|
||||||
<Text variant="subtext" fontSize="sm">
|
)}
|
||||||
Olives
|
{['all', 'lora'].includes(modelFormatFilter) &&
|
||||||
</Text>
|
!isLoadingLoraModels &&
|
||||||
{filteredOliveModels.map((model) => (
|
|
||||||
<ModelListItem
|
|
||||||
key={model.id}
|
|
||||||
model={model}
|
|
||||||
isSelected={selectedModelId === model.id}
|
|
||||||
setSelectedModelId={setSelectedModelId}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</Flex>
|
|
||||||
</StyledModelContainer>
|
|
||||||
)}
|
|
||||||
{['images', 'onnx'].includes(modelFormatFilter) &&
|
|
||||||
filteredOnnxModels.length > 0 && (
|
|
||||||
<StyledModelContainer>
|
|
||||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
|
||||||
<Text variant="subtext" fontSize="sm">
|
|
||||||
Onnx
|
|
||||||
</Text>
|
|
||||||
{filteredOnnxModels.map((model) => (
|
|
||||||
<ModelListItem
|
|
||||||
key={model.id}
|
|
||||||
model={model}
|
|
||||||
isSelected={selectedModelId === model.id}
|
|
||||||
setSelectedModelId={setSelectedModelId}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</Flex>
|
|
||||||
</StyledModelContainer>
|
|
||||||
)}
|
|
||||||
{['images', 'lora'].includes(modelFormatFilter) &&
|
|
||||||
filteredLoraModels.length > 0 && (
|
filteredLoraModels.length > 0 && (
|
||||||
<StyledModelContainer>
|
<ModelListWrapper
|
||||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
title="LoRAs"
|
||||||
<Text variant="subtext" fontSize="sm">
|
modelList={filteredLoraModels}
|
||||||
LoRAs
|
selected={{ selectedModelId, setSelectedModelId }}
|
||||||
</Text>
|
key="loras"
|
||||||
{filteredLoraModels.map((model) => (
|
/>
|
||||||
<ModelListItem
|
)}
|
||||||
key={model.id}
|
{/* Olive List */}
|
||||||
model={model}
|
{isLoadingOliveModels && (
|
||||||
isSelected={selectedModelId === model.id}
|
<FetchingModelsLoader loadingMessage="Loading Olives..." />
|
||||||
setSelectedModelId={setSelectedModelId}
|
)}
|
||||||
/>
|
{['all', 'olive'].includes(modelFormatFilter) &&
|
||||||
))}
|
!isLoadingOliveModels &&
|
||||||
</Flex>
|
filteredOliveModels.length > 0 && (
|
||||||
</StyledModelContainer>
|
<ModelListWrapper
|
||||||
|
title="Olives"
|
||||||
|
modelList={filteredOliveModels}
|
||||||
|
selected={{ selectedModelId, setSelectedModelId }}
|
||||||
|
key="olive"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{/* Onnx List */}
|
||||||
|
{isLoadingOnnxModels && (
|
||||||
|
<FetchingModelsLoader loadingMessage="Loading ONNX..." />
|
||||||
|
)}
|
||||||
|
{['all', 'onnx'].includes(modelFormatFilter) &&
|
||||||
|
!isLoadingOnnxModels &&
|
||||||
|
filteredOnnxModels.length > 0 && (
|
||||||
|
<ModelListWrapper
|
||||||
|
title="ONNX"
|
||||||
|
modelList={filteredOnnxModels}
|
||||||
|
selected={{ selectedModelId, setSelectedModelId }}
|
||||||
|
key="onnx"
|
||||||
|
/>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
@ -330,11 +285,37 @@ const StyledModelContainer = (props: PropsWithChildren) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const FetchingModelsLoader = ({
|
type ModelListWrapperProps = {
|
||||||
loadingMessage,
|
title: string;
|
||||||
}: {
|
modelList:
|
||||||
loadingMessage?: string;
|
| MainModelConfigEntity[]
|
||||||
}) => {
|
| LoRAModelConfigEntity[]
|
||||||
|
| OnnxModelConfigEntity[];
|
||||||
|
selected: ModelListProps;
|
||||||
|
};
|
||||||
|
|
||||||
|
function ModelListWrapper(props: ModelListWrapperProps) {
|
||||||
|
const { title, modelList, selected } = props;
|
||||||
|
return (
|
||||||
|
<StyledModelContainer>
|
||||||
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
|
<Text variant="subtext" fontSize="sm">
|
||||||
|
{title}
|
||||||
|
</Text>
|
||||||
|
{modelList.map((model) => (
|
||||||
|
<ModelListItem
|
||||||
|
key={model.id}
|
||||||
|
model={model}
|
||||||
|
isSelected={selected.selectedModelId === model.id}
|
||||||
|
setSelectedModelId={selected.setSelectedModelId}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</StyledModelContainer>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function FetchingModelsLoader({ loadingMessage }: { loadingMessage?: string }) {
|
||||||
return (
|
return (
|
||||||
<StyledModelContainer>
|
<StyledModelContainer>
|
||||||
<Flex
|
<Flex
|
||||||
@ -351,4 +332,4 @@ const FetchingModelsLoader = ({
|
|||||||
</Flex>
|
</Flex>
|
||||||
</StyledModelContainer>
|
</StyledModelContainer>
|
||||||
);
|
);
|
||||||
};
|
}
|
||||||
|
@ -181,7 +181,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result, error, arg) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
{ id: 'OnnxModel', type: LIST_TAG },
|
{ type: 'OnnxModel', id: LIST_TAG },
|
||||||
];
|
];
|
||||||
|
|
||||||
if (result) {
|
if (result) {
|
||||||
@ -266,6 +266,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
{ type: 'OnnxModel', id: LIST_TAG },
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
importMainModels: build.mutation<
|
importMainModels: build.mutation<
|
||||||
@ -282,6 +283,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
{ type: 'OnnxModel', id: LIST_TAG },
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
||||||
@ -295,6 +297,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
{ type: 'OnnxModel', id: LIST_TAG },
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
deleteMainModels: build.mutation<
|
deleteMainModels: build.mutation<
|
||||||
@ -310,6 +313,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
{ type: 'OnnxModel', id: LIST_TAG },
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
convertMainModels: build.mutation<
|
convertMainModels: build.mutation<
|
||||||
@ -326,6 +330,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
{ type: 'OnnxModel', id: LIST_TAG },
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
||||||
@ -339,6 +344,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
{ type: 'OnnxModel', id: LIST_TAG },
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
syncModels: build.mutation<SyncModelsResponse, void>({
|
syncModels: build.mutation<SyncModelsResponse, void>({
|
||||||
@ -351,6 +357,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
invalidatesTags: [
|
invalidatesTags: [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
{ type: 'OnnxModel', id: LIST_TAG },
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||||
|
Loading…
Reference in New Issue
Block a user