fix(ui): use new scan_folder response instead of hook to determine if models are installed already

This commit is contained in:
psychedelicious 2024-02-24 18:48:21 +11:00 committed by Brandon Rising
parent 8848443eff
commit 3130b3db64
3 changed files with 22 additions and 88 deletions

View File

@ -1,62 +0,0 @@
import type { EntityState } from '@reduxjs/toolkit';
import { forEach } from 'lodash-es';
import { useCallback } from 'react';
import { ALL_BASE_MODELS } from 'services/api/constants';
import {
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetT2IAdapterModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
export const useIsImported = () => {
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
const { data: loras } = useGetLoRAModelsQuery();
const { data: embeddings } = useGetTextualInversionModelsQuery();
const { data: controlnets } = useGetControlNetModelsQuery();
const { data: ipAdapters } = useGetIPAdapterModelsQuery();
const { data: t2is } = useGetT2IAdapterModelsQuery();
const { data: vaes } = useGetVaeModelsQuery();
const isImported = useCallback(
({ name }: { name: string }) => {
const data = [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes];
let isMatch = false;
for (let index = 0; index < data.length; index++) {
const modelType: EntityState<AnyModelConfig, string> | undefined = data[index];
const match = modelsFilter(modelType, name);
if (match.length) {
isMatch = true;
break;
}
}
return isMatch;
},
[mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes]
);
return { isImported };
};
const modelsFilter = <T extends AnyModelConfig>(data: EntityState<T, string> | undefined, nameFilter: string): T[] => {
const filteredModels: T[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.path.toLowerCase().includes(nameFilter.toLowerCase());
if (matchesFilter) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -1,33 +1,24 @@
import { Badge, Box, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library'; import { Badge, Box, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { useIsImported } from 'features/modelManagerV2/hooks/useIsImported';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { useCallback, useMemo } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoAdd } from 'react-icons/io5'; import { IoAdd } from 'react-icons/io5';
import type { ScanFolderResponse } from 'services/api/endpoints/models';
import { useImportMainModelsMutation } from 'services/api/endpoints/models'; import { useImportMainModelsMutation } from 'services/api/endpoints/models';
export const ScanModelResultItem = ({ result }: { result: string }) => { type Props = {
result: ScanFolderResponse[number];
};
export const ScanModelResultItem = ({ result }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { isImported } = useIsImported();
const [importMainModel] = useImportMainModelsMutation(); const [importMainModel] = useImportMainModelsMutation();
const isAlreadyImported = useMemo(() => {
const prettyName = result.split('\\').slice(-1)[0];
if (prettyName) {
return isImported({ name: prettyName });
} else {
return false;
}
}, [result, isImported]);
const handleQuickAdd = useCallback(() => { const handleQuickAdd = useCallback(() => {
importMainModel({ source: result, config: undefined }) importMainModel({ source: result.path, config: undefined })
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(
@ -56,11 +47,11 @@ export const ScanModelResultItem = ({ result }: { result: string }) => {
return ( return (
<Flex justifyContent="space-between"> <Flex justifyContent="space-between">
<Flex fontSize="sm" flexDir="column"> <Flex fontSize="sm" flexDir="column">
<Text fontWeight="semibold">{result.split('\\').slice(-1)[0]}</Text> <Text fontWeight="semibold">{result.path.split('\\').slice(-1)[0]}</Text>
<Text variant="subtext">{result}</Text> <Text variant="subtext">{result.path}</Text>
</Flex> </Flex>
<Box> <Box>
{isAlreadyImported ? ( {result.is_installed ? (
<Badge>{t('common.installed')}</Badge> <Badge>{t('common.installed')}</Badge>
) : ( ) : (
<Tooltip label={t('modelManager.quickAdd')}> <Tooltip label={t('modelManager.quickAdd')}>

View File

@ -4,21 +4,26 @@ import { t } from 'i18next';
import type { ChangeEventHandler } from 'react'; import type { ChangeEventHandler } from 'react';
import { useCallback, useMemo, useState } from 'react'; import { useCallback, useMemo, useState } from 'react';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import type { ScanFolderResponse } from 'services/api/endpoints/models';
import { ScanModelResultItem } from './ScanModelResultItem'; import { ScanModelResultItem } from './ScanModelResultItem';
export const ScanModelsResults = ({ results }: { results: string[] }) => { type ScanModelResultsProps = {
results: ScanFolderResponse;
};
export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
const [searchTerm, setSearchTerm] = useState(''); const [searchTerm, setSearchTerm] = useState('');
const filteredResults = useMemo(() => { const filteredResults = useMemo(() => {
return results.filter((result) => { return results.filter((result) => {
const modelName = result.split('\\').slice(-1)[0]; const modelName = result.path.split('\\').slice(-1)[0];
return modelName?.includes(searchTerm); return modelName?.toLowerCase().includes(searchTerm.toLowerCase());
}); });
}, [results, searchTerm]); }, [results, searchTerm]);
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => { const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setSearchTerm(e.target.value); setSearchTerm(e.target.value.trim());
}, []); }, []);
const clearSearch = useCallback(() => { const clearSearch = useCallback(() => {
@ -36,13 +41,13 @@ export const ScanModelsResults = ({ results }: { results: string[] }) => {
<InputGroup maxW="300px" size="xs"> <InputGroup maxW="300px" size="xs">
<Input <Input
placeholder={t('modelManager.search')} placeholder={t('modelManager.search')}
value={searchTerm || ''} value={searchTerm}
data-testid="board-search-input" data-testid="board-search-input"
onChange={handleSearch} onChange={handleSearch}
size="xs" size="xs"
/> />
{!!searchTerm?.length && ( {searchTerm && (
<InputRightElement h="full" pe={2}> <InputRightElement h="full" pe={2}>
<IconButton <IconButton
size="sm" size="sm"
@ -59,7 +64,7 @@ export const ScanModelsResults = ({ results }: { results: string[] }) => {
<ScrollableContent> <ScrollableContent>
<Flex flexDir="column" gap={3}> <Flex flexDir="column" gap={3}>
{filteredResults.map((result) => ( {filteredResults.map((result) => (
<ScanModelResultItem key={result} result={result} /> <ScanModelResultItem key={result.path} result={result} />
))} ))}
</Flex> </Flex>
</ScrollableContent> </ScrollableContent>