mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix logic to see if scanned models are already installed, style tweaks
This commit is contained in:
parent
26a209a00d
commit
c3f4e87a6e
@ -6,6 +6,7 @@ import {
|
|||||||
} from 'services/events/actions';
|
} from 'services/events/actions';
|
||||||
|
|
||||||
import { startAppListening } from '../..';
|
import { startAppListening } from '../..';
|
||||||
|
import { api } from '../../../../../../services/api';
|
||||||
|
|
||||||
export const addModelInstallEventListener = () => {
|
export const addModelInstallEventListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -40,6 +41,7 @@ export const addModelInstallEventListener = () => {
|
|||||||
return draft;
|
return draft;
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
dispatch(api.util.invalidateTags([{ type: "ModelConfig" }]))
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -0,0 +1,54 @@
|
|||||||
|
import { useCallback } from "react";
|
||||||
|
import { ALL_BASE_MODELS } from "../../../services/api/constants";
|
||||||
|
import { useGetMainModelsQuery, useGetLoRAModelsQuery, useGetTextualInversionModelsQuery, useGetControlNetModelsQuery, useGetT2IAdapterModelsQuery, useGetIPAdapterModelsQuery, useGetVaeModelsQuery, } from "../../../services/api/endpoints/models";
|
||||||
|
import { EntityState } from "@reduxjs/toolkit";
|
||||||
|
import { forEach } from "lodash-es";
|
||||||
|
import { AnyModelConfig } from "../../../services/api/types";
|
||||||
|
|
||||||
|
export const useIsImported = () => {
|
||||||
|
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||||
|
const { data: loras } = useGetLoRAModelsQuery();
|
||||||
|
const { data: embeddings } = useGetTextualInversionModelsQuery();
|
||||||
|
const { data: controlnets } = useGetControlNetModelsQuery();
|
||||||
|
const { data: ipAdapters } = useGetIPAdapterModelsQuery();
|
||||||
|
const { data: t2is } = useGetT2IAdapterModelsQuery();
|
||||||
|
const { data: vaes } = useGetVaeModelsQuery();
|
||||||
|
|
||||||
|
const isImported = useCallback(({ name }: { name: string }) => {
|
||||||
|
const data = [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes]
|
||||||
|
let isMatch = false;
|
||||||
|
for (let index = 0; index < data.length; index++) {
|
||||||
|
const modelType: EntityState<AnyModelConfig, string> | undefined = data[index];
|
||||||
|
|
||||||
|
const match = modelsFilter(modelType, name)
|
||||||
|
|
||||||
|
if (!!match.length) {
|
||||||
|
isMatch = true
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return isMatch
|
||||||
|
}, [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes])
|
||||||
|
|
||||||
|
return { isImported }
|
||||||
|
}
|
||||||
|
|
||||||
|
const modelsFilter = <T extends AnyModelConfig>(
|
||||||
|
data: EntityState<T, string> | undefined,
|
||||||
|
nameFilter: string,
|
||||||
|
): T[] => {
|
||||||
|
const filteredModels: T[] = [];
|
||||||
|
|
||||||
|
forEach(data?.entities, (model) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const matchesFilter = model.path.toLowerCase().includes(nameFilter.toLowerCase());
|
||||||
|
|
||||||
|
if (matchesFilter) {
|
||||||
|
filteredModels.push(model);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return filteredModels;
|
||||||
|
};
|
@ -77,7 +77,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
|||||||
case 'url':
|
case 'url':
|
||||||
return source.url;
|
return source.url;
|
||||||
case 'local':
|
case 'local':
|
||||||
return source.path.substring(source.path.lastIndexOf('/') + 1);
|
return source.path.split('\\').slice(-1)[0];
|
||||||
default:
|
default:
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
@ -99,13 +99,13 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
|||||||
}, [model.bytes, model.total_bytes, model.status]);
|
}, [model.bytes, model.total_bytes, model.status]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex gap="2" w="full" alignItems="center" textAlign="center">
|
<Flex gap="2" w="full" alignItems="center">
|
||||||
<Tooltip label={modelName}>
|
<Tooltip label={modelName}>
|
||||||
<Text w="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
|
<Text width="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
|
||||||
{modelName}
|
{modelName}
|
||||||
</Text>
|
</Text>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
<Flex flexDir="column" w="50%">
|
<Flex flexDir="column" flex={1}>
|
||||||
<Tooltip label={progressString}>
|
<Tooltip label={progressString}>
|
||||||
<Progress
|
<Progress
|
||||||
value={progressValue}
|
value={progressValue}
|
||||||
@ -115,11 +115,11 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
|||||||
/>
|
/>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Box w="15%">
|
<Box minW="100px" textAlign="center">
|
||||||
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
|
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
|
||||||
</Box>
|
</Box>
|
||||||
|
|
||||||
<Box w="10%">
|
<Box minW="20px">
|
||||||
{(model.status === 'downloading' || model.status === 'waiting') && (
|
{(model.status === 'downloading' || model.status === 'waiting') && (
|
||||||
<IconButton
|
<IconButton
|
||||||
isRound={true}
|
isRound={true}
|
||||||
|
@ -1,17 +1,31 @@
|
|||||||
import { Flex, Text, Box, Button, IconButton, Tooltip } from '@invoke-ai/ui-library';
|
import { Flex, Text, Box, Button, IconButton, Tooltip, Badge } from '@invoke-ai/ui-library';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { IoAdd } from 'react-icons/io5';
|
import { IoAdd } from 'react-icons/io5';
|
||||||
import { useAppDispatch } from '../../../../../app/store/storeHooks';
|
import { useAppDispatch } from '../../../../../app/store/storeHooks';
|
||||||
import { useImportMainModelsMutation } from '../../../../../services/api/endpoints/models';
|
import { useImportMainModelsMutation } from '../../../../../services/api/endpoints/models';
|
||||||
import { addToast } from '../../../../system/store/systemSlice';
|
import { addToast } from '../../../../system/store/systemSlice';
|
||||||
import { makeToast } from '../../../../system/util/makeToast';
|
import { makeToast } from '../../../../system/util/makeToast';
|
||||||
|
import { useIsImported } from '../../../hooks/useIsImported';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
export const ScanModelResultItem = ({ result }: { result: string }) => {
|
export const ScanModelResultItem = ({ result }: { result: string }) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const { isImported } = useIsImported();
|
||||||
|
|
||||||
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
||||||
|
|
||||||
|
const isAlreadyImported = useMemo(() => {
|
||||||
|
const prettyName = result.split('\\').slice(-1)[0];
|
||||||
|
console.log({ prettyName });
|
||||||
|
if (prettyName) {
|
||||||
|
return isImported({ name: prettyName });
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}, [result, isImported]);
|
||||||
|
|
||||||
const handleQuickAdd = () => {
|
const handleQuickAdd = () => {
|
||||||
importMainModel({ source: result, config: undefined })
|
importMainModel({ source: result, config: undefined })
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@ -46,9 +60,13 @@ export const ScanModelResultItem = ({ result }: { result: string }) => {
|
|||||||
<Text variant="subtext">{result}</Text>
|
<Text variant="subtext">{result}</Text>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Box>
|
<Box>
|
||||||
|
{isAlreadyImported ? (
|
||||||
|
<Badge>{t('common.installed')}</Badge>
|
||||||
|
) : (
|
||||||
<Tooltip label={t('modelManager.quickAdd')}>
|
<Tooltip label={t('modelManager.quickAdd')}>
|
||||||
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
|
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -9,22 +9,18 @@ import { ScanModelsResults } from './ScanModelsResults';
|
|||||||
export const ScanModelsForm = () => {
|
export const ScanModelsForm = () => {
|
||||||
const [scanPath, setScanPath] = useState('');
|
const [scanPath, setScanPath] = useState('');
|
||||||
const [errorMessage, setErrorMessage] = useState('');
|
const [errorMessage, setErrorMessage] = useState('');
|
||||||
const [results, setResults] = useState<string[] | undefined>();
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const [_scanModels, { isLoading }] = useLazyScanModelsQuery();
|
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
|
||||||
|
|
||||||
const handleSubmitScan = useCallback(async () => {
|
const handleSubmitScan = useCallback(async () => {
|
||||||
_scanModels({ scan_path: scanPath })
|
try {
|
||||||
.unwrap()
|
await _scanModels({ scan_path: scanPath }).unwrap();
|
||||||
.then((result) => {
|
} catch (error: any) {
|
||||||
setResults(result);
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
if (error) {
|
if (error) {
|
||||||
setErrorMessage(error.data.detail);
|
setErrorMessage(error.data.detail);
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
}, [_scanModels, scanPath]);
|
}, [_scanModels, scanPath]);
|
||||||
|
|
||||||
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||||
@ -49,7 +45,7 @@ export const ScanModelsForm = () => {
|
|||||||
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
||||||
</Flex>
|
</Flex>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
{results && <ScanModelsResults results={results} />}
|
{data && <ScanModelsResults results={data} />}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -2,23 +2,23 @@ import { Divider, Flex, Heading, IconButton, Input, InputGroup, InputRightElemen
|
|||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import type { ChangeEventHandler } from 'react';
|
import type { ChangeEventHandler } from 'react';
|
||||||
import { useCallback, useState } from 'react';
|
import { useCallback, useMemo, useState } from 'react';
|
||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
import { ScanModelResultItem } from './ScanModelResultItem';
|
import { ScanModelResultItem } from './ScanModelResultItem';
|
||||||
|
|
||||||
export const ScanModelsResults = ({ results }: { results: string[] }) => {
|
export const ScanModelsResults = ({ results }: { results: string[] }) => {
|
||||||
const [searchTerm, setSearchTerm] = useState('');
|
const [searchTerm, setSearchTerm] = useState('');
|
||||||
const [filteredResults, setFilteredResults] = useState(results);
|
|
||||||
|
const filteredResults = useMemo(() => {
|
||||||
|
return results.filter((result) => {
|
||||||
|
const modelName = result.split('\\').slice(-1)[0];
|
||||||
|
return modelName?.includes(searchTerm);
|
||||||
|
});
|
||||||
|
}, [results, searchTerm]);
|
||||||
|
|
||||||
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback(
|
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback(
|
||||||
(e) => {
|
(e) => {
|
||||||
setSearchTerm(e.target.value);
|
setSearchTerm(e.target.value);
|
||||||
setFilteredResults(
|
|
||||||
results.filter((result) => {
|
|
||||||
const modelName = result.split('\\').slice(-1)[0];
|
|
||||||
return modelName?.includes(e.target.value);
|
|
||||||
})
|
|
||||||
);
|
|
||||||
},
|
},
|
||||||
[results]
|
[results]
|
||||||
);
|
);
|
||||||
|
@ -14,7 +14,7 @@ export const ModelManager = () => {
|
|||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
<Box layerStyle="first" p={3} borderRadius="base" w="50%" h="full">
|
||||||
<Flex w="full" p={3} justifyContent="space-between" alignItems="center">
|
<Flex w="full" p={3} justifyContent="space-between" alignItems="center">
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<Heading fontSize="xl">Model Manager</Heading>
|
<Heading fontSize="xl">Model Manager</Heading>
|
||||||
|
@ -7,7 +7,7 @@ import { Model } from './ModelPanel/Model';
|
|||||||
export const ModelPane = () => {
|
export const ModelPane = () => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
return (
|
return (
|
||||||
<Box layerStyle="first" p={2} borderRadius="base" w="full" h="full">
|
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
|
||||||
{selectedModelKey ? <Model /> : <ImportModels />}
|
{selectedModelKey ? <Model /> : <ImportModels />}
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
|
@ -55,13 +55,13 @@ type MergeMainModelArg = {
|
|||||||
type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json'];
|
type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type ImportMainModelArg = {
|
type ImportMainModelArg = {
|
||||||
source: NonNullable<operations['heuristic_import_model']['parameters']['query']['source']>;
|
source: NonNullable<operations['heuristic_install_model']['parameters']['query']['source']>;
|
||||||
access_token?: operations['heuristic_import_model']['parameters']['query']['access_token'];
|
access_token?: operations['heuristic_install_model']['parameters']['query']['access_token'];
|
||||||
config: NonNullable<operations['heuristic_import_model']['requestBody']['content']['application/json']>;
|
config: NonNullable<operations['heuristic_install_model']['requestBody']['content']['application/json']>;
|
||||||
};
|
};
|
||||||
|
|
||||||
type ImportMainModelResponse =
|
type ImportMainModelResponse =
|
||||||
paths['/api/v2/models/heuristic_import']['post']['responses']['201']['content']['application/json'];
|
paths['/api/v2/models/heuristic_install']['post']['responses']['201']['content']['application/json'];
|
||||||
|
|
||||||
type ListImportModelsResponse =
|
type ListImportModelsResponse =
|
||||||
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
|
||||||
@ -191,7 +191,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
|
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
|
||||||
query: ({ source, config, access_token }) => {
|
query: ({ source, config, access_token }) => {
|
||||||
return {
|
return {
|
||||||
url: buildModelsUrl('heuristic_import'),
|
url: buildModelsUrl('heuristic_install'),
|
||||||
params: { source, access_token },
|
params: { source, access_token },
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: config,
|
body: config,
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user