mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): use new scan_folder response instead of hook to determine if models are installed already
This commit is contained in:
parent
7bc454209c
commit
7e13224ec8
@ -1,62 +0,0 @@
|
|||||||
import type { EntityState } from '@reduxjs/toolkit';
|
|
||||||
import { forEach } from 'lodash-es';
|
|
||||||
import { useCallback } from 'react';
|
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
|
||||||
import {
|
|
||||||
useGetControlNetModelsQuery,
|
|
||||||
useGetIPAdapterModelsQuery,
|
|
||||||
useGetLoRAModelsQuery,
|
|
||||||
useGetMainModelsQuery,
|
|
||||||
useGetT2IAdapterModelsQuery,
|
|
||||||
useGetTextualInversionModelsQuery,
|
|
||||||
useGetVaeModelsQuery,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
export const useIsImported = () => {
|
|
||||||
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
|
||||||
const { data: loras } = useGetLoRAModelsQuery();
|
|
||||||
const { data: embeddings } = useGetTextualInversionModelsQuery();
|
|
||||||
const { data: controlnets } = useGetControlNetModelsQuery();
|
|
||||||
const { data: ipAdapters } = useGetIPAdapterModelsQuery();
|
|
||||||
const { data: t2is } = useGetT2IAdapterModelsQuery();
|
|
||||||
const { data: vaes } = useGetVaeModelsQuery();
|
|
||||||
|
|
||||||
const isImported = useCallback(
|
|
||||||
({ name }: { name: string }) => {
|
|
||||||
const data = [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes];
|
|
||||||
let isMatch = false;
|
|
||||||
for (let index = 0; index < data.length; index++) {
|
|
||||||
const modelType: EntityState<AnyModelConfig, string> | undefined = data[index];
|
|
||||||
|
|
||||||
const match = modelsFilter(modelType, name);
|
|
||||||
|
|
||||||
if (match.length) {
|
|
||||||
isMatch = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return isMatch;
|
|
||||||
},
|
|
||||||
[mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes]
|
|
||||||
);
|
|
||||||
|
|
||||||
return { isImported };
|
|
||||||
};
|
|
||||||
|
|
||||||
const modelsFilter = <T extends AnyModelConfig>(data: EntityState<T, string> | undefined, nameFilter: string): T[] => {
|
|
||||||
const filteredModels: T[] = [];
|
|
||||||
|
|
||||||
forEach(data?.entities, (model) => {
|
|
||||||
if (!model) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const matchesFilter = model.path.toLowerCase().includes(nameFilter.toLowerCase());
|
|
||||||
|
|
||||||
if (matchesFilter) {
|
|
||||||
filteredModels.push(model);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return filteredModels;
|
|
||||||
};
|
|
@ -1,33 +1,24 @@
|
|||||||
import { Badge, Box, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
|
import { Badge, Box, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useIsImported } from 'features/modelManagerV2/hooks/useIsImported';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { IoAdd } from 'react-icons/io5';
|
import { IoAdd } from 'react-icons/io5';
|
||||||
|
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||||
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
|
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
export const ScanModelResultItem = ({ result }: { result: string }) => {
|
type Props = {
|
||||||
|
result: ScanFolderResponse[number];
|
||||||
|
};
|
||||||
|
export const ScanModelResultItem = ({ result }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { isImported } = useIsImported();
|
|
||||||
|
|
||||||
const [importMainModel] = useImportMainModelsMutation();
|
const [importMainModel] = useImportMainModelsMutation();
|
||||||
|
|
||||||
const isAlreadyImported = useMemo(() => {
|
|
||||||
const prettyName = result.split('\\').slice(-1)[0];
|
|
||||||
|
|
||||||
if (prettyName) {
|
|
||||||
return isImported({ name: prettyName });
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}, [result, isImported]);
|
|
||||||
|
|
||||||
const handleQuickAdd = useCallback(() => {
|
const handleQuickAdd = useCallback(() => {
|
||||||
importMainModel({ source: result, config: undefined })
|
importMainModel({ source: result.path, config: undefined })
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((_) => {
|
.then((_) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -56,11 +47,11 @@ export const ScanModelResultItem = ({ result }: { result: string }) => {
|
|||||||
return (
|
return (
|
||||||
<Flex justifyContent="space-between">
|
<Flex justifyContent="space-between">
|
||||||
<Flex fontSize="sm" flexDir="column">
|
<Flex fontSize="sm" flexDir="column">
|
||||||
<Text fontWeight="semibold">{result.split('\\').slice(-1)[0]}</Text>
|
<Text fontWeight="semibold">{result.path.split('\\').slice(-1)[0]}</Text>
|
||||||
<Text variant="subtext">{result}</Text>
|
<Text variant="subtext">{result.path}</Text>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Box>
|
<Box>
|
||||||
{isAlreadyImported ? (
|
{result.is_installed ? (
|
||||||
<Badge>{t('common.installed')}</Badge>
|
<Badge>{t('common.installed')}</Badge>
|
||||||
) : (
|
) : (
|
||||||
<Tooltip label={t('modelManager.quickAdd')}>
|
<Tooltip label={t('modelManager.quickAdd')}>
|
||||||
|
@ -4,21 +4,26 @@ import { t } from 'i18next';
|
|||||||
import type { ChangeEventHandler } from 'react';
|
import type { ChangeEventHandler } from 'react';
|
||||||
import { useCallback, useMemo, useState } from 'react';
|
import { useCallback, useMemo, useState } from 'react';
|
||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
|
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { ScanModelResultItem } from './ScanModelResultItem';
|
import { ScanModelResultItem } from './ScanModelResultItem';
|
||||||
|
|
||||||
export const ScanModelsResults = ({ results }: { results: string[] }) => {
|
type ScanModelResultsProps = {
|
||||||
|
results: ScanFolderResponse;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
||||||
const [searchTerm, setSearchTerm] = useState('');
|
const [searchTerm, setSearchTerm] = useState('');
|
||||||
|
|
||||||
const filteredResults = useMemo(() => {
|
const filteredResults = useMemo(() => {
|
||||||
return results.filter((result) => {
|
return results.filter((result) => {
|
||||||
const modelName = result.split('\\').slice(-1)[0];
|
const modelName = result.path.split('\\').slice(-1)[0];
|
||||||
return modelName?.includes(searchTerm);
|
return modelName?.toLowerCase().includes(searchTerm.toLowerCase());
|
||||||
});
|
});
|
||||||
}, [results, searchTerm]);
|
}, [results, searchTerm]);
|
||||||
|
|
||||||
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||||
setSearchTerm(e.target.value);
|
setSearchTerm(e.target.value.trim());
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const clearSearch = useCallback(() => {
|
const clearSearch = useCallback(() => {
|
||||||
@ -36,13 +41,13 @@ export const ScanModelsResults = ({ results }: { results: string[] }) => {
|
|||||||
<InputGroup maxW="300px" size="xs">
|
<InputGroup maxW="300px" size="xs">
|
||||||
<Input
|
<Input
|
||||||
placeholder={t('modelManager.search')}
|
placeholder={t('modelManager.search')}
|
||||||
value={searchTerm || ''}
|
value={searchTerm}
|
||||||
data-testid="board-search-input"
|
data-testid="board-search-input"
|
||||||
onChange={handleSearch}
|
onChange={handleSearch}
|
||||||
size="xs"
|
size="xs"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{!!searchTerm?.length && (
|
{searchTerm && (
|
||||||
<InputRightElement h="full" pe={2}>
|
<InputRightElement h="full" pe={2}>
|
||||||
<IconButton
|
<IconButton
|
||||||
size="sm"
|
size="sm"
|
||||||
@ -59,7 +64,7 @@ export const ScanModelsResults = ({ results }: { results: string[] }) => {
|
|||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex flexDir="column" gap={3}>
|
<Flex flexDir="column" gap={3}>
|
||||||
{filteredResults.map((result) => (
|
{filteredResults.map((result) => (
|
||||||
<ScanModelResultItem key={result} result={result} />
|
<ScanModelResultItem key={result.path} result={result} />
|
||||||
))}
|
))}
|
||||||
</Flex>
|
</Flex>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
|
Loading…
Reference in New Issue
Block a user