mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix logic to see if scanned models are already installed, style tweaks
This commit is contained in:
parent
26a209a00d
commit
c3f4e87a6e
@ -6,6 +6,7 @@ import {
|
||||
} from 'services/events/actions';
|
||||
|
||||
import { startAppListening } from '../..';
|
||||
import { api } from '../../../../../../services/api';
|
||||
|
||||
export const addModelInstallEventListener = () => {
|
||||
startAppListening({
|
||||
@ -40,6 +41,7 @@ export const addModelInstallEventListener = () => {
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
dispatch(api.util.invalidateTags([{ type: "ModelConfig" }]))
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -0,0 +1,54 @@
|
||||
import { useCallback } from "react";
|
||||
import { ALL_BASE_MODELS } from "../../../services/api/constants";
|
||||
import { useGetMainModelsQuery, useGetLoRAModelsQuery, useGetTextualInversionModelsQuery, useGetControlNetModelsQuery, useGetT2IAdapterModelsQuery, useGetIPAdapterModelsQuery, useGetVaeModelsQuery, } from "../../../services/api/endpoints/models";
|
||||
import { EntityState } from "@reduxjs/toolkit";
|
||||
import { forEach } from "lodash-es";
|
||||
import { AnyModelConfig } from "../../../services/api/types";
|
||||
|
||||
export const useIsImported = () => {
|
||||
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||
const { data: loras } = useGetLoRAModelsQuery();
|
||||
const { data: embeddings } = useGetTextualInversionModelsQuery();
|
||||
const { data: controlnets } = useGetControlNetModelsQuery();
|
||||
const { data: ipAdapters } = useGetIPAdapterModelsQuery();
|
||||
const { data: t2is } = useGetT2IAdapterModelsQuery();
|
||||
const { data: vaes } = useGetVaeModelsQuery();
|
||||
|
||||
const isImported = useCallback(({ name }: { name: string }) => {
|
||||
const data = [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes]
|
||||
let isMatch = false;
|
||||
for (let index = 0; index < data.length; index++) {
|
||||
const modelType: EntityState<AnyModelConfig, string> | undefined = data[index];
|
||||
|
||||
const match = modelsFilter(modelType, name)
|
||||
|
||||
if (!!match.length) {
|
||||
isMatch = true
|
||||
break;
|
||||
}
|
||||
}
|
||||
return isMatch
|
||||
}, [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes])
|
||||
|
||||
return { isImported }
|
||||
}
|
||||
|
||||
const modelsFilter = <T extends AnyModelConfig>(
|
||||
data: EntityState<T, string> | undefined,
|
||||
nameFilter: string,
|
||||
): T[] => {
|
||||
const filteredModels: T[] = [];
|
||||
|
||||
forEach(data?.entities, (model) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
const matchesFilter = model.path.toLowerCase().includes(nameFilter.toLowerCase());
|
||||
|
||||
if (matchesFilter) {
|
||||
filteredModels.push(model);
|
||||
}
|
||||
});
|
||||
return filteredModels;
|
||||
};
|
@ -77,7 +77,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
case 'url':
|
||||
return source.url;
|
||||
case 'local':
|
||||
return source.path.substring(source.path.lastIndexOf('/') + 1);
|
||||
return source.path.split('\\').slice(-1)[0];
|
||||
default:
|
||||
return '';
|
||||
}
|
||||
@ -99,13 +99,13 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
}, [model.bytes, model.total_bytes, model.status]);
|
||||
|
||||
return (
|
||||
<Flex gap="2" w="full" alignItems="center" textAlign="center">
|
||||
<Flex gap="2" w="full" alignItems="center">
|
||||
<Tooltip label={modelName}>
|
||||
<Text w="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
|
||||
<Text width="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
|
||||
{modelName}
|
||||
</Text>
|
||||
</Tooltip>
|
||||
<Flex flexDir="column" w="50%">
|
||||
<Flex flexDir="column" flex={1}>
|
||||
<Tooltip label={progressString}>
|
||||
<Progress
|
||||
value={progressValue}
|
||||
@ -115,11 +115,11 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<Box w="15%">
|
||||
<Box minW="100px" textAlign="center">
|
||||
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
|
||||
</Box>
|
||||
|
||||
<Box w="10%">
|
||||
<Box minW="20px">
|
||||
{(model.status === 'downloading' || model.status === 'waiting') && (
|
||||
<IconButton
|
||||
isRound={true}
|
||||
|
@ -1,17 +1,31 @@
|
||||
import { Flex, Text, Box, Button, IconButton, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { Flex, Text, Box, Button, IconButton, Tooltip, Badge } from '@invoke-ai/ui-library';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoAdd } from 'react-icons/io5';
|
||||
import { useAppDispatch } from '../../../../../app/store/storeHooks';
|
||||
import { useImportMainModelsMutation } from '../../../../../services/api/endpoints/models';
|
||||
import { addToast } from '../../../../system/store/systemSlice';
|
||||
import { makeToast } from '../../../../system/util/makeToast';
|
||||
import { useIsImported } from '../../../hooks/useIsImported';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const ScanModelResultItem = ({ result }: { result: string }) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { isImported } = useIsImported();
|
||||
|
||||
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
||||
|
||||
const isAlreadyImported = useMemo(() => {
|
||||
const prettyName = result.split('\\').slice(-1)[0];
|
||||
console.log({ prettyName });
|
||||
if (prettyName) {
|
||||
return isImported({ name: prettyName });
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}, [result, isImported]);
|
||||
|
||||
const handleQuickAdd = () => {
|
||||
importMainModel({ source: result, config: undefined })
|
||||
.unwrap()
|
||||
@ -46,9 +60,13 @@ export const ScanModelResultItem = ({ result }: { result: string }) => {
|
||||
<Text variant="subtext">{result}</Text>
|
||||
</Flex>
|
||||
<Box>
|
||||
{isAlreadyImported ? (
|
||||
<Badge>{t('common.installed')}</Badge>
|
||||
) : (
|
||||
<Tooltip label={t('modelManager.quickAdd')}>
|
||||
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
|
||||
</Tooltip>
|
||||
)}
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -9,22 +9,18 @@ import { ScanModelsResults } from './ScanModelsResults';
|
||||
export const ScanModelsForm = () => {
|
||||
const [scanPath, setScanPath] = useState('');
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
const [results, setResults] = useState<string[] | undefined>();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [_scanModels, { isLoading }] = useLazyScanModelsQuery();
|
||||
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
|
||||
|
||||
const handleSubmitScan = useCallback(async () => {
|
||||
_scanModels({ scan_path: scanPath })
|
||||
.unwrap()
|
||||
.then((result) => {
|
||||
setResults(result);
|
||||
})
|
||||
.catch((error) => {
|
||||
try {
|
||||
await _scanModels({ scan_path: scanPath }).unwrap();
|
||||
} catch (error: any) {
|
||||
if (error) {
|
||||
setErrorMessage(error.data.detail);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [_scanModels, scanPath]);
|
||||
|
||||
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||
@ -49,7 +45,7 @@ export const ScanModelsForm = () => {
|
||||
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
{results && <ScanModelsResults results={results} />}
|
||||
{data && <ScanModelsResults results={data} />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -2,23 +2,23 @@ import { Divider, Flex, Heading, IconButton, Input, InputGroup, InputRightElemen
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { t } from 'i18next';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { ScanModelResultItem } from './ScanModelResultItem';
|
||||
|
||||
export const ScanModelsResults = ({ results }: { results: string[] }) => {
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const [filteredResults, setFilteredResults] = useState(results);
|
||||
|
||||
const filteredResults = useMemo(() => {
|
||||
return results.filter((result) => {
|
||||
const modelName = result.split('\\').slice(-1)[0];
|
||||
return modelName?.includes(searchTerm);
|
||||
});
|
||||
}, [results, searchTerm]);
|
||||
|
||||
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback(
|
||||
(e) => {
|
||||
setSearchTerm(e.target.value);
|
||||
setFilteredResults(
|
||||
results.filter((result) => {
|
||||
const modelName = result.split('\\').slice(-1)[0];
|
||||
return modelName?.includes(e.target.value);
|
||||
})
|
||||
);
|
||||
},
|
||||
[results]
|
||||
);
|
||||
|
@ -14,7 +14,7 @@ export const ModelManager = () => {
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Box layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
||||
<Box layerStyle="first" p={3} borderRadius="base" w="50%" h="full">
|
||||
<Flex w="full" p={3} justifyContent="space-between" alignItems="center">
|
||||
<Flex gap={2}>
|
||||
<Heading fontSize="xl">Model Manager</Heading>
|
||||
|
@ -7,7 +7,7 @@ import { Model } from './ModelPanel/Model';
|
||||
export const ModelPane = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
return (
|
||||
<Box layerStyle="first" p={2} borderRadius="base" w="full" h="full">
|
||||
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
|
||||
{selectedModelKey ? <Model /> : <ImportModels />}
|
||||
</Box>
|
||||
);
|
||||
|
@ -55,13 +55,13 @@ type MergeMainModelArg = {
|
||||
type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ImportMainModelArg = {
|
||||
source: NonNullable<operations['heuristic_import_model']['parameters']['query']['source']>;
|
||||
access_token?: operations['heuristic_import_model']['parameters']['query']['access_token'];
|
||||
config: NonNullable<operations['heuristic_import_model']['requestBody']['content']['application/json']>;
|
||||
source: NonNullable<operations['heuristic_install_model']['parameters']['query']['source']>;
|
||||
access_token?: operations['heuristic_install_model']['parameters']['query']['access_token'];
|
||||
config: NonNullable<operations['heuristic_install_model']['requestBody']['content']['application/json']>;
|
||||
};
|
||||
|
||||
type ImportMainModelResponse =
|
||||
paths['/api/v2/models/heuristic_import']['post']['responses']['201']['content']['application/json'];
|
||||
paths['/api/v2/models/heuristic_install']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
type ListImportModelsResponse =
|
||||
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
|
||||
@ -191,7 +191,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
|
||||
query: ({ source, config, access_token }) => {
|
||||
return {
|
||||
url: buildModelsUrl('heuristic_import'),
|
||||
url: buildModelsUrl('heuristic_install'),
|
||||
params: { source, access_token },
|
||||
method: 'POST',
|
||||
body: config,
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user