added hf models import tab and route for getting available hf models

This commit is contained in:
Jennifer Player
2024-03-07 16:57:28 -05:00
committed by psychedelicious
parent efea1a8a7d
commit f7cd3cf1f4
10 changed files with 308 additions and 3 deletions

View File

@ -762,6 +762,8 @@
"formMessageDiffusersVAELocationDesc": "If not provided, InvokeAI will look for the VAE file inside the model location given above.",
"height": "Height",
"heightValidationMsg": "Default height of your model.",
"huggingFace": "Hugging Face",
"huggingFaceRepoID": "HuggingFace Repo ID",
"ignoreMismatch": "Ignore Mismatches Between Selected Models",
"imageEncoderModelId": "Image Encoder Model ID",
"importModels": "Import Models",

View File

@ -0,0 +1,49 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { HuggingFaceResults } from './HuggingFaceResults';
export const HuggingFaceForm = () => {
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation();
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
const scanFolder = useCallback(async () => {
_getHuggingFaceModels(huggingFaceRepo).catch((error) => {
if (error) {
setErrorMessage(error.data.detail);
}
});
}, [_getHuggingFaceModels, huggingFaceRepo]);
const handleSetHuggingFaceRepo: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setHuggingFaceRepo(e.target.value);
setErrorMessage('');
}, []);
return (
<Flex flexDir="column" height="100%">
<FormControl isInvalid={!!errorMessage.length} w="full">
<Flex flexDir="column" w="full">
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
<Flex direction="column" w="full">
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
<Input value={huggingFaceRepo} onChange={handleSetHuggingFaceRepo} />
</Flex>
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={huggingFaceRepo.length === 0}>
{t('modelManager.addModel')}
</Button>
</Flex>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</Flex>
</FormControl>
{data && <HuggingFaceResults results={data} />}
</Flex>
);
};

View File

@ -0,0 +1,59 @@
import { Box, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { IoAdd } from 'react-icons/io5';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type Props = {
result: string;
};
export const HuggingFaceResultItem = ({ result }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const handleQuickAdd = useCallback(() => {
installModel({ source: result })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [installModel, result, dispatch, t]);
return (
<Flex justifyContent="space-between" w="100%">
<Flex fontSize="sm" flexDir="column">
<Text fontWeight="semibold">{result.split('/').slice(-1)[0]}</Text>
<Text variant="subtext">{result}</Text>
</Flex>
<Box>
<Tooltip label={t('modelManager.quickAdd')}>
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
</Tooltip>
</Box>
</Flex>
);
};

View File

@ -0,0 +1,119 @@
import {
Button,
Divider,
Flex,
Heading,
IconButton,
Input,
InputGroup,
InputRightElement,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import { HuggingFaceResultItem } from './HuggingFaceResultItem';
type HuggingFaceResultsProps = {
// results: HuggingFaceFolderResponse;
results: string[];
};
export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState('');
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setSearchTerm(e.target.value.trim());
}, []);
const clearSearch = useCallback(() => {
setSearchTerm('');
}, []);
const handleAddAll = useCallback(() => {
for (const result of results) {
installModel({ source: result })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}
}, [results, installModel, dispatch, t]);
return (
<>
<Divider mt={4} />
<Flex flexDir="column" gap={2} mt={4} height="100%">
<Flex justifyContent="space-between" alignItems="center">
<Heading fontSize="md" as="h4">
{t('modelManager.availableModels')}
</Heading>
<Flex alignItems="center" gap="4">
<Button onClick={handleAddAll} isDisabled={results.length === 0}>
{t('modelManager.addAll')}
</Button>
<InputGroup maxW="300px" size="xs">
<Input
placeholder={t('modelManager.search')}
value={searchTerm}
data-testid="board-search-input"
onChange={handleSearch}
size="xs"
/>
{searchTerm && (
<InputRightElement h="full" pe={2}>
<IconButton
size="sm"
variant="link"
aria-label={t('boards.clearSearch')}
icon={<PiXBold />}
onClick={clearSearch}
/>
</InputRightElement>
)}
</InputGroup>
</Flex>
</Flex>
<Flex height="100%" layerStyle="third" borderRadius="base" p={4} mt={4} mb={4}>
<ScrollableContent>
<Flex flexDir="column" gap={3}>
{results.map((result) => (
<HuggingFaceResultItem key={result} result={result} />
))}
</Flex>
</ScrollableContent>
</Flex>
</Flex>
</>
);
};

View File

@ -1,6 +1,7 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useTranslation } from 'react-i18next';
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
@ -16,12 +17,16 @@ export const InstallModels = () => {
<Tabs variant="collapse" height="100%">
<TabList>
<Tab>{t('common.simple')}</Tab>
<Tab>{t('modelManager.huggingFace')}</Tab>
<Tab>{t('modelManager.scan')}</Tab>
</TabList>
<TabPanels p={3} height="100%">
<TabPanel>
<InstallModelForm />
</TabPanel>
<TabPanel height="100%">
<HuggingFaceForm />
</TabPanel>
<TabPanel height="100%">
<ScanModelsForm />
</TabPanel>

View File

@ -258,6 +258,13 @@ export const modelsApi = api.injectEndpoints({
};
},
}),
getHuggingFaceModels: build.query<string[], string>({
query: (hugging_face_repo) => {
return {
url: buildModelsUrl(`hugging_face?hugging_face_repo=${hugging_face_repo}`),
};
},
}),
listModelInstalls: build.query<ListModelInstallsResponse, void>({
query: () => {
return {
@ -381,6 +388,7 @@ export const {
useConvertModelMutation,
useSyncModelsMutation,
useLazyScanFolderQuery,
useLazyGetHuggingFaceModelsQuery,
useListModelInstallsQuery,
useCancelModelInstallMutation,
usePruneCompletedModelInstallsMutation,