fix logic to see if scanned models are already installed, style tweaks

This commit is contained in:
Mary Hipp 2024-02-23 14:58:36 -05:00 committed by psychedelicious
parent 26a209a00d
commit c3f4e87a6e
10 changed files with 277 additions and 142 deletions

View File

@ -6,6 +6,7 @@ import {
} from 'services/events/actions';
import { startAppListening } from '../..';
import { api } from '../../../../../../services/api';
export const addModelInstallEventListener = () => {
startAppListening({
@ -40,6 +41,7 @@ export const addModelInstallEventListener = () => {
return draft;
})
);
dispatch(api.util.invalidateTags([{ type: "ModelConfig" }]))
},
});

View File

@ -0,0 +1,54 @@
import { useCallback } from "react";
import { ALL_BASE_MODELS } from "../../../services/api/constants";
import { useGetMainModelsQuery, useGetLoRAModelsQuery, useGetTextualInversionModelsQuery, useGetControlNetModelsQuery, useGetT2IAdapterModelsQuery, useGetIPAdapterModelsQuery, useGetVaeModelsQuery, } from "../../../services/api/endpoints/models";
import { EntityState } from "@reduxjs/toolkit";
import { forEach } from "lodash-es";
import { AnyModelConfig } from "../../../services/api/types";
export const useIsImported = () => {
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
const { data: loras } = useGetLoRAModelsQuery();
const { data: embeddings } = useGetTextualInversionModelsQuery();
const { data: controlnets } = useGetControlNetModelsQuery();
const { data: ipAdapters } = useGetIPAdapterModelsQuery();
const { data: t2is } = useGetT2IAdapterModelsQuery();
const { data: vaes } = useGetVaeModelsQuery();
const isImported = useCallback(({ name }: { name: string }) => {
const data = [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes]
let isMatch = false;
for (let index = 0; index < data.length; index++) {
const modelType: EntityState<AnyModelConfig, string> | undefined = data[index];
const match = modelsFilter(modelType, name)
if (!!match.length) {
isMatch = true
break;
}
}
return isMatch
}, [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes])
return { isImported }
}
const modelsFilter = <T extends AnyModelConfig>(
data: EntityState<T, string> | undefined,
nameFilter: string,
): T[] => {
const filteredModels: T[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.path.toLowerCase().includes(nameFilter.toLowerCase());
if (matchesFilter) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -77,7 +77,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
case 'url':
return source.url;
case 'local':
return source.path.substring(source.path.lastIndexOf('/') + 1);
return source.path.split('\\').slice(-1)[0];
default:
return '';
}
@ -99,13 +99,13 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
}, [model.bytes, model.total_bytes, model.status]);
return (
<Flex gap="2" w="full" alignItems="center" textAlign="center">
<Flex gap="2" w="full" alignItems="center">
<Tooltip label={modelName}>
<Text w="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
<Text width="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
{modelName}
</Text>
</Tooltip>
<Flex flexDir="column" w="50%">
<Flex flexDir="column" flex={1}>
<Tooltip label={progressString}>
<Progress
value={progressValue}
@ -115,11 +115,11 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
/>
</Tooltip>
</Flex>
<Box w="15%">
<Box minW="100px" textAlign="center">
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
</Box>
<Box w="10%">
<Box minW="20px">
{(model.status === 'downloading' || model.status === 'waiting') && (
<IconButton
isRound={true}

View File

@ -1,17 +1,31 @@
import { Flex, Text, Box, Button, IconButton, Tooltip } from '@invoke-ai/ui-library';
import { Flex, Text, Box, Button, IconButton, Tooltip, Badge } from '@invoke-ai/ui-library';
import { useTranslation } from 'react-i18next';
import { IoAdd } from 'react-icons/io5';
import { useAppDispatch } from '../../../../../app/store/storeHooks';
import { useImportMainModelsMutation } from '../../../../../services/api/endpoints/models';
import { addToast } from '../../../../system/store/systemSlice';
import { makeToast } from '../../../../system/util/makeToast';
import { useIsImported } from '../../../hooks/useIsImported';
import { useMemo } from 'react';
export const ScanModelResultItem = ({ result }: { result: string }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { isImported } = useIsImported();
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const isAlreadyImported = useMemo(() => {
const prettyName = result.split('\\').slice(-1)[0];
console.log({ prettyName });
if (prettyName) {
return isImported({ name: prettyName });
} else {
return false;
}
}, [result, isImported]);
const handleQuickAdd = () => {
importMainModel({ source: result, config: undefined })
.unwrap()
@ -46,9 +60,13 @@ export const ScanModelResultItem = ({ result }: { result: string }) => {
<Text variant="subtext">{result}</Text>
</Flex>
<Box>
<Tooltip label={t('modelManager.quickAdd')}>
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
</Tooltip>
{isAlreadyImported ? (
<Badge>{t('common.installed')}</Badge>
) : (
<Tooltip label={t('modelManager.quickAdd')}>
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
</Tooltip>
)}
</Box>
</Flex>
);

View File

@ -1,5 +1,5 @@
import { Button,Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
import type { ChangeEventHandler} from 'react';
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyScanModelsQuery } from 'services/api/endpoints/models';
@ -9,22 +9,18 @@ import { ScanModelsResults } from './ScanModelsResults';
export const ScanModelsForm = () => {
const [scanPath, setScanPath] = useState('');
const [errorMessage, setErrorMessage] = useState('');
const [results, setResults] = useState<string[] | undefined>();
const { t } = useTranslation();
const [_scanModels, { isLoading }] = useLazyScanModelsQuery();
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
const handleSubmitScan = useCallback(async () => {
_scanModels({ scan_path: scanPath })
.unwrap()
.then((result) => {
setResults(result);
})
.catch((error) => {
if (error) {
setErrorMessage(error.data.detail);
}
});
try {
await _scanModels({ scan_path: scanPath }).unwrap();
} catch (error: any) {
if (error) {
setErrorMessage(error.data.detail);
}
}
}, [_scanModels, scanPath]);
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
@ -49,7 +45,7 @@ export const ScanModelsForm = () => {
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</Flex>
</FormControl>
{results && <ScanModelsResults results={results} />}
{data && <ScanModelsResults results={data} />}
</Flex>
);
};

View File

@ -2,23 +2,23 @@ import { Divider, Flex, Heading, IconButton, Input, InputGroup, InputRightElemen
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { t } from 'i18next';
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { PiXBold } from 'react-icons/pi';
import { ScanModelResultItem } from './ScanModelResultItem';
export const ScanModelsResults = ({ results }: { results: string[] }) => {
const [searchTerm, setSearchTerm] = useState('');
const [filteredResults, setFilteredResults] = useState(results);
const filteredResults = useMemo(() => {
return results.filter((result) => {
const modelName = result.split('\\').slice(-1)[0];
return modelName?.includes(searchTerm);
});
}, [results, searchTerm]);
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback(
(e) => {
setSearchTerm(e.target.value);
setFilteredResults(
results.filter((result) => {
const modelName = result.split('\\').slice(-1)[0];
return modelName?.includes(e.target.value);
})
);
},
[results]
);

View File

@ -14,7 +14,7 @@ export const ModelManager = () => {
}, [dispatch]);
return (
<Box layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<Box layerStyle="first" p={3} borderRadius="base" w="50%" h="full">
<Flex w="full" p={3} justifyContent="space-between" alignItems="center">
<Flex gap={2}>
<Heading fontSize="xl">Model Manager</Heading>

View File

@ -7,7 +7,7 @@ import { Model } from './ModelPanel/Model';
export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return (
<Box layerStyle="first" p={2} borderRadius="base" w="full" h="full">
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model /> : <ImportModels />}
</Box>
);

View File

@ -55,13 +55,13 @@ type MergeMainModelArg = {
type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json'];
type ImportMainModelArg = {
source: NonNullable<operations['heuristic_import_model']['parameters']['query']['source']>;
access_token?: operations['heuristic_import_model']['parameters']['query']['access_token'];
config: NonNullable<operations['heuristic_import_model']['requestBody']['content']['application/json']>;
source: NonNullable<operations['heuristic_install_model']['parameters']['query']['source']>;
access_token?: operations['heuristic_install_model']['parameters']['query']['access_token'];
config: NonNullable<operations['heuristic_install_model']['requestBody']['content']['application/json']>;
};
type ImportMainModelResponse =
paths['/api/v2/models/heuristic_import']['post']['responses']['201']['content']['application/json'];
paths['/api/v2/models/heuristic_install']['post']['responses']['201']['content']['application/json'];
type ListImportModelsResponse =
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
@ -191,7 +191,7 @@ export const modelsApi = api.injectEndpoints({
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
query: ({ source, config, access_token }) => {
return {
url: buildModelsUrl('heuristic_import'),
url: buildModelsUrl('heuristic_install'),
params: { source, access_token },
method: 'POST',
body: config,

File diff suppressed because one or more lines are too long