fix logic to see if scanned models are already installed, style tweaks

This commit is contained in:
Mary Hipp 2024-02-23 14:58:36 -05:00 committed by psychedelicious
parent 26a209a00d
commit c3f4e87a6e
10 changed files with 277 additions and 142 deletions

View File

@ -6,6 +6,7 @@ import {
} from 'services/events/actions'; } from 'services/events/actions';
import { startAppListening } from '../..'; import { startAppListening } from '../..';
import { api } from '../../../../../../services/api';
export const addModelInstallEventListener = () => { export const addModelInstallEventListener = () => {
startAppListening({ startAppListening({
@ -40,6 +41,7 @@ export const addModelInstallEventListener = () => {
return draft; return draft;
}) })
); );
dispatch(api.util.invalidateTags([{ type: "ModelConfig" }]))
}, },
}); });

View File

@ -0,0 +1,54 @@
import { useCallback } from "react";
import { ALL_BASE_MODELS } from "../../../services/api/constants";
import { useGetMainModelsQuery, useGetLoRAModelsQuery, useGetTextualInversionModelsQuery, useGetControlNetModelsQuery, useGetT2IAdapterModelsQuery, useGetIPAdapterModelsQuery, useGetVaeModelsQuery, } from "../../../services/api/endpoints/models";
import { EntityState } from "@reduxjs/toolkit";
import { forEach } from "lodash-es";
import { AnyModelConfig } from "../../../services/api/types";
export const useIsImported = () => {
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
const { data: loras } = useGetLoRAModelsQuery();
const { data: embeddings } = useGetTextualInversionModelsQuery();
const { data: controlnets } = useGetControlNetModelsQuery();
const { data: ipAdapters } = useGetIPAdapterModelsQuery();
const { data: t2is } = useGetT2IAdapterModelsQuery();
const { data: vaes } = useGetVaeModelsQuery();
const isImported = useCallback(({ name }: { name: string }) => {
const data = [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes]
let isMatch = false;
for (let index = 0; index < data.length; index++) {
const modelType: EntityState<AnyModelConfig, string> | undefined = data[index];
const match = modelsFilter(modelType, name)
if (!!match.length) {
isMatch = true
break;
}
}
return isMatch
}, [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes])
return { isImported }
}
const modelsFilter = <T extends AnyModelConfig>(
data: EntityState<T, string> | undefined,
nameFilter: string,
): T[] => {
const filteredModels: T[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.path.toLowerCase().includes(nameFilter.toLowerCase());
if (matchesFilter) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -77,7 +77,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
case 'url': case 'url':
return source.url; return source.url;
case 'local': case 'local':
return source.path.substring(source.path.lastIndexOf('/') + 1); return source.path.split('\\').slice(-1)[0];
default: default:
return ''; return '';
} }
@ -99,13 +99,13 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
}, [model.bytes, model.total_bytes, model.status]); }, [model.bytes, model.total_bytes, model.status]);
return ( return (
<Flex gap="2" w="full" alignItems="center" textAlign="center"> <Flex gap="2" w="full" alignItems="center">
<Tooltip label={modelName}> <Tooltip label={modelName}>
<Text w="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis"> <Text width="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
{modelName} {modelName}
</Text> </Text>
</Tooltip> </Tooltip>
<Flex flexDir="column" w="50%"> <Flex flexDir="column" flex={1}>
<Tooltip label={progressString}> <Tooltip label={progressString}>
<Progress <Progress
value={progressValue} value={progressValue}
@ -115,11 +115,11 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
/> />
</Tooltip> </Tooltip>
</Flex> </Flex>
<Box w="15%"> <Box minW="100px" textAlign="center">
<ImportQueueBadge status={model.status} errorReason={model.error_reason} /> <ImportQueueBadge status={model.status} errorReason={model.error_reason} />
</Box> </Box>
<Box w="10%"> <Box minW="20px">
{(model.status === 'downloading' || model.status === 'waiting') && ( {(model.status === 'downloading' || model.status === 'waiting') && (
<IconButton <IconButton
isRound={true} isRound={true}

View File

@ -1,17 +1,31 @@
import { Flex, Text, Box, Button, IconButton, Tooltip } from '@invoke-ai/ui-library'; import { Flex, Text, Box, Button, IconButton, Tooltip, Badge } from '@invoke-ai/ui-library';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoAdd } from 'react-icons/io5'; import { IoAdd } from 'react-icons/io5';
import { useAppDispatch } from '../../../../../app/store/storeHooks'; import { useAppDispatch } from '../../../../../app/store/storeHooks';
import { useImportMainModelsMutation } from '../../../../../services/api/endpoints/models'; import { useImportMainModelsMutation } from '../../../../../services/api/endpoints/models';
import { addToast } from '../../../../system/store/systemSlice'; import { addToast } from '../../../../system/store/systemSlice';
import { makeToast } from '../../../../system/util/makeToast'; import { makeToast } from '../../../../system/util/makeToast';
import { useIsImported } from '../../../hooks/useIsImported';
import { useMemo } from 'react';
export const ScanModelResultItem = ({ result }: { result: string }) => { export const ScanModelResultItem = ({ result }: { result: string }) => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { isImported } = useIsImported();
const [importMainModel, { isLoading }] = useImportMainModelsMutation(); const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const isAlreadyImported = useMemo(() => {
const prettyName = result.split('\\').slice(-1)[0];
console.log({ prettyName });
if (prettyName) {
return isImported({ name: prettyName });
} else {
return false;
}
}, [result, isImported]);
const handleQuickAdd = () => { const handleQuickAdd = () => {
importMainModel({ source: result, config: undefined }) importMainModel({ source: result, config: undefined })
.unwrap() .unwrap()
@ -46,9 +60,13 @@ export const ScanModelResultItem = ({ result }: { result: string }) => {
<Text variant="subtext">{result}</Text> <Text variant="subtext">{result}</Text>
</Flex> </Flex>
<Box> <Box>
{isAlreadyImported ? (
<Badge>{t('common.installed')}</Badge>
) : (
<Tooltip label={t('modelManager.quickAdd')}> <Tooltip label={t('modelManager.quickAdd')}>
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} /> <IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
</Tooltip> </Tooltip>
)}
</Box> </Box>
</Flex> </Flex>
); );

View File

@ -9,22 +9,18 @@ import { ScanModelsResults } from './ScanModelsResults';
export const ScanModelsForm = () => { export const ScanModelsForm = () => {
const [scanPath, setScanPath] = useState(''); const [scanPath, setScanPath] = useState('');
const [errorMessage, setErrorMessage] = useState(''); const [errorMessage, setErrorMessage] = useState('');
const [results, setResults] = useState<string[] | undefined>();
const { t } = useTranslation(); const { t } = useTranslation();
const [_scanModels, { isLoading }] = useLazyScanModelsQuery(); const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
const handleSubmitScan = useCallback(async () => { const handleSubmitScan = useCallback(async () => {
_scanModels({ scan_path: scanPath }) try {
.unwrap() await _scanModels({ scan_path: scanPath }).unwrap();
.then((result) => { } catch (error: any) {
setResults(result);
})
.catch((error) => {
if (error) { if (error) {
setErrorMessage(error.data.detail); setErrorMessage(error.data.detail);
} }
}); }
}, [_scanModels, scanPath]); }, [_scanModels, scanPath]);
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => { const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
@ -49,7 +45,7 @@ export const ScanModelsForm = () => {
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>} {!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</Flex> </Flex>
</FormControl> </FormControl>
{results && <ScanModelsResults results={results} />} {data && <ScanModelsResults results={data} />}
</Flex> </Flex>
); );
}; };

View File

@ -2,23 +2,23 @@ import { Divider, Flex, Heading, IconButton, Input, InputGroup, InputRightElemen
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { t } from 'i18next'; import { t } from 'i18next';
import type { ChangeEventHandler } from 'react'; import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react'; import { useCallback, useMemo, useState } from 'react';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import { ScanModelResultItem } from './ScanModelResultItem'; import { ScanModelResultItem } from './ScanModelResultItem';
export const ScanModelsResults = ({ results }: { results: string[] }) => { export const ScanModelsResults = ({ results }: { results: string[] }) => {
const [searchTerm, setSearchTerm] = useState(''); const [searchTerm, setSearchTerm] = useState('');
const [filteredResults, setFilteredResults] = useState(results);
const filteredResults = useMemo(() => {
return results.filter((result) => {
const modelName = result.split('\\').slice(-1)[0];
return modelName?.includes(searchTerm);
});
}, [results, searchTerm]);
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback( const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback(
(e) => { (e) => {
setSearchTerm(e.target.value); setSearchTerm(e.target.value);
setFilteredResults(
results.filter((result) => {
const modelName = result.split('\\').slice(-1)[0];
return modelName?.includes(e.target.value);
})
);
}, },
[results] [results]
); );

View File

@ -14,7 +14,7 @@ export const ModelManager = () => {
}, [dispatch]); }, [dispatch]);
return ( return (
<Box layerStyle="first" p={3} borderRadius="base" w="full" h="full"> <Box layerStyle="first" p={3} borderRadius="base" w="50%" h="full">
<Flex w="full" p={3} justifyContent="space-between" alignItems="center"> <Flex w="full" p={3} justifyContent="space-between" alignItems="center">
<Flex gap={2}> <Flex gap={2}>
<Heading fontSize="xl">Model Manager</Heading> <Heading fontSize="xl">Model Manager</Heading>

View File

@ -7,7 +7,7 @@ import { Model } from './ModelPanel/Model';
export const ModelPane = () => { export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return ( return (
<Box layerStyle="first" p={2} borderRadius="base" w="full" h="full"> <Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model /> : <ImportModels />} {selectedModelKey ? <Model /> : <ImportModels />}
</Box> </Box>
); );

View File

@ -55,13 +55,13 @@ type MergeMainModelArg = {
type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json']; type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json'];
type ImportMainModelArg = { type ImportMainModelArg = {
source: NonNullable<operations['heuristic_import_model']['parameters']['query']['source']>; source: NonNullable<operations['heuristic_install_model']['parameters']['query']['source']>;
access_token?: operations['heuristic_import_model']['parameters']['query']['access_token']; access_token?: operations['heuristic_install_model']['parameters']['query']['access_token'];
config: NonNullable<operations['heuristic_import_model']['requestBody']['content']['application/json']>; config: NonNullable<operations['heuristic_install_model']['requestBody']['content']['application/json']>;
}; };
type ImportMainModelResponse = type ImportMainModelResponse =
paths['/api/v2/models/heuristic_import']['post']['responses']['201']['content']['application/json']; paths['/api/v2/models/heuristic_install']['post']['responses']['201']['content']['application/json'];
type ListImportModelsResponse = type ListImportModelsResponse =
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
@ -191,7 +191,7 @@ export const modelsApi = api.injectEndpoints({
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({ importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
query: ({ source, config, access_token }) => { query: ({ source, config, access_token }) => {
return { return {
url: buildModelsUrl('heuristic_import'), url: buildModelsUrl('heuristic_install'),
params: { source, access_token }, params: { source, access_token },
method: 'POST', method: 'POST',
body: config, body: config,

File diff suppressed because one or more lines are too long