feat: Make ModelListWrapper instead of rendering conditionally

This commit is contained in:
blessedcoolant 2023-07-31 23:09:00 +12:00 committed by psychedelicious
parent f404669831
commit dcc274a2b9
2 changed files with 123 additions and 135 deletions

View File

@ -71,56 +71,30 @@ const ModelList = (props: ModelListProps) => {
} }
); );
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, { const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery(
selectFromResult: ({ data }) => ({ ALL_BASE_MODELS,
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter), {
}), selectFromResult: ({ data, isLoading }) => ({
}); filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
isLoadingOnnxModels: isLoading,
}),
}
);
const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, { const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery(
selectFromResult: ({ data }) => ({ ALL_BASE_MODELS,
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter), {
}), selectFromResult: ({ data, isLoading }) => ({
}); filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
isLoadingOliveModels: isLoading,
}),
}
);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => { const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value); setNameFilter(e.target.value);
}, []); }, []);
const renderModelList = (
filterArray: Partial<CombinedModelFormat>[],
isLoading: boolean,
loadingMessage: string,
title: string,
modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[]
) => {
if (!filterArray.includes(modelFormatFilter)) return;
if (isLoading) {
return <FetchingModelsLoader loadingMessage={loadingMessage} />;
}
if (modelList.length === 0) return;
return (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
{title}
</Text>
{modelList.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
);
};
return ( return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%"> <Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}> <Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
@ -181,95 +155,76 @@ const ModelList = (props: ModelListProps) => {
maxHeight={window.innerHeight - 280} maxHeight={window.innerHeight - 280}
overflow="scroll" overflow="scroll"
> >
{['images', 'diffusers'].includes(modelFormatFilter) && {/* Diffusers List */}
{isLoadingDiffusersModels && (
<FetchingModelsLoader loadingMessage="Loading Diffusers..." />
)}
{['all', 'diffusers'].includes(modelFormatFilter) &&
!isLoadingDiffusersModels &&
filteredDiffusersModels.length > 0 && ( filteredDiffusersModels.length > 0 && (
<StyledModelContainer> <ModelListWrapper
<Flex sx={{ gap: 2, flexDir: 'column' }}> title="Diffusers"
<Text variant="subtext" fontSize="sm"> modelList={filteredDiffusersModels}
Diffusers selected={{ selectedModelId, setSelectedModelId }}
</Text> key="diffusers"
{filteredDiffusersModels.map((model) => ( />
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)} )}
{['images', 'checkpoint'].includes(modelFormatFilter) && {/* Checkpoints List */}
{isLoadingCheckpointModels && (
<FetchingModelsLoader loadingMessage="Loading Checkpoints..." />
)}
{['all', 'checkpoint'].includes(modelFormatFilter) &&
!isLoadingCheckpointModels &&
filteredCheckpointModels.length > 0 && ( filteredCheckpointModels.length > 0 && (
<StyledModelContainer> <ModelListWrapper
<Flex sx={{ gap: 2, flexDir: 'column' }}> title="Checkpoints"
<Text variant="subtext" fontSize="sm"> modelList={filteredCheckpointModels}
Checkpoints selected={{ selectedModelId, setSelectedModelId }}
</Text> key="checkpoints"
{filteredCheckpointModels.map((model) => ( />
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)} )}
{['images', 'olive'].includes(modelFormatFilter) &&
filteredOliveModels.length > 0 && ( {/* LoRAs List */}
<StyledModelContainer> {isLoadingLoraModels && (
<Flex sx={{ gap: 2, flexDir: 'column' }}> <FetchingModelsLoader loadingMessage="Loading LoRAs..." />
<Text variant="subtext" fontSize="sm"> )}
Olives {['all', 'lora'].includes(modelFormatFilter) &&
</Text> !isLoadingLoraModels &&
{filteredOliveModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
{['images', 'onnx'].includes(modelFormatFilter) &&
filteredOnnxModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Onnx
</Text>
{filteredOnnxModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
{['images', 'lora'].includes(modelFormatFilter) &&
filteredLoraModels.length > 0 && ( filteredLoraModels.length > 0 && (
<StyledModelContainer> <ModelListWrapper
<Flex sx={{ gap: 2, flexDir: 'column' }}> title="LoRAs"
<Text variant="subtext" fontSize="sm"> modelList={filteredLoraModels}
LoRAs selected={{ selectedModelId, setSelectedModelId }}
</Text> key="loras"
{filteredLoraModels.map((model) => ( />
<ModelListItem )}
key={model.id} {/* Olive List */}
model={model} {isLoadingOliveModels && (
isSelected={selectedModelId === model.id} <FetchingModelsLoader loadingMessage="Loading Olives..." />
setSelectedModelId={setSelectedModelId} )}
/> {['all', 'olive'].includes(modelFormatFilter) &&
))} !isLoadingOliveModels &&
</Flex> filteredOliveModels.length > 0 && (
</StyledModelContainer> <ModelListWrapper
title="Olives"
modelList={filteredOliveModels}
selected={{ selectedModelId, setSelectedModelId }}
key="olive"
/>
)}
{/* Onnx List */}
{isLoadingOnnxModels && (
<FetchingModelsLoader loadingMessage="Loading ONNX..." />
)}
{['all', 'onnx'].includes(modelFormatFilter) &&
!isLoadingOnnxModels &&
filteredOnnxModels.length > 0 && (
<ModelListWrapper
title="ONNX"
modelList={filteredOnnxModels}
selected={{ selectedModelId, setSelectedModelId }}
key="onnx"
/>
)} )}
</Flex> </Flex>
</Flex> </Flex>
@ -330,11 +285,37 @@ const StyledModelContainer = (props: PropsWithChildren) => {
); );
}; };
const FetchingModelsLoader = ({ type ModelListWrapperProps = {
loadingMessage, title: string;
}: { modelList:
loadingMessage?: string; | MainModelConfigEntity[]
}) => { | LoRAModelConfigEntity[]
| OnnxModelConfigEntity[];
selected: ModelListProps;
};
function ModelListWrapper(props: ModelListWrapperProps) {
const { title, modelList, selected } = props;
return (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
{title}
</Text>
{modelList.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selected.selectedModelId === model.id}
setSelectedModelId={selected.setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
);
}
function FetchingModelsLoader({ loadingMessage }: { loadingMessage?: string }) {
return ( return (
<StyledModelContainer> <StyledModelContainer>
<Flex <Flex
@ -351,4 +332,4 @@ const FetchingModelsLoader = ({
</Flex> </Flex>
</StyledModelContainer> </StyledModelContainer>
); );
}; }

View File

@ -181,7 +181,7 @@ export const modelsApi = api.injectEndpoints({
}, },
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'OnnxModel', type: LIST_TAG }, { type: 'OnnxModel', id: LIST_TAG },
]; ];
if (result) { if (result) {
@ -266,6 +266,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
importMainModels: build.mutation< importMainModels: build.mutation<
@ -282,6 +283,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({ addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
@ -295,6 +297,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
deleteMainModels: build.mutation< deleteMainModels: build.mutation<
@ -310,6 +313,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
convertMainModels: build.mutation< convertMainModels: build.mutation<
@ -326,6 +330,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({ mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
@ -339,6 +344,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
syncModels: build.mutation<SyncModelsResponse, void>({ syncModels: build.mutation<SyncModelsResponse, void>({
@ -351,6 +357,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({